use crate::error::ModelError;
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix2};
use ndarray_rand::rand::rngs::StdRng;
use ndarray_rand::rand::{Rng, SeedableRng};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::IntoParallelRefIterator;
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)]
pub enum SVDSolver {
Full,
Randomized(u64),
ARPACK,
}
const PCA_PARALLEL_THRESHOLD: usize = 200;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PCA {
n_components: usize,
svd_solver: SVDSolver,
mean: Option<Array1<f64>>,
components: Option<Array2<f64>>,
explained_variance: Option<Array1<f64>>,
explained_variance_ratio: Option<Array1<f64>>,
singular_values: Option<Array1<f64>>,
n_samples: Option<usize>,
n_features: Option<usize>,
}
impl Default for PCA {
fn default() -> Self {
PCA::new(2, SVDSolver::Full).expect("Default PCA parameters should be valid")
}
}
impl PCA {
pub fn new(n_components: usize, svd_solver: SVDSolver) -> Result<Self, ModelError> {
if n_components == 0 {
return Err(ModelError::InputValidationError(
"n_components must be greater than 0".to_string(),
));
}
Ok(Self {
n_components,
svd_solver,
mean: None,
components: None,
explained_variance: None,
explained_variance_ratio: None,
singular_values: None,
n_samples: None,
n_features: None,
})
}
get_field!(get_n_components, n_components, usize);
get_field!(get_svd_solver, svd_solver, SVDSolver);
get_field!(get_n_samples, n_samples, Option<usize>);
get_field!(get_n_features, n_features, Option<usize>);
get_field_as_ref!(get_mean, mean, Option<&Array1<f64>>);
get_field_as_ref!(get_components, components, Option<&Array2<f64>>);
get_field_as_ref!(
get_explained_variance,
explained_variance,
Option<&Array1<f64>>
);
get_field_as_ref!(
get_explained_variance_ratio,
explained_variance_ratio,
Option<&Array1<f64>>
);
get_field_as_ref!(get_singular_values, singular_values, Option<&Array1<f64>>);
pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64>,
{
self.fit_internal(x, true)
}
pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>, ModelError>
where
S: Data<Elem = f64>,
{
self.transform_internal(x, true)
}
pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>, ModelError>
where
S: Data<Elem = f64>,
{
let progress_bar = Self::create_progress_bar(2, "Fitting model");
self.fit_internal(x, false)?;
progress_bar.inc(1);
progress_bar.set_message("Transforming data");
let transformed = self.transform_internal(x, false)?;
progress_bar.inc(1);
progress_bar.finish_with_message("Completed");
Ok(transformed)
}
pub fn inverse_transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>, ModelError>
where
S: Data<Elem = f64>,
{
let components = self.components.as_ref().ok_or(ModelError::NotFitted)?;
let mean = self.mean.as_ref().ok_or(ModelError::NotFitted)?;
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot inverse transform empty dataset".to_string(),
));
}
if x.ncols() != components.nrows() {
return Err(ModelError::InputValidationError(format!(
"Number of components does not match training data, x columns: {}, expected: {}",
x.ncols(),
components.nrows()
)));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let progress_bar = Self::create_progress_bar(3, "Validating input");
progress_bar.inc(1);
progress_bar.set_message("Reconstructing data");
let reconstructed = if x.nrows() >= PCA_PARALLEL_THRESHOLD {
let x_owned = x.to_owned();
Self::reconstruct_parallel(&x_owned, components, mean)?
} else {
let mut reconstructed = x.dot(components);
reconstructed += mean;
reconstructed
};
progress_bar.inc(1);
progress_bar.set_message("Finalizing output");
progress_bar.inc(1);
progress_bar.finish_with_message("Completed");
Ok(reconstructed)
}
fn fit_internal<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
show_progress: bool,
) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64>,
{
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Input data cannot be empty".to_string(),
));
}
if x.ncols() == 0 {
return Err(ModelError::InputValidationError(
"Number of features must be greater than 0".to_string(),
));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let n_samples = x.nrows();
let n_features = x.ncols();
if n_samples < 2 {
return Err(ModelError::InputValidationError(
"PCA requires at least 2 samples".to_string(),
));
}
let max_components = n_samples.min(n_features);
if self.n_components > max_components {
return Err(ModelError::InputValidationError(format!(
"n_components should be <= {}, got {}",
max_components, self.n_components
)));
}
let progress_bar = if show_progress {
Some(Self::create_progress_bar(5, "Validating input"))
} else {
None
};
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.set_message("Centering data");
}
let mut x_centered = x.to_owned();
let mean = Self::compute_mean(&x_centered);
Self::center_data(&mut x_centered, &mean);
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.set_message("Computing decomposition");
}
let (components, singular_values) = self.compute_components(&x_centered)?;
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.set_message("Computing explained variance");
}
let explained_variance = singular_values.mapv(|s| (s * s) / ((n_samples - 1) as f64));
let total_variance = Self::total_variance(&x_centered, n_samples)?;
let explained_variance_ratio = if total_variance > 0.0 && total_variance.is_finite() {
explained_variance.mapv(|v| v / total_variance)
} else {
Array1::zeros(self.n_components)
};
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.set_message("Finalizing model state");
}
self.mean = Some(mean);
self.components = Some(components);
self.explained_variance = Some(explained_variance);
self.explained_variance_ratio = Some(explained_variance_ratio);
self.singular_values = Some(singular_values);
self.n_samples = Some(n_samples);
self.n_features = Some(n_features);
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.finish_with_message("Completed");
}
Ok(self)
}
fn transform_internal<S>(
&self,
x: &ArrayBase<S, Ix2>,
show_progress: bool,
) -> Result<Array2<f64>, ModelError>
where
S: Data<Elem = f64>,
{
let components = self.components.as_ref().ok_or(ModelError::NotFitted)?;
let mean = self.mean.as_ref().ok_or(ModelError::NotFitted)?;
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot transform empty dataset".to_string(),
));
}
if x.ncols() != components.ncols() {
return Err(ModelError::InputValidationError(format!(
"Number of features does not match training data, x columns: {}, expected: {}",
x.ncols(),
components.ncols()
)));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let progress_bar = if show_progress {
Some(Self::create_progress_bar(3, "Validating input"))
} else {
None
};
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.set_message("Centering data");
}
let mut x_centered = x.to_owned();
Self::center_data(&mut x_centered, mean);
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.set_message("Projecting data");
}
let transformed = if x_centered.nrows() >= PCA_PARALLEL_THRESHOLD {
Self::project_parallel(&x_centered, components)?
} else {
x_centered.dot(&components.t())
};
if let Some(pb) = &progress_bar {
pb.inc(1);
pb.finish_with_message("Completed");
}
Ok(transformed)
}
fn compute_mean(x: &Array2<f64>) -> Array1<f64> {
let n_samples = x.nrows();
let n_features = x.ncols();
if n_samples >= PCA_PARALLEL_THRESHOLD {
let means: Vec<f64> = (0..n_features)
.into_par_iter()
.map(|col| x.column(col).sum() / n_samples as f64)
.collect();
Array1::from_vec(means)
} else {
x.mean_axis(Axis(0)).expect("Input data must be non-empty")
}
}
fn center_data(x: &mut Array2<f64>, mean: &Array1<f64>) {
if x.nrows() >= PCA_PARALLEL_THRESHOLD {
let mean = mean.to_owned();
x.axis_iter_mut(Axis(0))
.into_par_iter()
.for_each(|mut row| {
row -= &mean;
});
} else {
for mut row in x.axis_iter_mut(Axis(0)) {
row -= mean;
}
}
}
fn total_variance(x_centered: &Array2<f64>, n_samples: usize) -> Result<f64, ModelError> {
let denom = (n_samples - 1) as f64;
if denom <= 0.0 {
return Err(ModelError::ProcessingError(
"Variance computation requires at least 2 samples".to_string(),
));
}
let sum_sq = if x_centered.nrows() >= PCA_PARALLEL_THRESHOLD {
if let Some(slice) = x_centered.as_slice() {
slice.par_iter().map(|v| v * v).sum::<f64>()
} else {
x_centered.iter().map(|v| v * v).sum::<f64>()
}
} else {
x_centered.iter().map(|v| v * v).sum::<f64>()
};
Ok(sum_sq / denom)
}
fn compute_components(
&self,
x_centered: &Array2<f64>,
) -> Result<(Array2<f64>, Array1<f64>), ModelError> {
match self.svd_solver {
SVDSolver::Full => self.compute_full_svd(x_centered),
SVDSolver::Randomized(seed) => self.compute_randomized_svd(x_centered, seed),
SVDSolver::ARPACK => self.compute_arpack_svd(x_centered),
}
}
fn compute_full_svd(
&self,
x_centered: &Array2<f64>,
) -> Result<(Array2<f64>, Array1<f64>), ModelError> {
let n_samples = x_centered.nrows();
let n_features = x_centered.ncols();
let x_slice = x_centered.as_slice().ok_or_else(|| {
ModelError::ProcessingError("Failed to convert centered data to slice".to_string())
})?;
let x_mat = nalgebra::DMatrix::from_row_slice(n_samples, n_features, x_slice);
let svd = nalgebra::linalg::SVD::new(x_mat, false, true);
let v_t = svd.v_t.ok_or_else(|| {
ModelError::ProcessingError("SVD did not compute V^T matrix".to_string())
})?;
let singular_values: Vec<f64> = svd
.singular_values
.iter()
.take(self.n_components)
.cloned()
.collect();
let mut components = Array2::<f64>::zeros((self.n_components, n_features));
for i in 0..self.n_components {
for j in 0..n_features {
components[[i, j]] = v_t[(i, j)];
}
}
Ok((components, Array1::from_vec(singular_values)))
}
fn compute_randomized_svd(
&self,
x_centered: &Array2<f64>,
seed: u64,
) -> Result<(Array2<f64>, Array1<f64>), ModelError> {
let n_samples = x_centered.nrows();
let n_features = x_centered.ncols();
let max_rank = n_samples.min(n_features);
let oversampling = 5usize;
let k = (self.n_components + oversampling).min(max_rank);
let mut rng = StdRng::seed_from_u64(seed);
let mut omega = Vec::with_capacity(n_features * k);
for _ in 0..(n_features * k) {
omega.push(rng.random_range(-1.0..1.0));
}
let x_slice = x_centered.as_slice().ok_or_else(|| {
ModelError::ProcessingError("Failed to convert centered data to slice".to_string())
})?;
let x_mat = nalgebra::DMatrix::from_row_slice(n_samples, n_features, x_slice);
let omega_mat = nalgebra::DMatrix::from_row_slice(n_features, k, &omega);
let mut y_mat = &x_mat * &omega_mat;
let n_iter = 2usize;
for _ in 0..n_iter {
let y_t = x_mat.transpose() * &y_mat;
y_mat = &x_mat * y_t;
}
let qr = nalgebra::linalg::QR::new(y_mat);
let q = qr.q();
let b = q.transpose() * x_mat;
let svd = nalgebra::linalg::SVD::new(b, false, true);
let v_t = svd.v_t.ok_or_else(|| {
ModelError::ProcessingError("Randomized SVD did not compute V^T matrix".to_string())
})?;
let singular_values: Vec<f64> = svd
.singular_values
.iter()
.take(self.n_components)
.cloned()
.collect();
let mut components = Array2::<f64>::zeros((self.n_components, n_features));
for i in 0..self.n_components {
for j in 0..n_features {
components[[i, j]] = v_t[(i, j)];
}
}
Ok((components, Array1::from_vec(singular_values)))
}
fn compute_arpack_svd(
&self,
x_centered: &Array2<f64>,
) -> Result<(Array2<f64>, Array1<f64>), ModelError> {
let n_samples = x_centered.nrows();
let n_features = x_centered.ncols();
let denom = (n_samples - 1) as f64;
let mut cov = x_centered.t().dot(x_centered) / denom;
let mut components = Array2::<f64>::zeros((self.n_components, n_features));
let mut eigenvalues = Vec::with_capacity(self.n_components);
let mut rng = StdRng::seed_from_u64(0);
let max_iter = 1000usize;
let tol = 1e-6;
for idx in 0..self.n_components {
let (eigenvector, eigenvalue) = Self::power_iteration(&cov, &mut rng, max_iter, tol)?;
components.row_mut(idx).assign(&eigenvector);
eigenvalues.push(eigenvalue);
let v_col = eigenvector.view().insert_axis(Axis(1));
let v_row = eigenvector.view().insert_axis(Axis(0));
cov -= &(v_col.dot(&v_row) * eigenvalue);
}
let singular_values: Vec<f64> = eigenvalues
.into_iter()
.map(|lambda| {
let clamped = if lambda.is_finite() && lambda > 0.0 {
lambda
} else {
0.0
};
(clamped * denom).sqrt()
})
.collect();
Ok((components, Array1::from_vec(singular_values)))
}
fn power_iteration(
cov: &Array2<f64>,
rng: &mut StdRng,
max_iter: usize,
tol: f64,
) -> Result<(Array1<f64>, f64), ModelError> {
let n_features = cov.ncols();
let mut v = Array1::<f64>::from_vec(
(0..n_features)
.map(|_| rng.random_range(-1.0..1.0))
.collect(),
);
let norm = v.dot(&v).sqrt();
if norm <= f64::EPSILON {
v.fill(1.0 / (n_features as f64).sqrt());
} else {
v /= norm;
}
let mut prev_lambda = 0.0;
for _ in 0..max_iter {
let w = cov.dot(&v);
let w_norm = w.dot(&w).sqrt();
if w_norm <= f64::EPSILON || !w_norm.is_finite() {
return Err(ModelError::ProcessingError(
"Power iteration failed to converge".to_string(),
));
}
let v_next = &w / w_norm;
let lambda = v_next.dot(&cov.dot(&v_next));
if !lambda.is_finite() {
return Err(ModelError::ProcessingError(
"Power iteration produced non-finite eigenvalue".to_string(),
));
}
if (lambda - prev_lambda).abs() < tol {
return Ok((v_next, lambda));
}
prev_lambda = lambda;
v = v_next;
}
let lambda = v.dot(&cov.dot(&v));
if !lambda.is_finite() {
return Err(ModelError::ProcessingError(
"Power iteration produced non-finite eigenvalue".to_string(),
));
}
Ok((v, lambda))
}
fn project_parallel(
x_centered: &Array2<f64>,
components: &Array2<f64>,
) -> Result<Array2<f64>, ModelError> {
let n_samples = x_centered.nrows();
let n_components = components.nrows();
let rows: Vec<Vec<f64>> = x_centered
.outer_iter()
.into_par_iter()
.map(|row| {
let mut projected = vec![0.0; n_components];
for (idx, comp) in components.outer_iter().enumerate() {
projected[idx] = row.dot(&comp);
}
projected
})
.collect();
let flat: Vec<f64> = rows.into_iter().flatten().collect();
Array2::from_shape_vec((n_samples, n_components), flat).map_err(|e| {
ModelError::ProcessingError(format!("Failed to build projected matrix: {}", e))
})
}
fn reconstruct_parallel(
x: &Array2<f64>,
components: &Array2<f64>,
mean: &Array1<f64>,
) -> Result<Array2<f64>, ModelError> {
let n_samples = x.nrows();
let n_features = components.ncols();
let components_t = components.t().to_owned();
let mean_vec = mean.to_owned();
let rows: Vec<Vec<f64>> = x
.outer_iter()
.into_par_iter()
.map(|row| {
let mut reconstructed = vec![0.0; n_features];
for (j, comp_row) in components_t.outer_iter().enumerate() {
reconstructed[j] = row.dot(&comp_row) + mean_vec[j];
}
reconstructed
})
.collect();
let flat: Vec<f64> = rows.into_iter().flatten().collect();
Array2::from_shape_vec((n_samples, n_features), flat).map_err(|e| {
ModelError::ProcessingError(format!("Failed to build reconstructed matrix: {}", e))
})
}
fn create_progress_bar(len: u64, message: &str) -> ProgressBar {
let progress_bar = ProgressBar::new(len);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Stage: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("=>-"),
);
progress_bar.set_message(message.to_string());
progress_bar
}
model_save_and_load_methods!(PCA);
}