rstsr_openblas/driver_impl/lapack/solve/
getri.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11 T func_ ;
12 [f32] [sgetri_];
13 [f64] [dgetri_];
14)]
15impl GETRIDriverAPI<T> for DeviceBLAS {
16 unsafe fn driver_getri(order: FlagOrder, n: usize, a: *mut T, lda: usize, ipiv: *mut blas_int) -> blas_int {
17 use lapack_ffi::lapack::func_;
18
19 let mut info = 0;
21 let lwork = -1;
22 let mut work_query = 0.0;
23 func_(&(n as _), a, &(lda as _), ipiv, &mut work_query, &lwork, &mut info);
24 if info != 0 {
25 return info;
26 }
27 let lwork = work_query as usize;
28
29 let mut work: Vec<T> = match uninitialized_vec(lwork) {
31 Ok(work) => work,
32 Err(_) => return -1010,
33 };
34
35 if order == ColMajor {
36 func_(&(n as _), a, &(lda as _), ipiv, work.as_mut_ptr(), &(lwork as _), &mut info);
38 if info != 0 {
39 return info;
40 }
41 } else {
42 let lda_t = n.max(1);
43 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
45 Ok(a_t) => a_t,
46 Err(_) => return -1011,
47 };
48 let a_slice = from_raw_parts_mut(a, n * lda);
49 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
50 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
51 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
52 func_(&(n as _), a_t.as_mut_ptr(), &(lda_t as _), ipiv, work.as_mut_ptr(), &(lwork as _), &mut info);
54 if info != 0 {
55 return info;
56 }
57 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
59 }
60 return info;
61 }
62}
63
64#[duplicate_item(
65 T func_ ;
66 [Complex<f32>] [cgetri_];
67 [Complex<f64>] [zgetri_];
68)]
69impl GETRIDriverAPI<T> for DeviceBLAS {
70 unsafe fn driver_getri(order: FlagOrder, n: usize, a: *mut T, lda: usize, ipiv: *mut blas_int) -> blas_int {
71 use lapack_ffi::lapack::func_;
72
73 let mut info = 0;
75 let lwork = -1;
76 let mut work_query: T = num::zero();
77 func_(&(n as _), a as *mut _, &(lda as _), ipiv, &mut work_query as *mut _ as *mut _, &lwork, &mut info);
78 if info != 0 {
79 return info;
80 }
81 let lwork = work_query.re as usize;
82
83 let mut work: Vec<T> = match uninitialized_vec(lwork) {
85 Ok(work) => work,
86 Err(_) => return -1010,
87 };
88
89 if order == ColMajor {
90 func_(&(n as _), a as *mut _, &(lda as _), ipiv, work.as_mut_ptr() as *mut _, &(lwork as _), &mut info);
92 if info != 0 {
93 return info;
94 }
95 } else {
96 let lda_t = n.max(1);
97 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
99 Ok(a_t) => a_t,
100 Err(_) => return -1011,
101 };
102 let a_slice = from_raw_parts_mut(a, n * lda);
103 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
104 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
105 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
106 func_(
108 &(n as _),
109 a_t.as_mut_ptr() as *mut _,
110 &(lda_t as _),
111 ipiv,
112 work.as_mut_ptr() as *mut _,
113 &(lwork as _),
114 &mut info,
115 );
116 if info != 0 {
117 return info;
118 }
119 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
121 }
122 return info;
123 }
124}