rstsr_openblas/driver_impl/lapack/solve/
getrf.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6use rstsr_native_impl::prelude_dev::*;
7use std::slice::from_raw_parts_mut;
8
9#[duplicate_item(
10 T func_ ;
11 [f32] [sgetrf_];
12 [f64] [dgetrf_];
13)]
14impl GETRFDriverAPI<T> for DeviceBLAS {
15 unsafe fn driver_getrf(
16 order: FlagOrder,
17 m: usize,
18 n: usize,
19 a: *mut T,
20 lda: usize,
21 ipiv: *mut blas_int,
22 ) -> blas_int {
23 use lapack_ffi::lapack::func_;
24
25 let mut info = 0;
26
27 if order == ColMajor {
28 func_(&(m as _), &(n as _), a, &(lda as _), ipiv, &mut info);
30 if info != 0 {
31 return info;
32 }
33 } else {
34 let lda_t = m.max(1);
35 let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
37 Ok(a_t) => a_t,
38 Err(_) => return -1011,
39 };
40 let a_slice = from_raw_parts_mut(a, m * lda);
41 let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
42 let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
43 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
44 func_(&(m as _), &(n as _), a_t.as_mut_ptr(), &(lda_t as _), ipiv, &mut info);
46 if info != 0 {
47 return info;
48 }
49 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
51 }
52 return info;
53 }
54}
55
56#[duplicate_item(
57 T func_ ;
58 [Complex<f32>] [cgetrf_];
59 [Complex<f64>] [zgetrf_];
60)]
61impl GETRFDriverAPI<T> for DeviceBLAS {
62 unsafe fn driver_getrf(
63 order: FlagOrder,
64 m: usize,
65 n: usize,
66 a: *mut T,
67 lda: usize,
68 ipiv: *mut blas_int,
69 ) -> blas_int {
70 use lapack_ffi::lapack::func_;
71
72 let mut info = 0;
73
74 if order == ColMajor {
75 func_(&(m as _), &(n as _), a as *mut _, &(lda as _), ipiv, &mut info);
77 if info != 0 {
78 return info;
79 }
80 } else {
81 let lda_t = m.max(1);
82 let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
84 Ok(a_t) => a_t,
85 Err(_) => return -1011,
86 };
87 let a_slice = from_raw_parts_mut(a, m * lda);
88 let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
89 let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
90 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
91 func_(&(m as _), &(n as _), a_t.as_mut_ptr() as *mut _, &(lda_t as _), ipiv, &mut info);
93 if info != 0 {
94 return info;
95 }
96 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
98 }
99 return info;
100 }
101}