rstsr_openblas/driver_impl/lapack/solve/
gesv.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6use rstsr_native_impl::prelude_dev::*;
7use std::slice::from_raw_parts_mut;
8
9#[duplicate_item(
10    T     func_   ;
11   [f32] [sgesv_];
12   [f64] [dgesv_];
13)]
14impl GESVDriverAPI<T> for DeviceBLAS {
15    unsafe fn driver_gesv(
16        order: FlagOrder,
17        n: usize,
18        nrhs: usize,
19        a: *mut T,
20        lda: usize,
21        ipiv: *mut blas_int,
22        b: *mut T,
23        ldb: usize,
24    ) -> blas_int {
25        use lapack_ffi::lapack::func_;
26
27        let mut info = 0;
28
29        if order == ColMajor {
30            func_(&(n as _), &(nrhs as _), a, &(lda as _), ipiv, b, &(ldb as _), &mut info);
31            if info != 0 {
32                return info;
33            }
34        } else {
35            let lda_t = n.max(1);
36            let ldb_t = n.max(1);
37            // Transpose input matrices
38            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
39                Ok(a_t) => a_t,
40                Err(_) => return -1011,
41            };
42            let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
43                Ok(b_t) => b_t,
44                Err(_) => return -1011,
45            };
46            let a_slice = from_raw_parts_mut(a, n * lda);
47            let b_slice = from_raw_parts_mut(b, n * ldb);
48            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
49            let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
50            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
51            let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
52            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
53            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
54            // Call LAPACK function
55            func_(
56                &(n as _),
57                &(nrhs as _),
58                a_t.as_mut_ptr(),
59                &(lda_t as _),
60                ipiv,
61                b_t.as_mut_ptr(),
62                &(ldb_t as _),
63                &mut info,
64            );
65            if info != 0 {
66                return info;
67            }
68            // Transpose output matrices
69            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
70            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
71        }
72        return info;
73    }
74}
75
76#[duplicate_item(
77    T              func_   ;
78   [Complex<f32>] [cgesv_];
79   [Complex<f64>] [zgesv_];
80)]
81impl GESVDriverAPI<T> for DeviceBLAS {
82    unsafe fn driver_gesv(
83        order: FlagOrder,
84        n: usize,
85        nrhs: usize,
86        a: *mut T,
87        lda: usize,
88        ipiv: *mut blas_int,
89        b: *mut T,
90        ldb: usize,
91    ) -> blas_int {
92        use lapack_ffi::lapack::func_;
93
94        let mut info = 0;
95
96        if order == ColMajor {
97            func_(&(n as _), &(nrhs as _), a as *mut _, &(lda as _), ipiv, b as *mut _, &(ldb as _), &mut info);
98            if info != 0 {
99                return info;
100            }
101        } else {
102            let lda_t = n.max(1);
103            let ldb_t = n.max(1);
104            // Transpose input matrices
105            let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
106                Ok(a_t) => a_t,
107                Err(_) => return -1011,
108            };
109            let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
110                Ok(b_t) => b_t,
111                Err(_) => return -1011,
112            };
113            let a_slice = from_raw_parts_mut(a, n * lda);
114            let b_slice = from_raw_parts_mut(b, n * ldb);
115            let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
116            let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
117            let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
118            let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
119            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
120            orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
121            // Call LAPACK function
122            func_(
123                &(n as _),
124                &(nrhs as _),
125                a_t.as_mut_ptr() as *mut _,
126                &(lda_t as _),
127                ipiv,
128                b_t.as_mut_ptr() as *mut _,
129                &(ldb_t as _),
130                &mut info,
131            );
132            if info != 0 {
133                return info;
134            }
135            // Transpose output matrices
136            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
137            orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
138        }
139        return info;
140    }
141}