1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct Pca {
11 pub n_components: usize,
13}
14
15impl Pca {
16 pub fn new(n_components: usize) -> Self {
18 Self { n_components }
19 }
20}
21
22#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
24#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
25pub struct FittedPca<F: Float> {
26 components: Array2<F>,
29 explained_variance: Array1<F>,
31 mean: Array1<F>,
33}
34
35const POWER_ITER_STEPS: usize = 200;
37
38impl<F: Float> FitUnsupervised<F> for Pca {
39 type Fitted = FittedPca<F>;
40
41 fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
42 let (n_samples, n_features) = x.dim();
43
44 if n_samples == 0 || n_features == 0 {
45 return Err(RustMlError::EmptyInput("input array is empty".into()));
46 }
47
48 if self.n_components == 0 {
49 return Err(RustMlError::InvalidParameter(
50 "n_components must be at least 1".into(),
51 ));
52 }
53
54 if self.n_components > n_features {
55 return Err(RustMlError::InvalidParameter(format!(
56 "n_components ({}) must be <= n_features ({})",
57 self.n_components, n_features
58 )));
59 }
60
61 if n_samples < 2 {
62 return Err(RustMlError::InvalidParameter(
63 "PCA requires at least 2 samples to compute covariance".into(),
64 ));
65 }
66
67 let n_f = F::from_usize(n_samples).unwrap();
68
69 let mean = x.sum_axis(Axis(0)) / n_f;
71
72 let x_centered = x - &mean;
74
75 let n_minus_1 = F::from_usize(n_samples - 1).unwrap();
77 let mut cov = x_centered.t().dot(&x_centered);
78 cov.mapv_inplace(|v| v / n_minus_1);
79
80 let mut components = Array2::<F>::zeros((self.n_components, n_features));
82 let mut explained_variance = Array1::<F>::zeros(self.n_components);
83 let eps = F::from_f64(1e-12).unwrap();
84
85 for k in 0..self.n_components {
86 let mut v = Array1::<F>::zeros(n_features);
88 for i in 0..n_features {
89 v[i] = F::from_usize(i + 1).unwrap();
90 }
91 for prev in 0..k {
93 let prev_comp = components.row(prev);
94 let dot = v.dot(&prev_comp);
95 v.scaled_add(-dot, &prev_comp);
96 }
97 let norm = v.dot(&v).sqrt();
99 if norm < eps {
100 explained_variance[k] = F::zero();
103 for basis_idx in 0..n_features {
105 v = Array1::<F>::zeros(n_features);
106 v[basis_idx] = F::one();
107 for prev in 0..k {
108 let prev_comp = components.row(prev);
109 let dot = v.dot(&prev_comp);
110 v.scaled_add(-dot, &prev_comp);
111 }
112 let n2 = v.dot(&v).sqrt();
113 if n2 > eps {
114 v.mapv_inplace(|vi| vi / n2);
115 break;
116 }
117 }
118 components.row_mut(k).assign(&v);
119 continue;
120 }
121 v.mapv_inplace(|vi| vi / norm);
122
123 let convergence_tol = F::from_f64(1e-12).unwrap();
125 for _ in 0..POWER_ITER_STEPS {
126 let mut w = cov.dot(&v);
128 for prev in 0..k {
131 let prev_comp = components.row(prev);
132 let dot = w.dot(&prev_comp);
133 w.scaled_add(-dot, &prev_comp);
134 }
135 let w_norm = w.dot(&w).sqrt();
137 if w_norm < F::from_f64(1e-30).unwrap() {
138 break;
140 }
141 let v_new = w.mapv(|wi| wi / w_norm);
142 let diff_vec = &v_new - &v;
144 let diff = diff_vec.dot(&diff_vec);
145 v = v_new;
146 if diff < convergence_tol {
147 break;
148 }
149 }
150
151 let cv = cov.dot(&v);
153 let eigenvalue = v.dot(&cv);
154 let eigenvalue = if eigenvalue < F::zero() {
155 F::zero()
156 } else {
157 eigenvalue
158 };
159
160 let v_col = v.view().insert_axis(Axis(1));
162 let v_row = v.view().insert_axis(Axis(0));
163 cov -= &(v_col.dot(&v_row) * eigenvalue);
164
165 components.row_mut(k).assign(&v);
167 explained_variance[k] = eigenvalue;
168 }
169
170 Ok(FittedPca {
171 components,
172 explained_variance,
173 mean,
174 })
175 }
176}
177
178impl<F: Float> Transform<F> for FittedPca<F> {
179 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
180 let n_features = self.mean.len();
181 if x.ncols() != n_features {
182 return Err(RustMlError::ShapeMismatch(format!(
183 "expected {} features, got {}",
184 n_features,
185 x.ncols()
186 )));
187 }
188
189 let centered = x - &self.mean;
191 Ok(centered.dot(&self.components.t()))
192 }
193}
194
195impl<F: Float> InverseTransform<F> for FittedPca<F> {
196 fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
197 let n_components = self.components.nrows();
198 if x.ncols() != n_components {
199 return Err(RustMlError::ShapeMismatch(format!(
200 "expected {} components, got {}",
201 n_components,
202 x.ncols()
203 )));
204 }
205
206 Ok(x.dot(&self.components) + &self.mean)
208 }
209}
210
211impl<F: Float> FittedPca<F> {
212 pub fn components(&self) -> &Array2<F> {
214 &self.components
215 }
216
217 pub fn explained_variance(&self) -> &Array1<F> {
219 &self.explained_variance
220 }
221
222 pub fn mean(&self) -> &Array1<F> {
224 &self.mean
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use approx::assert_abs_diff_eq;
232 use ndarray::array;
233
234 #[test]
235 fn test_first_component_captures_most_variance() {
236 let x = array![
239 [1.0, 1.0],
240 [2.0, 2.1],
241 [3.0, 2.9],
242 [4.0, 4.0],
243 [5.0, 5.1],
244 [6.0, 5.9],
245 [7.0, 7.0],
246 [8.0, 8.1],
247 ];
248
249 let pca = Pca { n_components: 2 };
250 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
251
252 let var = fitted.explained_variance();
253
254 let total: f64 = var.iter().copied().sum();
256 let ratio = var[0] / total;
257 assert!(
258 ratio > 0.95,
259 "first component should capture >95% variance, got {:.4}",
260 ratio
261 );
262 }
263
264 #[test]
265 fn test_transform_inverse_transform_roundtrip() {
266 let x = array![
268 [1.0, 2.0, 3.0],
269 [4.0, 5.0, 6.0],
270 [7.0, 8.0, 9.0],
271 [10.0, 11.0, 12.0],
272 ];
273
274 let pca = Pca { n_components: 3 };
275 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
276 let transformed = fitted.transform(&x).unwrap();
277 let recovered = fitted.inverse_transform(&transformed).unwrap();
278
279 for (a, b) in x.iter().zip(recovered.iter()) {
280 assert_abs_diff_eq!(a, b, epsilon = 1e-8);
281 }
282 }
283
284 #[test]
285 fn test_transform_inverse_transform_lossy() {
286 let x = array![
288 [1.0, 2.0, 0.5],
289 [2.0, 4.0, 1.0],
290 [3.0, 6.0, 1.5],
291 [4.0, 8.0, 2.0],
292 [5.0, 10.0, 2.5],
293 ];
294
295 let pca = Pca { n_components: 1 };
296 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
297 let transformed = fitted.transform(&x).unwrap();
298 let recovered = fitted.inverse_transform(&transformed).unwrap();
299
300 for (a, b) in x.iter().zip(recovered.iter()) {
303 assert_abs_diff_eq!(a, b, epsilon = 0.1);
304 }
305 }
306
307 #[test]
308 fn test_explained_variance_sorted_descending() {
309 let x = array![
311 [1.0, 0.5, 0.1],
312 [2.0, 1.0, 0.3],
313 [3.0, 1.4, 0.2],
314 [4.0, 2.1, 0.5],
315 [5.0, 2.5, 0.8],
316 [6.0, 3.2, 0.4],
317 [7.0, 3.6, 0.9],
318 ];
319
320 let pca = Pca { n_components: 3 };
321 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
322 let var = fitted.explained_variance();
323
324 for (i, &v) in var.iter().enumerate() {
326 assert!(v >= 0.0, "explained_variance[{}] = {} is negative", i, v);
327 }
328
329 for i in 1..var.len() {
330 assert!(
331 var[i - 1] >= var[i],
332 "explained_variance not sorted descending: var[{}]={} < var[{}]={}",
333 i - 1,
334 var[i - 1],
335 i,
336 var[i]
337 );
338 }
339 }
340
341 #[test]
342 fn test_n_components_exceeds_n_features() {
343 let x = array![[1.0, 2.0], [3.0, 4.0]];
344
345 let pca = Pca { n_components: 5 };
346 let result = FitUnsupervised::<f64>::fit(&pca, &x);
347 assert!(result.is_err());
348
349 let err = result.unwrap_err();
350 match err {
351 RustMlError::InvalidParameter(msg) => {
352 assert!(
353 msg.contains("n_components"),
354 "error should mention n_components: {}",
355 msg
356 );
357 }
358 other => panic!("expected InvalidParameter, got {:?}", other),
359 }
360 }
361
362 #[test]
363 fn test_components_are_unit_vectors() {
364 let x = array![
365 [1.0, 2.0, 3.0],
366 [4.0, 5.0, 6.0],
367 [7.0, 8.0, 9.0],
368 [10.0, 11.0, 12.0],
369 ];
370
371 let pca = Pca { n_components: 2 };
372 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
373
374 for row in fitted.components().rows() {
375 let norm: f64 = row.iter().map(|&v| v * v).sum::<f64>().sqrt();
376 assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
377 }
378 }
379
380 #[test]
381 fn test_mean_is_correct() {
382 let x = array![[1.0, 4.0], [3.0, 6.0]];
383
384 let pca = Pca { n_components: 2 };
385 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
386
387 assert_abs_diff_eq!(fitted.mean()[0], 2.0, epsilon = 1e-10);
388 assert_abs_diff_eq!(fitted.mean()[1], 5.0, epsilon = 1e-10);
389 }
390
391 #[test]
392 fn test_shape_mismatch_on_transform() {
393 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
394
395 let pca = Pca { n_components: 1 };
396 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
397
398 let wrong = array![[1.0, 2.0, 3.0]];
399 assert!(fitted.transform(&wrong).is_err());
400 }
401
402 #[test]
403 fn test_empty_input() {
404 let x = Array2::<f64>::zeros((0, 3));
405
406 let pca = Pca { n_components: 1 };
407 let result = FitUnsupervised::<f64>::fit(&pca, &x);
408 assert!(result.is_err());
409 }
410
411 #[test]
412 fn test_single_sample_error() {
413 let x = array![[1.0, 2.0, 3.0]];
414
415 let pca = Pca { n_components: 1 };
416 let result = FitUnsupervised::<f64>::fit(&pca, &x);
417 assert!(result.is_err());
418 }
419
420 #[test]
421 fn test_constant_features() {
422 let x = array![[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]];
424
425 let pca = Pca { n_components: 2 };
426 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
427
428 for &v in fitted.explained_variance().iter() {
430 assert!(v.abs() < 1e-10, "expected near-zero variance, got {}", v);
431 }
432 }
433
434 #[test]
435 fn test_large_values() {
436 let x = array![[1e10, 2e10], [3e10, 4e10], [5e10, 6e10], [7e10, 8e10],];
438
439 let pca = Pca { n_components: 2 };
440 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
441 let transformed = fitted.transform(&x).unwrap();
442
443 for &v in transformed.iter() {
444 assert!(
445 v.is_finite(),
446 "PCA on large values produced non-finite: {}",
447 v
448 );
449 }
450 for &v in fitted.explained_variance().iter() {
451 assert!(
452 v.is_finite() && v >= 0.0,
453 "variance should be finite and non-negative: {}",
454 v
455 );
456 }
457 }
458
459 #[test]
460 fn test_near_zero_variance_column() {
461 let x = array![
463 [1.0, 5.0],
464 [2.0, 5.0 + 1e-14],
465 [3.0, 5.0 - 1e-14],
466 [4.0, 5.0],
467 ];
468
469 let pca = Pca { n_components: 2 };
470 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
471 let transformed = fitted.transform(&x).unwrap();
472
473 for &v in transformed.iter() {
474 assert!(
475 v.is_finite(),
476 "near-zero variance column produced non-finite: {}",
477 v
478 );
479 }
480 let var = fitted.explained_variance();
482 assert!(var[0] > var[1] * 1e6, "first component should dominate");
483 }
484
485 #[test]
486 fn test_collinear_features() {
487 let x = array![
490 [1.0, 2.0, 0.5],
491 [2.0, 4.0, 1.0],
492 [3.0, 6.0, 1.5],
493 [4.0, 8.0, 2.0],
494 [5.0, 10.0, 2.5],
495 ];
496
497 let pca = Pca { n_components: 3 };
498 let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
499 let var = fitted.explained_variance();
500
501 for &v in var.iter() {
503 assert!(
504 v.is_finite() && v >= -1e-10,
505 "variance should be finite and non-negative: {}",
506 v
507 );
508 }
509 let nonzero_count = var.iter().filter(|&&v| v > 1e-8).count();
511 assert!(
512 nonzero_count <= 2,
513 "collinear data should have rank <= 2, got {} non-zero eigenvalues",
514 nonzero_count
515 );
516 }
517}