Skip to main content

scirs2_stats/multivariate/
pca.rs

1//! Principal Component Analysis (PCA)
2//!
3//! PCA is a dimensionality reduction technique that finds the directions of maximum variance
4//! in high-dimensional data and projects the data onto a lower-dimensional subspace.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
8use scirs2_core::validation::*;
9
10/// Principal Component Analysis
11#[derive(Debug, Clone)]
12pub struct PCA {
13    /// Number of components to keep
14    pub n_components: Option<usize>,
15    /// Whether to use SVD instead of eigendecomposition  
16    pub svd_solver: SvdSolver,
17    /// Whether to center the data
18    pub center: bool,
19    /// Whether to scale the data to unit variance
20    pub scale: bool,
21    /// Random state for randomized solver
22    pub random_state: Option<u64>,
23}
24
25/// SVD solver type
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum SvdSolver {
28    /// Full SVD
29    Full,
30    /// Randomized SVD (for large datasets)
31    Randomized,
32    /// Automatically choose based on data size
33    Auto,
34}
35
36/// Result of PCA fit
37#[derive(Debug, Clone)]
38pub struct PCAResult {
39    /// Principal components (eigenvectors)
40    pub components: Array2<f64>,
41    /// Explained variance for each component
42    pub explained_variance: Array1<f64>,
43    /// Explained variance ratio for each component
44    pub explained_variance_ratio: Array1<f64>,
45    /// Singular values corresponding to each component
46    pub singular_values: Array1<f64>,
47    /// Mean of the training data
48    pub mean: Array1<f64>,
49    /// Standard deviation of the training data (if scaling was used)
50    pub scale: Option<Array1<f64>>,
51    /// Number of samples used for fitting
52    pub n_samples_: usize,
53    /// Number of features
54    pub n_features: usize,
55}
56
57impl Default for PCA {
58    fn default() -> Self {
59        Self {
60            n_components: None,
61            svd_solver: SvdSolver::Auto,
62            center: true,
63            scale: false,
64            random_state: None,
65        }
66    }
67}
68
69impl PCA {
70    /// Create a new PCA instance
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    /// Set the number of components to keep
76    pub fn with_n_components(mut self, n_components: usize) -> Self {
77        self.n_components = Some(n_components);
78        self
79    }
80
81    /// Set the SVD solver
82    pub fn with_svd_solver(mut self, solver: SvdSolver) -> Self {
83        self.svd_solver = solver;
84        self
85    }
86
87    /// Enable or disable centering
88    pub fn with_center(mut self, center: bool) -> Self {
89        self.center = center;
90        self
91    }
92
93    /// Enable or disable scaling
94    pub fn with_scale(mut self, scale: bool) -> Self {
95        self.scale = scale;
96        self
97    }
98
99    /// Set random state for reproducibility
100    pub fn with_random_state(mut self, seed: u64) -> Self {
101        self.random_state = Some(seed);
102        self
103    }
104
105    /// Fit the PCA model to the data
106    pub fn fit(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
107        checkarray_finite(&data, "data")?;
108        let (n_samples, n_features) = data.dim();
109        if n_samples < 2 {
110            return Err(StatsError::InvalidArgument(
111                "n_samples must be at least 2".to_string(),
112            ));
113        }
114        if n_features < 1 {
115            return Err(StatsError::InvalidArgument(
116                "n_features must be at least 1".to_string(),
117            ));
118        }
119
120        // Determine number of components
121        let max_components = n_samples.min(n_features);
122        let n_components = match self.n_components {
123            Some(k) => {
124                check_positive(k, "n_components")?;
125                if k > max_components {
126                    return Err(StatsError::InvalidArgument(format!(
127                        "n_components ({}) cannot be larger than min(n_samples, n_features) = {}",
128                        k, max_components
129                    )));
130                }
131                k
132            }
133            None => max_components,
134        };
135
136        // Center the data
137        let mean = if self.center {
138            data.mean_axis(Axis(0)).expect("Operation failed")
139        } else {
140            Array1::zeros(n_features)
141        };
142
143        let mut centereddata = data.to_owned();
144        if self.center {
145            for mut row in centereddata.rows_mut() {
146                row -= &mean;
147            }
148        }
149
150        // Scale the data
151        let scale = if self.scale {
152            let std = centereddata.std_axis(Axis(0), 1.0);
153            // Avoid division by zero
154            let std = std.mapv(|s| if s > 1e-10 { s } else { 1.0 });
155
156            for (mut col, &s) in centereddata.columns_mut().into_iter().zip(std.iter()) {
157                col /= s;
158            }
159            Some(std)
160        } else {
161            None
162        };
163
164        // Choose solver
165        let solver = match self.svd_solver {
166            SvdSolver::Auto => {
167                if n_samples >= 500 && n_features >= 500 && n_components < max_components / 2 {
168                    SvdSolver::Randomized
169                } else {
170                    SvdSolver::Full
171                }
172            }
173            solver => solver,
174        };
175
176        // Perform PCA
177        let result = match solver {
178            SvdSolver::Full => self.pca_svd(&centereddata, n_components, n_samples)?,
179            SvdSolver::Randomized => self.pca_randomized(&centereddata, n_components, n_samples)?,
180            _ => unreachable!(),
181        };
182
183        Ok(PCAResult {
184            components: result.0,
185            explained_variance: result.1,
186            explained_variance_ratio: result.2,
187            singular_values: result.3,
188            mean,
189            scale,
190            n_samples_: n_samples,
191            n_features,
192        })
193    }
194
195    /// Perform PCA using SVD
196    fn pca_svd(
197        &self,
198        data: &Array2<f64>,
199        n_components: usize,
200        n_samples: usize,
201    ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
202        // Perform SVD: X = U * S * V^T using scirs2_linalg
203        let (_u, s, vt) = scirs2_linalg::svd(&data.view(), true, None)
204            .map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
205        let v = vt.t().to_owned();
206
207        // Extract _components
208        let components = v
209            .slice(scirs2_core::ndarray::s![.., ..n_components])
210            .to_owned();
211
212        // Compute explained variance
213        let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
214        let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
215
216        // Compute explained variance ratio
217        let total_variance = explained_variance.sum();
218        let explained_variance_ratio = &explained_variance / total_variance;
219
220        Ok((
221            components.t().to_owned(),
222            explained_variance,
223            explained_variance_ratio,
224            singular_values,
225        ))
226    }
227
228    /// Perform PCA using randomized SVD
229    fn pca_randomized(
230        &self,
231        data: &Array2<f64>,
232        n_components: usize,
233        n_samples: usize,
234    ) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
235        use scirs2_core::random::{rngs::StdRng, SeedableRng};
236        use scirs2_core::random::{Distribution, Normal};
237
238        let n_features = data.ncols();
239        let n_oversamples = 10.min((n_features - n_components) / 2);
240        let n_random = n_components + n_oversamples;
241
242        // Initialize RNG
243        let mut rng = match self.random_state {
244            Some(seed) => StdRng::seed_from_u64(seed),
245            None => {
246                // Use a simple fallback seed based on current time or a fixed seed
247                use std::time::{SystemTime, UNIX_EPOCH};
248                let seed = SystemTime::now()
249                    .duration_since(UNIX_EPOCH)
250                    .unwrap_or_default()
251                    .as_secs();
252                StdRng::seed_from_u64(seed)
253            }
254        };
255
256        // Generate random matrix
257        let normal = Normal::new(0.0, 1.0).map_err(|e| {
258            StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
259        })?;
260        let omega = Array2::from_shape_fn((n_features, n_random), |_| normal.sample(&mut rng));
261
262        // Power iterations for better approximation
263        let n_iter = 4;
264        let mut q = data.dot(&omega);
265
266        for _ in 0..n_iter {
267            // QR decomposition using scirs2_linalg
268            let (q_mat, _r) = scirs2_linalg::qr(&q.view(), None).map_err(|e| {
269                StatsError::ComputationError(format!("QR decomposition failed: {}", e))
270            })?;
271            q = q_mat;
272
273            // Project back
274            let z = data.t().dot(&q);
275            let (q_mat, _r) = scirs2_linalg::qr(&z.view(), None).map_err(|e| {
276                StatsError::ComputationError(format!("QR decomposition failed: {}", e))
277            })?;
278            q = data.dot(&q_mat);
279        }
280
281        // Final QR decomposition using scirs2_linalg
282        let (q_final, _r) = scirs2_linalg::qr(&q.view(), None).map_err(|e| {
283            StatsError::ComputationError(format!("Final QR decomposition failed: {}", e))
284        })?;
285
286        // Project data onto subspace
287        let b = q_final.t().dot(data);
288
289        // SVD of small matrix B using scirs2_linalg
290        let (_u_small, s, vt) = scirs2_linalg::svd(&b.view(), true, None).map_err(|e| {
291            StatsError::ComputationError(format!("SVD of projected matrix failed: {}", e))
292        })?;
293
294        let v = vt.t().to_owned();
295
296        // Extract _components
297        let components = v
298            .slice(scirs2_core::ndarray::s![.., ..n_components])
299            .to_owned();
300
301        // Compute explained variance
302        let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
303        let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
304
305        // Compute explained variance ratio
306        let total_variance = explained_variance.sum();
307        let explained_variance_ratio = &explained_variance / total_variance;
308
309        Ok((
310            components.t().to_owned(),
311            explained_variance,
312            explained_variance_ratio,
313            singular_values,
314        ))
315    }
316
317    /// Transform data using the fitted PCA model
318    pub fn transform(&self, data: ArrayView2<f64>, result: &PCAResult) -> Result<Array2<f64>> {
319        checkarray_finite(&data, "data")?;
320        if data.ncols() != result.n_features {
321            return Err(StatsError::DimensionMismatch(format!(
322                "data has {} features, expected {}",
323                data.ncols(),
324                result.n_features
325            )));
326        }
327
328        let mut transformed = data.to_owned();
329
330        // Center
331        if self.center {
332            for mut row in transformed.rows_mut() {
333                row -= &result.mean;
334            }
335        }
336
337        // Scale
338        if let Some(ref scale) = result.scale {
339            for (mut col, &s) in transformed.columns_mut().into_iter().zip(scale.iter()) {
340                col /= s;
341            }
342        }
343
344        // Project onto components
345        Ok(transformed.dot(&result.components.t()))
346    }
347
348    /// Inverse transform from component space back to original space
349    pub fn inverse_transform(
350        &self,
351        data: ArrayView2<f64>,
352        result: &PCAResult,
353    ) -> Result<Array2<f64>> {
354        checkarray_finite(&data, "data")?;
355        let n_components = result.components.nrows();
356        if data.ncols() != n_components {
357            return Err(StatsError::DimensionMismatch(format!(
358                "data has {} components, expected {}",
359                data.ncols(),
360                n_components
361            )));
362        }
363
364        // Project back to original space
365        let mut reconstructed = data.dot(&result.components);
366
367        // Inverse scale
368        if let Some(ref scale) = result.scale {
369            for (mut col, &s) in reconstructed.columns_mut().into_iter().zip(scale.iter()) {
370                col *= s;
371            }
372        }
373
374        // Add mean back
375        if self.center {
376            for mut row in reconstructed.rows_mut() {
377                row += &result.mean;
378            }
379        }
380
381        Ok(reconstructed)
382    }
383
384    /// Fit and transform in one step
385    pub fn fit_transform(&self, data: ArrayView2<f64>) -> Result<(Array2<f64>, PCAResult)> {
386        let result = self.fit(data)?;
387        let transformed = self.transform(data, &result)?;
388        Ok((transformed, result))
389    }
390}
391
392/// Compute the optimal number of components using Minka's MLE
393#[allow(dead_code)]
394pub fn mle_components(data: ArrayView2<f64>, maxcomponents: Option<usize>) -> Result<usize> {
395    checkarray_finite(&data, "data")?;
396    let (n_samples, n_features) = data.dim();
397
398    let pca = PCA::new().with_n_components(maxcomponents.unwrap_or(n_features.min(n_samples)));
399    let result = pca.fit(data)?;
400
401    let eigenvalues = &result.explained_variance;
402    let n = n_samples as f64;
403    let p = n_features as f64;
404
405    // Minka's MLE for PCA
406    let mut best_k = 0;
407    let mut best_ll = f64::NEG_INFINITY;
408
409    for k in 0..eigenvalues.len() {
410        let k_f64 = k as f64;
411
412        // Average of remaining eigenvalues
413        let sigma2 = if k < eigenvalues.len() - 1 {
414            eigenvalues.slice(scirs2_core::ndarray::s![k + 1..]).sum() / (p - k_f64 - 1.0)
415        } else {
416            1e-10
417        };
418
419        // Log-likelihood
420        let ll = -n / 2.0
421            * (eigenvalues
422                .slice(scirs2_core::ndarray::s![..=k])
423                .mapv(f64::ln)
424                .sum()
425                + (p - k_f64 - 1.0) * sigma2.ln()
426                + p * (2.0 * std::f64::consts::PI).ln());
427
428        // AIC penalty
429        let aic_penalty = k_f64 * (2.0 * p - k_f64 - 1.0);
430        let aic = ll - aic_penalty;
431
432        if aic > best_ll {
433            best_ll = aic;
434            best_k = k + 1;
435        }
436    }
437
438    Ok(best_k)
439}
440
441/// Incremental PCA for large datasets that don't fit in memory
442#[derive(Debug, Clone)]
443pub struct IncrementalPCA {
444    /// Base PCA configuration
445    pub pca: PCA,
446    /// Batch size for incremental updates
447    pub batchsize: usize,
448    /// Running mean
449    mean: Option<Array1<f64>>,
450    /// Running components
451    components: Option<Array2<f64>>,
452    /// Singular values
453    singular_values: Option<Array1<f64>>,
454    /// Number of samples seen
455    n_samples_seen: usize,
456    /// Incremental SVD state
457    svd_u: Option<Array2<f64>>,
458    svd_s: Option<Array1<f64>>,
459    svd_v: Option<Array2<f64>>,
460}
461
462impl IncrementalPCA {
463    /// Create a new incremental PCA instance
464    pub fn new(n_components: usize, batchsize: usize) -> Result<Self> {
465        check_positive(n_components, "n_components")?;
466        check_positive(batchsize, "batchsize")?;
467
468        Ok(Self {
469            pca: PCA::new().with_n_components(n_components),
470            batchsize,
471            mean: None,
472            components: None,
473            singular_values: None,
474            n_samples_seen: 0,
475            svd_u: None,
476            svd_s: None,
477            svd_v: None,
478        })
479    }
480
481    /// Partial fit on a batch of data
482    pub fn partial_fit(&mut self, batch: ArrayView2<f64>) -> Result<()> {
483        checkarray_finite(&batch, "batch")?;
484        let (batchsize, n_features) = batch.dim();
485
486        // Update mean incrementally
487        let batch_mean = batch.mean_axis(Axis(0)).expect("Operation failed");
488        let old_n = self.n_samples_seen;
489        self.n_samples_seen += batchsize;
490
491        self.mean = match &self.mean {
492            None => Some(batch_mean.clone()),
493            Some(mean) => {
494                let updated = (mean * old_n as f64 + &batch_mean * batchsize as f64)
495                    / self.n_samples_seen as f64;
496                Some(updated)
497            }
498        };
499
500        // Center the batch
501        let mut centered_batch = batch.to_owned();
502        for mut row in centered_batch.rows_mut() {
503            row -= &batch_mean;
504        }
505
506        // Incremental SVD update using Brand's algorithm
507        let n_components = self
508            .pca
509            .n_components
510            .unwrap_or(n_features.min(self.n_samples_seen));
511
512        if self.svd_u.is_none() {
513            // First batch - initialize with standard SVD using scirs2_linalg
514            let (u, s, vt) = scirs2_linalg::svd(&centered_batch.view(), true, None)
515                .map_err(|e| StatsError::ComputationError(format!("Initial SVD failed: {}", e)))?;
516
517            // Keep only n_components
518            self.svd_u = Some(
519                u.slice(scirs2_core::ndarray::s![.., ..n_components])
520                    .to_owned(),
521            );
522            self.svd_s = Some(s.slice(scirs2_core::ndarray::s![..n_components]).to_owned());
523            self.svd_v = Some(
524                vt.slice(scirs2_core::ndarray::s![..n_components, ..])
525                    .t()
526                    .to_owned(),
527            );
528
529            self.components = Some(
530                self.svd_v
531                    .as_ref()
532                    .expect("Operation failed")
533                    .t()
534                    .to_owned(),
535            );
536            self.singular_values = Some(self.svd_s.as_ref().expect("Operation failed").clone());
537        } else {
538            // Incremental update
539            let u_old = self.svd_u.as_ref().expect("Operation failed");
540            let s_old = self.svd_s.as_ref().expect("Operation failed");
541            let v_old = self.svd_v.as_ref().expect("Operation failed");
542
543            // Project new data onto existing components
544            let projection = centered_batch.dot(v_old);
545            let residual = &centered_batch - &projection.dot(&v_old.t());
546
547            // QR decomposition of residual using scirs2_linalg
548            let (q_res, r_res) = scirs2_linalg::qr(&residual.view(), None).map_err(|e| {
549                StatsError::ComputationError(format!("QR decomposition failed: {}", e))
550            })?;
551
552            // Build augmented matrix
553            let k = s_old.len();
554            let p = r_res.ncols();
555
556            // Create block matrix [diag(s_old), projection^T; 0, r_res]
557            let mut augmented = Array2::zeros((k + p, k + p));
558            for i in 0..k {
559                augmented[[i, i]] = s_old[i];
560            }
561            for i in 0..projection.nrows() {
562                for j in 0..k {
563                    augmented[[j, k + i]] = projection[[i, j]];
564                }
565            }
566            for i in 0..p {
567                for j in 0..p {
568                    augmented[[k + i, k + j]] = r_res[[i, j]];
569                }
570            }
571
572            // SVD of augmented matrix using scirs2_linalg
573            let (u_aug, s_aug, vt_aug) = scirs2_linalg::svd(&augmented.view(), true, None)
574                .map_err(|e| {
575                    StatsError::ComputationError(format!("Augmented SVD failed: {}", e))
576                })?;
577
578            // Update U
579            let mut u_new = Array2::zeros((old_n + batchsize, n_components));
580            let u_aug_slice = u_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
581
582            // Update old samples part
583            let u_old_part = u_old.dot(&u_aug_slice.t());
584            u_new
585                .slice_mut(scirs2_core::ndarray::s![..old_n, ..])
586                .assign(&u_old_part);
587
588            // Update new samples part
589            let u_batch_part =
590                projection.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
591            let u_res_part = q_res.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
592            u_new
593                .slice_mut(scirs2_core::ndarray::s![old_n.., ..])
594                .assign(&(&u_batch_part + &u_res_part));
595
596            // Update singular values
597            self.svd_s = Some(
598                s_aug
599                    .slice(scirs2_core::ndarray::s![..n_components])
600                    .to_owned(),
601            );
602
603            // Update V
604            let v_aug_slice =
605                vt_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
606            let mut v_new = Array2::zeros((n_features, n_components));
607
608            let v_old_part = v_old.dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
609            let v_res_part = q_res
610                .t()
611                .dot(&centered_batch)
612                .t()
613                .dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
614            v_new.assign(&(&v_old_part + &v_res_part));
615
616            self.svd_u = Some(u_new);
617            self.svd_v = Some(v_new.clone());
618            self.components = Some(v_new.t().to_owned());
619            self.singular_values = Some(self.svd_s.as_ref().expect("Operation failed").clone());
620        }
621
622        Ok(())
623    }
624
625    /// Transform new data
626    pub fn transform(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
627        if self.components.is_none() || self.mean.is_none() {
628            return Err(StatsError::ComputationError(
629                "IncrementalPCA must be fitted before transform".to_string(),
630            ));
631        }
632
633        let mut centered = data.to_owned();
634        for mut row in centered.rows_mut() {
635            row -= self.mean.as_ref().expect("Operation failed");
636        }
637
638        Ok(centered.dot(&self.components.as_ref().expect("Operation failed").t()))
639    }
640}