use crate::Df64;
use crate::col_piv_qr::ColPivQR;
use crate::numeric::CustomNumeric;
use mdarray::DTensor;
use nalgebra::{ComplexField, DMatrix, DVector, RealField};
use num_traits::{One, ToPrimitive, Zero};
#[derive(Debug, Clone)]
pub struct SVDResult<T> {
pub u: DMatrix<T>,
pub s: DVector<T>,
pub v: DMatrix<T>,
pub rank: usize,
}
#[derive(Debug, Clone)]
pub struct TSVDConfig<T> {
pub rtol: T,
}
impl<T> TSVDConfig<T> {
pub fn new(rtol: T) -> Self {
Self { rtol }
}
}
#[derive(Debug, thiserror::Error)]
pub enum TSVDError {
#[error("Matrix is empty")]
EmptyMatrix,
#[error("Invalid tolerance: {0}")]
InvalidTolerance(String),
}
#[inline]
fn get_epsilon_for_svd<T: RealField + Copy>() -> T {
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
unsafe { std::ptr::read(&f64::EPSILON as *const f64 as *const T) }
} else if TypeId::of::<T>() == TypeId::of::<crate::Df64>() {
unsafe { std::ptr::read(&crate::Df64::EPSILON as *const crate::Df64 as *const T) }
} else {
T::from_f64(1e-15).unwrap_or(T::one() * T::from_f64(1e-15).unwrap_or(T::one()))
}
}
pub fn svd_decompose<T>(matrix: &DMatrix<T>, rtol: f64) -> SVDResult<T>
where
T: ComplexField + RealField + Copy + nalgebra::RealField + ToPrimitive,
{
let eps = get_epsilon_for_svd::<T>();
let svd = matrix
.clone()
.try_svd(true, true, eps, 0)
.expect("SVD computation failed");
let u_matrix = svd.u.unwrap();
let s_vector = svd.singular_values; let v_t_matrix = svd.v_t.unwrap();
let rank = calculate_rank_from_vector(&s_vector, rtol);
let u = DMatrix::from(u_matrix.columns(0, rank));
let s = DVector::from(s_vector.rows(0, rank));
let v = DMatrix::from(v_t_matrix.rows(0, rank).transpose());
SVDResult { u, s, v, rank }
}
fn calculate_rank_from_vector<T>(singular_values: &DVector<T>, rtol: f64) -> usize
where
T: RealField + Copy + ToPrimitive,
{
if singular_values.is_empty() {
return 0;
}
let max_sv = singular_values[0];
let threshold = max_sv * T::from_f64(rtol).unwrap_or(T::zero());
let mut rank = 0;
for &sv in singular_values.iter() {
if sv > threshold {
rank += 1;
} else {
break;
}
}
rank
}
fn calculate_rank_from_r<T: RealField>(r_matrix: &DMatrix<T>, rtol: T) -> usize
where
T: ComplexField + RealField + Copy,
{
let dim = r_matrix.nrows().min(r_matrix.ncols());
let mut rank = dim;
let mut max_diag_abs = Zero::zero();
for i in 0..dim {
let diag_abs = ComplexField::abs(r_matrix[(i, i)]);
if diag_abs > max_diag_abs {
max_diag_abs = diag_abs;
}
}
if max_diag_abs == Zero::zero() {
return 0;
}
for i in 0..dim {
let diag_abs = ComplexField::abs(r_matrix[(i, i)]);
if diag_abs < rtol * max_diag_abs {
rank = i;
break;
}
}
rank
}
pub fn tsvd<T>(matrix: &DMatrix<T>, config: TSVDConfig<T>) -> Result<SVDResult<T>, TSVDError>
where
T: ComplexField
+ RealField
+ Copy
+ nalgebra::RealField
+ std::fmt::Debug
+ ToPrimitive
+ CustomNumeric,
{
let (m, n) = matrix.shape();
if m == 0 || n == 0 {
return Err(TSVDError::EmptyMatrix);
}
if config.rtol <= Zero::zero() || config.rtol >= One::one() {
return Err(TSVDError::InvalidTolerance(format!(
"Tolerance must be in (0, 1), got {:?}",
config.rtol
)));
}
let qr_rtol = Some(config.rtol.clone().modulus());
let qr = ColPivQR::new_with_rtol(matrix.clone(), qr_rtol);
let q_matrix = qr.q();
let r_matrix = qr.r();
let permutation = qr.p();
let qr_rank = calculate_rank_from_r(
&r_matrix,
T::from_f64_unchecked(2.0) * get_epsilon_for_svd::<T>(),
);
if qr_rank == 0 {
return Ok(SVDResult {
u: DMatrix::zeros(m, 0),
s: DVector::zeros(0),
v: DMatrix::zeros(n, 0),
rank: 0,
});
}
let r_truncated: DMatrix<T> = r_matrix.rows(0, qr_rank).into();
let rtol_t = config.rtol;
let rtol_f64 = rtol_t.to_f64();
let svd_result = svd_decompose(&r_truncated, rtol_f64);
if svd_result.rank == 0 {
return Ok(SVDResult {
u: DMatrix::zeros(m, 0),
s: DVector::zeros(0),
v: DMatrix::zeros(n, 0),
rank: 0,
});
}
let q_truncated: DMatrix<T> = q_matrix.columns(0, qr_rank).into();
let u_full = &q_truncated * &svd_result.u;
let mut v_full = svd_result.v.clone();
permutation.inv_permute_rows(&mut v_full);
let s_full = svd_result.s.clone();
Ok(SVDResult {
u: u_full,
s: s_full,
v: v_full,
rank: svd_result.rank,
})
}
pub fn tsvd_f64(matrix: &DMatrix<f64>, rtol: f64) -> Result<SVDResult<f64>, TSVDError> {
tsvd(matrix, TSVDConfig::new(rtol))
}
pub fn tsvd_df64(matrix: &DMatrix<Df64>, rtol: Df64) -> Result<SVDResult<Df64>, TSVDError> {
tsvd(matrix, TSVDConfig::new(rtol))
}
pub fn tsvd_df64_from_f64(matrix: &DMatrix<f64>, rtol: f64) -> Result<SVDResult<Df64>, TSVDError> {
let matrix_df64 = DMatrix::from_fn(matrix.nrows(), matrix.ncols(), |i, j| {
Df64::from(matrix[(i, j)])
});
let rtol_df64 = Df64::from(rtol);
tsvd(&matrix_df64, TSVDConfig::new(rtol_df64))
}
pub fn compute_svd_dtensor<T: CustomNumeric + 'static>(
matrix: &DTensor<T, 2>,
) -> (DTensor<T, 2>, Vec<T>, DTensor<T, 2>) {
use nalgebra::DMatrix;
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f64>() {
let matrix_f64 = DMatrix::from_fn(matrix.shape().0, matrix.shape().1, |i, j| {
CustomNumeric::to_f64(matrix[[i, j]])
});
let rtol = 2.0 * f64::EPSILON;
let result = tsvd(&matrix_f64, TSVDConfig::new(rtol)).expect("TSVD computation failed");
let u = DTensor::<T, 2>::from_fn([result.u.nrows(), result.u.ncols()], |idx| {
let [i, j] = [idx[0], idx[1]];
T::from_f64_unchecked(result.u[(i, j)])
});
let s: Vec<T> = result.s.iter().map(|x| T::from_f64_unchecked(*x)).collect();
let v = DTensor::<T, 2>::from_fn([result.v.nrows(), result.v.ncols()], |idx| {
let [i, j] = [idx[0], idx[1]];
T::from_f64_unchecked(result.v[(i, j)])
});
(u, s, v)
} else if TypeId::of::<T>() == TypeId::of::<Df64>() {
let matrix_df64: DMatrix<Df64> =
DMatrix::from_fn(matrix.shape().0, matrix.shape().1, |i, j| {
unsafe { std::mem::transmute_copy(&matrix[[i, j]]) }
});
let rtol = Df64::from(2.0) * Df64::epsilon();
let result = tsvd_df64(&matrix_df64, rtol).expect("TSVD computation failed");
let u = DTensor::<T, 2>::from_fn([result.u.nrows(), result.u.ncols()], |idx| {
let [i, j] = [idx[0], idx[1]];
T::convert_from(result.u[(i, j)])
});
let s: Vec<T> = result.s.iter().map(|x| T::convert_from(*x)).collect();
let v = DTensor::<T, 2>::from_fn([result.v.nrows(), result.v.ncols()], |idx| {
let [i, j] = [idx[0], idx[1]];
T::convert_from(result.v[(i, j)])
});
(u, s, v)
} else {
panic!("SVD is only implemented for f64 and Df64");
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::DMatrix;
use num_traits::cast::ToPrimitive;
#[test]
fn test_svd_identity_matrix() {
let matrix = DMatrix::<f64>::identity(3, 3);
let result = svd_decompose(&matrix, 1e-12);
assert_eq!(result.rank, 3);
assert_eq!(result.s.len(), 3);
assert_eq!(result.u.nrows(), 3);
assert_eq!(result.u.ncols(), 3);
assert_eq!(result.v.nrows(), 3);
assert_eq!(result.v.ncols(), 3);
}
#[test]
fn test_tsvd_identity_matrix() {
let matrix = DMatrix::<f64>::identity(3, 3);
let result = tsvd_f64(&matrix, 1e-12).unwrap();
assert_eq!(result.rank, 3);
assert_eq!(result.s.len(), 3);
}
#[test]
fn test_tsvd_rank_one() {
let matrix = DMatrix::<f64>::from_fn(3, 3, |i, j| (i + 1) as f64 * (j + 1) as f64);
let result = tsvd_f64(&matrix, 1e-12).unwrap();
assert_eq!(result.rank, 1);
}
#[test]
fn test_tsvd_empty_matrix() {
let matrix = DMatrix::<f64>::zeros(0, 0);
let result = tsvd_f64(&matrix, 1e-12);
assert!(matches!(result, Err(TSVDError::EmptyMatrix)));
}
fn create_hilbert_matrix_generic<T>(n: usize) -> DMatrix<T>
where
T: nalgebra::RealField + From<f64> + Copy + std::ops::Div<Output = T>,
{
DMatrix::from_fn(n, n, |i, j| {
T::one() / T::from((i + j + 1) as f64)
})
}
fn reconstruct_matrix_generic<T>(
u: &DMatrix<T>,
s: &nalgebra::DVector<T>,
v: &DMatrix<T>,
) -> DMatrix<T>
where
T: nalgebra::RealField + Copy,
{
u * &DMatrix::from_diagonal(s) * &v.transpose()
}
fn frobenius_norm_generic<T>(matrix: &DMatrix<T>) -> f64
where
T: nalgebra::RealField + Copy + ToPrimitive,
{
let mut sum = 0.0;
for i in 0..matrix.nrows() {
for j in 0..matrix.ncols() {
let val = matrix[(i, j)].to_f64().unwrap_or(0.0);
sum += val * val;
}
}
sum.sqrt()
}
fn test_hilbert_reconstruction_generic<T>(n: usize, rtol: f64, expected_max_error: f64)
where
T: nalgebra::RealField
+ From<f64>
+ Copy
+ ToPrimitive
+ std::fmt::Debug
+ crate::numeric::CustomNumeric,
{
let h = create_hilbert_matrix_generic::<T>(n);
let config = TSVDConfig::new(T::from(rtol));
let result = tsvd(&h, config).unwrap();
let h_reconstructed = reconstruct_matrix_generic(&result.u, &result.s, &result.v);
let error_matrix = &h - &h_reconstructed;
let error_norm = frobenius_norm_generic(&error_matrix);
let relative_error = error_norm / frobenius_norm_generic(&h);
assert!(
relative_error <= expected_max_error,
"Relative reconstruction error {} exceeds expected maximum {}",
relative_error,
expected_max_error
);
}
#[test]
fn test_hilbert_5x5_f64_reconstruction() {
test_hilbert_reconstruction_generic::<f64>(5, 1e-12, 1e-14);
}
#[test]
fn test_hilbert_5x5_df64_reconstruction() {
test_hilbert_reconstruction_generic::<Df64>(5, 1e-28, 1e-28);
}
#[test]
fn test_hilbert_10x10_f64_reconstruction() {
test_hilbert_reconstruction_generic::<f64>(10, 1e-12, 1e-12);
}
#[test]
fn test_hilbert_10x10_df64_reconstruction() {
test_hilbert_reconstruction_generic::<Df64>(10, 1e-28, 1e-30);
}
#[test]
fn test_hilbert_100x100_f64_reconstruction() {
test_hilbert_reconstruction_generic::<f64>(100, 1e-12, 1e-12);
}
#[test]
fn test_hilbert_100x100_df64_reconstruction() {
test_hilbert_reconstruction_generic::<Df64>(100, 1e-28, 1e-28);
}
}