Skip to main content

flow_utils/pca/
mod.rs

1//! Principal Component Analysis (PCA) module
2
3use ndarray::{s, Array2, Axis};
4use linfa_linalg::svd::SVD;
5use thiserror::Error;
6
7/// Error type for PCA operations
8#[derive(Error, Debug)]
9pub enum PcaError {
10    #[error("Empty data")]
11    EmptyData,
12    #[error("Insufficient data: need at least {min} points, got {actual}")]
13    InsufficientData { min: usize, actual: usize },
14    #[error("Invalid number of components: {0}")]
15    InvalidComponents(String),
16    #[error("SVD decomposition failed: {0}")]
17    SvdFailed(String),
18}
19
20pub type PcaResult<T> = Result<T, PcaError>;
21
22/// Principal Component Analysis
23#[derive(Debug)]
24pub struct Pca {
25    /// Number of components
26    n_components: usize,
27    /// Principal components (eigenvectors)
28    components: Array2<f64>,
29    /// Explained variance ratio for each component
30    explained_variance_ratio: Vec<f64>,
31    /// Mean of the input data (for centering)
32    mean: ndarray::Array1<f64>,
33}
34
35impl Pca {
36    /// Create a new PCA instance
37    ///
38    /// # Arguments
39    /// * `n_components` - Number of components to keep
40    pub fn new(n_components: usize) -> Self {
41        Self {
42            n_components,
43            components: Array2::zeros((0, 0)),
44            explained_variance_ratio: Vec::new(),
45            mean: ndarray::Array1::zeros(0),
46        }
47    }
48
49    /// Fit PCA to data
50    ///
51    /// # Arguments
52    /// * `data` - Input data matrix (n_samples × n_features)
53    ///
54    /// # Returns
55    /// Fitted Pca instance
56    pub fn fit(mut self, data: &Array2<f64>) -> PcaResult<Self> {
57        if data.nrows() == 0 {
58            return Err(PcaError::EmptyData);
59        }
60
61        let n_samples = data.nrows();
62        let n_features = data.ncols();
63
64        if n_samples < 2 {
65            return Err(PcaError::InsufficientData {
66                min: 2,
67                actual: n_samples,
68            });
69        }
70
71        // Center the data
72        let mean = data.mean_axis(Axis(0))
73            .ok_or_else(|| PcaError::SvdFailed("Failed to calculate mean".to_string()))?;
74        let mut centered = data.clone();
75        for mut row in centered.rows_mut() {
76            row -= &mean;
77        }
78
79        // Perform SVD using linfa-linalg (compatible with ndarray 0.16)
80        // SVD returns (Option<U>, S, Option<Vt>) tuple
81        let (u_opt, s, vt_opt) = centered
82            .svd(true, true)
83            .map_err(|e| PcaError::SvdFailed(format!("SVD failed: {:?}", e)))?;
84        
85        let _u = u_opt.ok_or_else(|| PcaError::SvdFailed("U matrix not available".to_string()))?;
86        let vt = vt_opt.ok_or_else(|| PcaError::SvdFailed("Vt matrix not available".to_string()))?;
87
88        // Extract components (right singular vectors, transposed)
89        // vt is already an Array2, not an Option
90        let components = vt;
91
92        // Calculate explained variance ratio
93        let s_squared: Vec<f64> = s.iter().map(|&val| val * val).collect();
94        let total_variance: f64 = s_squared.iter().sum();
95        let explained_variance_ratio: Vec<f64> = s_squared
96            .iter()
97            .map(|&val| val / total_variance)
98            .collect();
99
100        // Limit to n_components
101        let n_components = self.n_components.min(n_features);
102        let components = components.slice(s![..n_components, ..]).to_owned();
103        let explained_variance_ratio = explained_variance_ratio[..n_components].to_vec();
104
105        self.n_components = n_components;
106        self.components = components;
107        self.explained_variance_ratio = explained_variance_ratio;
108        self.mean = mean;
109
110        Ok(self)
111    }
112
113    /// Transform data to principal component space
114    ///
115    /// # Arguments
116    /// * `data` - Input data matrix (n_samples × n_features)
117    ///
118    /// # Returns
119    /// Transformed data (n_samples × n_components)
120    pub fn transform(&self, data: &Array2<f64>) -> PcaResult<Array2<f64>> {
121        if data.nrows() == 0 {
122            return Err(PcaError::EmptyData);
123        }
124
125        // Center the data
126        let mut centered = data.clone();
127        for mut row in centered.rows_mut() {
128            row -= &self.mean;
129        }
130
131        // Project onto principal components
132        let transformed = centered.dot(&self.components.t());
133
134        Ok(transformed)
135    }
136
137    /// Get principal components
138    pub fn components(&self) -> &Array2<f64> {
139        &self.components
140    }
141
142    /// Get explained variance ratio
143    pub fn explained_variance_ratio(&self) -> &[f64] {
144        &self.explained_variance_ratio
145    }
146
147    /// Get mean of training data
148    pub fn mean(&self) -> &ndarray::Array1<f64> {
149        &self.mean
150    }
151}