use crate::error::{StatsError, StatsResult as Result};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
use scirs2_core::validation::*;
#[derive(Debug, Clone)]
pub struct PCA {
pub n_components: Option<usize>,
pub svd_solver: SvdSolver,
pub center: bool,
pub scale: bool,
pub random_state: Option<u64>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SvdSolver {
Full,
Randomized,
Auto,
}
#[derive(Debug, Clone)]
pub struct PCAResult {
pub components: Array2<f64>,
pub explained_variance: Array1<f64>,
pub explained_variance_ratio: Array1<f64>,
pub singular_values: Array1<f64>,
pub mean: Array1<f64>,
pub scale: Option<Array1<f64>>,
pub n_samples_: usize,
pub n_features: usize,
}
impl Default for PCA {
fn default() -> Self {
Self {
n_components: None,
svd_solver: SvdSolver::Auto,
center: true,
scale: false,
random_state: None,
}
}
}
impl PCA {
pub fn new() -> Self {
Self::default()
}
pub fn with_n_components(mut self, n_components: usize) -> Self {
self.n_components = Some(n_components);
self
}
pub fn with_svd_solver(mut self, solver: SvdSolver) -> Self {
self.svd_solver = solver;
self
}
pub fn with_center(mut self, center: bool) -> Self {
self.center = center;
self
}
pub fn with_scale(mut self, scale: bool) -> Self {
self.scale = scale;
self
}
pub fn with_random_state(mut self, seed: u64) -> Self {
self.random_state = Some(seed);
self
}
pub fn fit(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
checkarray_finite(&data, "data")?;
let (n_samples, n_features) = data.dim();
if n_samples < 2 {
return Err(StatsError::InvalidArgument(
"n_samples must be at least 2".to_string(),
));
}
if n_features < 1 {
return Err(StatsError::InvalidArgument(
"n_features must be at least 1".to_string(),
));
}
let max_components = n_samples.min(n_features);
let n_components = match self.n_components {
Some(k) => {
check_positive(k, "n_components")?;
if k > max_components {
return Err(StatsError::InvalidArgument(format!(
"n_components ({}) cannot be larger than min(n_samples, n_features) = {}",
k, max_components
)));
}
k
}
None => max_components,
};
let mean = if self.center {
data.mean_axis(Axis(0)).expect("Operation failed")
} else {
Array1::zeros(n_features)
};
let mut centereddata = data.to_owned();
if self.center {
for mut row in centereddata.rows_mut() {
row -= &mean;
}
}
let scale = if self.scale {
let std = centereddata.std_axis(Axis(0), 1.0);
let std = std.mapv(|s| if s > 1e-10 { s } else { 1.0 });
for (mut col, &s) in centereddata.columns_mut().into_iter().zip(std.iter()) {
col /= s;
}
Some(std)
} else {
None
};
let solver = match self.svd_solver {
SvdSolver::Auto => {
if n_samples >= 500 && n_features >= 500 && n_components < max_components / 2 {
SvdSolver::Randomized
} else {
SvdSolver::Full
}
}
solver => solver,
};
let result = match solver {
SvdSolver::Full => self.pca_svd(¢ereddata, n_components, n_samples)?,
SvdSolver::Randomized => self.pca_randomized(¢ereddata, n_components, n_samples)?,
_ => unreachable!(),
};
Ok(PCAResult {
components: result.0,
explained_variance: result.1,
explained_variance_ratio: result.2,
singular_values: result.3,
mean,
scale,
n_samples_: n_samples,
n_features,
})
}
fn pca_svd(
&self,
data: &Array2<f64>,
n_components: usize,
n_samples: usize,
) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
let (_u, s, vt) = scirs2_linalg::svd(&data.view(), true, None)
.map_err(|e| StatsError::ComputationError(format!("SVD failed: {}", e)))?;
let v = vt.t().to_owned();
let components = v
.slice(scirs2_core::ndarray::s![.., ..n_components])
.to_owned();
let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
let total_variance = explained_variance.sum();
let explained_variance_ratio = &explained_variance / total_variance;
Ok((
components.t().to_owned(),
explained_variance,
explained_variance_ratio,
singular_values,
))
}
fn pca_randomized(
&self,
data: &Array2<f64>,
n_components: usize,
n_samples: usize,
) -> Result<(Array2<f64>, Array1<f64>, Array1<f64>, Array1<f64>)> {
use scirs2_core::random::{rngs::StdRng, SeedableRng};
use scirs2_core::random::{Distribution, Normal};
let n_features = data.ncols();
let n_oversamples = 10.min((n_features - n_components) / 2);
let n_random = n_components + n_oversamples;
let mut rng = match self.random_state {
Some(seed) => StdRng::seed_from_u64(seed),
None => {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
StdRng::seed_from_u64(seed)
}
};
let normal = Normal::new(0.0, 1.0).map_err(|e| {
StatsError::ComputationError(format!("Failed to create normal distribution: {}", e))
})?;
let omega = Array2::from_shape_fn((n_features, n_random), |_| normal.sample(&mut rng));
let n_iter = 4;
let mut q = data.dot(&omega);
for _ in 0..n_iter {
let (q_mat, _r) = scirs2_linalg::qr(&q.view(), None).map_err(|e| {
StatsError::ComputationError(format!("QR decomposition failed: {}", e))
})?;
q = q_mat;
let z = data.t().dot(&q);
let (q_mat, _r) = scirs2_linalg::qr(&z.view(), None).map_err(|e| {
StatsError::ComputationError(format!("QR decomposition failed: {}", e))
})?;
q = data.dot(&q_mat);
}
let (q_final, _r) = scirs2_linalg::qr(&q.view(), None).map_err(|e| {
StatsError::ComputationError(format!("Final QR decomposition failed: {}", e))
})?;
let b = q_final.t().dot(data);
let (_u_small, s, vt) = scirs2_linalg::svd(&b.view(), true, None).map_err(|e| {
StatsError::ComputationError(format!("SVD of projected matrix failed: {}", e))
})?;
let v = vt.t().to_owned();
let components = v
.slice(scirs2_core::ndarray::s![.., ..n_components])
.to_owned();
let singular_values = s.slice(scirs2_core::ndarray::s![..n_components]).to_owned();
let explained_variance = &singular_values * &singular_values / (n_samples - 1) as f64;
let total_variance = explained_variance.sum();
let explained_variance_ratio = &explained_variance / total_variance;
Ok((
components.t().to_owned(),
explained_variance,
explained_variance_ratio,
singular_values,
))
}
pub fn transform(&self, data: ArrayView2<f64>, result: &PCAResult) -> Result<Array2<f64>> {
checkarray_finite(&data, "data")?;
if data.ncols() != result.n_features {
return Err(StatsError::DimensionMismatch(format!(
"data has {} features, expected {}",
data.ncols(),
result.n_features
)));
}
let mut transformed = data.to_owned();
if self.center {
for mut row in transformed.rows_mut() {
row -= &result.mean;
}
}
if let Some(ref scale) = result.scale {
for (mut col, &s) in transformed.columns_mut().into_iter().zip(scale.iter()) {
col /= s;
}
}
Ok(transformed.dot(&result.components.t()))
}
pub fn inverse_transform(
&self,
data: ArrayView2<f64>,
result: &PCAResult,
) -> Result<Array2<f64>> {
checkarray_finite(&data, "data")?;
let n_components = result.components.nrows();
if data.ncols() != n_components {
return Err(StatsError::DimensionMismatch(format!(
"data has {} components, expected {}",
data.ncols(),
n_components
)));
}
let mut reconstructed = data.dot(&result.components);
if let Some(ref scale) = result.scale {
for (mut col, &s) in reconstructed.columns_mut().into_iter().zip(scale.iter()) {
col *= s;
}
}
if self.center {
for mut row in reconstructed.rows_mut() {
row += &result.mean;
}
}
Ok(reconstructed)
}
pub fn fit_transform(&self, data: ArrayView2<f64>) -> Result<(Array2<f64>, PCAResult)> {
let result = self.fit(data)?;
let transformed = self.transform(data, &result)?;
Ok((transformed, result))
}
}
#[allow(dead_code)]
pub fn mle_components(data: ArrayView2<f64>, maxcomponents: Option<usize>) -> Result<usize> {
checkarray_finite(&data, "data")?;
let (n_samples, n_features) = data.dim();
let pca = PCA::new().with_n_components(maxcomponents.unwrap_or(n_features.min(n_samples)));
let result = pca.fit(data)?;
let eigenvalues = &result.explained_variance;
let n = n_samples as f64;
let p = n_features as f64;
let mut best_k = 0;
let mut best_ll = f64::NEG_INFINITY;
for k in 0..eigenvalues.len() {
let k_f64 = k as f64;
let sigma2 = if k < eigenvalues.len() - 1 {
eigenvalues.slice(scirs2_core::ndarray::s![k + 1..]).sum() / (p - k_f64 - 1.0)
} else {
1e-10
};
let ll = -n / 2.0
* (eigenvalues
.slice(scirs2_core::ndarray::s![..=k])
.mapv(f64::ln)
.sum()
+ (p - k_f64 - 1.0) * sigma2.ln()
+ p * (2.0 * std::f64::consts::PI).ln());
let aic_penalty = k_f64 * (2.0 * p - k_f64 - 1.0);
let aic = ll - aic_penalty;
if aic > best_ll {
best_ll = aic;
best_k = k + 1;
}
}
Ok(best_k)
}
#[derive(Debug, Clone)]
pub struct IncrementalPCA {
pub pca: PCA,
pub batchsize: usize,
mean: Option<Array1<f64>>,
components: Option<Array2<f64>>,
singular_values: Option<Array1<f64>>,
n_samples_seen: usize,
svd_u: Option<Array2<f64>>,
svd_s: Option<Array1<f64>>,
svd_v: Option<Array2<f64>>,
}
impl IncrementalPCA {
pub fn new(n_components: usize, batchsize: usize) -> Result<Self> {
check_positive(n_components, "n_components")?;
check_positive(batchsize, "batchsize")?;
Ok(Self {
pca: PCA::new().with_n_components(n_components),
batchsize,
mean: None,
components: None,
singular_values: None,
n_samples_seen: 0,
svd_u: None,
svd_s: None,
svd_v: None,
})
}
pub fn partial_fit(&mut self, batch: ArrayView2<f64>) -> Result<()> {
checkarray_finite(&batch, "batch")?;
let (batchsize, n_features) = batch.dim();
let batch_mean = batch.mean_axis(Axis(0)).expect("Operation failed");
let old_n = self.n_samples_seen;
self.n_samples_seen += batchsize;
self.mean = match &self.mean {
None => Some(batch_mean.clone()),
Some(mean) => {
let updated = (mean * old_n as f64 + &batch_mean * batchsize as f64)
/ self.n_samples_seen as f64;
Some(updated)
}
};
let mut centered_batch = batch.to_owned();
for mut row in centered_batch.rows_mut() {
row -= &batch_mean;
}
let n_components = self
.pca
.n_components
.unwrap_or(n_features.min(self.n_samples_seen));
if self.svd_u.is_none() {
let (u, s, vt) = scirs2_linalg::svd(¢ered_batch.view(), true, None)
.map_err(|e| StatsError::ComputationError(format!("Initial SVD failed: {}", e)))?;
self.svd_u = Some(
u.slice(scirs2_core::ndarray::s![.., ..n_components])
.to_owned(),
);
self.svd_s = Some(s.slice(scirs2_core::ndarray::s![..n_components]).to_owned());
self.svd_v = Some(
vt.slice(scirs2_core::ndarray::s![..n_components, ..])
.t()
.to_owned(),
);
self.components = Some(
self.svd_v
.as_ref()
.expect("Operation failed")
.t()
.to_owned(),
);
self.singular_values = Some(self.svd_s.as_ref().expect("Operation failed").clone());
} else {
let u_old = self.svd_u.as_ref().expect("Operation failed");
let s_old = self.svd_s.as_ref().expect("Operation failed");
let v_old = self.svd_v.as_ref().expect("Operation failed");
let projection = centered_batch.dot(v_old);
let residual = ¢ered_batch - &projection.dot(&v_old.t());
let (q_res, r_res) = scirs2_linalg::qr(&residual.view(), None).map_err(|e| {
StatsError::ComputationError(format!("QR decomposition failed: {}", e))
})?;
let k = s_old.len();
let p = r_res.ncols();
let mut augmented = Array2::zeros((k + p, k + p));
for i in 0..k {
augmented[[i, i]] = s_old[i];
}
for i in 0..projection.nrows() {
for j in 0..k {
augmented[[j, k + i]] = projection[[i, j]];
}
}
for i in 0..p {
for j in 0..p {
augmented[[k + i, k + j]] = r_res[[i, j]];
}
}
let (u_aug, s_aug, vt_aug) = scirs2_linalg::svd(&augmented.view(), true, None)
.map_err(|e| {
StatsError::ComputationError(format!("Augmented SVD failed: {}", e))
})?;
let mut u_new = Array2::zeros((old_n + batchsize, n_components));
let u_aug_slice = u_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
let u_old_part = u_old.dot(&u_aug_slice.t());
u_new
.slice_mut(scirs2_core::ndarray::s![..old_n, ..])
.assign(&u_old_part);
let u_batch_part =
projection.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
let u_res_part = q_res.dot(&u_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
u_new
.slice_mut(scirs2_core::ndarray::s![old_n.., ..])
.assign(&(&u_batch_part + &u_res_part));
self.svd_s = Some(
s_aug
.slice(scirs2_core::ndarray::s![..n_components])
.to_owned(),
);
let v_aug_slice =
vt_aug.slice(scirs2_core::ndarray::s![..n_components, ..n_components]);
let mut v_new = Array2::zeros((n_features, n_components));
let v_old_part = v_old.dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., ..k]).t());
let v_res_part = q_res
.t()
.dot(¢ered_batch)
.t()
.dot(&v_aug_slice.slice(scirs2_core::ndarray::s![.., k..]).t());
v_new.assign(&(&v_old_part + &v_res_part));
self.svd_u = Some(u_new);
self.svd_v = Some(v_new.clone());
self.components = Some(v_new.t().to_owned());
self.singular_values = Some(self.svd_s.as_ref().expect("Operation failed").clone());
}
Ok(())
}
pub fn transform(&self, data: ArrayView2<f64>) -> Result<Array2<f64>> {
if self.components.is_none() || self.mean.is_none() {
return Err(StatsError::ComputationError(
"IncrementalPCA must be fitted before transform".to_string(),
));
}
let mut centered = data.to_owned();
for mut row in centered.rows_mut() {
row -= self.mean.as_ref().expect("Operation failed");
}
Ok(centered.dot(&self.components.as_ref().expect("Operation failed").t()))
}
}