rstsr_openblas/driver_impl/lapack/solve/
getri.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11    T     func_   ;
12   [f32] [sgetri_];
13   [f64] [dgetri_];
14)]
15impl GETRIDriverAPI<T> for DeviceBLAS {
16    unsafe fn driver_getri(order: FlagOrder, n: usize, a: *mut T, lda: usize, ipiv: *mut blas_int) -> blas_int {
17        use lapack_ffi::lapack::func_;
18
19        // Query optimal working array(s) size
20        let mut info = 0;
21        let lwork = -1;
22        let mut work_query = 0.0;
23        func_(&(n as _), a, &(lda as _), ipiv, &mut work_query, &lwork, &mut info);
24        if info != 0 {
25            return info;
26        }
27        let lwork = work_query as usize;
28
29        // Allocate memory for work arrays
30        let mut work: Vec<T> = match uninitialized_vec(lwork) {
31            Ok(work) => work,
32            Err(_) => return -1010,
33        };
34
35        if order == ColMajor {
36            // Call LAPACK function and adjust info
37            func_(&(n as _), a, &(lda as _), ipiv, work.as_mut_ptr(), &(lwork as _), &mut info);
38            if info != 0 {
39                return info;
40            }
41        } else {
42            let lda_t = n.max(1);
43            // Transpose input matrices
44            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
45                Ok(a_t) => a_t,
46                Err(_) => return -1011,
47            };
48            let a_slice = from_raw_parts_mut(a, n * lda);
49            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
50            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
51            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
52            // Call LAPACK function and adjust info
53            func_(&(n as _), a_t.as_mut_ptr(), &(lda_t as _), ipiv, work.as_mut_ptr(), &(lwork as _), &mut info);
54            if info != 0 {
55                return info;
56            }
57            // Transpose output matrices
58            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
59        }
60        return info;
61    }
62}
63
64#[duplicate_item(
65    T              func_   ;
66   [Complex<f32>] [cgetri_];
67   [Complex<f64>] [zgetri_];
68)]
69impl GETRIDriverAPI<T> for DeviceBLAS {
70    unsafe fn driver_getri(order: FlagOrder, n: usize, a: *mut T, lda: usize, ipiv: *mut blas_int) -> blas_int {
71        use lapack_ffi::lapack::func_;
72
73        // Query optimal working array(s) size
74        let mut info = 0;
75        let lwork = -1;
76        let mut work_query: T = num::zero();
77        func_(&(n as _), a as *mut _, &(lda as _), ipiv, &mut work_query as *mut _ as *mut _, &lwork, &mut info);
78        if info != 0 {
79            return info;
80        }
81        let lwork = work_query.re as usize;
82
83        // Allocate memory for work arrays
84        let mut work: Vec<T> = match uninitialized_vec(lwork) {
85            Ok(work) => work,
86            Err(_) => return -1010,
87        };
88
89        if order == ColMajor {
90            // Call LAPACK function and adjust info
91            func_(&(n as _), a as *mut _, &(lda as _), ipiv, work.as_mut_ptr() as *mut _, &(lwork as _), &mut info);
92            if info != 0 {
93                return info;
94            }
95        } else {
96            let lda_t = n.max(1);
97            // Transpose input matrices
98            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
99                Ok(a_t) => a_t,
100                Err(_) => return -1011,
101            };
102            let a_slice = from_raw_parts_mut(a, n * lda);
103            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
104            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
105            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
106            // Call LAPACK function and adjust info
107            func_(
108                &(n as _),
109                a_t.as_mut_ptr() as *mut _,
110                &(lda_t as _),
111                ipiv,
112                work.as_mut_ptr() as *mut _,
113                &(lwork as _),
114                &mut info,
115            );
116            if info != 0 {
117                return info;
118            }
119            // Transpose output matrices
120            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
121        }
122        return info;
123    }
124}