rstsr_openblas/driver_impl/lapack/svd/
gesdd.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::Complex;
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7
8use rstsr_native_impl::prelude_dev::*;
9use std::slice::from_raw_parts_mut;
10
11#[duplicate_item(
12    T     func_   ;
13   [f32] [sgesdd_];
14   [f64] [dgesdd_];
15)]
16impl GESDDDriverAPI<T> for DeviceBLAS {
17    unsafe fn driver_gesdd(
18        order: FlagOrder,
19        jobz: char,
20        m: usize,
21        n: usize,
22        a: *mut T,
23        lda: usize,
24        s: *mut T,
25        u: *mut T,
26        ldu: usize,
27        vt: *mut T,
28        ldvt: usize,
29    ) -> blas_int {
30        use lapack_ffi::lapack::func_;
31
32        // Allocate memory for temporary array(s)
33        let liwork = 8 * m.min(n);
34        let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
35            Ok(iwork) => iwork,
36            Err(_) => return -1010,
37        };
38
39        // Query optimal working array(s) size
40        let mut info = 0;
41        let lwork = -1;
42        let mut work_query = 0.0;
43        func_(
44            &(jobz as _),
45            &(m as _),
46            &(n as _),
47            a,
48            &(m.max(n) as _),
49            s,
50            u,
51            &(m.max(n) as _),
52            vt,
53            &(m.max(n) as _),
54            &mut work_query,
55            &lwork,
56            iwork.as_mut_ptr(),
57            &mut info,
58        );
59        if info != 0 {
60            return info;
61        }
62        let lwork = work_query as usize;
63
64        // Allocate memory for temporary array(s)
65        let mut work: Vec<T> = match uninitialized_vec(lwork) {
66            Ok(work) => work,
67            Err(_) => return -1010,
68        };
69
70        if order == ColMajor {
71            // Call LAPACK function and adjust info
72            func_(
73                &(jobz as _),
74                &(m as _),
75                &(n as _),
76                a,
77                &(lda as _),
78                s,
79                u,
80                &(ldu as _),
81                vt,
82                &(ldvt as _),
83                work.as_mut_ptr(),
84                &(lwork as _),
85                iwork.as_mut_ptr(),
86                &mut info,
87            );
88            if info != 0 {
89                return info;
90            }
91        } else {
92            let lda_t = m.max(1);
93            let nrows_u = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { m } else { 1 };
94            let ncols_u = if jobz == 'A' || (jobz == 'O' && m < n) {
95                m
96            } else if jobz == 'S' {
97                m.min(n)
98            } else {
99                1
100            };
101            let nrows_vt = if jobz == 'A' || (jobz == 'O' && m >= n) {
102                n
103            } else if jobz == 'S' {
104                m.min(n)
105            } else {
106                1
107            };
108            let ldu_t = nrows_u.max(1);
109            let ldvt_t = nrows_vt.max(1);
110
111            // Transpose input matrices
112            let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
113                Ok(a_t) => a_t,
114                Err(_) => return -1011,
115            };
116            let a_slice = from_raw_parts_mut(a, m * lda);
117            let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
118            let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
119            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
120
121            let mut u_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) {
122                match uninitialized_vec(nrows_u * ncols_u) {
123                    Ok(u_t) => Some(u_t),
124                    Err(_) => return -1011,
125                }
126            } else {
127                None
128            };
129
130            let mut vt_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m >= n) {
131                match uninitialized_vec(nrows_vt * n) {
132                    Ok(vt_t) => Some(vt_t),
133                    Err(_) => return -1011,
134                }
135            } else {
136                None
137            };
138
139            // Call LAPACK function and adjust info
140            func_(
141                &(jobz as _),
142                &(m as _),
143                &(n as _),
144                a_t.as_mut_ptr(),
145                &(lda_t as _),
146                s,
147                u_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()),
148                &(ldu_t as _),
149                vt_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()),
150                &(ldvt_t as _),
151                work.as_mut_ptr(),
152                &(lwork as _),
153                iwork.as_mut_ptr(),
154                &mut info,
155            );
156            if info != 0 {
157                return info;
158            }
159
160            // Transpose output matrices
161            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
162
163            if let Some(u_t) = u_t {
164                let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
165                let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
166                let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
167                orderchange_out_c2r_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
168            }
169
170            if let Some(vt_t) = vt_t {
171                let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
172                let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
173                let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
174                orderchange_out_c2r_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
175            }
176        }
177        return info;
178    }
179}
180
181#[duplicate_item(
182    T              func_   ;
183   [Complex<f32>] [cgesdd_];
184   [Complex<f64>] [zgesdd_];
185)]
186impl GESDDDriverAPI<T> for DeviceBLAS {
187    unsafe fn driver_gesdd(
188        order: FlagOrder,
189        jobz: char,
190        m: usize,
191        n: usize,
192        a: *mut T,
193        lda: usize,
194        s: *mut <T as ComplexFloat>::Real,
195        u: *mut T,
196        ldu: usize,
197        vt: *mut T,
198        ldvt: usize,
199    ) -> blas_int {
200        use lapack_ffi::lapack::func_;
201
202        // Allocate memory for temporary array(s)
203        let liwork = 8 * m.min(n);
204        let mut iwork: Vec<blas_int> = match uninitialized_vec(liwork) {
205            Ok(iwork) => iwork,
206            Err(_) => return -1010,
207        };
208
209        // Query optimal working array(s) size
210        let mut info = 0;
211        let lwork = -1;
212        let lrwork =
213            if jobz == 'N' { 7 * m.min(n) } else { m.min(n) * (5 * m.min(n) + 7).max(2 * m.max(n) + 2 * m.min(n) + 1) };
214        let mut work_query = Complex::new(0.0, 0.0);
215        let mut rwork_query = 0.0;
216        func_(
217            &(jobz as _),
218            &(m as _),
219            &(n as _),
220            a as *mut _,
221            &(m.max(n) as _),
222            s as *mut _,
223            u as *mut _,
224            &(m.max(n) as _),
225            vt as *mut _,
226            &(m.max(n) as _),
227            &mut work_query as *mut _ as *mut _,
228            &lwork,
229            &mut rwork_query as *mut _ as *mut _,
230            iwork.as_mut_ptr(),
231            &mut info,
232        );
233        if info != 0 {
234            return info;
235        }
236        let lwork = work_query.re as usize;
237
238        // Allocate memory for temporary array(s)
239        let mut work: Vec<T> = match uninitialized_vec(lwork) {
240            Ok(work) => work,
241            Err(_) => return -1010,
242        };
243        let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(lrwork) {
244            Ok(rwork) => rwork,
245            Err(_) => return -1010,
246        };
247
248        if order == ColMajor {
249            // Call LAPACK function and adjust info
250            func_(
251                &(jobz as _),
252                &(m as _),
253                &(n as _),
254                a as *mut _,
255                &(lda as _),
256                s as *mut _,
257                u as *mut _,
258                &(ldu as _),
259                vt as *mut _,
260                &(ldvt as _),
261                work.as_mut_ptr() as *mut _,
262                &(lwork as _),
263                rwork.as_mut_ptr() as *mut _,
264                iwork.as_mut_ptr(),
265                &mut info,
266            );
267            if info != 0 {
268                return info;
269            }
270        } else {
271            let lda_t = m.max(1);
272            let nrows_u = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) { m } else { 1 };
273            let ncols_u = if jobz == 'A' || (jobz == 'O' && m < n) {
274                m
275            } else if jobz == 'S' {
276                m.min(n)
277            } else {
278                1
279            };
280            let nrows_vt = if jobz == 'A' || (jobz == 'O' && m >= n) {
281                n
282            } else if jobz == 'S' {
283                m.min(n)
284            } else {
285                1
286            };
287            let ldu_t = nrows_u.max(1);
288            let ldvt_t = nrows_vt.max(1);
289
290            // Transpose input matrices
291            let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
292                Ok(a_t) => a_t,
293                Err(_) => return -1011,
294            };
295            let a_slice = from_raw_parts_mut(a, m * lda);
296            let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
297            let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
298            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
299
300            let mut u_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m < n) {
301                match uninitialized_vec(nrows_u * ncols_u) {
302                    Ok(u_t) => Some(u_t),
303                    Err(_) => return -1011,
304                }
305            } else {
306                None
307            };
308
309            let mut vt_t = if jobz == 'A' || jobz == 'S' || (jobz == 'O' && m >= n) {
310                match uninitialized_vec(nrows_vt * n) {
311                    Ok(vt_t) => Some(vt_t),
312                    Err(_) => return -1011,
313                }
314            } else {
315                None
316            };
317
318            // Call LAPACK function and adjust info
319            func_(
320                &(jobz as _),
321                &(m as _),
322                &(n as _),
323                a_t.as_mut_ptr() as *mut _,
324                &(lda_t as _),
325                s as *mut _,
326                u_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()) as *mut _,
327                &(ldu_t as _),
328                vt_t.as_mut().map_or(std::ptr::null_mut(), |v| v.as_mut_ptr()) as *mut _,
329                &(ldvt_t as _),
330                work.as_mut_ptr() as *mut _,
331                &(lwork as _),
332                rwork.as_mut_ptr() as *mut _,
333                iwork.as_mut_ptr(),
334                &mut info,
335            );
336            if info != 0 {
337                return info;
338            }
339
340            // Transpose output matrices
341            orderchange_out_c2r_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
342
343            if let Some(u_t) = u_t {
344                let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
345                let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
346                let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
347                orderchange_out_c2r_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
348            }
349
350            if let Some(vt_t) = vt_t {
351                let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
352                let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
353                let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
354                orderchange_out_c2r_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
355            }
356        }
357        return info;
358    }
359}