use crate::matrix::Matrix;
use crate::vector::Vector;
use crate::{to_i32, StrError, SVD_CODE_A};
extern "C" {
fn c_dgesvd(
jobu_code: i32,
jobvt_code: i32,
m: *const i32,
n: *const i32,
a: *mut f64,
lda: *const i32,
s: *mut f64,
u: *mut f64,
ldu: *const i32,
vt: *mut f64,
ldvt: *const i32,
work: *mut f64,
lwork: *const i32,
info: *mut i32,
);
}
pub fn mat_svd(s: &mut Vector, u: &mut Matrix, vt: &mut Matrix, a: &mut Matrix) -> Result<(), StrError> {
let (m, n) = a.dims();
let min_mn = if m < n { m } else { n };
if s.dim() != min_mn {
return Err("[s] must be a min(m,n) vector");
}
if u.nrow() != m || u.ncol() != m {
return Err("[u] must be an m-by-m square matrix");
}
if vt.nrow() != n || vt.ncol() != n {
return Err("[vt] must be an n-by-n square matrix");
}
let m_i32 = to_i32(m);
let n_i32 = to_i32(n);
let lda = m_i32;
let ldu = m_i32;
let ldvt = n_i32;
let mut info = 0;
unsafe {
let lwork = -1; let mut work = vec![0.0]; c_dgesvd(
SVD_CODE_A,
SVD_CODE_A,
&m_i32,
&n_i32,
a.as_mut_data().as_mut_ptr(),
&lda,
s.as_mut_data().as_mut_ptr(),
u.as_mut_data().as_mut_ptr(),
&ldu,
vt.as_mut_data().as_mut_ptr(),
&ldvt,
work.as_mut_ptr(),
&lwork,
&mut info,
);
let lwork = work[0] as i32;
let mut work = vec![0.0; lwork as usize];
c_dgesvd(
SVD_CODE_A,
SVD_CODE_A,
&m_i32,
&n_i32,
a.as_mut_data().as_mut_ptr(),
&lda,
s.as_mut_data().as_mut_ptr(),
u.as_mut_data().as_mut_ptr(),
&ldu,
vt.as_mut_data().as_mut_ptr(),
&ldvt,
work.as_mut_ptr(),
&lwork,
&mut info,
);
}
if info < 0 {
println!("LAPACK ERROR (dgesvd): Argument #{} had an illegal value", -info);
return Err("LAPACK ERROR (dgesvd): An argument had an illegal value");
} else if info > 0 {
println!("LAPACK ERROR (dgesvd): {} is the number of super-diagonals of an intermediate bi-diagonal form B which did not converge to zero",info);
return Err("LAPACK ERROR (dgesvd): Algorithm did not converge");
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{mat_svd, Matrix, Vector};
use crate::{mat_approx_eq, vec_approx_eq};
#[test]
fn mat_svd_fails_on_wrong_dims() {
let mut a = Matrix::new(3, 2);
let mut s = Vector::new(2);
let mut u = Matrix::new(3, 3);
let mut vt = Matrix::new(2, 2);
let mut s_3 = Vector::new(3);
let mut u_2x2 = Matrix::new(2, 2);
let mut u_3x2 = Matrix::new(3, 2);
let mut vt_3x3 = Matrix::new(3, 3);
let mut vt_2x3 = Matrix::new(2, 3);
assert_eq!(
mat_svd(&mut s_3, &mut u, &mut vt, &mut a),
Err("[s] must be a min(m,n) vector")
);
assert_eq!(
mat_svd(&mut s, &mut u_2x2, &mut vt, &mut a),
Err("[u] must be an m-by-m square matrix")
);
assert_eq!(
mat_svd(&mut s, &mut u_3x2, &mut vt, &mut a),
Err("[u] must be an m-by-m square matrix")
);
assert_eq!(
mat_svd(&mut s, &mut u, &mut vt_3x3, &mut a),
Err("[vt] must be an n-by-n square matrix")
);
assert_eq!(
mat_svd(&mut s, &mut u, &mut vt_2x3, &mut a),
Err("[vt] must be an n-by-n square matrix")
);
}
#[test]
fn mat_svd_4x3_works() {
let s33 = f64::sqrt(3.0) / 3.0;
#[rustfmt::skip]
let data = [
[-s33, -s33, 1.0],
[ s33, -s33, 1.0],
[-s33, s33, 1.0],
[ s33, s33, 1.0],
];
let mut a = Matrix::from(&data);
let a_copy = Matrix::from(&data);
let (m, n) = a.dims();
let min_mn = if m < n { m } else { n };
let mut s = Vector::new(min_mn);
let mut u = Matrix::new(m, m);
let mut vt = Matrix::new(n, n);
mat_svd(&mut s, &mut u, &mut vt, &mut a).unwrap();
#[rustfmt::skip]
let s_correct = &[
2.0,
2.0 / f64::sqrt(3.0),
2.0 / f64::sqrt(3.0),
];
vec_approx_eq(&s, s_correct, 1e-14);
let mut usv = Matrix::new(m, n);
for i in 0..m {
for j in 0..n {
for k in 0..min_mn {
usv.add(i, j, u.get(i, k) * s[k] * vt.get(k, j));
}
}
}
mat_approx_eq(&usv, &a_copy, 1e-14);
}
#[test]
fn mat_svd_2x4_works() {
#[rustfmt::skip]
let data = [
[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
];
let mut a = Matrix::from(&data);
let a_copy = Matrix::from(&data);
let (m, n) = a.dims();
let min_mn = if m < n { m } else { n };
let mut s = Vector::new(min_mn);
let mut u = Matrix::new(m, m);
let mut vt = Matrix::new(n, n);
mat_svd(&mut s, &mut u, &mut vt, &mut a).unwrap();
let sqrt2 = std::f64::consts::SQRT_2;
#[rustfmt::skip]
let s_correct = &[
sqrt2,
sqrt2,
];
vec_approx_eq(&s, s_correct, 1e-14);
let mut usv = Matrix::new(m, n);
for i in 0..m {
for j in 0..n {
for k in 0..min_mn {
usv.add(i, j, u.get(i, k) * s[k] * vt.get(k, j));
}
}
}
mat_approx_eq(&usv, &a_copy, 1e-14);
}
#[test]
fn mat_svd_1x4_works() {
#[rustfmt::skip]
let data = [
[0.25, 0.25, 0.25, 0.25],
];
let mut a = Matrix::from(&data);
let a_copy = Matrix::from(&data);
let (m, n) = a.dims();
let min_mn = if m < n { m } else { n };
let mut s = Vector::new(min_mn);
let mut u = Matrix::new(m, m);
let mut vt = Matrix::new(n, n);
mat_svd(&mut s, &mut u, &mut vt, &mut a).unwrap();
#[rustfmt::skip]
let s_correct = &[
0.5,
];
vec_approx_eq(&s, s_correct, 1e-14);
let mut usv = Matrix::new(m, n);
for i in 0..m {
for j in 0..n {
for k in 0..min_mn {
usv.add(i, j, u.get(i, k) * s[k] * vt.get(k, j));
}
}
}
mat_approx_eq(&usv, &a_copy, 1e-14);
}
}