rstsr_openblas/driver_impl/lapack/solve/
getrf.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6use rstsr_native_impl::prelude_dev::*;
7use std::slice::from_raw_parts_mut;
8
9#[duplicate_item(
10    T     func_   ;
11   [f32] [sgetrf_];
12   [f64] [dgetrf_];
13)]
14impl GETRFDriverAPI<T> for DeviceBLAS {
15    unsafe fn driver_getrf(
16        order: FlagOrder,
17        m: usize,
18        n: usize,
19        a: *mut T,
20        lda: usize,
21        ipiv: *mut blas_int,
22    ) -> blas_int {
23        use lapack_ffi::lapack::func_;
24
25        let mut info = 0;
26
27        if order == ColMajor {
28            // Call LAPACK function and adjust info
29            func_(&(m as _), &(n as _), a, &(lda as _), ipiv, &mut info);
30            if info != 0 {
31                return info;
32            }
33        } else {
34            let lda_t = m.max(1);
35            // Transpose input matrices
36            let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
37                Ok(a_t) => a_t,
38                Err(_) => return -1011,
39            };
40            let a_slice = from_raw_parts_mut(a, m * lda);
41            let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
42            let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
43            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
44            // Call LAPACK function and adjust info
45            func_(&(m as _), &(n as _), a_t.as_mut_ptr(), &(lda_t as _), ipiv, &mut info);
46            if info != 0 {
47                return info;
48            }
49            // Transpose output matrices
50            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
51        }
52        return info;
53    }
54}
55
56#[duplicate_item(
57    T              func_   ;
58   [Complex<f32>] [cgetrf_];
59   [Complex<f64>] [zgetrf_];
60)]
61impl GETRFDriverAPI<T> for DeviceBLAS {
62    unsafe fn driver_getrf(
63        order: FlagOrder,
64        m: usize,
65        n: usize,
66        a: *mut T,
67        lda: usize,
68        ipiv: *mut blas_int,
69    ) -> blas_int {
70        use lapack_ffi::lapack::func_;
71
72        let mut info = 0;
73
74        if order == ColMajor {
75            // Call LAPACK function and adjust info
76            func_(&(m as _), &(n as _), a as *mut _, &(lda as _), ipiv, &mut info);
77            if info != 0 {
78                return info;
79            }
80        } else {
81            let lda_t = m.max(1);
82            // Transpose input matrices
83            let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
84                Ok(a_t) => a_t,
85                Err(_) => return -1011,
86            };
87            let a_slice = from_raw_parts_mut(a, m * lda);
88            let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
89            let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
90            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
91            // Call LAPACK function and adjust info
92            func_(&(m as _), &(n as _), a_t.as_mut_ptr() as *mut _, &(lda_t as _), ipiv, &mut info);
93            if info != 0 {
94                return info;
95            }
96            // Transpose output matrices
97            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
98        }
99        return info;
100    }
101}