1use ferrolearn_core::error::FerroError;
39use ferrolearn_core::traits::{Fit, Transform};
40use ndarray::{Array1, Array2};
41use num_traits::Float;
42use rand::SeedableRng;
43use rand_distr::{Distribution, Uniform};
44
45#[derive(Debug, Clone)]
55pub struct SparsePCA<F> {
56 n_components: usize,
58 alpha: f64,
60 max_iter: usize,
62 tol: f64,
64 random_state: Option<u64>,
66 _marker: std::marker::PhantomData<F>,
67}
68
69impl<F: Float + Send + Sync + 'static> SparsePCA<F> {
70 #[must_use]
75 pub fn new(n_components: usize) -> Self {
76 Self {
77 n_components,
78 alpha: 1.0,
79 max_iter: 1000,
80 tol: 1e-8,
81 random_state: None,
82 _marker: std::marker::PhantomData,
83 }
84 }
85
86 #[must_use]
88 pub fn with_alpha(mut self, alpha: f64) -> Self {
89 self.alpha = alpha;
90 self
91 }
92
93 #[must_use]
95 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
96 self.max_iter = max_iter;
97 self
98 }
99
100 #[must_use]
102 pub fn with_tol(mut self, tol: f64) -> Self {
103 self.tol = tol;
104 self
105 }
106
107 #[must_use]
109 pub fn with_random_state(mut self, seed: u64) -> Self {
110 self.random_state = Some(seed);
111 self
112 }
113
114 #[must_use]
116 pub fn n_components(&self) -> usize {
117 self.n_components
118 }
119
120 #[must_use]
122 pub fn alpha(&self) -> f64 {
123 self.alpha
124 }
125
126 #[must_use]
128 pub fn max_iter(&self) -> usize {
129 self.max_iter
130 }
131
132 #[must_use]
134 pub fn tol(&self) -> f64 {
135 self.tol
136 }
137}
138
139#[derive(Debug, Clone)]
148pub struct FittedSparsePCA<F> {
149 components_: Array2<F>,
151 mean_: Array1<F>,
153 n_iter_: usize,
155}
156
157impl<F: Float + Send + Sync + 'static> FittedSparsePCA<F> {
158 #[must_use]
160 pub fn components(&self) -> &Array2<F> {
161 &self.components_
162 }
163
164 #[must_use]
166 pub fn mean(&self) -> &Array1<F> {
167 &self.mean_
168 }
169
170 #[must_use]
172 pub fn n_iter(&self) -> usize {
173 self.n_iter_
174 }
175}
176
177#[inline]
183fn eps<F: Float>() -> F {
184 F::from(1e-12).unwrap_or_else(F::epsilon)
185}
186
187#[inline]
189fn soft_threshold<F: Float>(x: F, threshold: F) -> F {
190 if x > threshold {
191 x - threshold
192 } else if x < -threshold {
193 x + threshold
194 } else {
195 F::zero()
196 }
197}
198
199fn sparse_code_row<F: Float>(
204 x_row: &[F],
205 v: &Array2<F>,
206 alpha_f: F,
207 u_row: &mut [F],
208 n_cd_iters: usize,
209) {
210 let n_components = v.nrows();
211 let n_features = v.ncols();
212
213 for _iter in 0..n_cd_iters {
214 for k in 0..n_components {
215 let mut residual_dot = F::zero();
217 let mut vk_norm_sq = F::zero();
218
219 for j in 0..n_features {
220 let mut r = F::from(x_row[j]).unwrap();
221 for kk in 0..n_components {
222 if kk != k {
223 r = r - u_row[kk] * v[[kk, j]];
224 }
225 }
226 residual_dot = residual_dot + r * v[[k, j]];
227 vk_norm_sq = vk_norm_sq + v[[k, j]] * v[[k, j]];
228 }
229
230 if vk_norm_sq < eps::<F>() {
231 u_row[k] = F::zero();
232 } else {
233 u_row[k] = soft_threshold(residual_dot, alpha_f) / vk_norm_sq;
234 }
235 }
236 }
237}
238
239fn reconstruction_error_sq<F: Float + 'static>(x: &Array2<F>, u: &Array2<F>, v: &Array2<F>) -> F {
241 let uv = u.dot(v);
242 let mut err = F::zero();
243 for (a, b) in x.iter().zip(uv.iter()) {
244 let d = *a - *b;
245 err = err + d * d;
246 }
247 err
248}
249
250impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SparsePCA<F> {
255 type Fitted = FittedSparsePCA<F>;
256 type Error = FerroError;
257
258 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSparsePCA<F>, FerroError> {
266 let (n_samples, n_features) = x.dim();
267
268 if self.n_components == 0 {
269 return Err(FerroError::InvalidParameter {
270 name: "n_components".into(),
271 reason: "must be at least 1".into(),
272 });
273 }
274 if self.n_components > n_features {
275 return Err(FerroError::InvalidParameter {
276 name: "n_components".into(),
277 reason: format!(
278 "n_components ({}) exceeds n_features ({})",
279 self.n_components, n_features
280 ),
281 });
282 }
283 if n_samples < 2 {
284 return Err(FerroError::InsufficientSamples {
285 required: 2,
286 actual: n_samples,
287 context: "SparsePCA::fit requires at least 2 samples".into(),
288 });
289 }
290
291 let n_comp = self.n_components;
292 let n_f = F::from(n_samples).unwrap();
293 let alpha_f = F::from(self.alpha).unwrap_or_else(F::one);
294
295 let mut mean = Array1::<F>::zeros(n_features);
297 for j in 0..n_features {
298 let sum = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
299 mean[j] = sum / n_f;
300 }
301
302 let mut x_centered = x.to_owned();
303 for mut row in x_centered.rows_mut() {
304 for (v, &m) in row.iter_mut().zip(mean.iter()) {
305 *v = *v - m;
306 }
307 }
308
309 let seed = self.random_state.unwrap_or(42);
311 let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
312 let uniform = Uniform::new(-1.0f64, 1.0f64).unwrap();
313
314 let mut v = Array2::<F>::zeros((n_comp, n_features));
315 for elem in v.iter_mut() {
316 *elem = F::from(uniform.sample(&mut rng)).unwrap_or_else(F::zero);
317 }
318 for i in 0..n_comp {
320 let norm: F = v
321 .row(i)
322 .iter()
323 .fold(F::zero(), |acc, &val| acc + val * val)
324 .sqrt();
325 if norm > eps::<F>() {
326 for j in 0..n_features {
327 v[[i, j]] = v[[i, j]] / norm;
328 }
329 }
330 }
331
332 let mut u = Array2::<F>::zeros((n_samples, n_comp));
334
335 let n_cd_iters = 10; let mut prev_err = F::infinity();
337 let tol_f = F::from(self.tol).unwrap_or_else(F::epsilon);
338 let mut actual_iter = 0;
339
340 for iteration in 0..self.max_iter {
341 actual_iter = iteration + 1;
342
343 for i in 0..n_samples {
345 let x_row: Vec<F> = x_centered.row(i).to_vec();
346 let mut u_row: Vec<F> = u.row(i).to_vec();
347 sparse_code_row(&x_row, &v, alpha_f, &mut u_row, n_cd_iters);
348 for k in 0..n_comp {
349 u[[i, k]] = u_row[k];
350 }
351 }
352
353 let utu = u.t().dot(&u);
356 let xtu = x_centered.t().dot(&u);
358
359 if let Some(utu_inv) = invert_small_symmetric(&utu) {
362 let v_new_t = xtu.dot(&utu_inv); for k in 0..n_comp {
365 for j in 0..n_features {
366 v[[k, j]] = v_new_t[[j, k]];
367 }
368 }
369 }
370 for k in 0..n_comp {
374 let norm: F = v
375 .row(k)
376 .iter()
377 .fold(F::zero(), |acc, &val| acc + val * val)
378 .sqrt();
379 if norm > eps::<F>() {
380 for j in 0..n_features {
381 v[[k, j]] = v[[k, j]] / norm;
382 }
383 }
384 }
385
386 let err = reconstruction_error_sq(&x_centered, &u, &v);
388 if prev_err > eps::<F>() && (prev_err - err).abs() / prev_err < tol_f {
389 break;
390 }
391 prev_err = err;
392 }
393
394 Ok(FittedSparsePCA {
395 components_: v,
396 mean_: mean,
397 n_iter_: actual_iter,
398 })
399 }
400}
401
402fn invert_small_symmetric<F: Float>(a: &Array2<F>) -> Option<Array2<F>> {
406 let n = a.nrows();
407 if n == 0 {
408 return Some(Array2::zeros((0, 0)));
409 }
410
411 let mut aug = Array2::<F>::zeros((n, 2 * n));
413 for i in 0..n {
414 for j in 0..n {
415 aug[[i, j]] = a[[i, j]];
416 }
417 aug[[i, n + i]] = F::one();
418 }
419
420 let reg = F::from(1e-10).unwrap_or_else(F::epsilon);
422 for i in 0..n {
423 aug[[i, i]] = aug[[i, i]] + reg;
424 }
425
426 for i in 0..n {
427 let mut max_val = aug[[i, i]].abs();
429 let mut max_row = i;
430 for r in (i + 1)..n {
431 if aug[[r, i]].abs() > max_val {
432 max_val = aug[[r, i]].abs();
433 max_row = r;
434 }
435 }
436 if max_val < F::from(1e-15).unwrap_or_else(F::epsilon) {
437 return None;
438 }
439
440 if max_row != i {
442 for c in 0..(2 * n) {
443 let tmp = aug[[i, c]];
444 aug[[i, c]] = aug[[max_row, c]];
445 aug[[max_row, c]] = tmp;
446 }
447 }
448
449 let pivot = aug[[i, i]];
451 for c in 0..(2 * n) {
452 aug[[i, c]] = aug[[i, c]] / pivot;
453 }
454
455 for r in 0..n {
457 if r != i {
458 let factor = aug[[r, i]];
459 for c in 0..(2 * n) {
460 aug[[r, c]] = aug[[r, c]] - factor * aug[[i, c]];
461 }
462 }
463 }
464 }
465
466 let mut inv = Array2::<F>::zeros((n, n));
468 for i in 0..n {
469 for j in 0..n {
470 inv[[i, j]] = aug[[i, n + j]];
471 }
472 }
473 Some(inv)
474}
475
476impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSparsePCA<F> {
477 type Output = Array2<F>;
478 type Error = FerroError;
479
480 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
489 let n_features = self.mean_.len();
490 if x.ncols() != n_features {
491 return Err(FerroError::ShapeMismatch {
492 expected: vec![x.nrows(), n_features],
493 actual: vec![x.nrows(), x.ncols()],
494 context: "FittedSparsePCA::transform".into(),
495 });
496 }
497
498 let mut x_centered = x.to_owned();
499 for mut row in x_centered.rows_mut() {
500 for (v, &m) in row.iter_mut().zip(self.mean_.iter()) {
501 *v = *v - m;
502 }
503 }
504
505 Ok(x_centered.dot(&self.components_.t()))
506 }
507}
508
509#[cfg(test)]
514mod tests {
515 use super::*;
516 use ndarray::array;
517
518 #[test]
519 fn test_sparse_pca_basic() {
520 let spca = SparsePCA::<f64>::new(2).with_random_state(42);
521 let x = array![
522 [1.0, 2.0, 3.0],
523 [4.0, 5.0, 6.0],
524 [7.0, 8.0, 9.0],
525 [10.0, 11.0, 12.0],
526 ];
527 let fitted = spca.fit(&x, &()).unwrap();
528 let projected = fitted.transform(&x).unwrap();
529 assert_eq!(projected.dim(), (4, 2));
530 }
531
532 #[test]
533 fn test_sparse_pca_single_component() {
534 let spca = SparsePCA::<f64>::new(1).with_random_state(0);
535 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
536 let fitted = spca.fit(&x, &()).unwrap();
537 assert_eq!(fitted.components().nrows(), 1);
538 let projected = fitted.transform(&x).unwrap();
539 assert_eq!(projected.ncols(), 1);
540 }
541
542 #[test]
543 fn test_sparse_pca_components_shape() {
544 let spca = SparsePCA::<f64>::new(2).with_random_state(7);
545 let x = array![
546 [1.0, 0.0, 0.0, 2.0],
547 [0.0, 3.0, 0.0, 1.0],
548 [2.0, 0.0, 1.0, 0.0],
549 [0.0, 2.0, 3.0, 0.0],
550 [1.0, 1.0, 1.0, 1.0],
551 ];
552 let fitted = spca.fit(&x, &()).unwrap();
553 assert_eq!(fitted.components().dim(), (2, 4));
554 }
555
556 #[test]
557 fn test_sparse_pca_high_alpha_produces_sparser() {
558 let x = array![
559 [1.0, 0.0, 0.0, 2.0, 0.0],
560 [0.0, 3.0, 0.0, 1.0, 0.0],
561 [2.0, 0.0, 1.0, 0.0, 4.0],
562 [0.0, 2.0, 3.0, 0.0, 1.0],
563 [1.0, 1.0, 1.0, 1.0, 1.0],
564 ];
565
566 let fitted_low = SparsePCA::<f64>::new(1)
567 .with_alpha(0.001)
568 .with_random_state(42)
569 .fit(&x, &())
570 .unwrap();
571 let fitted_high = SparsePCA::<f64>::new(1)
572 .with_alpha(100.0)
573 .with_random_state(42)
574 .fit(&x, &())
575 .unwrap();
576
577 let proj_low = fitted_low.transform(&x).unwrap();
580 let proj_high = fitted_high.transform(&x).unwrap();
581
582 let energy_low: f64 = proj_low.iter().map(|v| v * v).sum();
583 let energy_high: f64 = proj_high.iter().map(|v| v * v).sum();
584
585 assert!(energy_low.is_finite());
588 assert!(energy_high.is_finite());
589 }
590
591 #[test]
592 fn test_sparse_pca_n_components_zero() {
593 let spca = SparsePCA::<f64>::new(0);
594 let x = array![[1.0, 2.0], [3.0, 4.0]];
595 assert!(spca.fit(&x, &()).is_err());
596 }
597
598 #[test]
599 fn test_sparse_pca_n_components_too_large() {
600 let spca = SparsePCA::<f64>::new(5);
601 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
602 assert!(spca.fit(&x, &()).is_err());
603 }
604
605 #[test]
606 fn test_sparse_pca_insufficient_samples() {
607 let spca = SparsePCA::<f64>::new(1);
608 let x = array![[1.0, 2.0]];
609 assert!(spca.fit(&x, &()).is_err());
610 }
611
612 #[test]
613 fn test_sparse_pca_transform_shape_mismatch() {
614 let spca = SparsePCA::<f64>::new(1).with_random_state(0);
615 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
616 let fitted = spca.fit(&x, &()).unwrap();
617 let x_bad = array![[1.0, 2.0, 3.0]];
618 assert!(fitted.transform(&x_bad).is_err());
619 }
620
621 #[test]
622 fn test_sparse_pca_f32() {
623 let spca = SparsePCA::<f32>::new(1).with_random_state(0);
624 let x: Array2<f32> = array![[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
625 let fitted = spca.fit(&x, &()).unwrap();
626 let projected = fitted.transform(&x).unwrap();
627 assert_eq!(projected.ncols(), 1);
628 }
629
630 #[test]
631 fn test_sparse_pca_mean_is_correct() {
632 let spca = SparsePCA::<f64>::new(1).with_random_state(0);
633 let x = array![[2.0, 4.0], [4.0, 6.0], [6.0, 8.0]];
634 let fitted = spca.fit(&x, &()).unwrap();
635 let mean = fitted.mean();
636 assert!((mean[0] - 4.0).abs() < 1e-10);
637 assert!((mean[1] - 6.0).abs() < 1e-10);
638 }
639
640 #[test]
641 fn test_sparse_pca_builder_methods() {
642 let spca = SparsePCA::<f64>::new(3)
643 .with_alpha(0.5)
644 .with_max_iter(500)
645 .with_tol(1e-6)
646 .with_random_state(99);
647 assert_eq!(spca.n_components(), 3);
648 assert!((spca.alpha() - 0.5).abs() < 1e-15);
649 assert_eq!(spca.max_iter(), 500);
650 assert!((spca.tol() - 1e-6).abs() < 1e-15);
651 }
652
653 #[test]
654 fn test_sparse_pca_n_iter_positive() {
655 let spca = SparsePCA::<f64>::new(1)
656 .with_max_iter(10)
657 .with_random_state(0);
658 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
659 let fitted = spca.fit(&x, &()).unwrap();
660 assert!(fitted.n_iter() > 0);
661 assert!(fitted.n_iter() <= 10);
662 }
663}