use crate::matrix::DenseMatrix;
use crate::Scalar;
use faer::linalg::solvers::{Svd, ThinSvd};
use faer::{ComplexField, Conjugate, Entity, Mat, RealField, SimpleEntity};
use numra_core::LinalgError;
pub struct SvdDecomposition<S: Scalar + Entity> {
svd: Svd<S>,
m: usize,
n: usize,
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> SvdDecomposition<S>
where
S::Real: RealField,
{
pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
let m = matrix.rows();
let n = matrix.cols();
let svd = Svd::new(matrix.as_faer());
Ok(Self { svd, m, n })
}
pub fn u(&self) -> DenseMatrix<S> {
DenseMatrix::from_faer(self.svd.u().to_owned())
}
pub fn v(&self) -> DenseMatrix<S> {
DenseMatrix::from_faer(self.svd.v().to_owned())
}
pub fn singular_values(&self) -> Vec<S> {
let s = self.svd.s_diagonal();
(0..s.nrows()).map(|i| s.read(i)).collect()
}
pub fn pseudoinverse(&self) -> DenseMatrix<S> {
DenseMatrix::from_faer(self.svd.pseudoinverse())
}
pub fn rank(&self, tol: S) -> usize {
let s = self.svd.s_diagonal();
(0..s.nrows()).filter(|&i| s.read(i).abs() > tol).count()
}
pub fn cond(&self) -> S {
let s = self.svd.s_diagonal();
let k = s.nrows();
if k == 0 {
return S::ZERO;
}
let s_max = s.read(0).abs();
let s_min = s.read(k - 1).abs();
if s_min == S::ZERO {
return S::INFINITY;
}
s_max / s_min
}
pub fn nrows(&self) -> usize {
self.m
}
pub fn ncols(&self) -> usize {
self.n
}
}
pub struct ThinSvdDecomposition<S: Scalar + Entity> {
svd: ThinSvd<S>,
m: usize,
n: usize,
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> ThinSvdDecomposition<S>
where
S::Real: RealField,
{
pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
let m = matrix.rows();
let n = matrix.cols();
let svd = ThinSvd::new(matrix.as_faer());
Ok(Self { svd, m, n })
}
pub fn u(&self) -> DenseMatrix<S> {
DenseMatrix::from_faer(self.svd.u().to_owned())
}
pub fn v(&self) -> DenseMatrix<S> {
DenseMatrix::from_faer(self.svd.v().to_owned())
}
pub fn singular_values(&self) -> Vec<S> {
let s = self.svd.s_diagonal();
(0..s.nrows()).map(|i| s.read(i)).collect()
}
pub fn pseudoinverse(&self) -> DenseMatrix<S> {
DenseMatrix::from_faer(self.svd.pseudoinverse())
}
pub fn rank(&self, tol: S) -> usize {
let s = self.svd.s_diagonal();
(0..s.nrows()).filter(|&i| s.read(i).abs() > tol).count()
}
pub fn cond(&self) -> S {
let s = self.svd.s_diagonal();
let k = s.nrows();
if k == 0 {
return S::ZERO;
}
let s_max = s.read(0).abs();
let s_min = s.read(k - 1).abs();
if s_min == S::ZERO {
return S::INFINITY;
}
s_max / s_min
}
pub fn nrows(&self) -> usize {
self.m
}
pub fn ncols(&self) -> usize {
self.n
}
}
impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> DenseMatrix<S>
where
S::Real: RealField,
{
pub fn svd(&self) -> Result<SvdDecomposition<S>, LinalgError> {
SvdDecomposition::new(self)
}
pub fn thin_svd(&self) -> Result<ThinSvdDecomposition<S>, LinalgError> {
ThinSvdDecomposition::new(self)
}
pub fn singular_values(&self) -> Vec<S> {
let svd = ThinSvd::new(self.as_faer());
let s = svd.s_diagonal();
(0..s.nrows()).map(|i| s.read(i)).collect()
}
pub fn pinv(&self) -> Result<DenseMatrix<S>, LinalgError> {
let svd = ThinSvdDecomposition::new(self)?;
Ok(svd.pseudoinverse())
}
pub fn cond(&self) -> S {
let svd = ThinSvd::new(self.as_faer());
let s = svd.s_diagonal();
let k = s.nrows();
if k == 0 {
return S::ZERO;
}
let s_max = s.read(0).abs();
let s_min = s.read(k - 1).abs();
if s_min == S::ZERO {
return S::INFINITY;
}
s_max / s_min
}
pub fn rank(&self, tol: S) -> usize {
let svd = ThinSvd::new(self.as_faer());
let s = svd.s_diagonal();
(0..s.nrows()).filter(|&i| s.read(i).abs() > tol).count()
}
pub fn lstsq(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
if b.len() != self.rows() {
return Err(LinalgError::DimensionMismatch {
expected: (self.rows(), 1),
actual: (b.len(), 1),
});
}
let pinv = self.pinv()?;
let mut b_mat = Mat::zeros(self.rows(), 1);
for (i, &val) in b.iter().enumerate() {
b_mat.write(i, 0, val);
}
let pinv_ref = pinv.as_faer();
let result = pinv_ref * b_mat.as_ref();
let x: Vec<S> = (0..self.cols()).map(|i| result.read(i, 0)).collect();
Ok(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Matrix;
#[test]
fn test_svd_diagonal() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
a.set(0, 0, 3.0);
a.set(1, 1, 1.0);
a.set(2, 2, 2.0);
let svd = SvdDecomposition::new(&a).unwrap();
let s = svd.singular_values();
assert!((s[0] - 3.0).abs() < 1e-10);
assert!((s[1] - 2.0).abs() < 1e-10);
assert!((s[2] - 1.0).abs() < 1e-10);
}
#[test]
fn test_svd_rectangular_reconstruction() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
a.set(0, 0, 1.0);
a.set(0, 1, 2.0);
a.set(1, 0, 3.0);
a.set(1, 1, 4.0);
a.set(2, 0, 5.0);
a.set(2, 1, 6.0);
let svd = SvdDecomposition::new(&a).unwrap();
let u = svd.u();
let v = svd.v();
let s = svd.singular_values();
let m = a.rows();
let n = a.cols();
let k = s.len();
for i in 0..m {
for j in 0..n {
let mut val = 0.0;
for p in 0..k {
val += u.get(i, p) * s[p] * v.get(j, p);
}
assert!(
(val - a.get(i, j)).abs() < 1e-10,
"Reconstruction failed at ({}, {}): {} vs {}",
i,
j,
val,
a.get(i, j)
);
}
}
}
#[test]
fn test_pseudoinverse() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
a.set(0, 0, 1.0);
a.set(0, 1, 2.0);
a.set(1, 0, 3.0);
a.set(1, 1, 4.0);
a.set(2, 0, 5.0);
a.set(2, 1, 6.0);
let svd = SvdDecomposition::new(&a).unwrap();
let pinv = svd.pseudoinverse();
let m = a.rows();
let n = a.cols();
assert_eq!(pinv.rows(), n); assert_eq!(pinv.cols(), m);
let mut a_pinv: DenseMatrix<f64> = DenseMatrix::zeros(m, m);
for i in 0..m {
for j in 0..m {
let mut val = 0.0;
for k in 0..n {
val += a.get(i, k) * pinv.get(k, j);
}
a_pinv.set(i, j, val);
}
}
for i in 0..m {
for j in 0..n {
let mut val = 0.0;
for k in 0..m {
val += a_pinv.get(i, k) * a.get(k, j);
}
assert!(
(val - a.get(i, j)).abs() < 1e-10,
"A * pinv(A) * A != A at ({}, {}): {} vs {}",
i,
j,
val,
a.get(i, j)
);
}
}
}
#[test]
fn test_condition_number_identity() {
let a: DenseMatrix<f64> = DenseMatrix::identity(4);
let svd = SvdDecomposition::new(&a).unwrap();
assert!((svd.cond() - 1.0).abs() < 1e-10);
}
#[test]
fn test_rank_deficient() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
a.set(0, 0, 1.0);
a.set(0, 1, 2.0);
a.set(1, 0, 2.0);
a.set(1, 1, 4.0);
let svd = SvdDecomposition::new(&a).unwrap();
assert_eq!(svd.rank(1e-10), 1);
}
#[test]
fn test_lstsq() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(3, 2);
a.set(0, 0, 1.0);
a.set(0, 1, 1.0);
a.set(1, 0, 1.0);
a.set(1, 1, 2.0);
a.set(2, 0, 1.0);
a.set(2, 1, 3.0);
let b = vec![1.0, 2.0, 2.0];
let x = a.lstsq(&b).unwrap();
assert!((x[0] - 2.0 / 3.0).abs() < 1e-10);
assert!((x[1] - 0.5).abs() < 1e-10);
}
#[test]
fn test_thin_svd() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(4, 2);
a.set(0, 0, 1.0);
a.set(0, 1, 0.0);
a.set(1, 0, 0.0);
a.set(1, 1, 2.0);
a.set(2, 0, 0.0);
a.set(2, 1, 0.0);
a.set(3, 0, 0.0);
a.set(3, 1, 0.0);
let svd = ThinSvdDecomposition::new(&a).unwrap();
let u = svd.u();
let v = svd.v();
let s = svd.singular_values();
assert_eq!(u.rows(), 4);
assert_eq!(u.cols(), 2);
assert_eq!(v.rows(), 2);
assert_eq!(v.cols(), 2);
assert!((s[0] - 2.0).abs() < 1e-10);
assert!((s[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_convenience_methods() {
let mut a: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
a.set(0, 0, 3.0);
a.set(1, 1, 1.0);
let s = a.singular_values();
assert!((s[0] - 3.0).abs() < 1e-10);
assert!((s[1] - 1.0).abs() < 1e-10);
assert!((a.cond() - 3.0).abs() < 1e-10);
assert_eq!(a.rank(1e-10), 2);
}
#[test]
fn test_svd_f32() {
let mut a: DenseMatrix<f32> = DenseMatrix::zeros(2, 2);
a.set(0, 0, 3.0);
a.set(0, 1, 0.0);
a.set(1, 0, 0.0);
a.set(1, 1, 2.0);
let svd = SvdDecomposition::new(&a).unwrap();
let s = svd.singular_values();
assert!((s[0] - 3.0).abs() < 1e-5);
assert!((s[1] - 2.0).abs() < 1e-5);
}
}