use crate::tensor::DenseTensor;
use crate::tensor::TensorBase;
use crate::tensor::TensorError;
pub fn svd_decompose(
tensor: &DenseTensor,
k: Option<usize>,
) -> Result<(DenseTensor, DenseTensor, DenseTensor), TensorError> {
let shape = tensor.shape();
if shape.len() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: shape.len(),
});
}
let (m, n) = (shape[0], shape[1]);
let min_dim = std::cmp::min(m, n);
let k = k.unwrap_or(min_dim);
if k > min_dim {
return Err(TensorError::ShapeMismatch {
expected: vec![min_dim],
got: vec![k],
});
}
let data = tensor.data();
let mut ata = vec![0.0; n * n];
for i in 0..n {
for j in i..n {
let mut sum = 0.0;
for l in 0..m {
sum += data[l * n + i] * data[l * n + j];
}
ata[i * n + j] = sum;
ata[j * n + i] = sum; }
}
let mut v = vec![0.0; n * n];
for i in 0..n {
v[i * n + i] = 1.0;
}
let max_iter = 50;
let tol = 1e-10;
for _ in 0..max_iter {
let mut converged = true;
for p in 0..n {
for q in (p + 1)..n {
let app = ata[p * n + p];
let aqq = ata[q * n + q];
let apq = ata[p * n + q];
if apq.abs() < tol * (app * aqq).sqrt() {
continue;
}
converged = false;
let tau = (aqq - app) / (2.0 * apq);
let t = if tau >= 0.0 {
1.0 / (tau + (1.0 + tau * tau).sqrt())
} else {
-1.0 / (-tau + (1.0 + tau * tau).sqrt())
};
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
let new_app = c * c * app - 2.0 * s * c * apq + s * s * aqq;
let new_aqq = s * s * app + 2.0 * s * c * apq + c * c * aqq;
let new_apq = (c * c - s * s) * apq + s * c * (app - aqq);
ata[p * n + p] = new_app;
ata[q * n + q] = new_aqq;
ata[p * n + q] = new_apq;
ata[q * n + p] = new_apq;
for r in 0..n {
if r != p && r != q {
let apr = ata[p * n + r];
let aqr = ata[q * n + r];
ata[p * n + r] = c * apr - s * aqr;
ata[r * n + p] = ata[p * n + r];
ata[q * n + r] = s * apr + c * aqr;
ata[r * n + q] = ata[q * n + r];
}
}
for i in 0..n {
let vip = v[i * n + p];
let viq = v[i * n + q];
v[i * n + p] = c * vip - s * viq;
v[i * n + q] = s * vip + c * viq;
}
}
}
if converged {
break;
}
}
let mut s = vec![0.0; n];
for i in 0..n {
s[i] = if ata[i * n + i] > 0.0 { ata[i * n + i].sqrt() } else { 0.0 };
}
for i in 0..n {
for j in (i + 1)..n {
if s[j] > s[i] {
s.swap(i, j);
for row in 0..n {
v.swap(row * n + i, row * n + j);
}
}
}
}
let mut u_data = vec![0.0; m * k];
for i in 0..k {
if s[i] < 1e-10 {
for j in 0..m {
u_data[j * k + i] = 0.0;
}
continue;
}
for j in 0..m {
let mut sum = 0.0;
for l in 0..n {
sum += data[j * n + l] * v[l * n + i];
}
u_data[j * k + i] = sum / s[i];
}
}
let u_tensor = DenseTensor::from_vec(u_data, vec![m, k]);
let s_tensor = DenseTensor::from_vec(s[..k].to_vec(), vec![k]);
let v_data: Vec<f64> = v.iter().take(n * k).cloned().collect();
let v_tensor = DenseTensor::from_vec(v_data, vec![n, k]);
Ok((u_tensor, s_tensor, v_tensor))
}
pub fn low_rank_approx(tensor: &DenseTensor, rank: usize) -> Result<DenseTensor, TensorError> {
let (u, s, v) = svd_decompose(tensor, Some(rank))?;
let u_data = u.data();
let s_data = s.data();
let v_data = v.data();
let (m, k) = (u.shape()[0], u.shape()[1]);
let n = v.shape()[0];
let mut result = vec![0.0; m * n];
let mut us = vec![0.0; m * k];
for i in 0..m {
for j in 0..k {
us[i * k + j] = u_data[i * k + j] * s_data[j];
}
}
for i in 0..m {
for j in 0..n {
for l in 0..k {
result[i * n + j] += us[i * k + l] * v_data[j * k + l];
}
}
}
Ok(DenseTensor::from_vec(result, vec![m, n]))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_svd_decomposition() {
let tensor = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![3, 2],
);
let (u, s, v) = svd_decompose(&tensor, None).unwrap();
assert_eq!(u.shape(), &[3, 2]);
assert_eq!(s.shape(), &[2]);
assert_eq!(v.shape(), &[2, 2]);
let s_data = s.data();
assert!(s_data[0] > 0.0);
assert!(s_data[1] > 0.0);
assert!(s_data[0] >= s_data[1]);
}
#[test]
fn test_low_rank_approx() {
let original = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
vec![4, 2],
);
let approx = low_rank_approx(&original, 1).unwrap();
assert_eq!(approx.shape(), &[4, 2]);
let orig_data = original.data();
let approx_data = approx.data();
let mse: f64 = orig_data
.iter()
.zip(approx_data.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
/ orig_data.len() as f64;
assert!(mse < 1.0);
}
#[test]
fn test_svd_reconstruction() {
let original = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![3, 2],
);
let (u, s, v) = svd_decompose(&original, None).unwrap();
let u_data = u.data();
let s_data = s.data();
let v_data = v.data();
let (m, k) = (u.shape()[0], u.shape()[1]);
let n = v.shape()[0];
let mut reconstructed = vec![0.0; m * n];
for i in 0..m {
for j in 0..n {
for l in 0..k {
reconstructed[i * n + j] += u_data[i * k + l] * s_data[l] * v_data[j * k + l];
}
}
}
let orig_data = original.data();
for (a, b) in orig_data.iter().zip(reconstructed.iter()) {
assert!((a - b).abs() < 1e-5, "Reconstruction failed: {} vs {}", a, b);
}
}
#[test]
fn test_svd_orthogonality() {
let tensor = DenseTensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
vec![4, 2],
);
let (u, _s, v) = svd_decompose(&tensor, None).unwrap();
let u_data = u.data();
let (m, k) = (u.shape()[0], u.shape()[1]);
for i in 0..k {
for j in 0..k {
let mut dot = 0.0;
for l in 0..m {
dot += u_data[l * k + i] * u_data[l * k + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!((dot - expected).abs() < 1e-5, "U orthogonality failed at ({}, {})", i, j);
}
}
let v_data = v.data();
let n = v.shape()[0];
for i in 0..k {
for j in 0..k {
let mut dot = 0.0;
for l in 0..n {
dot += v_data[l * k + i] * v_data[l * k + j];
}
let expected = if i == j { 1.0 } else { 0.0 };
assert!((dot - expected).abs() < 1e-5, "V orthogonality failed at ({}, {})", i, j);
}
}
}
}