1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7
8use rstsr_native_impl::prelude_dev::*;
9use std::slice::from_raw_parts_mut;
10
11#[duplicate_item(
12 T func_ ;
13 [f32] [sgesdd_];
14 [f64] [dgesdd_];
15)]
16impl GESDDDriverAPI<T> for DeviceBLAS {
17 unsafe fn driver_gesdd(
18 order: FlagOrder,
19 jobz: char,
20 m: usize,
21 n: usize,
22 a: *mut T,
23 lda: usize,
24 s: *mut T,
25 u: *mut T,
26 ldu: usize,
27 vt: *mut T,
28 ldvt: usize,
29 ) -> blas_int {
30 use lapack_ffi::lapack::func_;
31
32 let liwork = 8 * m.min(n);
34 let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
35 Ok(iwork) => iwork,
36 Err(_) => return -1010,
37 };
38
39 let mut info = 0;
41 let lwork = -1;
42 let mut work_query = 0.0;
43 func_(
44 &(jobz as _),
45 &(m as _),
46 &(n as _),
47 a,
48 &(m.max(n) as _),
49 s,
50 u,
51 &(m.max(n) as _),
52 vt,
53 &(m.max(n) as _),
54 &mut work_query,
55 &lwork,
56 iwork.as_mut_ptr(),
57 &mut info,
58 );
59 if info != 0 {
60 return info;
61 }
62 let lwork = work_query as usize;
63
64 let mut work: Vec<T> = match uninitialized_vec(lwork) {
66 Ok(work) => work,
67 Err(_) => return -1010,
68 };
69
70 if order == ColMajor {
71 func_(
73 &(jobz as _),
74 &(m as _),
75 &(n as _),
76 a,
77 &(lda as _),
78 s,
79 u,
80 &(ldu as _),
81 vt,
82 &(ldvt as _),
83 work.as_mut_ptr(),
84 &(lwork as _),
85 iwork.as_mut_ptr(),
86 &mut info,
87 );
88 if info != 0 {
89 return info;
90 }
91 } else {
92 let lda_t = m.max(1);
93 let nrows_u = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { m } else { 1 };
94 let ncols_u = if jobz == 'A' || (jobz == 'O' && m < n) {
95 m
96 } else if jobz == 'S' {
97 m.min(n)
98 } else {
99 1
100 };
101 let nrows_vt = if jobz == 'A' || (jobz == 'O' && m >= n) {
102 n
103 } else if jobz == 'S' {
104 m.min(n)
105 } else {
106 1
107 };
108 let ldu_t = nrows_u.max(1);
109 let ldvt_t = nrows_vt.max(1);
110
111 let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
113 Ok(a_t) => a_t,
114 Err(_) => return -1011,
115 };
116 let a_slice = from_raw_parts_mut(a, m * lda);
117 let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
118 let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
119 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
120
121 let mut u_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) {
122 match uninitialized_vec(nrows_u * ncols_u) {
123 Ok(u_t) => Some(u_t),
124 Err(_) => return -1011,
125 }
126 } else {
127 None
128 };
129
130 let mut vt_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m >= n) {
131 match uninitialized_vec(nrows_vt * n) {
132 Ok(vt_t) => Some(vt_t),
133 Err(_) => return -1011,
134 }
135 } else {
136 None
137 };
138
139 func_(
141 &(jobz as _),
142 &(m as _),
143 &(n as _),
144 a_t.as_mut_ptr(),
145 &(lda_t as _),
146 s,
147 u_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()),
148 &(ldu_t as _),
149 vt_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()),
150 &(ldvt_t as _),
151 work.as_mut_ptr(),
152 &(lwork as _),
153 iwork.as_mut_ptr(),
154 &mut info,
155 );
156 if info != 0 {
157 return info;
158 }
159
160 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
162
163 if let Some(u_t) = u_t {
164 let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
165 let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
166 let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
167 orderchange_out_c2r_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
168 }
169
170 if let Some(vt_t) = vt_t {
171 let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
172 let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
173 let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
174 orderchange_out_c2r_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
175 }
176 }
177 return info;
178 }
179}
180
181#[duplicate_item(
182 T func_ ;
183 [Complex<f32>] [cgesdd_];
184 [Complex<f64>] [zgesdd_];
185)]
186impl GESDDDriverAPI<T> for DeviceBLAS {
187 unsafe fn driver_gesdd(
188 order: FlagOrder,
189 jobz: char,
190 m: usize,
191 n: usize,
192 a: *mut T,
193 lda: usize,
194 s: *mut <T as ComplexFloat>::Real,
195 u: *mut T,
196 ldu: usize,
197 vt: *mut T,
198 ldvt: usize,
199 ) -> blas_int {
200 use lapack_ffi::lapack::func_;
201
202 let liwork = 8 * m.min(n);
204 let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
205 Ok(iwork) => iwork,
206 Err(_) => return -1010,
207 };
208
209 let mut info = 0;
211 let lwork = -1;
212 let lrwork =
213 if jobz == 'N' { 7 * m.min(n) } else { m.min(n) * (5 * m.min(n) + 7).max(2 * m.max(n) + 2 * m.min(n) + 1) };
214 let mut work_query = Complex::new(0.0, 0.0);
215 let mut rwork_query = 0.0;
216 func_(
217 &(jobz as _),
218 &(m as _),
219 &(n as _),
220 a as *mut _,
221 &(m.max(n) as _),
222 s as *mut _,
223 u as *mut _,
224 &(m.max(n) as _),
225 vt as *mut _,
226 &(m.max(n) as _),
227 &mut work_query as *mut _ as *mut _,
228 &lwork,
229 &mut rwork_query as *mut _ as *mut _,
230 iwork.as_mut_ptr(),
231 &mut info,
232 );
233 if info != 0 {
234 return info;
235 }
236 let lwork = work_query.re as usize;
237
238 let mut work: Vec<T> = match uninitialized_vec(lwork) {
240 Ok(work) => work,
241 Err(_) => return -1010,
242 };
243 let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(lrwork) {
244 Ok(rwork) => rwork,
245 Err(_) => return -1010,
246 };
247
248 if order == ColMajor {
249 func_(
251 &(jobz as _),
252 &(m as _),
253 &(n as _),
254 a as *mut _,
255 &(lda as _),
256 s as *mut _,
257 u as *mut _,
258 &(ldu as _),
259 vt as *mut _,
260 &(ldvt as _),
261 work.as_mut_ptr() as *mut _,
262 &(lwork as _),
263 rwork.as_mut_ptr() as *mut _,
264 iwork.as_mut_ptr(),
265 &mut info,
266 );
267 if info != 0 {
268 return info;
269 }
270 } else {
271 let lda_t = m.max(1);
272 let nrows_u = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { m } else { 1 };
273 let ncols_u = if jobz == 'A' || (jobz == 'O' && m < n) {
274 m
275 } else if jobz == 'S' {
276 m.min(n)
277 } else {
278 1
279 };
280 let nrows_vt = if jobz == 'A' || (jobz == 'O' && m >= n) {
281 n
282 } else if jobz == 'S' {
283 m.min(n)
284 } else {
285 1
286 };
287 let ldu_t = nrows_u.max(1);
288 let ldvt_t = nrows_vt.max(1);
289
290 let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
292 Ok(a_t) => a_t,
293 Err(_) => return -1011,
294 };
295 let a_slice = from_raw_parts_mut(a, m * lda);
296 let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
297 let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
298 orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
299
300 let mut u_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) {
301 match uninitialized_vec(nrows_u * ncols_u) {
302 Ok(u_t) => Some(u_t),
303 Err(_) => return -1011,
304 }
305 } else {
306 None
307 };
308
309 let mut vt_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m >= n) {
310 match uninitialized_vec(nrows_vt * n) {
311 Ok(vt_t) => Some(vt_t),
312 Err(_) => return -1011,
313 }
314 } else {
315 None
316 };
317
318 func_(
320 &(jobz as _),
321 &(m as _),
322 &(n as _),
323 a_t.as_mut_ptr() as *mut _,
324 &(lda_t as _),
325 s as *mut _,
326 u_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()) as *mut _,
327 &(ldu_t as _),
328 vt_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()) as *mut _,
329 &(ldvt_t as _),
330 work.as_mut_ptr() as *mut _,
331 &(lwork as _),
332 rwork.as_mut_ptr() as *mut _,
333 iwork.as_mut_ptr(),
334 &mut info,
335 );
336 if info != 0 {
337 return info;
338 }
339
340 orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
342
343 if let Some(u_t) = u_t {
344 let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
345 let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
346 let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
347 orderchange_out_c2r_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
348 }
349
350 if let Some(vt_t) = vt_t {
351 let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
352 let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
353 let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
354 orderchange_out_c2r_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
355 }
356 }
357 return info;
358 }
359}