1use std::cmp;
2
3use itertools::izip;
4use lair::{decomposition::lu, Scalar};
5use ndarray::{s, Array1, Array2, ArrayBase, AssignElem, Axis, Data, Ix2, ScalarOperand};
6use num_traits::{real::Real, FromPrimitive};
7use rand::{Rng, RngCore, SeedableRng};
8use rand_distr::StandardNormal;
9#[cfg(target_pointer_width = "32")]
10use rand_pcg::Lcg64Xsh32 as Pcg;
11#[cfg(not(target_pointer_width = "32"))]
12use rand_pcg::Mcg128Xsl64 as Pcg;
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use crate::linalg::{self, qr, svd, svddc, Lapack, LayoutError};
17use crate::DecompositionError;
18
19#[cfg_attr(
37 feature = "serde",
38 derive(Serialize, Deserialize),
39 serde(bound = "A: Serialize, for<'a> A: Deserialize<'a>")
40)]
41pub struct Pca<A>
42where
43 A: Scalar,
44{
45 components: Array2<A>,
46 n_samples: usize,
47 means: Array1<A>,
48 total_variance: A::Real,
49 singular: Array1<A::Real>,
50 centering: bool,
51}
52
53impl<A> Pca<A>
54where
55 A: Scalar,
56{
57 #[must_use]
59 pub fn new(n_components: usize) -> Self {
60 Self {
61 components: Array2::<A>::zeros((n_components, 0)),
62 n_samples: 0,
63 means: Array1::<A>::zeros(0),
64 total_variance: A::zero().re(),
65 singular: Array1::<A::Real>::zeros(0),
66 centering: true,
67 }
68 }
69}
70
71impl<A> Pca<A>
72where
73 A: FromPrimitive + Lapack,
74 A::Real: ScalarOperand,
75{
76 #[inline]
78 pub fn components(&self) -> &Array2<A> {
79 &self.components
80 }
81
82 #[inline]
84 pub fn mean(&self) -> &Array1<A> {
85 &self.means
86 }
87
88 #[inline]
90 pub fn n_components(&self) -> usize {
91 self.components.nrows()
92 }
93
94 #[inline]
96 pub fn singular_values(&self) -> &Array1<A::Real> {
97 &self.singular
98 }
99
100 pub fn explained_variance_ratio(&self) -> Array1<A::Real> {
102 let mut variance: Array1<A::Real> = &self.singular * &self.singular;
103 variance /= self.total_variance;
104 variance
105 }
106
107 pub fn fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<(), DecompositionError>
117 where
118 S: Data<Elem = A>,
119 {
120 self.inner_fit(input)?;
121 Ok(())
122 }
123
124 pub fn transform<S>(&self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
131 where
132 S: Data<Elem = A>,
133 {
134 transform(input, &self.components, &self.means, self.centering)
135 }
136
137 pub fn fit_transform<S>(
154 &mut self,
155 input: &ArrayBase<S, Ix2>,
156 ) -> Result<Array2<A>, DecompositionError>
157 where
158 S: Data<Elem = A>,
159 {
160 let u = self.inner_fit(input)?;
161 Ok(transform_with_u(
162 &u,
163 input,
164 self.singular_values(),
165 self.n_components(),
166 ))
167 }
168
169 pub fn inverse_transform<S>(
177 &self,
178 input: &ArrayBase<S, Ix2>,
179 ) -> Result<Array2<A>, DecompositionError>
180 where
181 S: Data<Elem = A>,
182 {
183 inverse_transform(input, &self.components, &self.means, self.centering)
184 }
185
186 fn inner_fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
196 where
197 S: Data<Elem = A>,
198 {
199 if input.shape().iter().any(|v| *v < self.n_components()) {
200 return Err(DecompositionError::InvalidInput(format!(
201 "every dimension should be at least {}",
202 self.n_components()
203 )));
204 }
205
206 let means = if self.centering {
207 if let Some(means) = input.mean_axis(Axis(0)) {
208 means
209 } else {
210 return Ok(Array2::<A>::zeros((0, input.ncols())));
211 }
212 } else {
213 Array1::zeros(input.ncols())
214 };
215
216 let (mut u, sigma, vt) = if self.centering {
217 svd(&mut (input - &means), true)?
218 } else {
219 svd(&mut input.to_owned(), true)?
220 };
221
222 let mut vt = vt.expect("`svd` should return `vt`");
223 svd_flip(&mut u, &mut vt);
224 self.total_variance = sigma.dot(&sigma);
225 self.components = vt.slice(s![0..self.n_components(), ..]).into_owned();
226 self.n_samples = input.nrows();
227 self.means = means;
228 self.singular = sigma.slice(s![0..self.n_components()]).into_owned();
229
230 Ok(u)
231 }
232}
233
234#[allow(clippy::module_name_repetitions)]
246pub struct PcaBuilder {
247 n_components: usize,
248 centering: bool,
249}
250
251impl PcaBuilder {
252 #[must_use]
254 pub fn new(n_components: usize) -> Self {
255 Self {
256 n_components,
257 centering: true,
258 }
259 }
260
261 #[must_use]
266 pub fn centering(mut self, centering: bool) -> Self {
267 self.centering = centering;
268 self
269 }
270
271 #[must_use]
273 pub fn build<A: Scalar>(self) -> Pca<A> {
274 Pca {
275 components: Array2::<A>::zeros((self.n_components, 0)),
276 n_samples: 0,
277 means: Array1::<A>::zeros(0),
278 total_variance: A::zero().re(),
279 singular: Array1::<A::Real>::zeros(0),
280 centering: self.centering,
281 }
282 }
283}
284
285#[cfg_attr(
310 feature = "serde",
311 derive(Serialize, Deserialize),
312 serde(
313 bound = "A: Serialize, for<'a> A: Deserialize<'a>, R: Serialize, for<'a> R: Deserialize<'a>"
314 )
315)]
316#[allow(clippy::module_name_repetitions)]
317pub struct RandomizedPca<A, R>
318where
319 A: Scalar,
320 R: Rng,
321{
322 rng: R,
323 components: Array2<A>,
324 n_samples: usize,
325 means: Array1<A>,
326 total_variance: A::Real,
327 singular: Array1<A::Real>,
328 centering: bool,
329}
330
331impl<A> RandomizedPca<A, Pcg>
332where
333 A: Scalar,
334{
335 #[must_use]
342 pub fn new(n_components: usize) -> Self {
343 let seed: u128 = rand::rng().random();
344 Self::with_seed(n_components, seed)
345 }
346
347 #[must_use]
356 pub fn with_seed(n_components: usize, seed: u128) -> Self {
357 let rng = Pcg::from_seed(seed.to_be_bytes());
358 Self::with_rng(n_components, rng)
359 }
360}
361
362impl<A, R> RandomizedPca<A, R>
363where
364 A: Scalar,
365 R: Rng,
366{
367 #[must_use]
371 pub fn with_rng(n_components: usize, rng: R) -> Self {
372 Self {
373 rng,
374 components: Array2::<A>::zeros((n_components, 0)),
375 n_samples: 0,
376 means: Array1::<A>::zeros(0),
377 total_variance: A::zero().re(),
378 singular: Array1::<A::Real>::zeros(0),
379 centering: true,
380 }
381 }
382}
383
384impl<A, R> RandomizedPca<A, R>
385where
386 A: Scalar + FromPrimitive + Lapack,
387 A::Real: ScalarOperand + FromPrimitive,
388 R: Rng,
389{
390 #[inline]
392 pub fn components(&self) -> &Array2<A> {
393 &self.components
394 }
395
396 #[inline]
398 pub fn mean(&self) -> &Array1<A> {
399 &self.means
400 }
401
402 #[inline]
404 pub fn n_components(&self) -> usize {
405 self.components.nrows()
406 }
407
408 #[inline]
410 pub fn singular_values(&self) -> &Array1<A::Real> {
411 &self.singular
412 }
413
414 pub fn explained_variance_ratio(&self) -> Array1<A::Real> {
416 let mut variance: Array1<A::Real> = &self.singular * &self.singular;
417 variance /= self.total_variance;
418 variance
419 }
420
421 pub fn fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<(), DecompositionError>
431 where
432 S: Data<Elem = A>,
433 {
434 self.inner_fit(input)?;
435 Ok(())
436 }
437
438 pub fn transform<S>(&self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
445 where
446 S: Data<Elem = A>,
447 {
448 transform(input, &self.components, &self.means, self.centering)
449 }
450
451 pub fn fit_transform<S>(
468 &mut self,
469 input: &ArrayBase<S, Ix2>,
470 ) -> Result<Array2<A>, DecompositionError>
471 where
472 S: Data<Elem = A>,
473 {
474 let u = self.inner_fit(input)?;
475 Ok(transform_with_u(
476 &u,
477 input,
478 self.singular_values(),
479 self.n_components(),
480 ))
481 }
482
483 pub fn inverse_transform<S>(
491 &self,
492 input: &ArrayBase<S, Ix2>,
493 ) -> Result<Array2<A>, DecompositionError>
494 where
495 S: Data<Elem = A>,
496 {
497 inverse_transform(input, &self.components, &self.means, self.centering)
498 }
499
500 fn inner_fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
510 where
511 S: Data<Elem = A>,
512 {
513 if input.shape().iter().any(|v| *v < self.n_components()) {
514 return Err(DecompositionError::InvalidInput(format!(
515 "every dimension should be at least {}",
516 self.n_components()
517 )));
518 }
519
520 let means = if self.centering {
521 if let Some(means) = input.mean_axis(Axis(0)) {
522 means
523 } else {
524 return Ok(Array2::<A>::zeros((0, input.ncols())));
525 }
526 } else {
527 Array1::zeros(input.ncols())
528 };
529
530 let (u, sigma, vt, total_variance) = if self.centering {
531 let x = input - &means;
532 let (u, sigma, vt) = randomized_svd(&x, self.n_components(), &mut self.rng)?;
533 let total_variance = x.iter().fold(A::zero().re(), |var, &e| var + e.square());
534 (u, sigma, vt, total_variance)
535 } else {
536 let (u, sigma, vt) = randomized_svd(input, self.n_components(), &mut self.rng)?;
537 let total_variance = input
538 .iter()
539 .fold(A::zero().re(), |var, &e| var + e.square());
540 (u, sigma, vt, total_variance)
541 };
542
543 self.total_variance = total_variance;
544 self.components = vt.slice(s![0..self.n_components(), ..]).into_owned();
545 self.n_samples = input.nrows();
546 self.means = means;
547 self.singular = sigma.slice(s![0..self.n_components()]).into_owned();
548
549 Ok(u)
550 }
551}
552
553pub struct RandomizedPcaBuilder<R> {
565 n_components: usize,
566 rng: R,
567 centering: bool,
568}
569
570impl RandomizedPcaBuilder<Pcg> {
571 #[must_use]
578 pub fn new(n_components: usize) -> Self {
579 let seed: u128 = rand::rng().random();
580 Self {
581 n_components,
582 rng: Pcg::from_seed(seed.to_be_bytes()),
583 centering: true,
584 }
585 }
586
587 #[must_use]
599 pub fn seed(mut self, seed: u128) -> Self {
600 self.rng = Pcg::from_seed(seed.to_be_bytes());
601 self
602 }
603
604 #[must_use]
619 pub fn centering(mut self, centering: bool) -> Self {
620 self.centering = centering;
621 self
622 }
623}
624
625impl<R: Rng> RandomizedPcaBuilder<R> {
626 #[must_use]
643 pub fn with_rng(rng: R, n_components: usize) -> Self {
644 Self {
645 n_components,
646 rng,
647 centering: true,
648 }
649 }
650
651 pub fn build<A: Scalar>(self) -> RandomizedPca<A, R> {
653 RandomizedPca {
654 rng: self.rng,
655 components: Array2::<A>::zeros((self.n_components, 0)),
656 n_samples: 0,
657 means: Array1::<A>::zeros(0),
658 total_variance: A::zero().re(),
659 singular: Array1::<A::Real>::zeros(0),
660 centering: self.centering,
661 }
662 }
663}
664
665type Svd<A> = (Array2<A>, Array1<<A as Scalar>::Real>, Array2<A>);
666
667fn randomized_svd<A, S, R>(
669 input: &ArrayBase<S, Ix2>,
670 n_components: usize,
671 rng: &mut R,
672) -> Result<Svd<A>, linalg::Error>
673where
674 A: Scalar + Lapack,
675 A::Real: FromPrimitive,
676 S: Data<Elem = A>,
677 R: RngCore,
678{
679 let n_random = n_components + 10; let q = randomized_range_finder(input, n_random, 7, rng)?;
681 let mut b = q.t().dot(input);
682 let (u, sigma, mut vt) = svddc(&mut b)?;
683 let mut u = q.dot(&u);
684 svd_flip(&mut u, &mut vt);
685 Ok((u, sigma, vt))
686}
687
688fn randomized_range_finder<A, S, R>(
690 input: &ArrayBase<S, Ix2>,
691 size: usize,
692 n_iter: usize,
693 rng: &mut R,
694) -> Result<Array2<A>, LayoutError>
695where
696 A: Scalar + Lapack,
697 A::Real: FromPrimitive,
698 S: Data<Elem = A>,
699 R: RngCore,
700{
701 let mut q = ArrayBase::from_shape_fn((input.ncols(), size), |_| {
702 let r = A::Real::from_f64(rng.sample(StandardNormal))
703 .expect("float to float conversion never fails");
704 r.into()
705 });
706 let mut pl = q.view();
707 q = input.dot(&pl);
708 for _ in 0..n_iter {
709 q = lu::Factorized::from(q).into_pl();
710 pl = q.slice(s![.., 0..cmp::min(q.nrows(), q.ncols())]);
711 q = input.t().dot(&pl);
712 q = lu::Factorized::from(q).into_pl();
713 pl = q.slice(s![.., 0..cmp::min(q.nrows(), q.ncols())]);
714 q = input.dot(&pl);
715 }
716 let q = qr(q)?;
717 Ok(q)
718}
719
720fn transform<A, S>(
727 input: &ArrayBase<S, Ix2>,
728 components: &Array2<A>,
729 means: &Array1<A>,
730 centering: bool,
731) -> Result<Array2<A>, DecompositionError>
732where
733 A: Scalar,
734 S: Data<Elem = A>,
735{
736 if input.ncols() != means.len() {
737 return Err(DecompositionError::InvalidInput(format!(
738 "# of columns should be {}",
739 means.len()
740 )));
741 }
742
743 let transformed = if centering {
744 let x = input - means;
745 x.dot(&components.t())
746 } else {
747 input.dot(&components.t())
748 };
749 Ok(transformed)
750}
751
752fn transform_with_u<A, S>(
759 u: &Array2<A>,
760 input: &ArrayBase<S, Ix2>,
761 singular: &Array1<A::Real>,
762 n_components: usize,
763) -> Array2<A>
764where
765 A: Scalar,
766 S: Data<Elem = A>,
767{
768 let mut y = Array2::<A>::uninit((input.nrows(), n_components));
769 for (y_row, u_row) in y
770 .lanes_mut(Axis(1))
771 .into_iter()
772 .zip(u.slice(s![.., 0..n_components]).lanes(Axis(1)))
773 {
774 for (y_v, u_v, sigma_v) in izip!(y_row.into_iter(), u_row, singular) {
775 y_v.assign_elem(*u_v * (*sigma_v).into());
776 }
777 }
778 unsafe { y.assume_init() }
779}
780
781fn inverse_transform<A, S>(
789 input: &ArrayBase<S, Ix2>,
790 components: &Array2<A>,
791 means: &Array1<A>,
792 centering: bool,
793) -> Result<Array2<A>, DecompositionError>
794where
795 A: Scalar,
796 S: Data<Elem = A>,
797{
798 if input.ncols() != components.nrows() {
799 return Err(DecompositionError::InvalidInput(format!(
800 "# of columns should be {}",
801 components.nrows()
802 )));
803 }
804
805 let inverse_transformed = if centering {
806 input.dot(components) + means
807 } else {
808 input.dot(components)
809 };
810 Ok(inverse_transformed)
811}
812
813fn svd_flip<A>(u: &mut Array2<A>, v: &mut Array2<A>)
816where
817 A: Scalar,
818{
819 for (u_col, v_row) in u.lanes_mut(Axis(0)).into_iter().zip(v.lanes_mut(Axis(1))) {
820 let mut u_col_iter = u_col.iter();
821 let e = if let Some(e) = u_col_iter.next() {
822 *e
823 } else {
824 continue;
825 };
826 let mut absmax = e.abs();
827 let mut signum = e.re().signum();
828 for e in u_col_iter {
829 let abs = e.abs();
830 if abs <= absmax {
831 continue;
832 }
833 absmax = abs;
834 signum = if e.re() == A::zero().re() {
835 e.im().signum()
836 } else {
837 e.re().signum()
838 };
839 }
840 if signum < A::zero().re() {
841 let signum = signum.into();
842 for e in u_col {
843 *e *= signum;
844 }
845 for e in v_row {
846 *e *= signum;
847 }
848 }
849 }
850}
851
852#[cfg(test)]
853mod test {
854 use approx::{assert_abs_diff_eq, assert_relative_eq};
855 use ndarray::{arr2, Array2};
856 use rand::Rng;
857 use rand_distr::StandardNormal;
858 use rand_pcg::Pcg64Mcg;
859
860 const RNG_SEED: u128 = 1_234_567_891_011_121_314;
861
862 #[test]
863 fn pca_zero_component() {
864 let mut pca = super::PcaBuilder::new(0).build();
865
866 let x = Array2::<f32>::zeros((0, 5));
867 let y = pca.fit_transform(&x).unwrap();
868 assert_eq!(y.nrows(), 0);
869 assert_eq!(y.ncols(), 0);
870
871 let x = arr2(&[[0_f32, 0_f32], [3_f32, 4_f32], [6_f32, 8_f32]]);
872 let y = pca.fit_transform(&x).unwrap();
873 assert_eq!(y.nrows(), 3);
874 assert_eq!(y.ncols(), 0);
875 }
876
877 #[test]
878 fn pca_single_sample() {
879 let mut pca = super::Pca::new(1);
880 let x = arr2(&[[1_f32, 1_f32]]);
881 let y = pca.fit_transform(&x).unwrap();
882 assert_eq!(y, arr2(&[[0.0]]));
883 }
884
885 #[test]
886 fn pca() {
887 let x = arr2(&[[0_f64, 0_f64], [3_f64, 4_f64], [6_f64, 8_f64]]);
888 let mut pca = super::Pca::new(1);
889 assert_eq!(pca.n_components(), 1);
890
891 let y = pca.fit_transform(&x).unwrap();
892 assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
893 assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
894 assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
895 let z = pca.inverse_transform(&y).expect("valid input");
896 assert!(z.abs_diff_eq(&x, 1e-10));
897
898 let mut pca = super::Pca::new(1);
899 assert!(pca.fit(&x).is_ok());
900 assert_eq!(pca.n_components(), 1);
901 assert!(pca.components().abs_diff_eq(&arr2(&[[-0.6, -0.8]]), 1e-10));
902 let y = pca.transform(&x).unwrap();
903 assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
904 assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
905 assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
906 }
907
908 #[test]
909 fn pca_without_centering() {
910 let x = arr2(&[[0_f64, 0_f64], [3_f64, 4_f64], [6_f64, 8_f64]]);
911 let mut pca = super::PcaBuilder::new(1).centering(false).build();
912 let y = pca.fit_transform(&x).unwrap();
913 assert_abs_diff_eq!(y[(0, 0)].abs(), 0., epsilon = 1e-10);
914 assert_abs_diff_eq!(y[(1, 0)], 5., epsilon = 1e-10);
915 assert_abs_diff_eq!(y[(2, 0)].abs(), 10., epsilon = 1e-10);
916 }
917
918 #[test]
919 fn pca_explained_variance_ratio() {
920 let x = arr2(&[
921 [-1_f64, -1_f64],
922 [-2_f64, -1_f64],
923 [-3_f64, -2_f64],
924 [1_f64, 1_f64],
925 [2_f64, 1_f64],
926 [3_f64, 2_f64],
927 ]);
928 let mut pca = super::Pca::new(2);
929 assert!(pca.fit(&x).is_ok());
930 let ratio = pca.explained_variance_ratio();
931 assert!(ratio.get(0).unwrap() > &0.99244);
932 assert!(ratio.get(1).unwrap() < &0.00756);
933 }
934
935 #[test]
936 #[cfg(feature = "serde")]
937 fn pca_serialize() {
938 let mut pca = super::Pca::new(1);
939 let x = arr2(&[[1_f32, 1_f32]]);
940 assert!(pca.fit(&x).is_ok());
941 let serialized = serde_json::to_string(&pca).unwrap();
942 let deserialized: super::Pca<f32> = serde_json::from_str(&serialized).unwrap();
943 assert!(deserialized
944 .components()
945 .abs_diff_eq(pca.components(), 1e-12));
946 assert!(deserialized.mean().abs_diff_eq(pca.mean(), 1e12));
947 }
948
949 #[test]
950 fn randomized_pca() {
951 let x = arr2(&[[0_f64, 0_f64], [3_f64, 4_f64], [6_f64, 8_f64]]);
952 let mut pca = super::RandomizedPca::with_seed(1, RNG_SEED);
953 assert_eq!(pca.n_components(), 1);
954
955 let res = pca.fit(&x);
956 assert!(res.is_ok());
957 assert_eq!(pca.n_components(), 1);
958 let y = pca.transform(&x).unwrap();
959 assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
960 assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
961 assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
962 let z = pca.inverse_transform(&y).expect("valid input");
963 assert!(z.abs_diff_eq(&x, 1e-10));
964
965 let mut pca = super::RandomizedPca::with_rng(1, rand::rng());
966 let y = pca.fit_transform(&x).unwrap();
967 assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
968 assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
969 assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
970 }
971
972 #[test]
973 fn randomized_pca_explained_variance_ratio() {
974 let x = arr2(&[
975 [-1_f64, -1_f64],
976 [-2_f64, -1_f64],
977 [-3_f64, -2_f64],
978 [1_f64, 1_f64],
979 [2_f64, 1_f64],
980 [3_f64, 2_f64],
981 ]);
982 let mut pca = super::RandomizedPca::with_rng(2, rand::rng());
983 assert!(pca.fit(&x).is_ok());
984 let ratio = pca.explained_variance_ratio();
985 assert!(ratio.get(0).unwrap() > &0.99244);
986 assert!(ratio.get(1).unwrap() < &0.00756);
987 }
988
989 #[test]
990 fn randomized_pca_explained_variance_equivalence() {
991 let mut rng = Pcg64Mcg::new(RNG_SEED);
992 let x = Array2::from_shape_fn((100, 80), |_| rng.sample::<f64, _>(StandardNormal));
993
994 let mut pca = super::Pca::new(2);
995 let mut pca_rand = super::RandomizedPca::with_rng(2, rng);
996
997 assert!(pca.fit(&x).is_ok());
998 assert!(pca_rand.fit(&x).is_ok());
999
1000 for (a, b) in pca
1001 .explained_variance_ratio()
1002 .iter()
1003 .zip(pca_rand.explained_variance_ratio().iter())
1004 {
1005 assert_relative_eq!(a, b, max_relative = 0.05);
1006 }
1007 }
1008
1009 #[test]
1010 fn randomized_pca_singular_values_consistency() {
1011 let mut rng = Pcg64Mcg::new(RNG_SEED);
1012 let x = Array2::from_shape_fn((100, 80), |_| rng.sample::<f64, _>(StandardNormal));
1013
1014 let mut pca = super::Pca::new(2);
1015 let mut pca_rand = super::RandomizedPca::with_rng(2, rng);
1016
1017 assert!(pca.fit(&x).is_ok());
1018 assert!(pca_rand.fit(&x).is_ok());
1019
1020 for (a, b) in pca
1021 .singular_values()
1022 .iter()
1023 .zip(pca_rand.singular_values().iter())
1024 {
1025 assert_relative_eq!(a, b, max_relative = 0.05);
1026 }
1027 }
1028
1029 #[test]
1030 #[cfg(feature = "serde")]
1031 fn randomized_pca_serialize() {
1032 let mut pca = super::RandomizedPca::with_seed(1, RNG_SEED);
1033 let x = arr2(&[[1_f32, 1_f32]]);
1034 assert!(pca.fit(&x).is_ok());
1035 let serialized = serde_json::to_string(&pca).unwrap();
1036 let deserialized: super::Pca<f32> = serde_json::from_str(&serialized).unwrap();
1037 assert!(deserialized
1038 .components()
1039 .abs_diff_eq(pca.components(), 1e-12));
1040 assert!(deserialized.mean().abs_diff_eq(pca.mean(), 1e12));
1041 }
1042
1043 #[test]
1044 fn svd_flip() {
1045 let mut u = arr2(&[[2., -1., 3.], [-1., -3., 2.]]);
1046 let mut v = arr2(&[[1., 1.], [-2., 2.], [3., -3.]]);
1047 super::svd_flip(&mut u, &mut v);
1048 assert_eq!(u, arr2(&[[2., 1., 3.], [-1., 3., 2.]]));
1049 assert_eq!(v, arr2(&[[1., 1.], [2., -2.], [3., -3.]]));
1050 }
1051}