use ndarray::{s, Array2, Axis};
use linfa_linalg::svd::SVD;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum PcaError {
#[error("Empty data")]
EmptyData,
#[error("Insufficient data: need at least {min} points, got {actual}")]
InsufficientData { min: usize, actual: usize },
#[error("Invalid number of components: {0}")]
InvalidComponents(String),
#[error("SVD decomposition failed: {0}")]
SvdFailed(String),
}
pub type PcaResult<T> = Result<T, PcaError>;
#[derive(Debug)]
pub struct Pca {
n_components: usize,
components: Array2<f64>,
explained_variance_ratio: Vec<f64>,
mean: ndarray::Array1<f64>,
}
impl Pca {
pub fn new(n_components: usize) -> Self {
Self {
n_components,
components: Array2::zeros((0, 0)),
explained_variance_ratio: Vec::new(),
mean: ndarray::Array1::zeros(0),
}
}
pub fn fit(mut self, data: &Array2<f64>) -> PcaResult<Self> {
if data.nrows() == 0 {
return Err(PcaError::EmptyData);
}
let n_samples = data.nrows();
let n_features = data.ncols();
if n_samples < 2 {
return Err(PcaError::InsufficientData {
min: 2,
actual: n_samples,
});
}
let mean = data.mean_axis(Axis(0))
.ok_or_else(|| PcaError::SvdFailed("Failed to calculate mean".to_string()))?;
let mut centered = data.clone();
for mut row in centered.rows_mut() {
row -= &mean;
}
let (u_opt, s, vt_opt) = centered
.svd(true, true)
.map_err(|e| PcaError::SvdFailed(format!("SVD failed: {:?}", e)))?;
let _u = u_opt.ok_or_else(|| PcaError::SvdFailed("U matrix not available".to_string()))?;
let vt = vt_opt.ok_or_else(|| PcaError::SvdFailed("Vt matrix not available".to_string()))?;
let components = vt;
let s_squared: Vec<f64> = s.iter().map(|&val| val * val).collect();
let total_variance: f64 = s_squared.iter().sum();
let explained_variance_ratio: Vec<f64> = s_squared
.iter()
.map(|&val| val / total_variance)
.collect();
let n_components = self.n_components.min(n_features);
let components = components.slice(s![..n_components, ..]).to_owned();
let explained_variance_ratio = explained_variance_ratio[..n_components].to_vec();
self.n_components = n_components;
self.components = components;
self.explained_variance_ratio = explained_variance_ratio;
self.mean = mean;
Ok(self)
}
pub fn transform(&self, data: &Array2<f64>) -> PcaResult<Array2<f64>> {
if data.nrows() == 0 {
return Err(PcaError::EmptyData);
}
let mut centered = data.clone();
for mut row in centered.rows_mut() {
row -= &self.mean;
}
let transformed = centered.dot(&self.components.t());
Ok(transformed)
}
pub fn components(&self) -> &Array2<f64> {
&self.components
}
pub fn explained_variance_ratio(&self) -> &[f64] {
&self.explained_variance_ratio
}
pub fn mean(&self) -> &ndarray::Array1<f64> {
&self.mean
}
}