rstsr_openblas/driver_impl/lapack/svd/
gesvd.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use num::complex::ComplexFloat;
4use num::{Complex, Zero};
5use rstsr_blas_traits::prelude::*;
6use rstsr_common::prelude_dev::*;
7
8use rstsr_native_impl::prelude_dev::*;
9use std::slice::from_raw_parts_mut;
10
11#[duplicate_item(
12    T     func_   ;
13   [f32] [sgesvd_];
14   [f64] [dgesvd_];
15)]
16impl GESVDDriverAPI<T> for DeviceBLAS {
17    unsafe fn driver_gesvd(
18        order: FlagOrder,
19        jobu: char,
20        jobvt: char,
21        m: usize,
22        n: usize,
23        a: *mut T,
24        lda: usize,
25        s: *mut T,
26        u: *mut T,
27        ldu: usize,
28        vt: *mut T,
29        ldvt: usize,
30        superb: *mut T,
31    ) -> blas_int {
32        use lapack_ffi::lapack::func_;
33
34        // Query optimal working array size
35        let mut info = 0;
36        let lwork = -1;
37        let mut work_query = 0.0;
38        func_(
39            &(jobu as _),
40            &(jobvt as _),
41            &(m as _),
42            &(n as _),
43            a,
44            &(lda as _),
45            s,
46            u,
47            &(ldu as _),
48            vt,
49            &(ldvt as _),
50            &mut work_query,
51            &lwork,
52            &mut info,
53        );
54        if info != 0 {
55            return info;
56        }
57        let lwork = work_query as usize;
58
59        // Allocate memory for work array
60        let mut work: Vec<T> = match uninitialized_vec(lwork) {
61            Ok(work) => work,
62            Err(_) => return -1010,
63        };
64
65        if order == ColMajor {
66            // Call LAPACK function
67            func_(
68                &(jobu as _),
69                &(jobvt as _),
70                &(m as _),
71                &(n as _),
72                a,
73                &(lda as _),
74                s,
75                u,
76                &(ldu as _),
77                vt,
78                &(ldvt as _),
79                work.as_mut_ptr(),
80                &(lwork as _),
81                &mut info,
82            );
83            if info != 0 {
84                return info;
85            }
86        } else {
87            let lda_t = m.max(1);
88            let nrows_u = if jobu == 'A' || jobu == 'S' { m } else { 1 };
89            let ncols_u = if jobu == 'A' {
90                m
91            } else if jobu == 'S' {
92                m.min(n)
93            } else {
94                1
95            };
96            let nrows_vt = if jobvt == 'A' {
97                n
98            } else if jobvt == 'S' {
99                m.min(n)
100            } else {
101                1
102            };
103            let ldu_t = nrows_u.max(1);
104            let ldvt_t = nrows_vt.max(1);
105
106            // Transpose input matrices
107            let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
108                Ok(a_t) => a_t,
109                Err(_) => return -1011,
110            };
111            let a_slice = from_raw_parts_mut(a, m * lda);
112            let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
113            let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
114            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
115
116            let mut u_t = if jobu == 'A' || jobu == 'S' {
117                match uninitialized_vec(nrows_u * ncols_u) {
118                    Ok(u_t) => u_t,
119                    Err(_) => return -1011,
120                }
121            } else {
122                Vec::new()
123            };
124
125            let mut vt_t = if jobvt == 'A' || jobvt == 'S' {
126                match uninitialized_vec(nrows_vt * n) {
127                    Ok(vt_t) => vt_t,
128                    Err(_) => return -1011,
129                }
130            } else {
131                Vec::new()
132            };
133
134            // Call LAPACK function
135            func_(
136                &(jobu as _),
137                &(jobvt as _),
138                &(m as _),
139                &(n as _),
140                a_t.as_mut_ptr(),
141                &(lda_t as _),
142                s,
143                if jobu == 'A' || jobu == 'S' { u_t.as_mut_ptr() } else { u },
144                &(ldu_t as _),
145                if jobvt == 'A' || jobvt == 'S' { vt_t.as_mut_ptr() } else { vt },
146                &(ldvt_t as _),
147                work.as_mut_ptr(),
148                &(lwork as _),
149                &mut info,
150            );
151            if info != 0 {
152                return info;
153            }
154
155            // Transpose output matrices
156            orderchange_out_r2c_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
157
158            if jobu == 'A' || jobu == 'S' {
159                let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
160                let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
161                let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
162                orderchange_out_r2c_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
163            }
164
165            if jobvt == 'A' || jobvt == 'S' {
166                let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
167                let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
168                let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
169                orderchange_out_r2c_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
170            }
171        }
172
173        // Backup superb data
174        let min_mn = m.min(n);
175        for i in 0..min_mn - 1 {
176            superb.add(i).write(work[i + 1]);
177        }
178
179        return info;
180    }
181}
182
183#[duplicate_item(
184    T              func_   ;
185   [Complex<f32>] [cgesvd_];
186   [Complex<f64>] [zgesvd_];
187)]
188impl GESVDDriverAPI<T> for DeviceBLAS {
189    unsafe fn driver_gesvd(
190        order: FlagOrder,
191        jobu: char,
192        jobvt: char,
193        m: usize,
194        n: usize,
195        a: *mut T,
196        lda: usize,
197        s: *mut <T as ComplexFloat>::Real,
198        u: *mut T,
199        ldu: usize,
200        vt: *mut T,
201        ldvt: usize,
202        superb: *mut <T as ComplexFloat>::Real,
203    ) -> blas_int {
204        use lapack_ffi::lapack::func_;
205
206        // Allocate rwork
207        let min_mn = m.min(n);
208        let mut rwork: Vec<<T as ComplexFloat>::Real> = match uninitialized_vec(5 * min_mn) {
209            Ok(rwork) => rwork,
210            Err(_) => return -1010,
211        };
212
213        // Query optimal working array size
214        let mut info = 0;
215        let lwork = -1;
216        let mut work_query = <T as Zero>::zero();
217        func_(
218            &(jobu as _),
219            &(jobvt as _),
220            &(m as _),
221            &(n as _),
222            a as *mut _,
223            &(lda as _),
224            s as *mut _,
225            u as *mut _,
226            &(ldu as _),
227            vt as *mut _,
228            &(ldvt as _),
229            &mut work_query as *mut _ as *mut _,
230            &lwork,
231            rwork.as_mut_ptr() as *mut _,
232            &mut info,
233        );
234        if info != 0 {
235            return info;
236        }
237        let lwork = work_query.re() as usize;
238
239        // Allocate memory for work array
240        let mut work: Vec<T> = match uninitialized_vec(lwork) {
241            Ok(work) => work,
242            Err(_) => return -1010,
243        };
244
245        if order == ColMajor {
246            // Call LAPACK function
247            func_(
248                &(jobu as _),
249                &(jobvt as _),
250                &(m as _),
251                &(n as _),
252                a as *mut _,
253                &(lda as _),
254                s as *mut _,
255                u as *mut _,
256                &(ldu as _),
257                vt as *mut _,
258                &(ldvt as _),
259                work.as_mut_ptr() as *mut _,
260                &(lwork as _),
261                rwork.as_mut_ptr() as *mut _,
262                &mut info,
263            );
264            if info != 0 {
265                return info;
266            }
267        } else {
268            let lda_t = m.max(1);
269            let nrows_u = if jobu == 'A' || jobu == 'S' { m } else { 1 };
270            let ncols_u = if jobu == 'A' {
271                m
272            } else if jobu == 'S' {
273                m.min(n)
274            } else {
275                1
276            };
277            let nrows_vt = if jobvt == 'A' {
278                n
279            } else if jobvt == 'S' {
280                m.min(n)
281            } else {
282                1
283            };
284            let ldu_t = nrows_u.max(1);
285            let ldvt_t = nrows_vt.max(1);
286
287            // Transpose input matrices
288            let mut a_t: Vec<T> = match uninitialized_vec(m * n) {
289                Ok(a_t) => a_t,
290                Err(_) => return -1011,
291            };
292            let a_slice = from_raw_parts_mut(a, m * lda);
293            let la = Layout::new_unchecked([m, n], [lda as isize, 1], 0);
294            let la_t = Layout::new_unchecked([m, n], [1, lda_t as isize], 0);
295            orderchange_out_r2c_ix2_cpu_serial(&mut a_t, &la_t, a_slice, &la).unwrap();
296
297            let mut u_t = if jobu == 'A' || jobu == 'S' {
298                match uninitialized_vec(nrows_u * ncols_u) {
299                    Ok(u_t) => u_t,
300                    Err(_) => return -1011,
301                }
302            } else {
303                Vec::new()
304            };
305
306            let mut vt_t = if jobvt == 'A' || jobvt == 'S' {
307                match uninitialized_vec(nrows_vt * n) {
308                    Ok(vt_t) => vt_t,
309                    Err(_) => return -1011,
310                }
311            } else {
312                Vec::new()
313            };
314
315            // Call LAPACK function
316            func_(
317                &(jobu as _),
318                &(jobvt as _),
319                &(m as _),
320                &(n as _),
321                a_t.as_mut_ptr() as *mut _,
322                &(lda_t as _),
323                s as *mut _,
324                if jobu == 'A' || jobu == 'S' { u_t.as_mut_ptr() as *mut _ } else { u as *mut _ },
325                &(ldu_t as _),
326                if jobvt == 'A' || jobvt == 'S' { vt_t.as_mut_ptr() as *mut _ } else { vt as *mut _ },
327                &(ldvt_t as _),
328                work.as_mut_ptr() as *mut _,
329                &(lwork as _),
330                rwork.as_mut_ptr() as *mut _,
331                &mut info,
332            );
333            if info != 0 {
334                return info;
335            }
336
337            // Transpose output matrices
338            orderchange_out_r2c_ix2_cpu_serial(a_slice, &la, &a_t, &la_t).unwrap();
339
340            if jobu == 'A' || jobu == 'S' {
341                let u_slice = from_raw_parts_mut(u, nrows_u * ldu);
342                let lu = Layout::new_unchecked([nrows_u, ncols_u], [ldu as isize, 1], 0);
343                let lu_t = Layout::new_unchecked([nrows_u, ncols_u], [1, ldu_t as isize], 0);
344                orderchange_out_r2c_ix2_cpu_serial(u_slice, &lu, &u_t, &lu_t).unwrap();
345            }
346
347            if jobvt == 'A' || jobvt == 'S' {
348                let vt_slice = from_raw_parts_mut(vt, nrows_vt * ldvt);
349                let lvt = Layout::new_unchecked([nrows_vt, n], [ldvt as isize, 1], 0);
350                let lvt_t = Layout::new_unchecked([nrows_vt, n], [1, ldvt_t as isize], 0);
351                orderchange_out_r2c_ix2_cpu_serial(vt_slice, &lvt, &vt_t, &lvt_t).unwrap();
352            }
353        }
354
355        // Backup superb data
356        #[allow(clippy::needless_range_loop)]
357        for i in 0..min_mn - 1 {
358            superb.add(i).write(rwork[i]);
359        }
360
361        return info;
362    }
363}