ndarray-linalg 0.13.1

Linear algebra package for rust-ndarray using LAPACK
Documentation
use ndarray::*;
use ndarray_linalg::*;
use std::cmp::min;

fn test<T: Scalar + Lapack>(a: &Array2<T>) {
    let (n, m) = a.dim();
    let answer = a.clone();
    println!("a = \n{:?}", a);
    let (u, s, vt): (_, Array1<_>, _) = a.svd(true, true).unwrap();
    let u: Array2<_> = u.unwrap();
    let vt: Array2<_> = vt.unwrap();
    println!("u = \n{:?}", &u);
    println!("s = \n{:?}", &s);
    println!("v = \n{:?}", &vt);
    let mut sm = Array::<T, _>::zeros((n, m));
    for i in 0..min(n, m) {
        sm[(i, i)] = T::from(s[i]).unwrap();
    }
    assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, T::real(1e-7));
}

fn test_no_vt<T: Scalar + Lapack>(a: &Array2<T>) {
    let (n, _m) = a.dim();
    println!("a = \n{:?}", a);
    let (u, _s, vt): (_, Array1<_>, _) = a.svd(true, false).unwrap();
    assert!(u.is_some());
    assert!(vt.is_none());
    let u = u.unwrap();
    assert_eq!(u.dim().0, n);
    assert_eq!(u.dim().1, n);
}

fn test_no_u<T: Scalar + Lapack>(a: &Array2<T>) {
    let (_n, m) = a.dim();
    println!("a = \n{:?}", a);
    let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, true).unwrap();
    assert!(u.is_none());
    assert!(vt.is_some());
    let vt = vt.unwrap();
    assert_eq!(vt.dim().0, m);
    assert_eq!(vt.dim().1, m);
}

fn test_diag_only<T: Scalar + Lapack>(a: &Array2<T>) {
    println!("a = \n{:?}", a);
    let (u, _s, vt): (_, Array1<_>, _) = a.svd(false, false).unwrap();
    assert!(u.is_none());
    assert!(vt.is_none());
}

macro_rules! test_svd_impl {
    ($type:ty, $test:ident, $n:expr, $m:expr) => {
        paste::item! {
            #[test]
            fn [<svd_ $type _ $test _ $n x $m>]() {
                let a = random(($n, $m));
                $test::<$type>(&a);
            }

            #[test]
            fn [<svd_ $type _ $test _ $n x $m _t>]() {
                let a = random(($n, $m).f());
                $test::<$type>(&a);
            }
        }
    };
}

test_svd_impl!(f64, test, 3, 3);
test_svd_impl!(f64, test_no_vt, 3, 3);
test_svd_impl!(f64, test_no_u, 3, 3);
test_svd_impl!(f64, test_diag_only, 3, 3);
test_svd_impl!(f64, test, 4, 3);
test_svd_impl!(f64, test_no_vt, 4, 3);
test_svd_impl!(f64, test_no_u, 4, 3);
test_svd_impl!(f64, test_diag_only, 4, 3);
test_svd_impl!(f64, test, 3, 4);
test_svd_impl!(f64, test_no_vt, 3, 4);
test_svd_impl!(f64, test_no_u, 3, 4);
test_svd_impl!(f64, test_diag_only, 3, 4);
test_svd_impl!(c64, test, 3, 3);
test_svd_impl!(c64, test_no_vt, 3, 3);
test_svd_impl!(c64, test_no_u, 3, 3);
test_svd_impl!(c64, test_diag_only, 3, 3);
test_svd_impl!(c64, test, 4, 3);
test_svd_impl!(c64, test_no_vt, 4, 3);
test_svd_impl!(c64, test_no_u, 4, 3);
test_svd_impl!(c64, test_diag_only, 4, 3);
test_svd_impl!(c64, test, 3, 4);
test_svd_impl!(c64, test_no_vt, 3, 4);
test_svd_impl!(c64, test_no_u, 3, 4);
test_svd_impl!(c64, test_diag_only, 3, 4);