anofox_ml_preprocessing/
truncated_svd.rs1use anofox_ml_core::{FitUnsupervised, InverseTransform, Result, RustMlError, Transform};
10use faer::linalg::solvers::Svd;
11use faer::Mat;
12use ndarray::{Array1, Array2};
13
14#[derive(Debug, Clone)]
15pub struct TruncatedSvd {
16 pub n_components: usize,
17}
18
19impl TruncatedSvd {
20 pub fn new(n_components: usize) -> Self {
21 Self { n_components }
22 }
23}
24
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
26pub struct FittedTruncatedSvd {
27 pub components: Array2<f64>,
29 pub singular_values: Array1<f64>,
31 pub explained_variance: Array1<f64>,
33 n_features: usize,
34}
35
36impl FittedTruncatedSvd {
37 pub fn n_components(&self) -> usize {
38 self.components.ncols()
39 }
40}
41
42impl FitUnsupervised<f64> for TruncatedSvd {
43 type Fitted = FittedTruncatedSvd;
44
45 fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
46 let (n, d) = x.dim();
47 if n == 0 || d == 0 {
48 return Err(RustMlError::EmptyInput("empty input".into()));
49 }
50 let k = self.n_components.min(d.min(n));
51 if k == 0 {
52 return Err(RustMlError::InvalidParameter(
53 "n_components must be at least 1".into(),
54 ));
55 }
56
57 let m = Mat::<f64>::from_fn(n, d, |i, j| x[[i, j]]);
58 let svd = Svd::new(m.as_ref())
59 .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
60 let v = svd.V(); let s = svd.S(); let sv_len = s.column_vector().nrows();
63
64 let mut components = Array2::<f64>::zeros((d, k));
68 let mut sv = Array1::<f64>::zeros(k);
69 for j in 0..k {
70 for i in 0..d {
71 components[[i, j]] = v[(i, j)];
72 }
73 sv[j] = if j < sv_len {
74 s.column_vector()[j]
75 } else {
76 0.0
77 };
78 }
79 let mut ev = Array1::<f64>::zeros(k);
81 let denom = (n as f64 - 1.0).max(1.0);
82 for j in 0..k {
83 ev[j] = sv[j] * sv[j] / denom;
84 }
85
86 Ok(FittedTruncatedSvd {
87 components,
88 singular_values: sv,
89 explained_variance: ev,
90 n_features: d,
91 })
92 }
93}
94
95impl Transform<f64> for FittedTruncatedSvd {
96 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
97 if x.ncols() != self.n_features {
98 return Err(RustMlError::ShapeMismatch(format!(
99 "expected {} features, got {}",
100 self.n_features,
101 x.ncols()
102 )));
103 }
104 Ok(x.dot(&self.components))
105 }
106}
107
108impl InverseTransform<f64> for FittedTruncatedSvd {
109 fn inverse_transform(&self, t: &Array2<f64>) -> Result<Array2<f64>> {
113 if t.ncols() != self.components.ncols() {
114 return Err(RustMlError::ShapeMismatch(format!(
115 "expected {} components, got {}",
116 self.components.ncols(),
117 t.ncols()
118 )));
119 }
120 Ok(t.dot(&self.components.t()))
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use ndarray::array;
128
129 #[test]
130 fn test_truncated_svd_reduces_dim() {
131 let x = array![
132 [1.0, 2.0, 3.0],
133 [4.0, 5.0, 6.0],
134 [7.0, 8.0, 9.0],
135 [2.0, 3.0, 5.0]
136 ];
137 let svd = TruncatedSvd::new(2).fit(&x).unwrap();
138 let t = svd.transform(&x).unwrap();
139 assert_eq!(t.shape(), &[4, 2]);
140 assert!(svd.singular_values[0] > svd.singular_values[1]);
142 }
143
144 #[test]
145 fn test_inverse_transform_reconstructs_full_rank() {
146 let x = array![
147 [1.0_f64, 2.0, 3.0],
148 [4.0, 5.0, 6.0],
149 [7.0, 8.0, 9.0],
150 [2.0, 3.0, 5.0]
151 ];
152 let svd = TruncatedSvd::new(3).fit(&x).unwrap();
154 let t = svd.transform(&x).unwrap();
155 let back = svd.inverse_transform(&t).unwrap();
156 for ((i, j), &v) in x.indexed_iter() {
157 assert!(
158 (back[[i, j]] - v).abs() < 1e-9,
159 "[{},{}]: {} vs {}",
160 i,
161 j,
162 back[[i, j]],
163 v
164 );
165 }
166 }
167}