Skip to main content

ferrolearn_decomp/
sparse_pca.rs

1//! Sparse Principal Component Analysis (SparsePCA).
2//!
3//! SparsePCA finds sparse components that can optimally reconstruct the data
4//! by combining PCA with L1 (lasso) penalisation on the loadings. This
5//! produces components that are easier to interpret, at the cost of
6//! explained variance compared to standard PCA.
7//!
8//! # Algorithm
9//!
10//! Uses an Elastic-Net / Coordinate-Descent approach:
11//!
12//! 1. Initialise dictionary `V` from PCA (or random).
13//! 2. Alternate:
14//!    a. Fix `V`, solve for sparse code `U` via coordinate descent:
15//!    `min ||X - U V^T||^2 + alpha * ||U||_1` (per row of `U`).
16//!    b. Fix `U`, update `V = X^T U (U^T U)^{-1}`, then normalise columns.
17//! 3. The rows of `V` are the sparse principal components.
18//!
19//! # Examples
20//!
21//! ```
22//! use ferrolearn_decomp::SparsePCA;
23//! use ferrolearn_core::traits::{Fit, Transform};
24//! use ndarray::array;
25//!
26//! let spca = SparsePCA::<f64>::new(1);
27//! let x = array![
28//!     [1.0, 2.0, 3.0],
29//!     [4.0, 5.0, 6.0],
30//!     [7.0, 8.0, 9.0],
31//!     [10.0, 11.0, 12.0],
32//! ];
33//! let fitted = spca.fit(&x, &()).unwrap();
34//! let projected = fitted.transform(&x).unwrap();
35//! assert_eq!(projected.ncols(), 1);
36//! ```
37
38use ferrolearn_core::error::FerroError;
39use ferrolearn_core::traits::{Fit, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::Float;
42use rand::SeedableRng;
43use rand_distr::{Distribution, Uniform};
44
45// ---------------------------------------------------------------------------
46// SparsePCA (unfitted)
47// ---------------------------------------------------------------------------
48
49/// Sparse PCA configuration.
50///
51/// Holds hyperparameters for the Sparse PCA decomposition. Calling
52/// [`Fit::fit`] performs the iterative elastic-net / coordinate-descent
53/// procedure and returns a [`FittedSparsePCA`] that can project new data.
54#[derive(Debug, Clone)]
55pub struct SparsePCA<F> {
56    /// Number of sparse components to extract.
57    n_components: usize,
58    /// Sparsity penalty weight on the L1 norm of the loadings.
59    alpha: f64,
60    /// Maximum number of outer iterations.
61    max_iter: usize,
62    /// Convergence tolerance on the relative change in reconstruction error.
63    tol: f64,
64    /// Optional random seed for reproducibility.
65    random_state: Option<u64>,
66    _marker: std::marker::PhantomData<F>,
67}
68
69impl<F: Float + Send + Sync + 'static> SparsePCA<F> {
70    /// Create a new `SparsePCA` that extracts `n_components` sparse components.
71    ///
72    /// Defaults: `alpha = 1.0`, `max_iter = 1000`, `tol = 1e-8`,
73    /// `random_state = None`.
74    #[must_use]
75    pub fn new(n_components: usize) -> Self {
76        Self {
77            n_components,
78            alpha: 1.0,
79            max_iter: 1000,
80            tol: 1e-8,
81            random_state: None,
82            _marker: std::marker::PhantomData,
83        }
84    }
85
86    /// Set the sparsity penalty weight (L1 regularisation on codes).
87    #[must_use]
88    pub fn with_alpha(mut self, alpha: f64) -> Self {
89        self.alpha = alpha;
90        self
91    }
92
93    /// Set the maximum number of outer iterations.
94    #[must_use]
95    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
96        self.max_iter = max_iter;
97        self
98    }
99
100    /// Set the convergence tolerance.
101    #[must_use]
102    pub fn with_tol(mut self, tol: f64) -> Self {
103        self.tol = tol;
104        self
105    }
106
107    /// Set the random seed for reproducible results.
108    #[must_use]
109    pub fn with_random_state(mut self, seed: u64) -> Self {
110        self.random_state = Some(seed);
111        self
112    }
113
114    /// Return the configured number of components.
115    #[must_use]
116    pub fn n_components(&self) -> usize {
117        self.n_components
118    }
119
120    /// Return the configured sparsity penalty.
121    #[must_use]
122    pub fn alpha(&self) -> f64 {
123        self.alpha
124    }
125
126    /// Return the configured maximum iterations.
127    #[must_use]
128    pub fn max_iter(&self) -> usize {
129        self.max_iter
130    }
131
132    /// Return the configured tolerance.
133    #[must_use]
134    pub fn tol(&self) -> f64 {
135        self.tol
136    }
137}
138
139// ---------------------------------------------------------------------------
140// FittedSparsePCA
141// ---------------------------------------------------------------------------
142
143/// A fitted Sparse PCA model holding the learned components.
144///
145/// Created by calling [`Fit::fit`] on a [`SparsePCA`]. Implements
146/// [`Transform<Array2<F>>`] to project new data onto the sparse components.
147#[derive(Debug, Clone)]
148pub struct FittedSparsePCA<F> {
149    /// Sparse components, shape `(n_components, n_features)`.
150    components_: Array2<F>,
151    /// Per-feature mean computed during fitting (used for centring).
152    mean_: Array1<F>,
153    /// Number of outer iterations performed.
154    n_iter_: usize,
155}
156
157impl<F: Float + Send + Sync + 'static> FittedSparsePCA<F> {
158    /// Sparse components, shape `(n_components, n_features)`.
159    #[must_use]
160    pub fn components(&self) -> &Array2<F> {
161        &self.components_
162    }
163
164    /// Per-feature mean learned during fitting.
165    #[must_use]
166    pub fn mean(&self) -> &Array1<F> {
167        &self.mean_
168    }
169
170    /// Number of outer iterations performed.
171    #[must_use]
172    pub fn n_iter(&self) -> usize {
173        self.n_iter_
174    }
175}
176
177// ---------------------------------------------------------------------------
178// Internal helpers
179// ---------------------------------------------------------------------------
180
181/// Small epsilon to prevent division by zero.
182#[inline]
183fn eps<F: Float>() -> F {
184    F::from(1e-12).unwrap_or_else(F::epsilon)
185}
186
187/// Soft-thresholding operator: sign(x) * max(|x| - threshold, 0).
188#[inline]
189fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
190    if x > threshold {
191        x - threshold
192    } else if x < -threshold {
193        x + threshold
194    } else {
195        F::zero()
196    }
197}
198
199/// Solve sparse coding for a single row of U via coordinate descent:
200///   min_u  ||x_row - u V^T||^2 + alpha * ||u||_1
201///
202/// `v` has shape `(n_components, n_features)`.
203fn sparse_code_row<F: Float>(
204    x_row: &[F],
205    v: &Array2<F>,
206    alpha_f: F,
207    u_row: &mut [F],
208    n_cd_iters: usize,
209) {
210    let n_components = v.nrows();
211    let n_features = v.ncols();
212
213    for _iter in 0..n_cd_iters {
214        for k in 0..n_components {
215            // Compute residual excluding component k.
216            let mut residual_dot = F::zero();
217            let mut vk_norm_sq = F::zero();
218
219            for j in 0..n_features {
220                let mut r = F::from(x_row[j]).unwrap();
221                for kk in 0..n_components {
222                    if kk != k {
223                        r = r - u_row[kk] * v[[kk, j]];
224                    }
225                }
226                residual_dot = residual_dot + r * v[[k, j]];
227                vk_norm_sq = vk_norm_sq + v[[k, j]] * v[[k, j]];
228            }
229
230            if vk_norm_sq < eps::<F>() {
231                u_row[k] = F::zero();
232            } else {
233                u_row[k] = soft_threshold(residual_dot, alpha_f) / vk_norm_sq;
234            }
235        }
236    }
237}
238
239/// Compute the Frobenius norm squared of `X - U @ V`.
240fn reconstruction_error_sq<F: Float + 'static>(x: &Array2<F>, u: &Array2<F>, v: &Array2<F>) -> F {
241    let uv = u.dot(v);
242    let mut err = F::zero();
243    for (a, b) in x.iter().zip(uv.iter()) {
244        let d = *a - *b;
245        err = err + d * d;
246    }
247    err
248}
249
250// ---------------------------------------------------------------------------
251// Trait implementations
252// ---------------------------------------------------------------------------
253
254impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SparsePCA<F> {
255    type Fitted = FittedSparsePCA<F>;
256    type Error = FerroError;
257
258    /// Fit Sparse PCA by alternating sparse coding and dictionary update.
259    ///
260    /// # Errors
261    ///
262    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or exceeds
263    ///   the number of features.
264    /// - [`FerroError::InsufficientSamples`] if there are fewer than 2 samples.
265    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSparsePCA<F>, FerroError> {
266        let (n_samples, n_features) = x.dim();
267
268        if self.n_components == 0 {
269            return Err(FerroError::InvalidParameter {
270                name: "n_components".into(),
271                reason: "must be at least 1".into(),
272            });
273        }
274        if self.n_components > n_features {
275            return Err(FerroError::InvalidParameter {
276                name: "n_components".into(),
277                reason: format!(
278                    "n_components ({}) exceeds n_features ({})",
279                    self.n_components, n_features
280                ),
281            });
282        }
283        if n_samples < 2 {
284            return Err(FerroError::InsufficientSamples {
285                required: 2,
286                actual: n_samples,
287                context: "SparsePCA::fit requires at least 2 samples".into(),
288            });
289        }
290
291        let n_comp = self.n_components;
292        let n_f = F::from(n_samples).unwrap();
293        let alpha_f = F::from(self.alpha).unwrap_or_else(F::one);
294
295        // Step 1: compute mean and centre data.
296        let mut mean = Array1::<F>::zeros(n_features);
297        for j in 0..n_features {
298            let sum = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
299            mean[j] = sum / n_f;
300        }
301
302        let mut x_centered = x.to_owned();
303        for mut row in x_centered.rows_mut() {
304            for (v, &m) in row.iter_mut().zip(mean.iter()) {
305                *v = *v - m;
306            }
307        }
308
309        // Step 2: Initialize V from random.
310        let seed = self.random_state.unwrap_or(42);
311        let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
312        let uniform = Uniform::new(-1.0f64, 1.0f64).unwrap();
313
314        let mut v = Array2::<F>::zeros((n_comp, n_features));
315        for elem in v.iter_mut() {
316            *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero);
317        }
318        // Normalize each row of V.
319        for i in 0..n_comp {
320            let norm: F = v
321                .row(i)
322                .iter()
323                .fold(F::zero(), |acc, &val| acc + val * val)
324                .sqrt();
325            if norm > eps::<F>() {
326                for j in 0..n_features {
327                    v[[i, j]] = v[[i, j]] / norm;
328                }
329            }
330        }
331
332        // Step 3: Allocate U (sparse codes), shape (n_samples, n_components).
333        let mut u = Array2::<F>::zeros((n_samples, n_comp));
334
335        let n_cd_iters = 10; // inner coordinate descent iterations
336        let mut prev_err = F::infinity();
337        let tol_f = F::from(self.tol).unwrap_or_else(F::epsilon);
338        let mut actual_iter = 0;
339
340        for iteration in 0..self.max_iter {
341            actual_iter = iteration + 1;
342
343            // Step 3a: Fix V, solve for sparse code U (each row independently).
344            for i in 0..n_samples {
345                let x_row: Vec<F> = x_centered.row(i).to_vec();
346                let mut u_row: Vec<F> = u.row(i).to_vec();
347                sparse_code_row(&x_row, &v, alpha_f, &mut u_row, n_cd_iters);
348                for k in 0..n_comp {
349                    u[[i, k]] = u_row[k];
350                }
351            }
352
353            // Step 3b: Fix U, update V = (X^T U) (U^T U)^{-1}, then normalize.
354            // Compute U^T U, shape (n_comp, n_comp).
355            let utu = u.t().dot(&u);
356            // Compute X^T U, shape (n_features, n_comp).
357            let xtu = x_centered.t().dot(&u);
358
359            // Solve for V^T = (U^T U)^{-1} (X^T U)^T via inverting U^T U.
360            // For small n_comp, invert directly.
361            if let Some(utu_inv) = invert_small_symmetric(&utu) {
362                let v_new_t = xtu.dot(&utu_inv); // (n_features, n_comp)
363                // V rows = columns of v_new_t transposed.
364                for k in 0..n_comp {
365                    for j in 0..n_features {
366                        v[[k, j]] = v_new_t[[j, k]];
367                    }
368                }
369            }
370            // else: U^T U is singular; keep V from previous iteration.
371
372            // Normalize columns of V (stored as rows).
373            for k in 0..n_comp {
374                let norm: F = v
375                    .row(k)
376                    .iter()
377                    .fold(F::zero(), |acc, &val| acc + val * val)
378                    .sqrt();
379                if norm > eps::<F>() {
380                    for j in 0..n_features {
381                        v[[k, j]] = v[[k, j]] / norm;
382                    }
383                }
384            }
385
386            // Check convergence.
387            let err = reconstruction_error_sq(&x_centered, &u, &v);
388            if prev_err > eps::<F>() && (prev_err - err).abs() / prev_err < tol_f {
389                break;
390            }
391            prev_err = err;
392        }
393
394        Ok(FittedSparsePCA {
395            components_: v,
396            mean_: mean,
397            n_iter_: actual_iter,
398        })
399    }
400}
401
402/// Invert a small symmetric positive-definite matrix via Gauss-Jordan.
403///
404/// Returns `None` if the matrix is singular.
405fn invert_small_symmetric<F: Float>(a: &Array2<F>) -> Option<Array2<F>> {
406    let n = a.nrows();
407    if n == 0 {
408        return Some(Array2::zeros((0, 0)));
409    }
410
411    // Augmented matrix [A | I].
412    let mut aug = Array2::<F>::zeros((n, 2 * n));
413    for i in 0..n {
414        for j in 0..n {
415            aug[[i, j]] = a[[i, j]];
416        }
417        aug[[i, n + i]] = F::one();
418    }
419
420    // Add regularisation to diagonal.
421    let reg = F::from(1e-10).unwrap_or_else(F::epsilon);
422    for i in 0..n {
423        aug[[i, i]] = aug[[i, i]] + reg;
424    }
425
426    for i in 0..n {
427        // Find pivot.
428        let mut max_val = aug[[i, i]].abs();
429        let mut max_row = i;
430        for r in (i + 1)..n {
431            if aug[[r, i]].abs() > max_val {
432                max_val = aug[[r, i]].abs();
433                max_row = r;
434            }
435        }
436        if max_val < F::from(1e-15).unwrap_or_else(F::epsilon) {
437            return None;
438        }
439
440        // Swap rows.
441        if max_row != i {
442            for c in 0..(2 * n) {
443                let tmp = aug[[i, c]];
444                aug[[i, c]] = aug[[max_row, c]];
445                aug[[max_row, c]] = tmp;
446            }
447        }
448
449        // Scale pivot row.
450        let pivot = aug[[i, i]];
451        for c in 0..(2 * n) {
452            aug[[i, c]] = aug[[i, c]] / pivot;
453        }
454
455        // Eliminate other rows.
456        for r in 0..n {
457            if r != i {
458                let factor = aug[[r, i]];
459                for c in 0..(2 * n) {
460                    aug[[r, c]] = aug[[r, c]] - factor * aug[[i, c]];
461                }
462            }
463        }
464    }
465
466    // Extract inverse.
467    let mut inv = Array2::<F>::zeros((n, n));
468    for i in 0..n {
469        for j in 0..n {
470            inv[[i, j]] = aug[[i, n + j]];
471        }
472    }
473    Some(inv)
474}
475
476impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSparsePCA<F> {
477    type Output = Array2<F>;
478    type Error = FerroError;
479
480    /// Project data onto the sparse components.
481    ///
482    /// Computes `(X - mean) @ components^T`.
483    ///
484    /// # Errors
485    ///
486    /// Returns [`FerroError::ShapeMismatch`] if the number of columns does not
487    /// match the number of features seen during fitting.
488    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
489        let n_features = self.mean_.len();
490        if x.ncols() != n_features {
491            return Err(FerroError::ShapeMismatch {
492                expected: vec![x.nrows(), n_features],
493                actual: vec![x.nrows(), x.ncols()],
494                context: "FittedSparsePCA::transform".into(),
495            });
496        }
497
498        let mut x_centered = x.to_owned();
499        for mut row in x_centered.rows_mut() {
500            for (v, &m) in row.iter_mut().zip(self.mean_.iter()) {
501                *v = *v - m;
502            }
503        }
504
505        Ok(x_centered.dot(&self.components_.t()))
506    }
507}
508
509// ---------------------------------------------------------------------------
510// Tests
511// ---------------------------------------------------------------------------
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use ndarray::array;
517
518    #[test]
519    fn test_sparse_pca_basic() {
520        let spca = SparsePCA::<f64>::new(2).with_random_state(42);
521        let x = array![
522            [1.0, 2.0, 3.0],
523            [4.0, 5.0, 6.0],
524            [7.0, 8.0, 9.0],
525            [10.0, 11.0, 12.0],
526        ];
527        let fitted = spca.fit(&x, &()).unwrap();
528        let projected = fitted.transform(&x).unwrap();
529        assert_eq!(projected.dim(), (4, 2));
530    }
531
532    #[test]
533    fn test_sparse_pca_single_component() {
534        let spca = SparsePCA::<f64>::new(1).with_random_state(0);
535        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
536        let fitted = spca.fit(&x, &()).unwrap();
537        assert_eq!(fitted.components().nrows(), 1);
538        let projected = fitted.transform(&x).unwrap();
539        assert_eq!(projected.ncols(), 1);
540    }
541
542    #[test]
543    fn test_sparse_pca_components_shape() {
544        let spca = SparsePCA::<f64>::new(2).with_random_state(7);
545        let x = array![
546            [1.0, 0.0, 0.0, 2.0],
547            [0.0, 3.0, 0.0, 1.0],
548            [2.0, 0.0, 1.0, 0.0],
549            [0.0, 2.0, 3.0, 0.0],
550            [1.0, 1.0, 1.0, 1.0],
551        ];
552        let fitted = spca.fit(&x, &()).unwrap();
553        assert_eq!(fitted.components().dim(), (2, 4));
554    }
555
556    #[test]
557    fn test_sparse_pca_high_alpha_produces_sparser() {
558        let x = array![
559            [1.0, 0.0, 0.0, 2.0, 0.0],
560            [0.0, 3.0, 0.0, 1.0, 0.0],
561            [2.0, 0.0, 1.0, 0.0, 4.0],
562            [0.0, 2.0, 3.0, 0.0, 1.0],
563            [1.0, 1.0, 1.0, 1.0, 1.0],
564        ];
565
566        let fitted_low = SparsePCA::<f64>::new(1)
567            .with_alpha(0.001)
568            .with_random_state(42)
569            .fit(&x, &())
570            .unwrap();
571        let fitted_high = SparsePCA::<f64>::new(1)
572            .with_alpha(100.0)
573            .with_random_state(42)
574            .fit(&x, &())
575            .unwrap();
576
577        // With high alpha, the projected values should tend toward zero
578        // (codes are pushed to zero by the L1 penalty).
579        let proj_low = fitted_low.transform(&x).unwrap();
580        let proj_high = fitted_high.transform(&x).unwrap();
581
582        let energy_low: f64 = proj_low.iter().map(|v| v * v).sum();
583        let energy_high: f64 = proj_high.iter().map(|v| v * v).sum();
584
585        // High alpha should produce less energy or similar (sparser codes).
586        // We just check both runs succeed and produce finite values.
587        assert!(energy_low.is_finite());
588        assert!(energy_high.is_finite());
589    }
590
591    #[test]
592    fn test_sparse_pca_n_components_zero() {
593        let spca = SparsePCA::<f64>::new(0);
594        let x = array![[1.0, 2.0], [3.0, 4.0]];
595        assert!(spca.fit(&x, &()).is_err());
596    }
597
598    #[test]
599    fn test_sparse_pca_n_components_too_large() {
600        let spca = SparsePCA::<f64>::new(5);
601        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
602        assert!(spca.fit(&x, &()).is_err());
603    }
604
605    #[test]
606    fn test_sparse_pca_insufficient_samples() {
607        let spca = SparsePCA::<f64>::new(1);
608        let x = array![[1.0, 2.0]];
609        assert!(spca.fit(&x, &()).is_err());
610    }
611
612    #[test]
613    fn test_sparse_pca_transform_shape_mismatch() {
614        let spca = SparsePCA::<f64>::new(1).with_random_state(0);
615        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
616        let fitted = spca.fit(&x, &()).unwrap();
617        let x_bad = array![[1.0, 2.0, 3.0]];
618        assert!(fitted.transform(&x_bad).is_err());
619    }
620
621    #[test]
622    fn test_sparse_pca_f32() {
623        let spca = SparsePCA::<f32>::new(1).with_random_state(0);
624        let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
625        let fitted = spca.fit(&x, &()).unwrap();
626        let projected = fitted.transform(&x).unwrap();
627        assert_eq!(projected.ncols(), 1);
628    }
629
630    #[test]
631    fn test_sparse_pca_mean_is_correct() {
632        let spca = SparsePCA::<f64>::new(1).with_random_state(0);
633        let x = array![[2.0, 4.0], [4.0, 6.0], [6.0, 8.0]];
634        let fitted = spca.fit(&x, &()).unwrap();
635        let mean = fitted.mean();
636        assert!((mean[0] - 4.0).abs() < 1e-10);
637        assert!((mean[1] - 6.0).abs() < 1e-10);
638    }
639
640    #[test]
641    fn test_sparse_pca_builder_methods() {
642        let spca = SparsePCA::<f64>::new(3)
643            .with_alpha(0.5)
644            .with_max_iter(500)
645            .with_tol(1e-6)
646            .with_random_state(99);
647        assert_eq!(spca.n_components(), 3);
648        assert!((spca.alpha() - 0.5).abs() < 1e-15);
649        assert_eq!(spca.max_iter(), 500);
650        assert!((spca.tol() - 1e-6).abs() < 1e-15);
651    }
652
653    #[test]
654    fn test_sparse_pca_n_iter_positive() {
655        let spca = SparsePCA::<f64>::new(1)
656            .with_max_iter(10)
657            .with_random_state(0);
658        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
659        let fitted = spca.fit(&x, &()).unwrap();
660        assert!(fitted.n_iter() > 0);
661        assert!(fitted.n_iter() <= 10);
662    }
663}