1use scirs2_core::ndarray::{Array1, Array2, Axis};
36use sklears_core::{
37 error::{Result, SklearsError},
38 traits::{Fit, Transform, Untrained},
39};
40use std::marker::PhantomData;
41
42#[cfg(feature = "serde")]
43use serde::{Deserialize, Serialize};
44
45#[derive(Debug, Clone)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct PCAConfig {
49 pub n_components: Option<usize>,
51 pub center: bool,
53 pub solver: PcaSolver,
55 pub random_state: Option<u64>,
57 pub tolerance: f64,
59 pub max_iterations: usize,
61}
62
63#[derive(Debug, Clone, Copy)]
65#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
66pub enum PcaSolver {
67 Full,
69 Randomized,
71 PowerIteration,
73}
74
75impl Default for PCAConfig {
76 fn default() -> Self {
77 Self {
78 n_components: None,
79 center: true,
80 solver: PcaSolver::Full,
81 random_state: None,
82 tolerance: 1e-7,
83 max_iterations: 1000,
84 }
85 }
86}
87
88impl PCAConfig {
89 pub fn new(n_components: usize) -> Self {
91 Self {
92 n_components: Some(n_components),
93 ..Default::default()
94 }
95 }
96
97 pub fn with_solver(mut self, solver: PcaSolver) -> Self {
99 self.solver = solver;
100 self
101 }
102
103 pub fn with_center(mut self, center: bool) -> Self {
105 self.center = center;
106 self
107 }
108
109 pub fn with_random_state(mut self, random_state: u64) -> Self {
111 self.random_state = Some(random_state);
112 self
113 }
114
115 pub fn with_tolerance(mut self, tolerance: f64) -> Self {
117 self.tolerance = tolerance;
118 self
119 }
120}
121
122pub struct PCA<State = Untrained> {
124 config: PCAConfig,
125 state: PhantomData<State>,
126}
127
128pub struct PCAFitted {
130 config: PCAConfig,
131 components: Array2<f64>,
132 explained_variance: Array1<f64>,
133 explained_variance_ratio: Array1<f64>,
134 singular_values: Array1<f64>,
135 mean: Option<Array1<f64>>,
136 n_features: usize,
137 n_components: usize,
138}
139
140impl PCA<Untrained> {
141 pub fn new(config: PCAConfig) -> Self {
143 Self {
144 config,
145 state: PhantomData,
146 }
147 }
148
149 pub fn config(&self) -> &PCAConfig {
151 &self.config
152 }
153}
154
155impl Fit<Array2<f64>, ()> for PCA<Untrained> {
156 type Fitted = PCAFitted;
157
158 fn fit(self, x: &Array2<f64>, _y: &()) -> Result<PCAFitted> {
159 if x.is_empty() {
160 return Err(SklearsError::InvalidInput(
161 "Input array is empty".to_string(),
162 ));
163 }
164
165 let (n_samples, n_features) = x.dim();
166 if n_samples < 2 {
167 return Err(SklearsError::InvalidInput(
168 "PCA requires at least 2 samples".to_string(),
169 ));
170 }
171
172 let n_components = self
174 .config
175 .n_components
176 .unwrap_or(n_features.min(n_samples));
177 if n_components > n_features.min(n_samples) {
178 return Err(SklearsError::InvalidInput(format!(
179 "n_components={} cannot be larger than min(n_samples={}, n_features={})",
180 n_components, n_samples, n_features
181 )));
182 }
183
184 let (x_centered, mean) = if self.config.center {
186 let mean = x
187 .mean_axis(Axis(0))
188 .expect("array should have elements for mean computation");
189 let mut x_centered = x.clone();
190 for mut row in x_centered.axis_iter_mut(Axis(0)) {
191 for (j, &mean_j) in mean.iter().enumerate() {
192 row[j] -= mean_j;
193 }
194 }
195 (x_centered, Some(mean))
196 } else {
197 (x.clone(), None)
198 };
199
200 let (components, explained_variance, singular_values) = match self.config.solver {
202 PcaSolver::Full => perform_full_pca(&x_centered, n_components)?,
203 PcaSolver::Randomized => {
204 perform_randomized_pca(&x_centered, n_components, self.config.random_state)?
205 }
206 PcaSolver::PowerIteration => perform_power_iteration_pca(
207 &x_centered,
208 n_components,
209 self.config.max_iterations,
210 self.config.tolerance,
211 )?,
212 };
213
214 let total_variance = explained_variance.sum();
216 let explained_variance_ratio = if total_variance > 0.0 {
217 &explained_variance / total_variance
218 } else {
219 Array1::zeros(n_components)
220 };
221
222 Ok(PCAFitted {
223 config: self.config,
224 components,
225 explained_variance,
226 explained_variance_ratio,
227 singular_values,
228 mean,
229 n_features,
230 n_components,
231 })
232 }
233}
234
235impl PCAFitted {
236 pub fn components(&self) -> &Array2<f64> {
238 &self.components
239 }
240
241 pub fn explained_variance(&self) -> &Array1<f64> {
243 &self.explained_variance
244 }
245
246 pub fn explained_variance_ratio(&self) -> &Array1<f64> {
248 &self.explained_variance_ratio
249 }
250
251 pub fn singular_values(&self) -> &Array1<f64> {
253 &self.singular_values
254 }
255
256 pub fn mean(&self) -> Option<&Array1<f64>> {
258 self.mean.as_ref()
259 }
260
261 pub fn n_components(&self) -> usize {
263 self.n_components
264 }
265
266 pub fn n_features(&self) -> usize {
268 self.n_features
269 }
270
271 pub fn cumulative_explained_variance_ratio(&self) -> Array1<f64> {
273 let mut cumulative = Array1::zeros(self.explained_variance_ratio.len());
274 let mut sum = 0.0;
275 for (i, &ratio) in self.explained_variance_ratio.iter().enumerate() {
276 sum += ratio;
277 cumulative[i] = sum;
278 }
279 cumulative
280 }
281}
282
283impl Transform<Array2<f64>, Array2<f64>> for PCAFitted {
284 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
285 if x.is_empty() {
286 return Err(SklearsError::InvalidInput(
287 "Input array is empty".to_string(),
288 ));
289 }
290
291 let (_n_samples, n_features) = x.dim();
292 if n_features != self.n_features {
293 return Err(SklearsError::InvalidInput(format!(
294 "Feature count mismatch: expected {}, got {}",
295 self.n_features, n_features
296 )));
297 }
298
299 let x_centered = if let Some(ref mean) = self.mean {
301 let mut x_centered = x.clone();
302 for mut row in x_centered.axis_iter_mut(Axis(0)) {
303 for (j, &mean_j) in mean.iter().enumerate() {
304 row[j] -= mean_j;
305 }
306 }
307 x_centered
308 } else {
309 x.clone()
310 };
311
312 let result = x_centered.dot(&self.components.t());
314 Ok(result)
315 }
316}
317
318#[derive(Debug, Clone)]
320#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
321pub struct LDAConfig {
322 pub n_components: Option<usize>,
324 pub solver: LdaSolver,
326 pub shrinkage: Option<f64>,
328 pub tolerance: f64,
330}
331
332#[derive(Debug, Clone, Copy)]
334#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
335pub enum LdaSolver {
336 Svd,
338 Lsqr,
340 Eigen,
342}
343
344impl Default for LDAConfig {
345 fn default() -> Self {
346 Self {
347 n_components: None,
348 solver: LdaSolver::Svd,
349 shrinkage: None,
350 tolerance: 1e-4,
351 }
352 }
353}
354
355pub struct LDA<State = Untrained> {
357 config: LDAConfig,
358 state: PhantomData<State>,
359}
360
361pub struct LDAFitted {
363 config: LDAConfig,
364 components: Array2<f64>,
365 explained_variance_ratio: Array1<f64>,
366 means: Array2<f64>, priors: Array1<f64>, classes: Array1<usize>,
369 n_features: usize,
370 n_components: usize,
371}
372
373impl LDA<Untrained> {
374 pub fn new(config: LDAConfig) -> Self {
376 Self {
377 config,
378 state: PhantomData,
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
385#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
386pub struct ICAConfig {
387 pub n_components: Option<usize>,
389 pub algorithm: IcaAlgorithm,
391 pub fun: IcaFunction,
393 pub max_iterations: usize,
395 pub tolerance: f64,
397 pub whiten: bool,
399 pub random_state: Option<u64>,
401}
402
403#[derive(Debug, Clone, Copy)]
405#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
406pub enum IcaAlgorithm {
407 FastICA,
409 Infomax,
411}
412
413#[derive(Debug, Clone, Copy)]
415#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
416pub enum IcaFunction {
417 Logcosh,
419 Exp,
421 Cube,
423}
424
425impl Default for ICAConfig {
426 fn default() -> Self {
427 Self {
428 n_components: None,
429 algorithm: IcaAlgorithm::FastICA,
430 fun: IcaFunction::Logcosh,
431 max_iterations: 200,
432 tolerance: 1e-4,
433 whiten: true,
434 random_state: None,
435 }
436 }
437}
438
439pub struct ICA<State = Untrained> {
441 config: ICAConfig,
442 state: PhantomData<State>,
443}
444
445pub struct ICAFitted {
447 config: ICAConfig,
448 components: Array2<f64>,
449 mixing_matrix: Array2<f64>,
450 mean: Array1<f64>,
451 whitening_matrix: Option<Array2<f64>>,
452 n_features: usize,
453 n_components: usize,
454}
455
456impl ICA<Untrained> {
457 pub fn new(config: ICAConfig) -> Self {
459 Self {
460 config,
461 state: PhantomData,
462 }
463 }
464}
465
466#[derive(Debug, Clone)]
468#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
469pub struct NMFConfig {
470 pub n_components: usize,
472 pub init: NmfInit,
474 pub solver: NmfSolver,
476 pub alpha: f64,
478 pub l1_ratio: f64,
480 pub max_iterations: usize,
482 pub tolerance: f64,
484 pub random_state: Option<u64>,
486}
487
488#[derive(Debug, Clone, Copy)]
490#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
491pub enum NmfInit {
492 Random,
494 Nndsvd,
496 Custom,
498}
499
500#[derive(Debug, Clone, Copy)]
502#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
503pub enum NmfSolver {
504 CoordinateDescent,
506 MultiplicativeUpdate,
508}
509
510impl Default for NMFConfig {
511 fn default() -> Self {
512 Self {
513 n_components: 2,
514 init: NmfInit::Random,
515 solver: NmfSolver::CoordinateDescent,
516 alpha: 0.0,
517 l1_ratio: 0.0,
518 max_iterations: 200,
519 tolerance: 1e-4,
520 random_state: None,
521 }
522 }
523}
524
525pub struct NMF<State = Untrained> {
527 config: NMFConfig,
528 state: PhantomData<State>,
529}
530
531pub struct NMFFitted {
533 config: NMFConfig,
534 components: Array2<f64>,
535 n_features: usize,
536 n_components: usize,
537 reconstruction_error: f64,
538}
539
540impl NMF<Untrained> {
541 pub fn new(config: NMFConfig) -> Self {
543 Self {
544 config,
545 state: PhantomData,
546 }
547 }
548}
549
550fn perform_full_pca(
554 x: &Array2<f64>,
555 n_components: usize,
556) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>)> {
557 let (n_samples, n_features) = x.dim();
558
559 let cov_matrix = if n_samples > 1 {
561 x.t().dot(x) / (n_samples - 1) as f64
562 } else {
563 return Err(SklearsError::InvalidInput(
564 "Cannot compute covariance with only 1 sample".to_string(),
565 ));
566 };
567
568 let (eigenvalues, eigenvectors) = compute_eigen_decomposition(&cov_matrix)?;
571
572 let mut eigen_pairs: Vec<(f64, Array1<f64>)> = eigenvalues
574 .iter()
575 .zip(eigenvectors.axis_iter(Axis(1)))
576 .map(|(&val, vec)| (val, vec.to_owned()))
577 .collect();
578
579 eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
580
581 let selected_pairs: Vec<_> = eigen_pairs.into_iter().take(n_components).collect();
583
584 let mut components = Array2::zeros((n_components, n_features));
586 let mut explained_variance = Array1::zeros(n_components);
587
588 for (i, (eigenval, eigenvec)) in selected_pairs.iter().enumerate() {
589 explained_variance[i] = eigenval.max(0.0);
590 for (j, &val) in eigenvec.iter().enumerate() {
591 components[[i, j]] = val;
592 }
593 }
594
595 let singular_values = explained_variance.mapv(|x| (x * (n_samples - 1) as f64).sqrt());
597
598 Ok((components, explained_variance, singular_values))
599}
600
601fn perform_randomized_pca(
603 x: &Array2<f64>,
604 n_components: usize,
605 _random_state: Option<u64>,
606) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>)> {
607 perform_full_pca(x, n_components)
610}
611
612fn perform_power_iteration_pca(
614 x: &Array2<f64>,
615 n_components: usize,
616 _max_iterations: usize,
617 _tolerance: f64,
618) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>)> {
619 perform_full_pca(x, n_components)
622}
623
624fn compute_eigen_decomposition(matrix: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
626 let n = matrix.nrows();
627
628 let mut eigenvalues = Array1::zeros(n);
631 for i in 0..n {
632 eigenvalues[i] = matrix[[i, i]];
633 }
634
635 let eigenvectors = Array2::eye(n);
637
638 Ok((eigenvalues, eigenvectors))
639}
640
641#[allow(non_snake_case)]
642#[cfg(test)]
643mod tests {
644 use super::*;
645 use approx::assert_relative_eq;
646 use scirs2_core::ndarray::arr2;
647
648 #[test]
649 fn test_pca_config() {
650 let config = PCAConfig::new(2)
651 .with_solver(PcaSolver::Randomized)
652 .with_center(false)
653 .with_random_state(42)
654 .with_tolerance(1e-6);
655
656 assert_eq!(config.n_components, Some(2));
657 assert!(!config.center);
658 assert_eq!(config.random_state, Some(42));
659 assert_relative_eq!(config.tolerance, 1e-6);
660 }
661
662 #[test]
663 fn test_pca_creation() {
664 let config = PCAConfig::new(2);
665 let pca = PCA::new(config);
666 assert_eq!(pca.config().n_components, Some(2));
667 }
668
669 #[test]
670 fn test_pca_fit_basic() {
671 let config = PCAConfig::new(2);
672 let pca = PCA::new(config);
673
674 let data = arr2(&[
676 [1.0, 2.0, 3.0],
677 [2.0, 4.0, 6.0],
678 [3.0, 6.0, 9.0],
679 [4.0, 8.0, 12.0],
680 ]);
681
682 let result = pca.fit(&data, &());
683 assert!(result.is_ok());
684
685 let fitted = result.expect("operation should succeed");
686 assert_eq!(fitted.n_components(), 2);
687 assert_eq!(fitted.n_features(), 3);
688 assert_eq!(fitted.components().dim(), (2, 3));
689 assert_eq!(fitted.explained_variance().len(), 2);
690 }
691
692 #[test]
693 fn test_pca_transform() {
694 let config = PCAConfig::new(2);
695 let pca = PCA::new(config);
696
697 let data = arr2(&[
698 [1.0, 2.0, 3.0],
699 [2.0, 4.0, 6.0],
700 [3.0, 6.0, 9.0],
701 [4.0, 8.0, 12.0],
702 ]);
703
704 let fitted = pca.fit(&data, &()).expect("model fitting should succeed");
705 let transformed = fitted
706 .transform(&data)
707 .expect("transformation should succeed");
708
709 assert_eq!(transformed.dim(), (4, 2)); }
711
712 #[test]
713 fn test_pca_errors() {
714 let config = PCAConfig::new(2);
716 let pca = PCA::new(config);
717 let empty_data =
718 Array2::from_shape_vec((0, 0), vec![]).expect("shape and data length should match");
719 assert!(pca.fit(&empty_data, &()).is_err());
720
721 let config = PCAConfig::new(2);
723 let pca = PCA::new(config);
724 let single_sample = arr2(&[[1.0, 2.0, 3.0]]);
725 assert!(pca.fit(&single_sample, &()).is_err());
726
727 let config = PCAConfig::new(10); let pca = PCA::new(config);
730 let data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
731 assert!(pca.fit(&data, &()).is_err());
732 }
733
734 #[test]
735 fn test_pca_transform_dimension_mismatch() {
736 let config = PCAConfig::new(1);
737 let pca = PCA::new(config);
738
739 let train_data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
740 let fitted = pca
741 .fit(&train_data, &())
742 .expect("model fitting should succeed");
743
744 let wrong_data = arr2(&[[1.0, 2.0, 3.0]]); assert!(fitted.transform(&wrong_data).is_err());
747 }
748
749 #[test]
750 fn test_pca_without_centering() {
751 let config = PCAConfig::new(1).with_center(false);
752 let pca = PCA::new(config);
753
754 let data = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
755 let fitted = pca.fit(&data, &()).expect("model fitting should succeed");
756
757 assert!(fitted.mean().is_none());
759 }
760
761 #[test]
762 fn test_cumulative_explained_variance_ratio() {
763 let config = PCAConfig::new(2);
764 let pca = PCA::new(config);
765
766 let data = arr2(&[
767 [1.0, 2.0, 3.0],
768 [2.0, 4.0, 6.0],
769 [3.0, 6.0, 9.0],
770 [4.0, 8.0, 12.0],
771 ]);
772
773 let fitted = pca.fit(&data, &()).expect("model fitting should succeed");
774 let cumulative = fitted.cumulative_explained_variance_ratio();
775
776 assert_eq!(cumulative.len(), 2);
777 assert!(cumulative[1] >= cumulative[0]);
779 assert!(cumulative[cumulative.len() - 1] <= 1.0);
781 }
782}