1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
//! Implement SVD

use std::cmp::min;
use lapack::c::*;
use num_traits::Zero;

use error::LapackError;

pub trait ImplSVD: Sized {
    fn svd(layout: Layout,
           n: usize,
           m: usize,
           mut a: Vec<Self>)
           -> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError>;
}

macro_rules! impl_svd {
    ($scalar:ty, $gesvd:path) => {
impl ImplSVD for $scalar {
    fn svd(layout: Layout, n: usize, m: usize, mut a: Vec<Self>) -> Result<(Vec<Self>, Vec<Self>, Vec<Self>), LapackError> {
        let k = min(n, m);
        let n = n as i32;
        let m = m as i32;
        let lda = match layout {
            Layout::RowMajor => n,
            Layout::ColumnMajor => m,
        };
        let ldu = m;
        let ldvt = n;
        let mut u = vec![Self::zero(); (ldu * m) as usize];
        let mut vt = vec![Self::zero(); (ldvt * n) as usize];
        let mut s = vec![Self::zero(); n as usize];
        let mut superb = vec![Self::zero(); k-2];
        let info = $gesvd(layout, 'A' as u8, 'A' as u8, m, n, &mut a, lda, &mut s, &mut u, ldu, &mut vt, ldvt, &mut superb);
        if info == 0 {
            Ok((u, s, vt))
        } else {
            Err(From::from(info))
        }
    }
}
}} // end macro_rules

impl_svd!(f64, dgesvd);
impl_svd!(f32, sgesvd);