1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::{Complex, Zero};
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7
8use rstsr_native_impl::prelude_dev::*;
9use std::slice::from_raw_parts_mut;
10
11#[duplicate_item(
12 T func_ ;
13 [f32] [sgesvd_];
14 [f64] [dgesvd_];
15)]
16impl GESVDDriverAPI<T> for DeviceBLAS {
17 unsafe fn driver_gesvd(
18 order: FlagOrder,
19 jobu: char,
20 jobvt: char,
21 m: usize,
22 n: usize,
23 a: *mut T,
24 lda: usize,
25 s: *mut T,
26 u: *mut T,
27 ldu: usize,
28 vt: *mut T,
29 ldvt: usize,
30 superb: *mut T,
31 ) -> blas_int {
32 use lapack_ffi::lapack::func_;
33
34 let mut info = 0;
36 let lwork = -1;
37 let mut work_query = 0.0;
38 func_(
39 &(jobu as _),
40 &(jobvt as _),
41 &(m as _),
42 &(n as _),
43 a,
44 &(lda as _),
45 s,
46 u,
47 &(ldu as _),
48 vt,
49 &(ldvt as _),
50 &mut work_query,
51 &lwork,
52 &mut info,
53 );
54 if info != 0 {
55 return info;
56 }
57 let lwork = work_query as usize;
58
59 let mut work: Vec<T> = match uninitialized_vec(lwork) {
61 Ok(work) => work,
62 Err(_) => return -1010,
63 };
64
65 if order == ColMajor {
66 func_(
68 &(jobu as _),
69 &(jobvt as _),
70 &(m as _),
71 &(n as _),
72 a,
73 &(lda as _),
74 s,
75 u,
76 &(ldu as _),
77 vt,
78 &(ldvt as _),
79 work.as_mut_ptr(),
80 &(lwork as _),
81 &mut info,
82 );
83 if info != 0 {
84 return info;
85 }
86 } else {
87 let lda_t = m.max(1);
88 let nrows_u = if jobu == 'A' || jobu == 'S' { m } else { 1 };
89 let ncols_u = if jobu == 'A' {
90 m
91 } else if jobu == 'S' {
92 m.min(n)
93 } else {
94 1
95 };
96 let nrows_vt = if jobvt == 'A' {
97 n
98 } else if jobvt == 'S' {
99 m.min(n)
100 } else {
101 1
102 };
103 let ldu_t = nrows_u.max(1);
104 let ldvt_t = nrows_vt.max(1);
105
106 let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
108 Ok(a_t) => a_t,
109 Err(_) => return -1011,
110 };
111 let a_slice = from_raw_parts_mut(a, m * lda);
112 let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
113 let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
114 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
115
116 let mut u_t = if jobu == 'A' || jobu == 'S' {
117 match uninitialized_vec(nrows_u * ncols_u) {
118 Ok(u_t) => u_t,
119 Err(_) => return -1011,
120 }
121 } else {
122 Vec::new()
123 };
124
125 let mut vt_t = if jobvt == 'A' || jobvt == 'S' {
126 match uninitialized_vec(nrows_vt * n) {
127 Ok(vt_t) => vt_t,
128 Err(_) => return -1011,
129 }
130 } else {
131 Vec::new()
132 };
133
134 func_(
136 &(jobu as _),
137 &(jobvt as _),
138 &(m as _),
139 &(n as _),
140 a_t.as_mut_ptr(),
141 &(lda_t as _),
142 s,
143 if jobu == 'A' || jobu == 'S' { u_t.as_mut_ptr() } else { u },
144 &(ldu_t as _),
145 if jobvt == 'A' || jobvt == 'S' { vt_t.as_mut_ptr() } else { vt },
146 &(ldvt_t as _),
147 work.as_mut_ptr(),
148 &(lwork as _),
149 &mut info,
150 );
151 if info != 0 {
152 return info;
153 }
154
155 orderchange_out_r2c_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
157
158 if jobu == 'A' || jobu == 'S' {
159 let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
160 let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
161 let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
162 orderchange_out_r2c_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
163 }
164
165 if jobvt == 'A' || jobvt == 'S' {
166 let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
167 let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
168 let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
169 orderchange_out_r2c_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
170 }
171 }
172
173 let min_mn = m.min(n);
175 for i in 0..min_mn - 1 {
176 superb.add(i).write(work[i + 1]);
177 }
178
179 return info;
180 }
181}
182
183#[duplicate_item(
184 T func_ ;
185 [Complex<f32>] [cgesvd_];
186 [Complex<f64>] [zgesvd_];
187)]
188impl GESVDDriverAPI<T> for DeviceBLAS {
189 unsafe fn driver_gesvd(
190 order: FlagOrder,
191 jobu: char,
192 jobvt: char,
193 m: usize,
194 n: usize,
195 a: *mut T,
196 lda: usize,
197 s: *mut <T as ComplexFloat>::Real,
198 u: *mut T,
199 ldu: usize,
200 vt: *mut T,
201 ldvt: usize,
202 superb: *mut <T as ComplexFloat>::Real,
203 ) -> blas_int {
204 use lapack_ffi::lapack::func_;
205
206 let min_mn = m.min(n);
208 let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(5 * min_mn) {
209 Ok(rwork) => rwork,
210 Err(_) => return -1010,
211 };
212
213 let mut info = 0;
215 let lwork = -1;
216 let mut work_query = <T as Zero>::zero();
217 func_(
218 &(jobu as _),
219 &(jobvt as _),
220 &(m as _),
221 &(n as _),
222 a as *mut _,
223 &(lda as _),
224 s as *mut _,
225 u as *mut _,
226 &(ldu as _),
227 vt as *mut _,
228 &(ldvt as _),
229 &mut work_query as *mut _ as *mut _,
230 &lwork,
231 rwork.as_mut_ptr() as *mut _,
232 &mut info,
233 );
234 if info != 0 {
235 return info;
236 }
237 let lwork = work_query.re() as usize;
238
239 let mut work: Vec<T> = match uninitialized_vec(lwork) {
241 Ok(work) => work,
242 Err(_) => return -1010,
243 };
244
245 if order == ColMajor {
246 func_(
248 &(jobu as _),
249 &(jobvt as _),
250 &(m as _),
251 &(n as _),
252 a as *mut _,
253 &(lda as _),
254 s as *mut _,
255 u as *mut _,
256 &(ldu as _),
257 vt as *mut _,
258 &(ldvt as _),
259 work.as_mut_ptr() as *mut _,
260 &(lwork as _),
261 rwork.as_mut_ptr() as *mut _,
262 &mut info,
263 );
264 if info != 0 {
265 return info;
266 }
267 } else {
268 let lda_t = m.max(1);
269 let nrows_u = if jobu == 'A' || jobu == 'S' { m } else { 1 };
270 let ncols_u = if jobu == 'A' {
271 m
272 } else if jobu == 'S' {
273 m.min(n)
274 } else {
275 1
276 };
277 let nrows_vt = if jobvt == 'A' {
278 n
279 } else if jobvt == 'S' {
280 m.min(n)
281 } else {
282 1
283 };
284 let ldu_t = nrows_u.max(1);
285 let ldvt_t = nrows_vt.max(1);
286
287 let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
289 Ok(a_t) => a_t,
290 Err(_) => return -1011,
291 };
292 let a_slice = from_raw_parts_mut(a, m * lda);
293 let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
294 let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
295 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
296
297 let mut u_t = if jobu == 'A' || jobu == 'S' {
298 match uninitialized_vec(nrows_u * ncols_u) {
299 Ok(u_t) => u_t,
300 Err(_) => return -1011,
301 }
302 } else {
303 Vec::new()
304 };
305
306 let mut vt_t = if jobvt == 'A' || jobvt == 'S' {
307 match uninitialized_vec(nrows_vt * n) {
308 Ok(vt_t) => vt_t,
309 Err(_) => return -1011,
310 }
311 } else {
312 Vec::new()
313 };
314
315 func_(
317 &(jobu as _),
318 &(jobvt as _),
319 &(m as _),
320 &(n as _),
321 a_t.as_mut_ptr() as *mut _,
322 &(lda_t as _),
323 s as *mut _,
324 if jobu == 'A' || jobu == 'S' { u_t.as_mut_ptr() as *mut _ } else { u as *mut _ },
325 &(ldu_t as _),
326 if jobvt == 'A' || jobvt == 'S' { vt_t.as_mut_ptr() as *mut _ } else { vt as *mut _ },
327 &(ldvt_t as _),
328 work.as_mut_ptr() as *mut _,
329 &(lwork as _),
330 rwork.as_mut_ptr() as *mut _,
331 &mut info,
332 );
333 if info != 0 {
334 return info;
335 }
336
337 orderchange_out_r2c_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
339
340 if jobu == 'A' || jobu == 'S' {
341 let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
342 let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
343 let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
344 orderchange_out_r2c_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
345 }
346
347 if jobvt == 'A' || jobvt == 'S' {
348 let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
349 let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
350 let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
351 orderchange_out_r2c_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
352 }
353 }
354
355 #[allow(clippy::needless_range_loop)]
357 for i in 0..min_mn - 1 {
358 superb.add(i).write(rwork[i]);
359 }
360
361 return info;
362 }
363}