rstsr_openblas/driver_impl/lapack/solve/
sysv.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::Complex;
4use rstsr_blas_traits::prelude::*;
5use rstsr_common::prelude_dev::*;
6
7use rstsr_native_impl::prelude_dev::*;
8use std::slice::from_raw_parts_mut;
9
10#[duplicate_item(
11 T func_ ;
12 [f32] [ssysv_];
13 [f64] [dsysv_];
14)]
15impl<const HERMI: bool> SYSVDriverAPI<T, HERMI> for DeviceBLAS {
16 unsafe fn driver_sysv(
17 order: FlagOrder,
18 uplo: FlagUpLo,
19 n: usize,
20 nrhs: usize,
21 a: *mut T,
22 lda: usize,
23 ipiv: *mut blas_int,
24 b: *mut T,
25 ldb: usize,
26 ) -> blas_int {
27 use lapack_ffi::lapack::func_;
28
29 let mut info = 0;
31 let lwork = -1;
32 let mut work_query = 0.0;
33 func_(
34 &uplo.into(),
35 &(n as _),
36 &(nrhs as _),
37 a,
38 &(n as _),
39 ipiv,
40 b,
41 &(n as _),
42 &mut work_query,
43 &lwork,
44 &mut info,
45 );
46 if info != 0 {
47 return info;
48 }
49 let lwork = work_query as usize;
50
51 let mut work: Vec<T> = match uninitialized_vec(lwork) {
53 Ok(work) => work,
54 Err(_) => return -1010,
55 };
56
57 if order == ColMajor {
58 func_(
60 &uplo.into(),
61 &(n as _),
62 &(nrhs as _),
63 a,
64 &(lda as _),
65 ipiv,
66 b,
67 &(ldb as _),
68 work.as_mut_ptr(),
69 &(lwork as _),
70 &mut info,
71 );
72 if info != 0 {
73 return info;
74 }
75 } else {
76 let lda_t = n.max(1);
77 let ldb_t = n.max(1);
78 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
80 Ok(a_t) => a_t,
81 Err(_) => return -1011,
82 };
83 let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
84 Ok(b_t) => b_t,
85 Err(_) => return -1011,
86 };
87 let a_slice = from_raw_parts_mut(a, n * lda);
88 let b_slice = from_raw_parts_mut(b, n * ldb);
89 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
90 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
91 let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
92 let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
93 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
94 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
95 func_(
97 &uplo.into(),
98 &(n as _),
99 &(nrhs as _),
100 a_t.as_mut_ptr(),
101 &(lda_t as _),
102 ipiv,
103 b_t.as_mut_ptr(),
104 &(ldb_t as _),
105 work.as_mut_ptr(),
106 &(lwork as _),
107 &mut info,
108 );
109 if info != 0 {
110 return info;
111 }
112 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
114 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
115 }
116 return info;
117 }
118}
119
120#[duplicate_item(
121 T func_ HERMI ;
122 [Complex<f32>] [csysv_] [false];
123 [Complex<f32>] [chesv_] [true ];
124 [Complex<f64>] [zsysv_] [false];
125 [Complex<f64>] [zhesv_] [true ];
126)]
127impl SYSVDriverAPI<T, HERMI> for DeviceBLAS {
128 unsafe fn driver_sysv(
129 order: FlagOrder,
130 uplo: FlagUpLo,
131 n: usize,
132 nrhs: usize,
133 a: *mut T,
134 lda: usize,
135 ipiv: *mut blas_int,
136 b: *mut T,
137 ldb: usize,
138 ) -> blas_int {
139 use lapack_ffi::lapack::func_;
140
141 let mut info = 0;
143 let lwork = -1;
144 let mut work_query = 0.0;
145 func_(
146 &uplo.into(),
147 &(n as _),
148 &(nrhs as _),
149 a as *mut _,
150 &(n as _),
151 ipiv,
152 b as *mut _,
153 &(n as _),
154 &mut work_query as *mut _ as *mut _,
155 &lwork,
156 &mut info,
157 );
158 if info != 0 {
159 return info;
160 }
161 let lwork = work_query as usize;
162
163 let mut work: Vec<T> = match uninitialized_vec(lwork) {
165 Ok(work) => work,
166 Err(_) => return -1010,
167 };
168
169 if order == ColMajor {
170 func_(
172 &uplo.into(),
173 &(n as _),
174 &(nrhs as _),
175 a as *mut _,
176 &(lda as _),
177 ipiv,
178 b as *mut _,
179 &(ldb as _),
180 work.as_mut_ptr() as *mut _,
181 &(lwork as _),
182 &mut info,
183 );
184 if info != 0 {
185 return info;
186 }
187 } else {
188 let lda_t = n.max(1);
189 let ldb_t = n.max(1);
190 let mut a_t: Vec<T> = match uninitialized_vec(n * n) {
192 Ok(a_t) => a_t,
193 Err(_) => return -1011,
194 };
195 let mut b_t: Vec<T> = match uninitialized_vec(n * nrhs) {
196 Ok(b_t) => b_t,
197 Err(_) => return -1011,
198 };
199 let a_slice = from_raw_parts_mut(a, n * lda);
200 let b_slice = from_raw_parts_mut(b, n * ldb);
201 let la = Layout::new_unchecked([n, n], [lda as isize, 1], 0);
202 let la_t = Layout::new_unchecked([n, n], [1, lda_t as isize], 0);
203 let lb = Layout::new_unchecked([n, nrhs], [ldb as isize, 1], 0);
204 let lb_t = Layout::new_unchecked([n, nrhs], [1, ldb_t as isize], 0);
205 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
206 orderchange_out_r2c_ix2_cpu_serial(&mut b_t, &lb_t, b_slice, &lb).unwrap();
207 func_(
209 &uplo.into(),
210 &(n as _),
211 &(nrhs as _),
212 a_t.as_mut_ptr() as *mut _,
213 &(lda_t as _),
214 ipiv,
215 b_t.as_mut_ptr() as *mut _,
216 &(ldb_t as _),
217 work.as_mut_ptr() as *mut _,
218 &(lwork as _),
219 &mut info,
220 );
221 if info != 0 {
222 return info;
223 }
224 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
226 orderchange_out_c2r_ix2_cpu_serial(b_slice, &lb, &b_t, &lb_t).unwrap();
227 }
228 return info;
229 }
230}