use anofox_ml_core::{FitUnsupervised, InverseTransform, Result, RustMlError, Transform};
use faer::linalg::solvers::Svd;
use faer::Mat;
use ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct TruncatedSvd {
pub n_components: usize,
}
impl TruncatedSvd {
pub fn new(n_components: usize) -> Self {
Self { n_components }
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct FittedTruncatedSvd {
pub components: Array2<f64>,
pub singular_values: Array1<f64>,
pub explained_variance: Array1<f64>,
n_features: usize,
}
impl FittedTruncatedSvd {
pub fn n_components(&self) -> usize {
self.components.ncols()
}
}
impl FitUnsupervised<f64> for TruncatedSvd {
type Fitted = FittedTruncatedSvd;
fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
let (n, d) = x.dim();
if n == 0 || d == 0 {
return Err(RustMlError::EmptyInput("empty input".into()));
}
let k = self.n_components.min(d.min(n));
if k == 0 {
return Err(RustMlError::InvalidParameter(
"n_components must be at least 1".into(),
));
}
let m = Mat::<f64>::from_fn(n, d, |i, j| x[[i, j]]);
let svd = Svd::new(m.as_ref())
.map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
let v = svd.V(); let s = svd.S(); let sv_len = s.column_vector().nrows();
let mut components = Array2::<f64>::zeros((d, k));
let mut sv = Array1::<f64>::zeros(k);
for j in 0..k {
for i in 0..d {
components[[i, j]] = v[(i, j)];
}
sv[j] = if j < sv_len {
s.column_vector()[j]
} else {
0.0
};
}
let mut ev = Array1::<f64>::zeros(k);
let denom = (n as f64 - 1.0).max(1.0);
for j in 0..k {
ev[j] = sv[j] * sv[j] / denom;
}
Ok(FittedTruncatedSvd {
components,
singular_values: sv,
explained_variance: ev,
n_features: d,
})
}
}
impl Transform<f64> for FittedTruncatedSvd {
fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
if x.ncols() != self.n_features {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} features, got {}",
self.n_features,
x.ncols()
)));
}
Ok(x.dot(&self.components))
}
}
impl InverseTransform<f64> for FittedTruncatedSvd {
fn inverse_transform(&self, t: &Array2<f64>) -> Result<Array2<f64>> {
if t.ncols() != self.components.ncols() {
return Err(RustMlError::ShapeMismatch(format!(
"expected {} components, got {}",
self.components.ncols(),
t.ncols()
)));
}
Ok(t.dot(&self.components.t()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn test_truncated_svd_reduces_dim() {
let x = array![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[2.0, 3.0, 5.0]
];
let svd = TruncatedSvd::new(2).fit(&x).unwrap();
let t = svd.transform(&x).unwrap();
assert_eq!(t.shape(), &[4, 2]);
assert!(svd.singular_values[0] > svd.singular_values[1]);
}
#[test]
fn test_inverse_transform_reconstructs_full_rank() {
let x = array![
[1.0_f64, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[2.0, 3.0, 5.0]
];
let svd = TruncatedSvd::new(3).fit(&x).unwrap();
let t = svd.transform(&x).unwrap();
let back = svd.inverse_transform(&t).unwrap();
for ((i, j), &v) in x.indexed_iter() {
assert!(
(back[[i, j]] - v).abs() < 1e-9,
"[{},{}]: {} vs {}",
i,
j,
back[[i, j]],
v
);
}
}
}