pub mod preconditioning;
pub mod sketching;
pub mod rand_nla;
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{Distribution, Normal};
use std::fmt::Debug;
use std::iter::Sum;
use crate::decomposition::{qr, svd};
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct RandomizedConfig {
pub rank: usize,
pub oversampling: usize,
pub power_iterations: usize,
pub seed: Option<u64>,
}
impl RandomizedConfig {
pub fn new(rank: usize) -> Self {
Self {
rank,
oversampling: 10,
power_iterations: 2,
seed: None,
}
}
pub fn with_oversampling(mut self, oversampling: usize) -> Self {
self.oversampling = oversampling;
self
}
pub fn with_power_iterations(mut self, power_iterations: usize) -> Self {
self.power_iterations = power_iterations;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Debug, Clone)]
pub struct RandomizedPcaResult<F> {
pub components: Array2<F>,
pub explained_variance: Array1<F>,
pub explained_variance_ratio: Array1<F>,
pub singular_values: Array1<F>,
pub mean: Array1<F>,
}
fn gaussian_random_matrix<F>(rows: usize, cols: usize) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + 'static,
{
let mut rng = scirs2_core::random::rng();
let normal = Normal::new(0.0, 1.0).map_err(|e| {
LinalgError::ComputationError(format!("Failed to create normal distribution: {e}"))
})?;
let mut omega = Array2::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
omega[[i, j]] = F::from(normal.sample(&mut rng)).unwrap_or(F::zero());
}
}
Ok(omega)
}
fn thin_orthogonalize<F>(y: &ArrayView2<F>, max_cols: usize) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n_cols) = y.dim();
let target = max_cols.min(n_cols).min(m);
if m >= n_cols {
let (q_full, _) = qr(y, None)?;
let actual = target.min(q_full.ncols());
Ok(q_full.slice(s![.., ..actual]).to_owned())
} else {
let (u, _, _) = svd(y, false, None)?;
let actual = target.min(u.ncols());
Ok(u.slice(s![.., ..actual]).to_owned())
}
}
pub fn randomized_range_finder<F>(
a: &ArrayView2<F>,
rank: usize,
oversampling: Option<usize>,
power_iterations: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
let p = oversampling.unwrap_or(10);
let q_iters = power_iterations.unwrap_or(0);
let l = (rank + p).min(m).min(n);
if rank == 0 {
return Err(LinalgError::InvalidInput(
"Target rank must be greater than 0".to_string(),
));
}
if rank > m.min(n) {
return Err(LinalgError::InvalidInput(format!(
"Target rank ({rank}) exceeds min(m, n) = {}",
m.min(n)
)));
}
let omega = gaussian_random_matrix::<F>(n, l)?;
let mut y = a.dot(&omega);
for _ in 0..q_iters {
let q_y = thin_orthogonalize(&y.view(), l)?;
let z = a.t().dot(&q_y);
let q_z = thin_orthogonalize(&z.view(), l)?;
y = a.dot(&q_z);
}
let q_trunc = thin_orthogonalize(&y.view(), l)?;
Ok(q_trunc)
}
pub fn adaptive_range_finder<F>(
a: &ArrayView2<F>,
tolerance: F,
max_rank: Option<usize>,
block_size: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
let max_r = max_rank.unwrap_or(m.min(n));
let bs = block_size.unwrap_or(5);
if tolerance <= F::zero() {
return Err(LinalgError::InvalidInput(
"Tolerance must be positive".to_string(),
));
}
let mut q_cols: Vec<Array1<F>> = Vec::new();
let mut current_rank = 0;
while current_rank < max_r {
let add_count = bs.min(max_r - current_rank);
let omega = gaussian_random_matrix::<F>(n, add_count)?;
let mut y_block = a.dot(&omega);
for q_col in &q_cols {
for j in 0..add_count {
let mut y_col = y_block.column(j).to_owned();
let dot: F = y_col
.iter()
.zip(q_col.iter())
.fold(F::zero(), |acc, (&yi, &qi)| acc + yi * qi);
for i in 0..m {
y_col[i] -= dot * q_col[i];
}
for i in 0..m {
y_block[[i, j]] = y_col[i];
}
}
}
let q_new = if y_block.nrows() >= y_block.ncols() {
let (q_tmp, _) = qr(&y_block.view(), None)?;
q_tmp
} else {
let (u_tmp, _, _) = svd(&y_block.view(), false, None)?;
u_tmp
};
let mut all_below_tol = true;
let cols_to_add = add_count.min(q_new.ncols());
for j in 0..cols_to_add {
let col = q_new.column(j);
let norm: F = col.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if norm > tolerance {
all_below_tol = false;
q_cols.push(col.to_owned());
current_rank += 1;
}
}
if all_below_tol {
break;
}
}
if q_cols.is_empty() {
return Err(LinalgError::ComputationError(
"Adaptive range finder found no significant directions".to_string(),
));
}
let k = q_cols.len();
let mut q = Array2::zeros((m, k));
for (j, col) in q_cols.iter().enumerate() {
for i in 0..m {
q[[i, j]] = col[i];
}
}
if q.nrows() >= q.ncols() {
let (q_final, _) = qr(&q.view(), None)?;
let k_final = k.min(q_final.ncols());
Ok(q_final.slice(s![.., ..k_final]).to_owned())
} else {
let (u_final, _, _) = svd(&q.view(), false, None)?;
let k_final = k.min(u_final.ncols());
Ok(u_final.slice(s![.., ..k_final]).to_owned())
}
}
pub fn randomized_svd<F>(
a: &ArrayView2<F>,
config: &RandomizedConfig,
) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let k = config.rank;
let (m, n) = a.dim();
if k == 0 {
return Err(LinalgError::InvalidInput(
"Target rank must be greater than 0".to_string(),
));
}
if k > m.min(n) {
return Err(LinalgError::InvalidInput(format!(
"Target rank ({k}) exceeds min(m, n) = {}",
m.min(n)
)));
}
let q = randomized_range_finder(
a,
k,
Some(config.oversampling),
Some(config.power_iterations),
)?;
let b = q.t().dot(a);
let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
let u = q.dot(&u_b);
let k_actual = k.min(sigma.len()).min(u.ncols()).min(vt.nrows());
let u_k = u.slice(s![.., ..k_actual]).to_owned();
let s_k = sigma.slice(s![..k_actual]).to_owned();
let vt_k = vt.slice(s![..k_actual, ..]).to_owned();
Ok((u_k, s_k, vt_k))
}
pub fn single_pass_svd<F>(
a: &ArrayView2<F>,
rank: usize,
oversampling: Option<usize>,
) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
let p = oversampling.unwrap_or(10);
let l = (rank + p).min(m).min(n);
if rank == 0 {
return Err(LinalgError::InvalidInput(
"Target rank must be greater than 0".to_string(),
));
}
if rank > m.min(n) {
return Err(LinalgError::InvalidInput(format!(
"Target rank ({rank}) exceeds min(m, n) = {}",
m.min(n)
)));
}
let omega = gaussian_random_matrix::<F>(n, l)?;
let psi = gaussian_random_matrix::<F>(m, l)?;
let y = a.dot(&omega);
let z = a.t().dot(&psi);
let q = if y.nrows() >= y.ncols() {
let (q_tmp, _) = qr(&y.view(), None)?;
let l_a = l.min(q_tmp.ncols());
q_tmp.slice(s![.., ..l_a]).to_owned()
} else {
let (u_tmp, _, _) = svd(&y.view(), false, None)?;
let l_a = l.min(u_tmp.ncols()).min(m);
u_tmp.slice(s![.., ..l_a]).to_owned()
};
let b = q.t().dot(a);
let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
let u = q.dot(&u_b);
let k = rank.min(sigma.len()).min(u.ncols()).min(vt.nrows());
let u_k = u.slice(s![.., ..k]).to_owned();
let s_k = sigma.slice(s![..k]).to_owned();
let vt_k = vt.slice(s![..k, ..]).to_owned();
Ok((u_k, s_k, vt_k))
}
pub fn randomized_low_rank<F>(
a: &ArrayView2<F>,
rank: usize,
oversampling: Option<usize>,
power_iterations: Option<usize>,
) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if rank == 0 {
return Err(LinalgError::InvalidInput(
"Target rank must be greater than 0".to_string(),
));
}
if rank > m.min(n) {
return Err(LinalgError::InvalidInput(format!(
"Target rank ({rank}) exceeds min(m, n) = {}",
m.min(n)
)));
}
let q = randomized_range_finder(a, rank, oversampling, power_iterations)?;
let b = q.t().dot(a);
let (u_b, sigma, vt) = svd(&b.view(), false, None)?;
let k = rank.min(sigma.len()).min(u_b.ncols()).min(vt.nrows());
let u_bk = u_b.slice(s![.., ..k]).to_owned();
let mut l = q.dot(&u_bk);
for j in 0..k {
let sj = sigma[j];
for i in 0..m {
l[[i, j]] *= sj;
}
}
let r = vt.slice(s![..k, ..]).to_owned();
Ok((l, r))
}
pub fn approximation_error<F>(a: &ArrayView2<F>, q: &ArrayView2<F>) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if q.nrows() != m {
return Err(LinalgError::DimensionError(format!(
"Q has {} rows but A has {} rows",
q.nrows(),
m
)));
}
let qt_a = q.t().dot(a);
let q_qt_a = q.dot(&qt_a);
let mut frobenius_sq = F::zero();
for i in 0..m {
for j in 0..n {
let diff = a[[i, j]] - q_qt_a[[i, j]];
frobenius_sq += diff * diff;
}
}
Ok(frobenius_sq.sqrt())
}
pub fn randomized_pca<F>(
data: &ArrayView2<F>,
n_components: usize,
whiten: bool,
power_iterations: Option<usize>,
) -> LinalgResult<RandomizedPcaResult<F>>
where
F: Float
+ NumAssign
+ Sum
+ Debug
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync
+ 'static,
{
let (n_samples, n_features) = data.dim();
if n_components == 0 {
return Err(LinalgError::InvalidInput(
"Number of components must be greater than 0".to_string(),
));
}
if n_components > n_features.min(n_samples) {
return Err(LinalgError::InvalidInput(format!(
"n_components ({n_components}) exceeds min(n_samples, n_features) = {}",
n_features.min(n_samples)
)));
}
let mut mean = Array1::zeros(n_features);
let n_f = F::from(n_samples)
.ok_or_else(|| LinalgError::ComputationError("Failed to convert n_samples".to_string()))?;
for j in 0..n_features {
let col_sum: F = data.column(j).sum();
mean[j] = col_sum / n_f;
}
let mut centered = data.to_owned();
for i in 0..n_samples {
for j in 0..n_features {
centered[[i, j]] -= mean[j];
}
}
let config = RandomizedConfig::new(n_components)
.with_oversampling(10)
.with_power_iterations(power_iterations.unwrap_or(2));
let (u, sigma, vt) = randomized_svd(¢ered.view(), &config)?;
let k = sigma.len();
let denom = F::from(n_samples.saturating_sub(1).max(1)).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert denominator".to_string())
})?;
let explained_variance = sigma.mapv(|s| s * s / denom);
let total_var = {
let mut total = F::zero();
for j in 0..n_features {
let col = centered.column(j);
let col_var: F = col.iter().fold(F::zero(), |acc, &x| acc + x * x) / denom;
total += col_var;
}
total
};
let explained_variance_ratio = if total_var > F::zero() {
explained_variance.mapv(|v| v / total_var)
} else {
Array1::zeros(k)
};
let components = if whiten {
let mut whitened = vt.slice(s![..k, ..]).to_owned();
for i in 0..k {
if sigma[i] > F::epsilon() {
let scale = F::one() / sigma[i];
for j in 0..n_features {
whitened[[i, j]] *= scale;
}
}
}
whitened
} else {
vt.slice(s![..k, ..]).to_owned()
};
Ok(RandomizedPcaResult {
components,
explained_variance,
explained_variance_ratio,
singular_values: sigma.slice(s![..k]).to_owned(),
mean,
})
}
pub fn pca_transform<F>(
data: &ArrayView2<F>,
pca_result: &RandomizedPcaResult<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (n_samples, n_features) = data.dim();
if n_features != pca_result.mean.len() {
return Err(LinalgError::DimensionError(format!(
"Data has {} features but PCA was fitted with {} features",
n_features,
pca_result.mean.len()
)));
}
let mut centered = data.to_owned();
for i in 0..n_samples {
for j in 0..n_features {
centered[[i, j]] -= pca_result.mean[j];
}
}
let transformed = centered.dot(&pca_result.components.t());
Ok(transformed)
}
pub fn pca_inverse_transform<F>(
transformed: &ArrayView2<F>,
pca_result: &RandomizedPcaResult<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (n_samples, n_components) = transformed.dim();
let n_features = pca_result.mean.len();
if n_components != pca_result.components.nrows() {
return Err(LinalgError::DimensionError(format!(
"Transformed data has {} components but PCA has {} components",
n_components,
pca_result.components.nrows()
)));
}
let mut reconstructed = transformed.dot(&pca_result.components);
for i in 0..n_samples {
for j in 0..n_features {
reconstructed[[i, j]] += pca_result.mean[j];
}
}
Ok(reconstructed)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn make_low_rank_matrix(m: usize, n: usize, rank: usize) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let normal =
Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
let mut a_left = Array2::zeros((m, rank));
let mut a_right = Array2::zeros((rank, n));
for i in 0..m {
for j in 0..rank {
a_left[[i, j]] = normal.sample(&mut rng);
}
}
for i in 0..rank {
for j in 0..n {
a_right[[i, j]] = normal.sample(&mut rng);
}
}
a_left.dot(&a_right)
}
#[test]
fn test_randomized_range_finder_basic() {
let a = array![
[3.0, 1.0, 0.5],
[1.0, 3.0, 0.5],
[0.5, 0.5, 2.0],
[1.0, 1.0, 1.0]
];
let q = randomized_range_finder(&a.view(), 2, Some(1), Some(1));
assert!(q.is_ok());
let q = q.expect("range finder failed");
assert_eq!(q.nrows(), 4);
assert!(q.ncols() >= 2);
let qtq = q.t().dot(&q);
for i in 0..qtq.nrows() {
for j in 0..qtq.ncols() {
if i == j {
assert!(
(qtq[[i, j]] - 1.0).abs() < 1e-6,
"Q^TQ not identity on diagonal"
);
} else {
assert!(qtq[[i, j]].abs() < 1e-6, "Q^TQ not identity off-diagonal");
}
}
}
}
#[test]
fn test_randomized_range_finder_error_cases() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
assert!(randomized_range_finder(&a.view(), 0, None, None).is_err());
assert!(randomized_range_finder(&a.view(), 5, None, None).is_err());
}
#[test]
fn test_adaptive_range_finder() {
let a = make_low_rank_matrix(20, 15, 3);
let q = adaptive_range_finder(&a.view(), 1e-6, Some(10), Some(2));
assert!(q.is_ok());
let q = q.expect("adaptive range finder failed");
assert!(q.ncols() >= 3, "Should detect at least rank 3");
}
#[test]
fn test_randomized_svd_basic() {
let a = array![
[3.0, 1.0, 0.5],
[1.0, 3.0, 0.5],
[0.5, 0.5, 2.0],
[1.0, 1.0, 1.0]
];
let config = RandomizedConfig::new(2)
.with_oversampling(1)
.with_power_iterations(2);
let result = randomized_svd(&a.view(), &config);
assert!(result.is_ok());
let (u, s, vt) = result.expect("randomized SVD failed");
assert_eq!(u.nrows(), 4);
assert_eq!(u.ncols(), 2);
assert_eq!(s.len(), 2);
assert_eq!(vt.nrows(), 2);
assert_eq!(vt.ncols(), 3);
assert!(s[0] > 0.0);
assert!(s[0] >= s[1]);
}
#[test]
fn test_randomized_svd_low_rank() {
let a = make_low_rank_matrix(30, 20, 3);
let config = RandomizedConfig::new(3).with_power_iterations(3);
let result = randomized_svd(&a.view(), &config);
assert!(result.is_ok());
let (u, s, vt) = result.expect("randomized SVD failed");
let mut reconstructed = Array2::zeros((30, 20));
for i in 0..30 {
for j in 0..20 {
let mut val = 0.0;
for k in 0..3 {
val += u[[i, k]] * s[k] * vt[[k, j]];
}
reconstructed[[i, j]] = val;
}
}
let mut error = 0.0;
let mut total = 0.0;
for i in 0..30 {
for j in 0..20 {
let diff = a[[i, j]] - reconstructed[[i, j]];
error += diff * diff;
total += a[[i, j]] * a[[i, j]];
}
}
let rel_error = (error / total).sqrt();
assert!(
rel_error < 0.1,
"Reconstruction error too large: {rel_error}"
);
}
#[test]
fn test_randomized_svd_error_cases() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let config0 = RandomizedConfig::new(0);
assert!(randomized_svd(&a.view(), &config0).is_err());
let config5 = RandomizedConfig::new(5);
assert!(randomized_svd(&a.view(), &config5).is_err());
}
#[test]
fn test_single_pass_svd() {
let a = array![
[3.0, 1.0, 0.5],
[1.0, 3.0, 0.5],
[0.5, 0.5, 2.0],
[1.0, 1.0, 1.0]
];
let result = single_pass_svd(&a.view(), 2, Some(1));
assert!(result.is_ok());
let (u, s, vt) = result.expect("single pass SVD failed");
assert_eq!(u.nrows(), 4);
assert_eq!(u.ncols(), 2);
assert_eq!(s.len(), 2);
assert_eq!(vt.nrows(), 2);
assert_eq!(vt.ncols(), 3);
}
#[test]
fn test_single_pass_svd_errors() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
assert!(single_pass_svd(&a.view(), 0, None).is_err());
assert!(single_pass_svd(&a.view(), 5, None).is_err());
}
#[test]
fn test_randomized_low_rank() {
let a = make_low_rank_matrix(20, 15, 3);
let result = randomized_low_rank(&a.view(), 3, Some(5), Some(2));
assert!(result.is_ok());
let (l, r) = result.expect("low rank failed");
assert_eq!(l.nrows(), 20);
assert_eq!(l.ncols(), 3);
assert_eq!(r.nrows(), 3);
assert_eq!(r.ncols(), 15);
let approx = l.dot(&r);
let mut error = 0.0;
let mut total = 0.0;
for i in 0..20 {
for j in 0..15 {
let diff = a[[i, j]] - approx[[i, j]];
error += diff * diff;
total += a[[i, j]] * a[[i, j]];
}
}
let rel_error = if total > 0.0 {
(error / total).sqrt()
} else {
0.0
};
assert!(
rel_error < 0.2,
"Low-rank approximation error too large: {rel_error}"
);
}
#[test]
fn test_randomized_low_rank_errors() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
assert!(randomized_low_rank(&a.view(), 0, None, None).is_err());
assert!(randomized_low_rank(&a.view(), 5, None, None).is_err());
}
#[test]
fn test_approximation_error() {
let a = array![[3.0, 1.0], [1.0, 3.0], [0.5, 0.5]];
let q =
randomized_range_finder(&a.view(), 2, Some(0), Some(1)).expect("range finder failed");
let err = approximation_error(&a.view(), &q.view());
assert!(err.is_ok());
let err_val = err.expect("approx error failed");
assert!(
err_val < 1e-6,
"Full-rank approximation error should be small"
);
}
#[test]
fn test_approximation_error_dimension_mismatch() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let q = array![[1.0], [0.0], [0.0]]; assert!(approximation_error(&a.view(), &q.view()).is_err());
}
#[test]
fn test_randomized_pca_basic() {
let mut data = Array2::zeros((50, 5));
let mut rng = scirs2_core::random::rng();
let normal =
Normal::new(0.0, 1.0).unwrap_or_else(|_| panic!("Failed to create distribution"));
for i in 0..50 {
let c1 = normal.sample(&mut rng);
let c2 = normal.sample(&mut rng);
data[[i, 0]] = c1 * 3.0;
data[[i, 1]] = c1 * 3.0 + normal.sample(&mut rng) * 0.1;
data[[i, 2]] = c2 * 2.0;
data[[i, 3]] = c2 * 2.0 + normal.sample(&mut rng) * 0.1;
data[[i, 4]] = normal.sample(&mut rng) * 0.01;
}
let result = randomized_pca(&data.view(), 2, false, Some(3));
assert!(result.is_ok());
let pca = result.expect("PCA failed");
assert_eq!(pca.components.nrows(), 2);
assert_eq!(pca.components.ncols(), 5);
assert_eq!(pca.explained_variance.len(), 2);
assert_eq!(pca.explained_variance_ratio.len(), 2);
assert_eq!(pca.singular_values.len(), 2);
assert_eq!(pca.mean.len(), 5);
let total_explained: f64 = pca.explained_variance_ratio.sum();
assert!(
total_explained > 0.8,
"Top 2 components should explain >80% variance, got {total_explained}"
);
}
#[test]
fn test_randomized_pca_whiten() {
let data = array![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
[13.0, 14.0, 15.0]
];
let result = randomized_pca(&data.view(), 2, true, Some(1));
assert!(result.is_ok());
let pca = result.expect("whitened PCA failed");
assert_eq!(pca.components.nrows(), 2);
}
#[test]
fn test_randomized_pca_error_cases() {
let data = array![[1.0, 2.0], [3.0, 4.0]];
assert!(randomized_pca(&data.view(), 0, false, None).is_err());
assert!(randomized_pca(&data.view(), 5, false, None).is_err());
}
#[test]
fn test_pca_transform_and_inverse() {
let data = array![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0]
];
let pca = randomized_pca(&data.view(), 2, false, Some(2)).expect("PCA failed");
let transformed = pca_transform(&data.view(), &pca).expect("transform failed");
assert_eq!(transformed.nrows(), 4);
assert_eq!(transformed.ncols(), 2);
let reconstructed =
pca_inverse_transform(&transformed.view(), &pca).expect("inverse transform failed");
assert_eq!(reconstructed.nrows(), 4);
assert_eq!(reconstructed.ncols(), 3);
for i in 0..4 {
for j in 0..3 {
assert!(
(data[[i, j]] - reconstructed[[i, j]]).abs() < 1.0,
"Reconstruction error too large at [{i}, {j}]"
);
}
}
}
#[test]
fn test_pca_transform_dimension_mismatch() {
let data = array![[1.0, 2.0], [3.0, 4.0]];
let pca = randomized_pca(&data.view(), 1, false, Some(1)).expect("PCA failed");
let wrong_data = array![[1.0, 2.0, 3.0]]; assert!(pca_transform(&wrong_data.view(), &pca).is_err());
}
#[test]
fn test_config_builder() {
let config = RandomizedConfig::new(5)
.with_oversampling(20)
.with_power_iterations(3)
.with_seed(42);
assert_eq!(config.rank, 5);
assert_eq!(config.oversampling, 20);
assert_eq!(config.power_iterations, 3);
assert_eq!(config.seed, Some(42));
}
#[test]
fn test_randomized_svd_identity_like() {
let a = array![
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0]
];
let config = RandomizedConfig::new(3)
.with_oversampling(0)
.with_power_iterations(1);
let result = randomized_svd(&a.view(), &config);
assert!(result.is_ok());
let (_u, s, _vt) = result.expect("SVD of identity-like failed");
for i in 0..s.len() {
assert!(
(s[i] - 1.0).abs() < 0.1,
"Singular value {} = {}, expected ~1.0",
i,
s[i]
);
}
}
}