rlst/dense/linalg/lapack/interface/
gesvd.rs

1//! Implementation of ?gesvd - SVD factorization
2
3use lapack::{cgesvd, dgesvd, sgesvd, zgesvd};
4
5use crate::base_types::{LapackError, c32, c64};
6use crate::{base_types::LapackResult, traits::rlst_num::RlstScalar};
7
8use crate::dense::linalg::lapack::interface::lapack_return;
9
10use num::{Zero, complex::ComplexFloat};
11
12/// JobU specifies the computation of the left singular vectors.
13#[derive(Clone, Copy)]
14#[repr(u8)]
15pub enum JobU {
16    /// Return all columns.
17    A = b'A',
18    /// Return the first `min(m, n)` columns.
19    S = b'S',
20    /// Do not compute U.
21    N = b'N',
22}
23
24/// JobVt specifies the computation of the right singular vectors.
25#[derive(Clone, Copy)]
26#[repr(u8)]
27pub enum JobVt {
28    /// Return all rows.
29    A = b'A',
30    /// Return the first `min(m, n)` rows.
31    S = b'S',
32    /// Do not compute Vt.
33    N = b'N',
34}
35
36/// ?gesvd - SVD factorization
37pub trait Gesvd: RlstScalar {
38    /// Perform a singular value decomposition (SVD) of a matrix `a` with dimensions `m` x `n`.
39    /// If either `jobu` or `jobvt` is `JobU::N` or `JobVt::N`, the corresponding singular vectors
40    /// are not computed, and the array u or correspondingly vt is not referenced and can be
41    /// `None`.
42    #[allow(clippy::too_many_arguments)]
43    fn gesvd(
44        jobu: JobU,
45        jobvt: JobVt,
46        m: usize,
47        n: usize,
48        a: &mut [Self],
49        lda: usize,
50        s: &mut [Self::Real],
51        u: Option<&mut [Self]>,
52        ldu: usize,
53        vt: Option<&mut [Self]>,
54        ldvt: usize,
55    ) -> LapackResult<()>;
56}
57
58macro_rules! implement_gesvd {
59    ($scalar:ty, $gesvd:expr) => {
60        impl Gesvd for $scalar {
61            fn gesvd(
62                jobu: JobU,
63                jobvt: JobVt,
64                m: usize,
65                n: usize,
66                a: &mut [Self],
67                lda: usize,
68                s: &mut [Self::Real],
69                u: Option<&mut [Self]>,
70                ldu: usize,
71                vt: Option<&mut [Self]>,
72                ldvt: usize,
73            ) -> LapackResult<()> {
74                assert_eq!(
75                    a.len(),
76                    lda * n,
77                    "Require `a.len()` {} == `lda * n` {}.",
78                    a.len(),
79                    lda * n
80                );
81
82                assert!(
83                    lda >= std::cmp::max(1, m),
84                    "Require `lda` {} >= `max(1, m)` {}.",
85                    lda,
86                    std::cmp::max(1, m)
87                );
88
89                let k = std::cmp::min(m, n);
90
91                assert_eq!(
92                    s.len(),
93                    k,
94                    "Require `s.len()` {} == `min(m, n)` {}.",
95                    s.len(),
96                    k,
97                );
98
99                let mut info = 0;
100
101                let mut u_temp = Vec::<$scalar>::new();
102                let mut vt_temp = Vec::<$scalar>::new();
103
104                let u = match jobu {
105                    JobU::A => {
106                        let u = u.expect("JobU::A requires u to be Some");
107                        assert_eq!(
108                            u.len(),
109                            ldu * m,
110                            "Require `u.len()` {} == `ldu * m` {}.",
111                            u.len(),
112                            ldu * m
113                        );
114
115                        assert!(
116                            ldu >= std::cmp::max(1, m),
117                            "Require `ldu` {} >= `max(1, m)` {}.",
118                            ldu,
119                            std::cmp::max(1, m)
120                        );
121                        u
122                    }
123                    JobU::S => {
124                        let u = u.expect("JobU::S requires u to be Some");
125                        assert_eq!(
126                            u.len(),
127                            ldu * k,
128                            "Require `u.len()` {} == `ldu * min(m, n)` {}.",
129                            u.len(),
130                            ldu * k
131                        );
132                        assert!(
133                            ldu >= std::cmp::max(1, m),
134                            "Require `ldu` {} >= `max(1, m)` {}.",
135                            ldu,
136                            std::cmp::max(1, m)
137                        );
138                        u
139                    }
140                    JobU::N => {
141                        assert!(ldu >= 1, "Require `ldu` {} >= 1.", ldu);
142                        u_temp.as_mut_slice()
143                    }
144                };
145
146                let vt = match jobvt {
147                    JobVt::A => {
148                        let vt = vt.expect("JobVt::A requires vt to be Some");
149                        assert_eq!(
150                            vt.len(),
151                            ldvt * n,
152                            "Require `u.len()` {} == `ldvt * n` {}.",
153                            vt.len(),
154                            ldvt * n
155                        );
156
157                        assert!(
158                            ldvt >= std::cmp::max(1, n),
159                            "Require `ldvt` {} >= `max(1, n)` {}.",
160                            ldvt,
161                            std::cmp::max(1, k)
162                        );
163                        vt
164                    }
165                    JobVt::S => {
166                        let vt = vt.expect("JobVt::S requires u to be Some");
167                        assert_eq!(
168                            vt.len(),
169                            ldvt * n,
170                            "Require `vt.len()` {} == `ldvt * n` {}.",
171                            vt.len(),
172                            ldvt * n
173                        );
174                        assert!(
175                            ldvt >= std::cmp::max(1, k),
176                            "Require `ldvt` {} >= `max(1, min(m, n))` {}.",
177                            ldvt,
178                            std::cmp::max(1, k)
179                        );
180                        vt
181                    }
182                    JobVt::N => {
183                        assert!(ldvt >= 1, "Require `ldvt` {} >= 1.", ldvt);
184                        vt_temp.as_mut_slice()
185                    }
186                };
187
188                let mut work = vec![<$scalar>::zero(); 1];
189
190                unsafe {
191                    $gesvd(
192                        jobu as u8,
193                        jobvt as u8,
194                        m as i32,
195                        n as i32,
196                        a,
197                        lda as i32,
198                        s,
199                        u,
200                        ldu as i32,
201                        vt,
202                        ldvt as i32,
203                        &mut work,
204                        -1,
205                        &mut info,
206                    );
207                }
208
209                if info != 0 {
210                    return Err(LapackError::LapackInfoCode(info));
211                }
212
213                let lwork = work[0].re() as i32;
214
215                let mut work = vec![<$scalar>::zero(); lwork as usize];
216
217                unsafe {
218                    $gesvd(
219                        jobu as u8,
220                        jobvt as u8,
221                        m as i32,
222                        n as i32,
223                        a,
224                        lda as i32,
225                        s,
226                        u,
227                        ldu as i32,
228                        vt,
229                        ldvt as i32,
230                        &mut work,
231                        lwork,
232                        &mut info,
233                    );
234                }
235
236                lapack_return(info, ())
237            }
238        }
239    };
240}
241
242macro_rules! implement_gesvd_complex {
243    ($scalar:ty, $gesvd:expr) => {
244        impl Gesvd for $scalar {
245            fn gesvd(
246                jobu: JobU,
247                jobvt: JobVt,
248                m: usize,
249                n: usize,
250                a: &mut [Self],
251                lda: usize,
252                s: &mut [Self::Real],
253                u: Option<&mut [Self]>,
254                ldu: usize,
255                vt: Option<&mut [Self]>,
256                ldvt: usize,
257            ) -> LapackResult<()> {
258                assert_eq!(
259                    a.len(),
260                    lda * n,
261                    "Require `a.len()` {} == `lda * n` {}.",
262                    a.len(),
263                    lda * n
264                );
265
266                assert!(
267                    lda >= std::cmp::max(1, m),
268                    "Require `lda` {} >= `max(1, m)` {}.",
269                    lda,
270                    std::cmp::max(1, m)
271                );
272
273                let k = std::cmp::min(m, n);
274
275                assert_eq!(
276                    s.len(),
277                    k,
278                    "Require `s.len()` {} == `min(m, n)` {}.",
279                    s.len(),
280                    k,
281                );
282
283                let mut info = 0;
284
285                let mut rwork = vec![<<$scalar as RlstScalar>::Real as Zero>::zero(); 5 * k];
286
287                let mut u_temp = Vec::<$scalar>::new();
288                let mut vt_temp = Vec::<$scalar>::new();
289
290                let u = match jobu {
291                    JobU::A => {
292                        let u = u.expect("JobU::A requires u to be Some");
293                        assert_eq!(
294                            u.len(),
295                            ldu * m,
296                            "Require `u.len()` {} == `ldu * m` {}.",
297                            u.len(),
298                            ldu * m
299                        );
300
301                        assert!(
302                            ldu >= std::cmp::max(1, m),
303                            "Require `ldu` {} >= `max(1, m)` {}.",
304                            ldu,
305                            std::cmp::max(1, m)
306                        );
307                        u
308                    }
309                    JobU::S => {
310                        let u = u.expect("JobU::S requires u to be Some");
311                        assert_eq!(
312                            u.len(),
313                            ldu * k,
314                            "Require `u.len()` {} == `ldu * min(m, n)` {}.",
315                            u.len(),
316                            ldu * k
317                        );
318                        assert!(
319                            ldu >= std::cmp::max(1, m),
320                            "Require `ldu` {} >= `max(1, m)` {}.",
321                            ldu,
322                            std::cmp::max(1, m)
323                        );
324                        u
325                    }
326                    JobU::N => {
327                        assert!(ldu >= 1, "Require `ldu` {} >= 1.", ldu);
328                        u_temp.as_mut_slice()
329                    }
330                };
331
332                let vt = match jobvt {
333                    JobVt::A => {
334                        let vt = vt.expect("JobVt::A requires vt to be Some");
335                        assert_eq!(
336                            vt.len(),
337                            ldvt * n,
338                            "Require `u.len()` {} == `ldvt * n` {}.",
339                            vt.len(),
340                            ldvt * n
341                        );
342
343                        assert!(
344                            ldvt >= std::cmp::max(1, n),
345                            "Require `ldvt` {} >= `max(1, n)` {}.",
346                            ldvt,
347                            std::cmp::max(1, k)
348                        );
349                        vt
350                    }
351                    JobVt::S => {
352                        let vt = vt.expect("JobVt::S requires u to be Some");
353                        assert_eq!(
354                            vt.len(),
355                            ldvt * n,
356                            "Require `vt.len()` {} == `ldvt * n` {}.",
357                            vt.len(),
358                            ldvt * n
359                        );
360                        assert!(
361                            ldvt >= std::cmp::max(1, k),
362                            "Require `ldvt` {} >= `max(1, min(m, n))` {}.",
363                            ldvt,
364                            std::cmp::max(1, k)
365                        );
366                        vt
367                    }
368                    JobVt::N => {
369                        assert!(ldvt >= 1, "Require `ldvt` {} >= 1.", ldvt);
370                        vt_temp.as_mut_slice()
371                    }
372                };
373
374                let mut work = vec![<$scalar>::zero(); 1];
375
376                unsafe {
377                    $gesvd(
378                        jobu as u8,
379                        jobvt as u8,
380                        m as i32,
381                        n as i32,
382                        a,
383                        lda as i32,
384                        s,
385                        u,
386                        ldu as i32,
387                        vt,
388                        ldvt as i32,
389                        &mut work,
390                        -1,
391                        &mut rwork,
392                        &mut info,
393                    );
394                }
395
396                if info != 0 {
397                    return Err(LapackError::LapackInfoCode(info));
398                }
399
400                let lwork = work[0].re() as i32;
401
402                let mut work = vec![<$scalar>::zero(); lwork as usize];
403
404                unsafe {
405                    $gesvd(
406                        jobu as u8,
407                        jobvt as u8,
408                        m as i32,
409                        n as i32,
410                        a,
411                        lda as i32,
412                        s,
413                        u,
414                        ldu as i32,
415                        vt,
416                        ldvt as i32,
417                        &mut work,
418                        lwork,
419                        &mut rwork,
420                        &mut info,
421                    );
422                }
423
424                lapack_return(info, ())
425            }
426        }
427    };
428}
429
430implement_gesvd!(f32, sgesvd);
431implement_gesvd!(f64, dgesvd);
432implement_gesvd_complex!(c32, cgesvd);
433implement_gesvd_complex!(c64, zgesvd);