rstsr_openblas/driver_impl/lapack/solve/
gesv.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6use rstsr_native_impl::prelude_dev::*;
7use std::slice::from_raw_parts_mut;
8
9#[duplicate_item(
10 T func_ ;
11 [f32] [sgesv_];
12 [f64] [dgesv_];
13)]
14impl GESVDriverAPI<T> for DeviceBLAS {
15 unsafe fn driver_gesv(
16 order: FlagOrder,
17 n: usize,
18 nrhs: usize,
19 a: *mut T,
20 lda: usize,
21 ipiv: *mut blas_int,
22 b: *mut T,
23 ldb: usize,
24 ) -> blas_int {
25 use lapack_ffi::lapack::func_;
26
27 let mut info = 0;
28
29 if order == ColMajor {
30 func_(&(n as _), &(nrhs as _), a, &(lda as _), ipiv, b, &(ldb as _), &mut info);
31 if info != 0 {
32 return info;
33 }
34 } else {
35 let lda_t = n.max(1);
36 let ldb_t = n.max(1);
37 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
39 Ok(a_t) => a_t,
40 Err(_) => return -1011,
41 };
42 let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
43 Ok(b_t) => b_t,
44 Err(_) => return -1011,
45 };
46 let a_slice = from_raw_parts_mut(a, n * lda);
47 let b_slice = from_raw_parts_mut(b, n * ldb);
48 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
49 let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
50 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
51 let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
52 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
53 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
54 func_(
56 &(n as _),
57 &(nrhs as _),
58 a_t.as_mut_ptr(),
59 &(lda_t as _),
60 ipiv,
61 b_t.as_mut_ptr(),
62 &(ldb_t as _),
63 &mut info,
64 );
65 if info != 0 {
66 return info;
67 }
68 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
70 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
71 }
72 return info;
73 }
74}
75
76#[duplicate_item(
77 T func_ ;
78 [Complex<f32>] [cgesv_];
79 [Complex<f64>] [zgesv_];
80)]
81impl GESVDriverAPI<T> for DeviceBLAS {
82 unsafe fn driver_gesv(
83 order: FlagOrder,
84 n: usize,
85 nrhs: usize,
86 a: *mut T,
87 lda: usize,
88 ipiv: *mut blas_int,
89 b: *mut T,
90 ldb: usize,
91 ) -> blas_int {
92 use lapack_ffi::lapack::func_;
93
94 let mut info = 0;
95
96 if order == ColMajor {
97 func_(&(n as _), &(nrhs as _), a as *mut _, &(lda as _), ipiv, b as *mut _, &(ldb as _), &mut info);
98 if info != 0 {
99 return info;
100 }
101 } else {
102 let lda_t = n.max(1);
103 let ldb_t = n.max(1);
104 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
106 Ok(a_t) => a_t,
107 Err(_) => return -1011,
108 };
109 let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
110 Ok(b_t) => b_t,
111 Err(_) => return -1011,
112 };
113 let a_slice = from_raw_parts_mut(a, n * lda);
114 let b_slice = from_raw_parts_mut(b, n * ldb);
115 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
116 let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
117 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
118 let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
119 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
120 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
121 func_(
123 &(n as _),
124 &(nrhs as _),
125 a_t.as_mut_ptr() as *mut _,
126 &(lda_t as _),
127 ipiv,
128 b_t.as_mut_ptr() as *mut _,
129 &(ldb_t as _),
130 &mut info,
131 );
132 if info != 0 {
133 return info;
134 }
135 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
137 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
138 }
139 return info;
140 }
141}