Skip to main content

anofox_ml_preprocessing/
truncated_svd.rs

1//! Truncated SVD (a.k.a. LSA when applied to a term-document matrix).
2//!
3//! Mirrors `sklearn.decomposition.TruncatedSVD`. Unlike PCA, the data is not
4//! centered (so it can be applied to sparse inputs without densifying).
5//!
6//! Decomposes `X ≈ U Σ Vᵀ` and keeps the top `n_components` singular triplets.
7//! The transform is `X V_k`, of shape `(n_samples, n_components)`.
8
9use anofox_ml_core::{FitUnsupervised, InverseTransform, Result, RustMlError, Transform};
10use faer::linalg::solvers::Svd;
11use faer::Mat;
12use ndarray::{Array1, Array2};
13
14#[derive(Debug, Clone)]
15pub struct TruncatedSvd {
16    pub n_components: usize,
17}
18
19impl TruncatedSvd {
20    pub fn new(n_components: usize) -> Self {
21        Self { n_components }
22    }
23}
24
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
26pub struct FittedTruncatedSvd {
27    /// Top-`k` right-singular vectors, shape (n_features, k).
28    pub components: Array2<f64>,
29    /// Top-`k` singular values.
30    pub singular_values: Array1<f64>,
31    /// Explained variance per component.
32    pub explained_variance: Array1<f64>,
33    n_features: usize,
34}
35
36impl FittedTruncatedSvd {
37    pub fn n_components(&self) -> usize {
38        self.components.ncols()
39    }
40}
41
42impl FitUnsupervised<f64> for TruncatedSvd {
43    type Fitted = FittedTruncatedSvd;
44
45    fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
46        let (n, d) = x.dim();
47        if n == 0 || d == 0 {
48            return Err(RustMlError::EmptyInput("empty input".into()));
49        }
50        let k = self.n_components.min(d.min(n));
51        if k == 0 {
52            return Err(RustMlError::InvalidParameter(
53                "n_components must be at least 1".into(),
54            ));
55        }
56
57        let m = Mat::<f64>::from_fn(n, d, |i, j| x[[i, j]]);
58        let svd = Svd::new(m.as_ref())
59            .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
60        let v = svd.V(); // d × d
61        let s = svd.S(); // diag, length min(n, d)
62        let sv_len = s.column_vector().nrows();
63
64        // sklearn returns components_ shape (n_components, n_features) —
65        // rows are right-singular vectors (V columns). For us: take first k
66        // columns of V.
67        let mut components = Array2::<f64>::zeros((d, k));
68        let mut sv = Array1::<f64>::zeros(k);
69        for j in 0..k {
70            for i in 0..d {
71                components[[i, j]] = v[(i, j)];
72            }
73            sv[j] = if j < sv_len {
74                s.column_vector()[j]
75            } else {
76                0.0
77            };
78        }
79        // Explained variance ≈ Var(X V_j) = (s_j^2) / (n - 1)
80        let mut ev = Array1::<f64>::zeros(k);
81        let denom = (n as f64 - 1.0).max(1.0);
82        for j in 0..k {
83            ev[j] = sv[j] * sv[j] / denom;
84        }
85
86        Ok(FittedTruncatedSvd {
87            components,
88            singular_values: sv,
89            explained_variance: ev,
90            n_features: d,
91        })
92    }
93}
94
95impl Transform<f64> for FittedTruncatedSvd {
96    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
97        if x.ncols() != self.n_features {
98            return Err(RustMlError::ShapeMismatch(format!(
99                "expected {} features, got {}",
100                self.n_features,
101                x.ncols()
102            )));
103        }
104        Ok(x.dot(&self.components))
105    }
106}
107
108impl InverseTransform<f64> for FittedTruncatedSvd {
109    /// Reconstruct the original-space representation from the projection
110    /// `transform(X) = X V_k`. The inverse mapping is `T @ V_kᵀ`, valid up to
111    /// the rank-`k` approximation error.
112    fn inverse_transform(&self, t: &Array2<f64>) -> Result<Array2<f64>> {
113        if t.ncols() != self.components.ncols() {
114            return Err(RustMlError::ShapeMismatch(format!(
115                "expected {} components, got {}",
116                self.components.ncols(),
117                t.ncols()
118            )));
119        }
120        Ok(t.dot(&self.components.t()))
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use ndarray::array;
128
129    #[test]
130    fn test_truncated_svd_reduces_dim() {
131        let x = array![
132            [1.0, 2.0, 3.0],
133            [4.0, 5.0, 6.0],
134            [7.0, 8.0, 9.0],
135            [2.0, 3.0, 5.0]
136        ];
137        let svd = TruncatedSvd::new(2).fit(&x).unwrap();
138        let t = svd.transform(&x).unwrap();
139        assert_eq!(t.shape(), &[4, 2]);
140        // First singular value should be much larger than second.
141        assert!(svd.singular_values[0] > svd.singular_values[1]);
142    }
143
144    #[test]
145    fn test_inverse_transform_reconstructs_full_rank() {
146        let x = array![
147            [1.0_f64, 2.0, 3.0],
148            [4.0, 5.0, 6.0],
149            [7.0, 8.0, 9.0],
150            [2.0, 3.0, 5.0]
151        ];
152        // Keep all components → inverse should reconstruct exactly.
153        let svd = TruncatedSvd::new(3).fit(&x).unwrap();
154        let t = svd.transform(&x).unwrap();
155        let back = svd.inverse_transform(&t).unwrap();
156        for ((i, j), &v) in x.indexed_iter() {
157            assert!(
158                (back[[i, j]] - v).abs() < 1e-9,
159                "[{},{}]: {} vs {}",
160                i,
161                j,
162                back[[i, j]],
163                v
164            );
165        }
166    }
167}