1use ndarray::{s, Array2, Axis};
4use linfa_linalg::svd::SVD;
5use thiserror::Error;
6
7#[derive(Error, Debug)]
9pub enum PcaError {
10 #[error("Empty data")]
11 EmptyData,
12 #[error("Insufficient data: need at least {min} points, got {actual}")]
13 InsufficientData { min: usize, actual: usize },
14 #[error("Invalid number of components: {0}")]
15 InvalidComponents(String),
16 #[error("SVD decomposition failed: {0}")]
17 SvdFailed(String),
18}
19
20pub type PcaResult<T> = Result<T, PcaError>;
21
22#[derive(Debug)]
24pub struct Pca {
25 n_components: usize,
27 components: Array2<f64>,
29 explained_variance_ratio: Vec<f64>,
31 mean: ndarray::Array1<f64>,
33}
34
35impl Pca {
36 pub fn new(n_components: usize) -> Self {
41 Self {
42 n_components,
43 components: Array2::zeros((0, 0)),
44 explained_variance_ratio: Vec::new(),
45 mean: ndarray::Array1::zeros(0),
46 }
47 }
48
49 pub fn fit(mut self, data: &Array2<f64>) -> PcaResult<Self> {
57 if data.nrows() == 0 {
58 return Err(PcaError::EmptyData);
59 }
60
61 let n_samples = data.nrows();
62 let n_features = data.ncols();
63
64 if n_samples < 2 {
65 return Err(PcaError::InsufficientData {
66 min: 2,
67 actual: n_samples,
68 });
69 }
70
71 let mean = data.mean_axis(Axis(0))
73 .ok_or_else(|| PcaError::SvdFailed("Failed to calculate mean".to_string()))?;
74 let mut centered = data.clone();
75 for mut row in centered.rows_mut() {
76 row -= &mean;
77 }
78
79 let (u_opt, s, vt_opt) = centered
82 .svd(true, true)
83 .map_err(|e| PcaError::SvdFailed(format!("SVD failed: {:?}", e)))?;
84
85 let _u = u_opt.ok_or_else(|| PcaError::SvdFailed("U matrix not available".to_string()))?;
86 let vt = vt_opt.ok_or_else(|| PcaError::SvdFailed("Vt matrix not available".to_string()))?;
87
88 let components = vt;
91
92 let s_squared: Vec<f64> = s.iter().map(|&val| val * val).collect();
94 let total_variance: f64 = s_squared.iter().sum();
95 let explained_variance_ratio: Vec<f64> = s_squared
96 .iter()
97 .map(|&val| val / total_variance)
98 .collect();
99
100 let n_components = self.n_components.min(n_features);
102 let components = components.slice(s![..n_components, ..]).to_owned();
103 let explained_variance_ratio = explained_variance_ratio[..n_components].to_vec();
104
105 self.n_components = n_components;
106 self.components = components;
107 self.explained_variance_ratio = explained_variance_ratio;
108 self.mean = mean;
109
110 Ok(self)
111 }
112
113 pub fn transform(&self, data: &Array2<f64>) -> PcaResult<Array2<f64>> {
121 if data.nrows() == 0 {
122 return Err(PcaError::EmptyData);
123 }
124
125 let mut centered = data.clone();
127 for mut row in centered.rows_mut() {
128 row -= &self.mean;
129 }
130
131 let transformed = centered.dot(&self.components.t());
133
134 Ok(transformed)
135 }
136
137 pub fn components(&self) -> &Array2<f64> {
139 &self.components
140 }
141
142 pub fn explained_variance_ratio(&self) -> &[f64] {
144 &self.explained_variance_ratio
145 }
146
147 pub fn mean(&self) -> &ndarray::Array1<f64> {
149 &self.mean
150 }
151}