Skip to main content

anofox_ml_preprocessing/
pca.rs

1use anofox_ml_core::{FitUnsupervised, Float, InverseTransform, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2, Axis};
3
4/// Parameters for PCA (unfitted state).
5///
6/// Principal Component Analysis reduces dimensionality by projecting data
7/// onto the directions of maximum variance. Eigendecomposition is performed
8/// via power iteration with deflation, requiring no external LAPACK dependency.
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct Pca {
11    /// Number of principal components to keep.
12    pub n_components: usize,
13}
14
15impl Pca {
16    /// Create a new `Pca` with the given number of components.
17    pub fn new(n_components: usize) -> Self {
18        Self { n_components }
19    }
20}
21
22/// Fitted PCA — holds learned principal components, explained variance, and mean.
23#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
24#[serde(bound(deserialize = "F: serde::de::DeserializeOwned"))]
25pub struct FittedPca<F: Float> {
26    /// Principal component directions, shape (n_components, n_features).
27    /// Each row is a unit eigenvector of the covariance matrix.
28    components: Array2<F>,
29    /// Variance explained by each component (eigenvalues), length n_components.
30    explained_variance: Array1<F>,
31    /// Per-feature mean used for centering, length n_features.
32    mean: Array1<F>,
33}
34
35/// Number of power-iteration steps per component.
36const POWER_ITER_STEPS: usize = 200;
37
38impl<F: Float> FitUnsupervised<F> for Pca {
39    type Fitted = FittedPca<F>;
40
41    fn fit(&self, x: &Array2<F>) -> Result<Self::Fitted> {
42        let (n_samples, n_features) = x.dim();
43
44        if n_samples == 0 || n_features == 0 {
45            return Err(RustMlError::EmptyInput("input array is empty".into()));
46        }
47
48        if self.n_components == 0 {
49            return Err(RustMlError::InvalidParameter(
50                "n_components must be at least 1".into(),
51            ));
52        }
53
54        if self.n_components > n_features {
55            return Err(RustMlError::InvalidParameter(format!(
56                "n_components ({}) must be <= n_features ({})",
57                self.n_components, n_features
58            )));
59        }
60
61        if n_samples < 2 {
62            return Err(RustMlError::InvalidParameter(
63                "PCA requires at least 2 samples to compute covariance".into(),
64            ));
65        }
66
67        let n_f = F::from_usize(n_samples).unwrap();
68
69        // 1. Compute per-feature mean.
70        let mean = x.sum_axis(Axis(0)) / n_f;
71
72        // 2. Center the data: X_centered = X - mean (broadcasting).
73        let x_centered = x - &mean;
74
75        // 3. Covariance matrix via BLAS: C = X_centered.T @ X_centered / (n-1).
76        let n_minus_1 = F::from_usize(n_samples - 1).unwrap();
77        let mut cov = x_centered.t().dot(&x_centered);
78        cov.mapv_inplace(|v| v / n_minus_1);
79
80        // 4. Power iteration with deflation to extract top-k eigenpairs.
81        let mut components = Array2::<F>::zeros((self.n_components, n_features));
82        let mut explained_variance = Array1::<F>::zeros(self.n_components);
83        let eps = F::from_f64(1e-12).unwrap();
84
85        for k in 0..self.n_components {
86            // (a) Deterministic initial vector: v[i] = (i+1).
87            let mut v = Array1::<F>::zeros(n_features);
88            for i in 0..n_features {
89                v[i] = F::from_usize(i + 1).unwrap();
90            }
91            // Orthogonalize against previously found components.
92            for prev in 0..k {
93                let prev_comp = components.row(prev);
94                let dot = v.dot(&prev_comp);
95                v.scaled_add(-dot, &prev_comp);
96            }
97            // Normalize.
98            let norm = v.dot(&v).sqrt();
99            if norm < eps {
100                // All directions exhausted; store zero eigenvalue with arbitrary
101                // orthogonal direction (already zeroed out).
102                explained_variance[k] = F::zero();
103                // Build an orthogonal vector via standard basis probing.
104                for basis_idx in 0..n_features {
105                    v = Array1::<F>::zeros(n_features);
106                    v[basis_idx] = F::one();
107                    for prev in 0..k {
108                        let prev_comp = components.row(prev);
109                        let dot = v.dot(&prev_comp);
110                        v.scaled_add(-dot, &prev_comp);
111                    }
112                    let n2 = v.dot(&v).sqrt();
113                    if n2 > eps {
114                        v.mapv_inplace(|vi| vi / n2);
115                        break;
116                    }
117                }
118                components.row_mut(k).assign(&v);
119                continue;
120            }
121            v.mapv_inplace(|vi| vi / norm);
122
123            // (b) Power iteration with convergence check.
124            let convergence_tol = F::from_f64(1e-12).unwrap();
125            for _ in 0..POWER_ITER_STEPS {
126                // w = C @ v
127                let mut w = cov.dot(&v);
128                // Re-orthogonalize against previously found components
129                // for numerical stability.
130                for prev in 0..k {
131                    let prev_comp = components.row(prev);
132                    let dot = w.dot(&prev_comp);
133                    w.scaled_add(-dot, &prev_comp);
134                }
135                // norm(w)
136                let w_norm = w.dot(&w).sqrt();
137                if w_norm < F::from_f64(1e-30).unwrap() {
138                    // Degenerate -- remaining eigenvalues are essentially zero.
139                    break;
140                }
141                let v_new = w.mapv(|wi| wi / w_norm);
142                // Check convergence: |v_new - v| < tol
143                let diff_vec = &v_new - &v;
144                let diff = diff_vec.dot(&diff_vec);
145                v = v_new;
146                if diff < convergence_tol {
147                    break;
148                }
149            }
150
151            // (c) Eigenvalue = v^T C v. Clamp to zero if negative (numerical noise).
152            let cv = cov.dot(&v);
153            let eigenvalue = v.dot(&cv);
154            let eigenvalue = if eigenvalue < F::zero() {
155                F::zero()
156            } else {
157                eigenvalue
158            };
159
160            // (d) Deflate: C = C - eigenvalue * v v^T (outer product).
161            let v_col = v.view().insert_axis(Axis(1));
162            let v_row = v.view().insert_axis(Axis(0));
163            cov -= &(v_col.dot(&v_row) * eigenvalue);
164
165            // (e) Store results.
166            components.row_mut(k).assign(&v);
167            explained_variance[k] = eigenvalue;
168        }
169
170        Ok(FittedPca {
171            components,
172            explained_variance,
173            mean,
174        })
175    }
176}
177
178impl<F: Float> Transform<F> for FittedPca<F> {
179    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
180        let n_features = self.mean.len();
181        if x.ncols() != n_features {
182            return Err(RustMlError::ShapeMismatch(format!(
183                "expected {} features, got {}",
184                n_features,
185                x.ncols()
186            )));
187        }
188
189        // Center and project: (X - mean) @ components.T
190        let centered = x - &self.mean;
191        Ok(centered.dot(&self.components.t()))
192    }
193}
194
195impl<F: Float> InverseTransform<F> for FittedPca<F> {
196    fn inverse_transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
197        let n_components = self.components.nrows();
198        if x.ncols() != n_components {
199            return Err(RustMlError::ShapeMismatch(format!(
200                "expected {} components, got {}",
201                n_components,
202                x.ncols()
203            )));
204        }
205
206        // Reconstruct: X_reduced @ components + mean
207        Ok(x.dot(&self.components) + &self.mean)
208    }
209}
210
211impl<F: Float> FittedPca<F> {
212    /// Principal component directions, shape (n_components, n_features).
213    pub fn components(&self) -> &Array2<F> {
214        &self.components
215    }
216
217    /// Variance explained by each component.
218    pub fn explained_variance(&self) -> &Array1<F> {
219        &self.explained_variance
220    }
221
222    /// Per-feature mean used for centering.
223    pub fn mean(&self) -> &Array1<F> {
224        &self.mean
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use approx::assert_abs_diff_eq;
232    use ndarray::array;
233
234    #[test]
235    fn test_first_component_captures_most_variance() {
236        // 2D data with a clear principal axis along (1, 1).
237        // Variance along (1,1) is much larger than along (1,-1).
238        let x = array![
239            [1.0, 1.0],
240            [2.0, 2.1],
241            [3.0, 2.9],
242            [4.0, 4.0],
243            [5.0, 5.1],
244            [6.0, 5.9],
245            [7.0, 7.0],
246            [8.0, 8.1],
247        ];
248
249        let pca = Pca { n_components: 2 };
250        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
251
252        let var = fitted.explained_variance();
253
254        // First component should capture the vast majority of variance.
255        let total: f64 = var.iter().copied().sum();
256        let ratio = var[0] / total;
257        assert!(
258            ratio > 0.95,
259            "first component should capture >95% variance, got {:.4}",
260            ratio
261        );
262    }
263
264    #[test]
265    fn test_transform_inverse_transform_roundtrip() {
266        // With n_components == n_features, roundtrip should be exact.
267        let x = array![
268            [1.0, 2.0, 3.0],
269            [4.0, 5.0, 6.0],
270            [7.0, 8.0, 9.0],
271            [10.0, 11.0, 12.0],
272        ];
273
274        let pca = Pca { n_components: 3 };
275        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
276        let transformed = fitted.transform(&x).unwrap();
277        let recovered = fitted.inverse_transform(&transformed).unwrap();
278
279        for (a, b) in x.iter().zip(recovered.iter()) {
280            assert_abs_diff_eq!(a, b, epsilon = 1e-8);
281        }
282    }
283
284    #[test]
285    fn test_transform_inverse_transform_lossy() {
286        // With fewer components, roundtrip is approximate.
287        let x = array![
288            [1.0, 2.0, 0.5],
289            [2.0, 4.0, 1.0],
290            [3.0, 6.0, 1.5],
291            [4.0, 8.0, 2.0],
292            [5.0, 10.0, 2.5],
293        ];
294
295        let pca = Pca { n_components: 1 };
296        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
297        let transformed = fitted.transform(&x).unwrap();
298        let recovered = fitted.inverse_transform(&transformed).unwrap();
299
300        // The data is nearly rank-1 (cols 2 and 3 are ~2x and ~0.5x col 1),
301        // so even 1 component should give a reasonable reconstruction.
302        for (a, b) in x.iter().zip(recovered.iter()) {
303            assert_abs_diff_eq!(a, b, epsilon = 0.1);
304        }
305    }
306
307    #[test]
308    fn test_explained_variance_sorted_descending() {
309        // Data with three genuinely distinct variance directions.
310        let x = array![
311            [1.0, 0.5, 0.1],
312            [2.0, 1.0, 0.3],
313            [3.0, 1.4, 0.2],
314            [4.0, 2.1, 0.5],
315            [5.0, 2.5, 0.8],
316            [6.0, 3.2, 0.4],
317            [7.0, 3.6, 0.9],
318        ];
319
320        let pca = Pca { n_components: 3 };
321        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
322        let var = fitted.explained_variance();
323
324        // All eigenvalues should be non-negative.
325        for (i, &v) in var.iter().enumerate() {
326            assert!(v >= 0.0, "explained_variance[{}] = {} is negative", i, v);
327        }
328
329        for i in 1..var.len() {
330            assert!(
331                var[i - 1] >= var[i],
332                "explained_variance not sorted descending: var[{}]={} < var[{}]={}",
333                i - 1,
334                var[i - 1],
335                i,
336                var[i]
337            );
338        }
339    }
340
341    #[test]
342    fn test_n_components_exceeds_n_features() {
343        let x = array![[1.0, 2.0], [3.0, 4.0]];
344
345        let pca = Pca { n_components: 5 };
346        let result = FitUnsupervised::<f64>::fit(&pca, &x);
347        assert!(result.is_err());
348
349        let err = result.unwrap_err();
350        match err {
351            RustMlError::InvalidParameter(msg) => {
352                assert!(
353                    msg.contains("n_components"),
354                    "error should mention n_components: {}",
355                    msg
356                );
357            }
358            other => panic!("expected InvalidParameter, got {:?}", other),
359        }
360    }
361
362    #[test]
363    fn test_components_are_unit_vectors() {
364        let x = array![
365            [1.0, 2.0, 3.0],
366            [4.0, 5.0, 6.0],
367            [7.0, 8.0, 9.0],
368            [10.0, 11.0, 12.0],
369        ];
370
371        let pca = Pca { n_components: 2 };
372        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
373
374        for row in fitted.components().rows() {
375            let norm: f64 = row.iter().map(|&v| v * v).sum::<f64>().sqrt();
376            assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
377        }
378    }
379
380    #[test]
381    fn test_mean_is_correct() {
382        let x = array![[1.0, 4.0], [3.0, 6.0]];
383
384        let pca = Pca { n_components: 2 };
385        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
386
387        assert_abs_diff_eq!(fitted.mean()[0], 2.0, epsilon = 1e-10);
388        assert_abs_diff_eq!(fitted.mean()[1], 5.0, epsilon = 1e-10);
389    }
390
391    #[test]
392    fn test_shape_mismatch_on_transform() {
393        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
394
395        let pca = Pca { n_components: 1 };
396        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
397
398        let wrong = array![[1.0, 2.0, 3.0]];
399        assert!(fitted.transform(&wrong).is_err());
400    }
401
402    #[test]
403    fn test_empty_input() {
404        let x = Array2::<f64>::zeros((0, 3));
405
406        let pca = Pca { n_components: 1 };
407        let result = FitUnsupervised::<f64>::fit(&pca, &x);
408        assert!(result.is_err());
409    }
410
411    #[test]
412    fn test_single_sample_error() {
413        let x = array![[1.0, 2.0, 3.0]];
414
415        let pca = Pca { n_components: 1 };
416        let result = FitUnsupervised::<f64>::fit(&pca, &x);
417        assert!(result.is_err());
418    }
419
420    #[test]
421    fn test_constant_features() {
422        // All features identical — zero variance.
423        let x = array![[1.0, 2.0], [1.0, 2.0], [1.0, 2.0], [1.0, 2.0]];
424
425        let pca = Pca { n_components: 2 };
426        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
427
428        // All eigenvalues should be zero (or near-zero).
429        for &v in fitted.explained_variance().iter() {
430            assert!(v.abs() < 1e-10, "expected near-zero variance, got {}", v);
431        }
432    }
433
434    #[test]
435    fn test_large_values() {
436        // Large feature values should not produce NaN/Inf
437        let x = array![[1e10, 2e10], [3e10, 4e10], [5e10, 6e10], [7e10, 8e10],];
438
439        let pca = Pca { n_components: 2 };
440        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
441        let transformed = fitted.transform(&x).unwrap();
442
443        for &v in transformed.iter() {
444            assert!(
445                v.is_finite(),
446                "PCA on large values produced non-finite: {}",
447                v
448            );
449        }
450        for &v in fitted.explained_variance().iter() {
451            assert!(
452                v.is_finite() && v >= 0.0,
453                "variance should be finite and non-negative: {}",
454                v
455            );
456        }
457    }
458
459    #[test]
460    fn test_near_zero_variance_column() {
461        // One column has near-zero variance, other column has real variance
462        let x = array![
463            [1.0, 5.0],
464            [2.0, 5.0 + 1e-14],
465            [3.0, 5.0 - 1e-14],
466            [4.0, 5.0],
467        ];
468
469        let pca = Pca { n_components: 2 };
470        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
471        let transformed = fitted.transform(&x).unwrap();
472
473        for &v in transformed.iter() {
474            assert!(
475                v.is_finite(),
476                "near-zero variance column produced non-finite: {}",
477                v
478            );
479        }
480        // First component should capture nearly all variance
481        let var = fitted.explained_variance();
482        assert!(var[0] > var[1] * 1e6, "first component should dominate");
483    }
484
485    #[test]
486    fn test_collinear_features() {
487        // Features 1 and 2 are perfectly collinear (col2 = 2*col1)
488        // PCA should handle this gracefully
489        let x = array![
490            [1.0, 2.0, 0.5],
491            [2.0, 4.0, 1.0],
492            [3.0, 6.0, 1.5],
493            [4.0, 8.0, 2.0],
494            [5.0, 10.0, 2.5],
495        ];
496
497        let pca = Pca { n_components: 3 };
498        let fitted = FitUnsupervised::<f64>::fit(&pca, &x).unwrap();
499        let var = fitted.explained_variance();
500
501        // All values should be finite and non-negative
502        for &v in var.iter() {
503            assert!(
504                v.is_finite() && v >= -1e-10,
505                "variance should be finite and non-negative: {}",
506                v
507            );
508        }
509        // With perfect collinearity, effective rank is 1, so at most 1 non-zero eigenvalue
510        let nonzero_count = var.iter().filter(|&&v| v > 1e-8).count();
511        assert!(
512            nonzero_count <= 2,
513            "collinear data should have rank <= 2, got {} non-zero eigenvalues",
514            nonzero_count
515        );
516    }
517}