petal_decomposition/
pca.rs

1use std::cmp;
2
3use itertools::izip;
4use lair::{decomposition::lu, Scalar};
5use ndarray::{s, Array1, Array2, ArrayBase, AssignElem, Axis, Data, Ix2, ScalarOperand};
6use num_traits::{real::Real, FromPrimitive};
7use rand::{Rng, RngCore, SeedableRng};
8use rand_distr::StandardNormal;
9#[cfg(target_pointer_width = "32")]
10use rand_pcg::Lcg64Xsh32 as Pcg;
11#[cfg(not(target_pointer_width = "32"))]
12use rand_pcg::Mcg128Xsl64 as Pcg;
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use crate::linalg::{self, qr, svd, svddc, Lapack, LayoutError};
17use crate::DecompositionError;
18
19/// Principal component analysis.
20///
21/// This reduces the dimensionality of the input data using Singular Value
22/// Decomposition (SVD). The data is centered for each feature before applying
23/// SVD.
24///
25/// # Examples
26///
27/// ```
28/// use petal_decomposition::PcaBuilder;
29///
30/// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64], [2_f64, 2_f64]]);
31/// let y = PcaBuilder::new(1).build().fit_transform(&x).unwrap();  // [-2_f64.sqrt(), 0_f64, 2_f64.sqrt()]
32/// assert!((y[(0, 0)].abs() - 2_f64.sqrt()).abs() < 1e-8);
33/// assert!(y[(1, 0)].abs() < 1e-8);
34/// assert!((y[(2, 0)].abs() - 2_f64.sqrt()).abs() < 1e-8);
35/// ```
36#[cfg_attr(
37    feature = "serde",
38    derive(Serialize, Deserialize),
39    serde(bound = "A: Serialize, for<'a> A: Deserialize<'a>")
40)]
41pub struct Pca<A>
42where
43    A: Scalar,
44{
45    components: Array2<A>,
46    n_samples: usize,
47    means: Array1<A>,
48    total_variance: A::Real,
49    singular: Array1<A::Real>,
50    centering: bool,
51}
52
53impl<A> Pca<A>
54where
55    A: Scalar,
56{
57    /// Creates a PCA model with the given number of components.
58    #[must_use]
59    pub fn new(n_components: usize) -> Self {
60        Self {
61            components: Array2::<A>::zeros((n_components, 0)),
62            n_samples: 0,
63            means: Array1::<A>::zeros(0),
64            total_variance: A::zero().re(),
65            singular: Array1::<A::Real>::zeros(0),
66            centering: true,
67        }
68    }
69}
70
71impl<A> Pca<A>
72where
73    A: FromPrimitive + Lapack,
74    A::Real: ScalarOperand,
75{
76    /// Returns the principal axes in feature space.
77    #[inline]
78    pub fn components(&self) -> &Array2<A> {
79        &self.components
80    }
81
82    /// Returns the per-feature empirical mean.
83    #[inline]
84    pub fn mean(&self) -> &Array1<A> {
85        &self.means
86    }
87
88    /// Returns the number of components.
89    #[inline]
90    pub fn n_components(&self) -> usize {
91        self.components.nrows()
92    }
93
94    /// Returns sigular values.
95    #[inline]
96    pub fn singular_values(&self) -> &Array1<A::Real> {
97        &self.singular
98    }
99
100    /// Returns the ratio of explained variance for each component.
101    pub fn explained_variance_ratio(&self) -> Array1<A::Real> {
102        let mut variance: Array1<A::Real> = &self.singular * &self.singular;
103        variance /= self.total_variance;
104        variance
105    }
106
107    /// Fits the model with `input`.
108    ///
109    /// # Errors
110    ///
111    /// * [`DecompositionError::InvalidInput`] if any of the dimensions of
112    ///   `input` is less than the number of components, or the layout of
113    ///   `input` is incompatible with LAPACK.
114    /// * [`DecompositionError::LinalgError`] if the underlying Singular Vector
115    ///   Decomposition routine fails.
116    pub fn fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<(), DecompositionError>
117    where
118        S: Data<Elem = A>,
119    {
120        self.inner_fit(input)?;
121        Ok(())
122    }
123
124    /// Applies dimensionality reduction to `input`.
125    ///
126    /// # Errors
127    ///
128    /// * `DecompositionError::InvalidInput` if the number of features in
129    ///   `input` does not match that of the training data.
130    pub fn transform<S>(&self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
131    where
132        S: Data<Elem = A>,
133    {
134        transform(input, &self.components, &self.means, self.centering)
135    }
136
137    /// Fits the model with `input` and apply the dimensionality reduction on
138    /// `input`.
139    ///
140    /// This is equivalent to calling both [`fit`] and [`transform`] for the
141    /// same input, but more efficient.
142    ///
143    /// [`fit`]: #method.fit
144    /// [`transform`]: #method.transform
145    ///
146    /// # Errors
147    ///
148    /// * [`DecompositionError::InvalidInput`] if any of the dimensions of
149    ///   `input` is less than the number of components, or the layout of
150    ///   `input` is incompatible with LAPACK.
151    /// * [`DecompositionError::LinalgError`] if the underlying Singular Vector
152    ///   Decomposition routine fails.
153    pub fn fit_transform<S>(
154        &mut self,
155        input: &ArrayBase<S, Ix2>,
156    ) -> Result<Array2<A>, DecompositionError>
157    where
158        S: Data<Elem = A>,
159    {
160        let u = self.inner_fit(input)?;
161        Ok(transform_with_u(
162            &u,
163            input,
164            self.singular_values(),
165            self.n_components(),
166        ))
167    }
168
169    /// Transforms data back to its original space.
170    ///
171    /// # Errors
172    ///
173    /// Returns [`DecompositionError::InvalidInput`] if the number of rows of
174    /// `input` is different from that of the training data, or the number of
175    /// columns of `input` is different from the number of components.
176    pub fn inverse_transform<S>(
177        &self,
178        input: &ArrayBase<S, Ix2>,
179    ) -> Result<Array2<A>, DecompositionError>
180    where
181        S: Data<Elem = A>,
182    {
183        inverse_transform(input, &self.components, &self.means, self.centering)
184    }
185
186    /// Fits the model with `input`.
187    ///
188    /// # Errors
189    ///
190    /// * [`DecompositionError::InvalidInput`] if any of the dimensions of
191    ///   `input` is less than the number of components, or the layout of
192    ///   `input` is incompatible with LAPACK.
193    /// * [`DecompositionError::LinalgError`] if the underlying Singular Vector
194    ///   Decomposition routine fails.
195    fn inner_fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
196    where
197        S: Data<Elem = A>,
198    {
199        if input.shape().iter().any(|v| *v < self.n_components()) {
200            return Err(DecompositionError::InvalidInput(format!(
201                "every dimension should be at least {}",
202                self.n_components()
203            )));
204        }
205
206        let means = if self.centering {
207            if let Some(means) = input.mean_axis(Axis(0)) {
208                means
209            } else {
210                return Ok(Array2::<A>::zeros((0, input.ncols())));
211            }
212        } else {
213            Array1::zeros(input.ncols())
214        };
215
216        let (mut u, sigma, vt) = if self.centering {
217            svd(&mut (input - &means), true)?
218        } else {
219            svd(&mut input.to_owned(), true)?
220        };
221
222        let mut vt = vt.expect("`svd` should return `vt`");
223        svd_flip(&mut u, &mut vt);
224        self.total_variance = sigma.dot(&sigma);
225        self.components = vt.slice(s![0..self.n_components(), ..]).into_owned();
226        self.n_samples = input.nrows();
227        self.means = means;
228        self.singular = sigma.slice(s![0..self.n_components()]).into_owned();
229
230        Ok(u)
231    }
232}
233
234/// Builder for [`Pca`].
235///
236/// # Examples
237///
238/// ```
239/// use petal_decomposition::PcaBuilder;
240///
241/// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64]]);
242/// let mut pca = PcaBuilder::new(1).build();
243/// pca.fit(&x);
244/// ```
245#[allow(clippy::module_name_repetitions)]
246pub struct PcaBuilder {
247    n_components: usize,
248    centering: bool,
249}
250
251impl PcaBuilder {
252    /// Sets the number of components for PCA.
253    #[must_use]
254    pub fn new(n_components: usize) -> Self {
255        Self {
256            n_components,
257            centering: true,
258        }
259    }
260
261    /// Indicates whether or not to perform mean-centering on input data. It is
262    /// enabled by default. If the inputs are already centered, set `centering`
263    /// to `false`. *Note* [`Pca::mean()`] will return an [`Array1`] of 0's if
264    /// `centering` is `false`.
265    #[must_use]
266    pub fn centering(mut self, centering: bool) -> Self {
267        self.centering = centering;
268        self
269    }
270
271    /// Creates an instance of [`Pca`].
272    #[must_use]
273    pub fn build<A: Scalar>(self) -> Pca<A> {
274        Pca {
275            components: Array2::<A>::zeros((self.n_components, 0)),
276            n_samples: 0,
277            means: Array1::<A>::zeros(0),
278            total_variance: A::zero().re(),
279            singular: Array1::<A::Real>::zeros(0),
280            centering: self.centering,
281        }
282    }
283}
284
285/// Principal component analysis using randomized singular value decomposition.
286///
287/// This uses randomized SVD (singular value decomposition) proposed by Halko et
288/// al. \[1\] to reduce the dimensionality of the input data. The data is
289/// centered for each feature before applying randomized SVD.
290///
291/// # Examples
292///
293/// ```
294/// use petal_decomposition::RandomizedPcaBuilder;
295///
296/// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64], [2_f64, 2_f64]]);
297/// let mut pca = RandomizedPcaBuilder::new(1).build();
298/// let y = pca.fit_transform(&x).unwrap();  // [-2_f64.sqrt(), 0_f64, 2_f64.sqrt()]
299/// assert!((y[(0, 0)].abs() - 2_f64.sqrt()).abs() < 1e-8);
300/// assert!(y[(1, 0)].abs() < 1e-8);
301/// assert!((y[(2, 0)].abs() - 2_f64.sqrt()).abs() < 1e-8);
302/// ```
303///
304/// # References
305///
306/// 1. N. Halko, P. G. Martinsson, and J. A. Tropp. Finding Structure with
307///    Randomness: Probabilistic Algorithms for Constructing Approximate Matrix
308///    Decompositions. _SIAM Review,_ 53(2), 217–288, 2011.
309#[cfg_attr(
310    feature = "serde",
311    derive(Serialize, Deserialize),
312    serde(
313        bound = "A: Serialize, for<'a> A: Deserialize<'a>, R: Serialize, for<'a> R: Deserialize<'a>"
314    )
315)]
316#[allow(clippy::module_name_repetitions)]
317pub struct RandomizedPca<A, R>
318where
319    A: Scalar,
320    R: Rng,
321{
322    rng: R,
323    components: Array2<A>,
324    n_samples: usize,
325    means: Array1<A>,
326    total_variance: A::Real,
327    singular: Array1<A::Real>,
328    centering: bool,
329}
330
331impl<A> RandomizedPca<A, Pcg>
332where
333    A: Scalar,
334{
335    /// Creates a PCA model based on randomized SVD.
336    ///
337    /// The random matrix for randomized SVD is created from a PCG random number
338    /// generator (the XSL 128/64 (MCG) variant on a 64-bit CPU and the XSH RR
339    /// 64/32 (LCG) variant on a 32-bit CPU), initialized with a
340    /// randomly-generated seed.
341    #[must_use]
342    pub fn new(n_components: usize) -> Self {
343        let seed: u128 = rand::rng().random();
344        Self::with_seed(n_components, seed)
345    }
346
347    /// Creates a PCA model based on randomized SVD, with a PCG random number
348    /// generator initialized with the given seed.
349    ///
350    /// It uses a PCG random number generator (the XSL 128/64 (MCG) variant on a
351    /// 64-bit CPU and the XSH RR 64/32 (LCG) variant on a 32-bit CPU). Use
352    /// [`with_rng`] for a different random number generator.
353    ///
354    /// [`with_rng`]: #method.with_rng
355    #[must_use]
356    pub fn with_seed(n_components: usize, seed: u128) -> Self {
357        let rng = Pcg::from_seed(seed.to_be_bytes());
358        Self::with_rng(n_components, rng)
359    }
360}
361
362impl<A, R> RandomizedPca<A, R>
363where
364    A: Scalar,
365    R: Rng,
366{
367    /// Creates a PCA model with the given number of components and random
368    /// number generator. The random number generator is used to create a random
369    /// matrix for randomized SVD.
370    #[must_use]
371    pub fn with_rng(n_components: usize, rng: R) -> Self {
372        Self {
373            rng,
374            components: Array2::<A>::zeros((n_components, 0)),
375            n_samples: 0,
376            means: Array1::<A>::zeros(0),
377            total_variance: A::zero().re(),
378            singular: Array1::<A::Real>::zeros(0),
379            centering: true,
380        }
381    }
382}
383
384impl<A, R> RandomizedPca<A, R>
385where
386    A: Scalar + FromPrimitive + Lapack,
387    A::Real: ScalarOperand + FromPrimitive,
388    R: Rng,
389{
390    /// Returns the principal axes in feature space.
391    #[inline]
392    pub fn components(&self) -> &Array2<A> {
393        &self.components
394    }
395
396    /// Returns the per-feature empirical mean.
397    #[inline]
398    pub fn mean(&self) -> &Array1<A> {
399        &self.means
400    }
401
402    /// Returns the number of components.
403    #[inline]
404    pub fn n_components(&self) -> usize {
405        self.components.nrows()
406    }
407
408    /// Returns sigular values.
409    #[inline]
410    pub fn singular_values(&self) -> &Array1<A::Real> {
411        &self.singular
412    }
413
414    /// Returns the ratio of explained variance for each component.
415    pub fn explained_variance_ratio(&self) -> Array1<A::Real> {
416        let mut variance: Array1<A::Real> = &self.singular * &self.singular;
417        variance /= self.total_variance;
418        variance
419    }
420
421    /// Fits the model with `input`.
422    ///
423    /// # Errors
424    ///
425    /// * [`DecompositionError::InvalidInput`] if any of the dimensions of
426    ///   `input` is less than the number of components, or the layout of
427    ///   `input` is incompatible with LAPACK.
428    /// * [`DecompositionError::LinalgError`] if the underlying Singular Vector
429    ///   Decomposition routine fails.
430    pub fn fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<(), DecompositionError>
431    where
432        S: Data<Elem = A>,
433    {
434        self.inner_fit(input)?;
435        Ok(())
436    }
437
438    /// Applies dimensionality reduction to `input`.
439    ///
440    /// # Errors
441    ///
442    /// * [`DecompositionError::InvalidInput`] if the number of features in
443    ///   `input` does not match that of the training data.
444    pub fn transform<S>(&self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
445    where
446        S: Data<Elem = A>,
447    {
448        transform(input, &self.components, &self.means, self.centering)
449    }
450
451    /// Fits the model with `input` and apply the dimensionality reduction on
452    /// `input`.
453    ///
454    /// This is equivalent to calling both [`fit`] and [`transform`] for the
455    /// same input.
456    ///
457    /// [`fit`]: #method.fit
458    /// [`transform`]: #method.transform
459    ///
460    /// # Errors
461    ///
462    /// * [`DecompositionError::InvalidInput`] if any of the dimensions of
463    ///   `input` is less than the number of components, or the layout of
464    ///   `input` is incompatible with LAPACK.
465    /// * [`DecompositionError::LinalgError`] if the underlying Singular Vector
466    ///   Decomposition routine fails.
467    pub fn fit_transform<S>(
468        &mut self,
469        input: &ArrayBase<S, Ix2>,
470    ) -> Result<Array2<A>, DecompositionError>
471    where
472        S: Data<Elem = A>,
473    {
474        let u = self.inner_fit(input)?;
475        Ok(transform_with_u(
476            &u,
477            input,
478            self.singular_values(),
479            self.n_components(),
480        ))
481    }
482
483    /// Transforms data back to its original space.
484    ///
485    /// # Errors
486    ///
487    /// Returns [`DecompositionError::InvalidInput`] if the number of rows of
488    /// `input` is different from that of the training data, or the number of
489    /// columns of `input` is different from the number of components.
490    pub fn inverse_transform<S>(
491        &self,
492        input: &ArrayBase<S, Ix2>,
493    ) -> Result<Array2<A>, DecompositionError>
494    where
495        S: Data<Elem = A>,
496    {
497        inverse_transform(input, &self.components, &self.means, self.centering)
498    }
499
500    /// Fits the model with `input`.
501    ///
502    /// # Errors
503    ///
504    /// * [`DecompositionError::InvalidInput`] if any of the dimensions of
505    ///   `input` is less than the number of components, or the layout of
506    ///   `input` is incompatible with LAPACK.
507    /// * [`DecompositionError::LinalgError`] if the underlying Singular Vector
508    ///   Decomposition routine fails.
509    fn inner_fit<S>(&mut self, input: &ArrayBase<S, Ix2>) -> Result<Array2<A>, DecompositionError>
510    where
511        S: Data<Elem = A>,
512    {
513        if input.shape().iter().any(|v| *v < self.n_components()) {
514            return Err(DecompositionError::InvalidInput(format!(
515                "every dimension should be at least {}",
516                self.n_components()
517            )));
518        }
519
520        let means = if self.centering {
521            if let Some(means) = input.mean_axis(Axis(0)) {
522                means
523            } else {
524                return Ok(Array2::<A>::zeros((0, input.ncols())));
525            }
526        } else {
527            Array1::zeros(input.ncols())
528        };
529
530        let (u, sigma, vt, total_variance) = if self.centering {
531            let x = input - &means;
532            let (u, sigma, vt) = randomized_svd(&x, self.n_components(), &mut self.rng)?;
533            let total_variance = x.iter().fold(A::zero().re(), |var, &e| var + e.square());
534            (u, sigma, vt, total_variance)
535        } else {
536            let (u, sigma, vt) = randomized_svd(input, self.n_components(), &mut self.rng)?;
537            let total_variance = input
538                .iter()
539                .fold(A::zero().re(), |var, &e| var + e.square());
540            (u, sigma, vt, total_variance)
541        };
542
543        self.total_variance = total_variance;
544        self.components = vt.slice(s![0..self.n_components(), ..]).into_owned();
545        self.n_samples = input.nrows();
546        self.means = means;
547        self.singular = sigma.slice(s![0..self.n_components()]).into_owned();
548
549        Ok(u)
550    }
551}
552
553/// Builder for [`RandomizedPca`].
554///
555/// # Examples
556///
557/// ```
558/// use petal_decomposition::RandomizedPcaBuilder;
559///
560/// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64]]);
561/// let mut pca = RandomizedPcaBuilder::new(1).build();
562/// pca.fit(&x);
563/// ```
564pub struct RandomizedPcaBuilder<R> {
565    n_components: usize,
566    rng: R,
567    centering: bool,
568}
569
570impl RandomizedPcaBuilder<Pcg> {
571    /// Sets the number of components for PCA.
572    ///
573    /// The random matrix for randomized SVD is created from a PCG random number
574    /// generator (the XSL 128/64 (MCG) variant on a 64-bit CPU and the XSH RR
575    /// 64/32 (LCG) variant on a 32-bit CPU), initialized with a
576    /// randomly-generated seed.
577    #[must_use]
578    pub fn new(n_components: usize) -> Self {
579        let seed: u128 = rand::rng().random();
580        Self {
581            n_components,
582            rng: Pcg::from_seed(seed.to_be_bytes()),
583            centering: true,
584        }
585    }
586
587    /// Initialized the PCG random number genernator with the given seed.
588    ///
589    /// # Examples
590    ///
591    /// ```
592    /// use petal_decomposition::RandomizedPcaBuilder;
593    ///
594    /// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64]]);
595    /// let mut pca = RandomizedPcaBuilder::new(1).seed(1234567891011121314).build();
596    /// pca.fit(&x);
597    /// ```
598    #[must_use]
599    pub fn seed(mut self, seed: u128) -> Self {
600        self.rng = Pcg::from_seed(seed.to_be_bytes());
601        self
602    }
603
604    /// Indicates whether or not to perform mean-centering on input data. It is
605    /// enabled by default. If the inputs are already centered, set `centering`
606    /// to `false`. *Note* [`Pca::mean()`] will return an [`Array1`] of 0's if
607    /// `centering` is `false`.
608    ///
609    /// # Examples
610    ///
611    /// ```
612    /// use petal_decomposition::RandomizedPcaBuilder;
613    ///
614    /// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64]]);
615    /// let mut pca = RandomizedPcaBuilder::new(1).centering(false).build();
616    /// pca.fit(&x);
617    /// ```
618    #[must_use]
619    pub fn centering(mut self, centering: bool) -> Self {
620        self.centering = centering;
621        self
622    }
623}
624
625impl<R: Rng> RandomizedPcaBuilder<R> {
626    /// Sets the number of components and random number generator for PCA.
627    ///
628    /// The random number generator is used to create a random matrix for
629    /// randomized SVD.
630    ///
631    /// # Examples
632    ///
633    /// ```
634    /// use petal_decomposition::RandomizedPcaBuilder;
635    /// use rand_pcg::Pcg64;
636    ///
637    /// let x = ndarray::arr2(&[[0_f64, 0_f64], [1_f64, 1_f64]]);
638    /// let rng = Pcg64::new(0xcafef00dd15ea5e5, 0xa02bdbf7bb3c0a7ac28fa16a64abf96);
639    /// let mut pca = RandomizedPcaBuilder::with_rng(rng, 1).build();
640    /// pca.fit(&x);
641    /// ```
642    #[must_use]
643    pub fn with_rng(rng: R, n_components: usize) -> Self {
644        Self {
645            n_components,
646            rng,
647            centering: true,
648        }
649    }
650
651    /// Creates an instance of [`RandomizedPca`].
652    pub fn build<A: Scalar>(self) -> RandomizedPca<A, R> {
653        RandomizedPca {
654            rng: self.rng,
655            components: Array2::<A>::zeros((self.n_components, 0)),
656            n_samples: 0,
657            means: Array1::<A>::zeros(0),
658            total_variance: A::zero().re(),
659            singular: Array1::<A::Real>::zeros(0),
660            centering: self.centering,
661        }
662    }
663}
664
665type Svd<A> = (Array2<A>, Array1<<A as Scalar>::Real>, Array2<A>);
666
667/// Computes a truncated randomized SVD
668fn randomized_svd<A, S, R>(
669    input: &ArrayBase<S, Ix2>,
670    n_components: usize,
671    rng: &mut R,
672) -> Result<Svd<A>, linalg::Error>
673where
674    A: Scalar + Lapack,
675    A::Real: FromPrimitive,
676    S: Data<Elem = A>,
677    R: RngCore,
678{
679    let n_random = n_components + 10; // oversample by 10
680    let q = randomized_range_finder(input, n_random, 7, rng)?;
681    let mut b = q.t().dot(input);
682    let (u, sigma, mut vt) = svddc(&mut b)?;
683    let mut u = q.dot(&u);
684    svd_flip(&mut u, &mut vt);
685    Ok((u, sigma, vt))
686}
687
688/// Computes an orthonormal matrix whose range approximates the range of `input`.
689fn randomized_range_finder<A, S, R>(
690    input: &ArrayBase<S, Ix2>,
691    size: usize,
692    n_iter: usize,
693    rng: &mut R,
694) -> Result<Array2<A>, LayoutError>
695where
696    A: Scalar + Lapack,
697    A::Real: FromPrimitive,
698    S: Data<Elem = A>,
699    R: RngCore,
700{
701    let mut q = ArrayBase::from_shape_fn((input.ncols(), size), |_| {
702        let r = A::Real::from_f64(rng.sample(StandardNormal))
703            .expect("float to float conversion never fails");
704        r.into()
705    });
706    let mut pl = q.view();
707    q = input.dot(&pl);
708    for _ in 0..n_iter {
709        q = lu::Factorized::from(q).into_pl();
710        pl = q.slice(s![.., 0..cmp::min(q.nrows(), q.ncols())]);
711        q = input.t().dot(&pl);
712        q = lu::Factorized::from(q).into_pl();
713        pl = q.slice(s![.., 0..cmp::min(q.nrows(), q.ncols())]);
714        q = input.dot(&pl);
715    }
716    let q = qr(q)?;
717    Ok(q)
718}
719
720/// Applies dimensionality reduction to `input`.
721///
722/// # Errors
723///
724/// * [`DecompositionError::InvalidInput`] if the number of features in `input`
725///   does not match that of the training data.
726fn transform<A, S>(
727    input: &ArrayBase<S, Ix2>,
728    components: &Array2<A>,
729    means: &Array1<A>,
730    centering: bool,
731) -> Result<Array2<A>, DecompositionError>
732where
733    A: Scalar,
734    S: Data<Elem = A>,
735{
736    if input.ncols() != means.len() {
737        return Err(DecompositionError::InvalidInput(format!(
738            "# of columns should be {}",
739            means.len()
740        )));
741    }
742
743    let transformed = if centering {
744        let x = input - means;
745        x.dot(&components.t())
746    } else {
747        input.dot(&components.t())
748    };
749    Ok(transformed)
750}
751
752/// Applies dimensionality reduction to `input`, given matrix `u`.
753///
754/// # Errors
755///
756/// Returns [`DecompositionError::LinalgError`] if the underlying Singular
757/// Vector Decomposition routine fails.
758fn transform_with_u<A, S>(
759    u: &Array2<A>,
760    input: &ArrayBase<S, Ix2>,
761    singular: &Array1<A::Real>,
762    n_components: usize,
763) -> Array2<A>
764where
765    A: Scalar,
766    S: Data<Elem = A>,
767{
768    let mut y = Array2::<A>::uninit((input.nrows(), n_components));
769    for (y_row, u_row) in y
770        .lanes_mut(Axis(1))
771        .into_iter()
772        .zip(u.slice(s![.., 0..n_components]).lanes(Axis(1)))
773    {
774        for (y_v, u_v, sigma_v) in izip!(y_row.into_iter(), u_row, singular) {
775            y_v.assign_elem(*u_v * (*sigma_v).into());
776        }
777    }
778    unsafe { y.assume_init() }
779}
780
781/// Transforms data back to its original space.
782///
783/// # Errors
784///
785/// Returns [`DecompositionError::InvalidInput`] if the number of rows of
786/// `input` is different from that of the training data, or the number of
787/// columns of `input` is different from the number of components.
788fn inverse_transform<A, S>(
789    input: &ArrayBase<S, Ix2>,
790    components: &Array2<A>,
791    means: &Array1<A>,
792    centering: bool,
793) -> Result<Array2<A>, DecompositionError>
794where
795    A: Scalar,
796    S: Data<Elem = A>,
797{
798    if input.ncols() != components.nrows() {
799        return Err(DecompositionError::InvalidInput(format!(
800            "# of columns should be {}",
801            components.nrows()
802        )));
803    }
804
805    let inverse_transformed = if centering {
806        input.dot(components) + means
807    } else {
808        input.dot(components)
809    };
810    Ok(inverse_transformed)
811}
812
813/// Makes `SVD`'s output deterministic using the columns of `u` as the basis for
814/// sign flipping.
815fn svd_flip<A>(u: &mut Array2<A>, v: &mut Array2<A>)
816where
817    A: Scalar,
818{
819    for (u_col, v_row) in u.lanes_mut(Axis(0)).into_iter().zip(v.lanes_mut(Axis(1))) {
820        let mut u_col_iter = u_col.iter();
821        let e = if let Some(e) = u_col_iter.next() {
822            *e
823        } else {
824            continue;
825        };
826        let mut absmax = e.abs();
827        let mut signum = e.re().signum();
828        for e in u_col_iter {
829            let abs = e.abs();
830            if abs <= absmax {
831                continue;
832            }
833            absmax = abs;
834            signum = if e.re() == A::zero().re() {
835                e.im().signum()
836            } else {
837                e.re().signum()
838            };
839        }
840        if signum < A::zero().re() {
841            let signum = signum.into();
842            for e in u_col {
843                *e *= signum;
844            }
845            for e in v_row {
846                *e *= signum;
847            }
848        }
849    }
850}
851
852#[cfg(test)]
853mod test {
854    use approx::{assert_abs_diff_eq, assert_relative_eq};
855    use ndarray::{arr2, Array2};
856    use rand::Rng;
857    use rand_distr::StandardNormal;
858    use rand_pcg::Pcg64Mcg;
859
860    const RNG_SEED: u128 = 1_234_567_891_011_121_314;
861
862    #[test]
863    fn pca_zero_component() {
864        let mut pca = super::PcaBuilder::new(0).build();
865
866        let x = Array2::<f32>::zeros((0, 5));
867        let y = pca.fit_transform(&x).unwrap();
868        assert_eq!(y.nrows(), 0);
869        assert_eq!(y.ncols(), 0);
870
871        let x = arr2(&[[0_f32, 0_f32], [3_f32, 4_f32], [6_f32, 8_f32]]);
872        let y = pca.fit_transform(&x).unwrap();
873        assert_eq!(y.nrows(), 3);
874        assert_eq!(y.ncols(), 0);
875    }
876
877    #[test]
878    fn pca_single_sample() {
879        let mut pca = super::Pca::new(1);
880        let x = arr2(&[[1_f32, 1_f32]]);
881        let y = pca.fit_transform(&x).unwrap();
882        assert_eq!(y, arr2(&[[0.0]]));
883    }
884
885    #[test]
886    fn pca() {
887        let x = arr2(&[[0_f64, 0_f64], [3_f64, 4_f64], [6_f64, 8_f64]]);
888        let mut pca = super::Pca::new(1);
889        assert_eq!(pca.n_components(), 1);
890
891        let y = pca.fit_transform(&x).unwrap();
892        assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
893        assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
894        assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
895        let z = pca.inverse_transform(&y).expect("valid input");
896        assert!(z.abs_diff_eq(&x, 1e-10));
897
898        let mut pca = super::Pca::new(1);
899        assert!(pca.fit(&x).is_ok());
900        assert_eq!(pca.n_components(), 1);
901        assert!(pca.components().abs_diff_eq(&arr2(&[[-0.6, -0.8]]), 1e-10));
902        let y = pca.transform(&x).unwrap();
903        assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
904        assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
905        assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
906    }
907
908    #[test]
909    fn pca_without_centering() {
910        let x = arr2(&[[0_f64, 0_f64], [3_f64, 4_f64], [6_f64, 8_f64]]);
911        let mut pca = super::PcaBuilder::new(1).centering(false).build();
912        let y = pca.fit_transform(&x).unwrap();
913        assert_abs_diff_eq!(y[(0, 0)].abs(), 0., epsilon = 1e-10);
914        assert_abs_diff_eq!(y[(1, 0)], 5., epsilon = 1e-10);
915        assert_abs_diff_eq!(y[(2, 0)].abs(), 10., epsilon = 1e-10);
916    }
917
918    #[test]
919    fn pca_explained_variance_ratio() {
920        let x = arr2(&[
921            [-1_f64, -1_f64],
922            [-2_f64, -1_f64],
923            [-3_f64, -2_f64],
924            [1_f64, 1_f64],
925            [2_f64, 1_f64],
926            [3_f64, 2_f64],
927        ]);
928        let mut pca = super::Pca::new(2);
929        assert!(pca.fit(&x).is_ok());
930        let ratio = pca.explained_variance_ratio();
931        assert!(ratio.get(0).unwrap() > &0.99244);
932        assert!(ratio.get(1).unwrap() < &0.00756);
933    }
934
935    #[test]
936    #[cfg(feature = "serde")]
937    fn pca_serialize() {
938        let mut pca = super::Pca::new(1);
939        let x = arr2(&[[1_f32, 1_f32]]);
940        assert!(pca.fit(&x).is_ok());
941        let serialized = serde_json::to_string(&pca).unwrap();
942        let deserialized: super::Pca<f32> = serde_json::from_str(&serialized).unwrap();
943        assert!(deserialized
944            .components()
945            .abs_diff_eq(pca.components(), 1e-12));
946        assert!(deserialized.mean().abs_diff_eq(pca.mean(), 1e12));
947    }
948
949    #[test]
950    fn randomized_pca() {
951        let x = arr2(&[[0_f64, 0_f64], [3_f64, 4_f64], [6_f64, 8_f64]]);
952        let mut pca = super::RandomizedPca::with_seed(1, RNG_SEED);
953        assert_eq!(pca.n_components(), 1);
954
955        let res = pca.fit(&x);
956        assert!(res.is_ok());
957        assert_eq!(pca.n_components(), 1);
958        let y = pca.transform(&x).unwrap();
959        assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
960        assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
961        assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
962        let z = pca.inverse_transform(&y).expect("valid input");
963        assert!(z.abs_diff_eq(&x, 1e-10));
964
965        let mut pca = super::RandomizedPca::with_rng(1, rand::rng());
966        let y = pca.fit_transform(&x).unwrap();
967        assert_abs_diff_eq!(y[(0, 0)].abs(), 5., epsilon = 1e-10);
968        assert_abs_diff_eq!(y[(1, 0)], 0., epsilon = 1e-10);
969        assert_abs_diff_eq!(y[(2, 0)].abs(), 5., epsilon = 1e-10);
970    }
971
972    #[test]
973    fn randomized_pca_explained_variance_ratio() {
974        let x = arr2(&[
975            [-1_f64, -1_f64],
976            [-2_f64, -1_f64],
977            [-3_f64, -2_f64],
978            [1_f64, 1_f64],
979            [2_f64, 1_f64],
980            [3_f64, 2_f64],
981        ]);
982        let mut pca = super::RandomizedPca::with_rng(2, rand::rng());
983        assert!(pca.fit(&x).is_ok());
984        let ratio = pca.explained_variance_ratio();
985        assert!(ratio.get(0).unwrap() > &0.99244);
986        assert!(ratio.get(1).unwrap() < &0.00756);
987    }
988
989    #[test]
990    fn randomized_pca_explained_variance_equivalence() {
991        let mut rng = Pcg64Mcg::new(RNG_SEED);
992        let x = Array2::from_shape_fn((100, 80), |_| rng.sample::<f64, _>(StandardNormal));
993
994        let mut pca = super::Pca::new(2);
995        let mut pca_rand = super::RandomizedPca::with_rng(2, rng);
996
997        assert!(pca.fit(&x).is_ok());
998        assert!(pca_rand.fit(&x).is_ok());
999
1000        for (a, b) in pca
1001            .explained_variance_ratio()
1002            .iter()
1003            .zip(pca_rand.explained_variance_ratio().iter())
1004        {
1005            assert_relative_eq!(a, b, max_relative = 0.05);
1006        }
1007    }
1008
1009    #[test]
1010    fn randomized_pca_singular_values_consistency() {
1011        let mut rng = Pcg64Mcg::new(RNG_SEED);
1012        let x = Array2::from_shape_fn((100, 80), |_| rng.sample::<f64, _>(StandardNormal));
1013
1014        let mut pca = super::Pca::new(2);
1015        let mut pca_rand = super::RandomizedPca::with_rng(2, rng);
1016
1017        assert!(pca.fit(&x).is_ok());
1018        assert!(pca_rand.fit(&x).is_ok());
1019
1020        for (a, b) in pca
1021            .singular_values()
1022            .iter()
1023            .zip(pca_rand.singular_values().iter())
1024        {
1025            assert_relative_eq!(a, b, max_relative = 0.05);
1026        }
1027    }
1028
1029    #[test]
1030    #[cfg(feature = "serde")]
1031    fn randomized_pca_serialize() {
1032        let mut pca = super::RandomizedPca::with_seed(1, RNG_SEED);
1033        let x = arr2(&[[1_f32, 1_f32]]);
1034        assert!(pca.fit(&x).is_ok());
1035        let serialized = serde_json::to_string(&pca).unwrap();
1036        let deserialized: super::Pca<f32> = serde_json::from_str(&serialized).unwrap();
1037        assert!(deserialized
1038            .components()
1039            .abs_diff_eq(pca.components(), 1e-12));
1040        assert!(deserialized.mean().abs_diff_eq(pca.mean(), 1e12));
1041    }
1042
1043    #[test]
1044    fn svd_flip() {
1045        let mut u = arr2(&[[2., -1., 3.], [-1., -3., 2.]]);
1046        let mut v = arr2(&[[1., 1.], [-2., 2.], [3., -3.]]);
1047        super::svd_flip(&mut u, &mut v);
1048        assert_eq!(u, arr2(&[[2., 1., 3.], [-1., 3., 2.]]));
1049        assert_eq!(v, arr2(&[[1., 1.], [2., -2.], [3., -3.]]));
1050    }
1051}