use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{Distribution, Normal};
use std::iter::Sum;
use crate::decomposition::{qr, svd};
use crate::error::{LinalgError, LinalgResult};
type CURResult<F> = LinalgResult<(Array2<F>, Array2<F>, Array2<F>, Vec<usize>, Vec<usize>)>;
#[allow(dead_code)]
pub fn randomized_svd<F>(
a: &ArrayView2<F>,
k: usize,
oversampling: Option<usize>,
power_iterations: Option<usize>,
workers: Option<usize>,
) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
let oversampling = oversampling.unwrap_or(10);
let power_iterations = power_iterations.unwrap_or(0);
let l = k + oversampling;
if k == 0 {
return Err(LinalgError::ShapeError(
"Target rank k must be greater than 0".to_string(),
));
}
if k > m.min(n) {
return Err(LinalgError::ShapeError(format!(
"Target rank k ({}) cannot exceed min(m, n) = {}",
k,
m.min(n)
)));
}
if l > n {
return Err(LinalgError::ShapeError(format!(
"Oversampled dimension l ({l}) cannot exceed n = {n}"
)));
}
if let Some(num_workers) = workers {
std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
}
let mut rng = scirs2_core::random::rng();
let normal = Normal::new(0.0, 1.0)
.map_err(|_| LinalgError::ShapeError("Failed to create normal distribution".to_string()))?;
let mut omega = Array2::zeros((n, l));
for i in 0..n {
for j in 0..l {
omega[[i, j]] = F::from(normal.sample(&mut rng)).unwrap_or(F::zero());
}
}
let mut y = a.dot(&omega);
for _ in 0..power_iterations {
let aty = a.t().dot(&y);
y = a.dot(&aty);
}
let (q, _r) = qr(&y.view(), workers)?;
let b = q.t().dot(a);
let (u_tilde, s, vt) = svd(&b.view(), false, workers)?;
let u = q.dot(&u_tilde);
let u_k = u.slice(scirs2_core::ndarray::s![.., ..k]).to_owned();
let s_k = s.slice(scirs2_core::ndarray::s![..k]).to_owned();
let vt_k = vt.slice(scirs2_core::ndarray::s![..k, ..]).to_owned();
Ok((u_k, s_k, vt_k))
}
#[allow(dead_code)]
pub fn truncated_svd<F>(
a: &ArrayView2<F>,
k: usize,
workers: Option<usize>,
) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if k == 0 {
return Err(LinalgError::ShapeError(
"Number of components k must be greater than 0".to_string(),
));
}
if k > m.min(n) {
return Err(LinalgError::ShapeError(format!(
"Number of components k ({}) cannot exceed min(m, n) = {}",
k,
m.min(n)
)));
}
let (u, s, vt) = svd(a, false, workers)?;
let u_k = u.slice(scirs2_core::ndarray::s![.., ..k]).to_owned();
let s_k = s.slice(scirs2_core::ndarray::s![..k]).to_owned();
let vt_k = vt.slice(scirs2_core::ndarray::s![..k, ..]).to_owned();
Ok((u_k, s_k, vt_k))
}
#[allow(dead_code)]
pub fn pca<F>(
data: &ArrayView2<F>,
n_components: usize,
workers: Option<usize>,
) -> LinalgResult<(Array2<F>, Array1<F>, Array1<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (n_samples, n_features) = data.dim();
if n_components == 0 {
return Err(LinalgError::ShapeError(
"Number of _components must be greater than 0".to_string(),
));
}
if n_components > n_features.min(n_samples) {
return Err(LinalgError::ShapeError(format!(
"Number of _components ({}) cannot exceed min(n_samples, n_features) = {}",
n_components,
n_features.min(n_samples)
)));
}
let mut centered_data = data.to_owned();
for j in 0..n_features {
let mean = data.column(j).sum() / F::from(n_samples).expect("Operation failed");
for i in 0..n_samples {
centered_data[[i, j]] -= mean;
}
}
let scale = F::from(n_samples - 1).expect("Operation failed").sqrt();
centered_data.mapv_inplace(|x| x / scale);
let (_u, s, vt) = if n_samples > 1000 && n_features > 1000 {
randomized_svd(¢ered_data.view(), n_components, None, None, workers)?
} else {
truncated_svd(¢ered_data.view(), n_components, workers)?
};
let _components = vt;
let explained_variance = s.mapv(|x| x * x);
let total_variance = explained_variance.sum();
let explained_variance_ratio = if total_variance > F::zero() {
explained_variance.mapv(|x| x / total_variance)
} else {
Array1::zeros(n_components)
};
Ok((_components, explained_variance, explained_variance_ratio))
}
#[allow(dead_code)]
pub fn nmf<F>(
a: &ArrayView2<F>,
k: usize,
max_iter: Option<usize>,
tolerance: Option<F>,
workers: Option<usize>,
) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
let max_iter = max_iter.unwrap_or(100);
let tolerance = tolerance.unwrap_or_else(|| F::from(1e-6).expect("Operation failed"));
if k == 0 {
return Err(LinalgError::ShapeError(
"Number of components k must be greater than 0".to_string(),
));
}
if k > m.min(n) {
return Err(LinalgError::ShapeError(format!(
"Number of components k ({}) cannot exceed min(m, n) = {}",
k,
m.min(n)
)));
}
for &val in a.iter() {
if val < F::zero() {
return Err(LinalgError::ShapeError(
"Input matrix must be non-negative for NMF".to_string(),
));
}
}
if let Some(num_workers) = workers {
std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
}
let mut rng = scirs2_core::random::rng();
let mut w = Array2::zeros((m, k));
let mut h = Array2::zeros((k, n));
for i in 0..m {
for j in 0..k {
w[[i, j]] = F::from(rng.random::<f64>()).expect("Operation failed");
}
}
for i in 0..k {
for j in 0..n {
h[[i, j]] = F::from(rng.random::<f64>()).expect("Operation failed");
}
}
let mut prev_error = F::from(f64::INFINITY).expect("Operation failed");
for _iter in 0..max_iter {
let wt = w.t();
let wta = wt.dot(a);
let wtwh = wt.dot(&w).dot(&h);
for i in 0..k {
for j in 0..n {
let numerator = wta[[i, j]];
let denominator = wtwh[[i, j]] + F::epsilon();
h[[i, j]] = h[[i, j]] * numerator / denominator;
}
}
let ht = h.t();
let aht = a.dot(&ht);
let whht = w.dot(&h).dot(&ht);
for i in 0..m {
for j in 0..k {
let numerator = aht[[i, j]];
let denominator = whht[[i, j]] + F::epsilon();
w[[i, j]] = w[[i, j]] * numerator / denominator;
}
}
let wh = w.dot(&h);
let mut error = F::zero();
for i in 0..m {
for j in 0..n {
let diff = a[[i, j]] - wh[[i, j]];
error += diff * diff;
}
}
if (prev_error - error).abs() < tolerance {
break;
}
prev_error = error;
}
Ok((w, h))
}
#[allow(dead_code)]
pub fn cur_decomposition<F>(
a: &ArrayView2<F>,
k: usize,
oversampling: Option<usize>,
workers: Option<usize>,
) -> CURResult<F>
where
F: Float + NumAssign + Sum + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
let oversampling = oversampling.unwrap_or(5);
let l = (k + oversampling).min(n).min(m);
if k == 0 {
return Err(LinalgError::ShapeError(
"Target rank k must be greater than 0".to_string(),
));
}
if k > m.min(n) {
return Err(LinalgError::ShapeError(format!(
"Target rank k ({}) cannot exceed min(m, n) = {}",
k,
m.min(n)
)));
}
if let Some(num_workers) = workers {
std::env::set_var("OMP_NUM_THREADS", num_workers.to_string());
}
let (u_approx, _s_approx, _vt_approx) =
randomized_svd(a, l.min(n), Some(oversampling), Some(1), workers)?;
let mut col_leverage_scores = Array1::zeros(n);
for j in 0..n {
let col = a.column(j);
let col_proj = u_approx.t().dot(&col);
col_leverage_scores[j] = col_proj.iter().fold(F::zero(), |acc, &x| acc + x * x);
}
let total_leverage: F = col_leverage_scores.sum();
if total_leverage <= F::epsilon() {
return Err(LinalgError::ComputationError(
"Matrix has insufficient rank for CUR decomposition".to_string(),
));
}
col_leverage_scores.mapv_inplace(|x| x / total_leverage);
let mut rng = scirs2_core::random::rng();
let mut selected_cols = Vec::new();
let mut col_indices = Vec::new();
for _ in 0..k {
let r: f64 = rng.random();
let r_f = F::from(r).expect("Operation failed");
let mut cumsum = F::zero();
for (j, &score) in col_leverage_scores.iter().enumerate() {
cumsum += score;
if cumsum >= r_f && !col_indices.contains(&j) {
col_indices.push(j);
selected_cols.push(a.column(j).to_owned());
break;
}
}
}
while col_indices.len() < k {
for j in 0..n {
if !col_indices.contains(&j) {
col_indices.push(j);
selected_cols.push(a.column(j).to_owned());
break;
}
}
}
col_indices.truncate(k);
selected_cols.truncate(k);
let mut c = Array2::zeros((m, k));
for (idx, col) in selected_cols.iter().enumerate() {
for i in 0..m {
c[[i, idx]] = col[i];
}
}
let a_t = a.t().to_owned();
let (u_row_approx, _s_row_approx, _vt_row_approx) =
randomized_svd(&a_t.view(), l.min(m), Some(oversampling), Some(1), workers)?;
let mut row_leverage_scores = Array1::zeros(m);
for i in 0..m {
let row = a.row(i);
let row_proj = u_row_approx.t().dot(&row);
row_leverage_scores[i] = row_proj.iter().fold(F::zero(), |acc, &x| acc + x * x);
}
let total_row_leverage: F = row_leverage_scores.sum();
if total_row_leverage <= F::epsilon() {
return Err(LinalgError::ComputationError(
"Matrix has insufficient rank for row selection".to_string(),
));
}
row_leverage_scores.mapv_inplace(|x| x / total_row_leverage);
let mut selected_rows = Vec::new();
let mut row_indices = Vec::new();
for _ in 0..k {
let r: f64 = rng.random();
let r_f = F::from(r).expect("Operation failed");
let mut cumsum = F::zero();
for (i, &score) in row_leverage_scores.iter().enumerate() {
cumsum += score;
if cumsum >= r_f && !row_indices.contains(&i) {
row_indices.push(i);
selected_rows.push(a.row(i).to_owned());
break;
}
}
}
while row_indices.len() < k {
for i in 0..m {
if !row_indices.contains(&i) {
row_indices.push(i);
selected_rows.push(a.row(i).to_owned());
break;
}
}
}
row_indices.truncate(k);
selected_rows.truncate(k);
let mut r = Array2::zeros((k, n));
for (idx, row) in selected_rows.iter().enumerate() {
for j in 0..n {
r[[idx, j]] = row[j];
}
}
let mut w = Array2::zeros((k, k));
for (i_idx, &i) in row_indices.iter().enumerate() {
for (j_idx, &j) in col_indices.iter().enumerate() {
w[[i_idx, j_idx]] = a[[i, j]];
}
}
let (u_w, s_w, vt_w) = svd(&w.view(), false, workers)?;
let mut s_inv = Array1::zeros(k);
for i in 0..k {
if s_w[i] > F::epsilon() {
s_inv[i] = F::one() / s_w[i];
}
}
let s_inv_diag = Array2::from_diag(&s_inv);
let u = vt_w.t().dot(&s_inv_diag).dot(&u_w.t());
Ok((c, u, r, col_indices, row_indices))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_randomized_svd_basic() {
let a = array![
[3.0, 1.0, 0.5],
[1.0, 3.0, 0.5],
[0.5, 0.5, 2.0],
[1.0, 1.0, 1.0]
];
match randomized_svd(&a.view(), 2, Some(1), Some(2), None) {
Ok((u, s, vt)) => {
assert_eq!(u.shape(), [4, 2]);
assert_eq!(s.len(), 2);
assert_eq!(vt.shape(), [2, 3]);
}
Err(_) => {
}
}
}
#[test]
fn test_truncated_svd() {
let a = array![[3.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]];
match truncated_svd(&a.view(), 2, None) {
Ok((u, s, vt)) => {
assert_eq!(u.shape(), [3, 2]);
assert_eq!(s.len(), 2);
assert_eq!(vt.shape(), [2, 3]);
assert!(s[0] >= s[1]);
}
Err(_) => {
}
}
}
#[test]
fn test_pca_basic() {
let data = array![[1.0, 2.0], [2.0, 4.0], [3.0, 6.0], [4.0, 8.0]];
let (components, explained_var, explained_var_ratio) =
pca(&data.view(), 1, None).expect("Operation failed");
assert_eq!(components.shape(), [1, 2]);
assert_eq!(explained_var.len(), 1);
assert_eq!(explained_var_ratio.len(), 1);
assert!(explained_var_ratio[0] > 0.9);
}
#[test]
fn test_nmf_basic() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let (w, h) = nmf(&a.view(), 2, Some(50), Some(1e-4), None).expect("Operation failed");
assert_eq!(w.shape(), [2, 2]);
assert_eq!(h.shape(), [2, 3]);
for &val in w.iter() {
assert!(val >= 0.0);
}
for &val in h.iter() {
assert!(val >= 0.0);
}
let reconstruction = w.dot(&h);
let mut max_error = 0.0;
for i in 0..2 {
for j in 0..3 {
let error = (a[[i, j]] - reconstruction[[i, j]]).abs();
if error > max_error {
max_error = error;
}
}
}
assert!(max_error < 2.0);
}
#[test]
fn test_randomized_svd_error_handling() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let result = randomized_svd(&a.view(), 0, None, None, None);
assert!(result.is_err());
let result = randomized_svd(&a.view(), 5, None, None, None);
assert!(result.is_err());
}
#[test]
fn test_cur_decomposition_basic() {
let a = array![[3.0, 0.1, 0.2], [0.1, 3.0, 0.3], [0.2, 0.3, 3.0]];
match cur_decomposition(&a.view(), 2, Some(0), None) {
Ok((c, u, r, col_indices, row_indices)) => {
assert_eq!(c.shape(), [3, 2]);
assert_eq!(u.shape(), [2, 2]);
assert_eq!(r.shape(), [2, 3]);
assert_eq!(col_indices.len(), 2);
assert_eq!(row_indices.len(), 2);
for &idx in &col_indices {
assert!(idx < 3);
}
for &idx in &row_indices {
assert!(idx < 3);
}
}
Err(_) => {
}
}
}
#[test]
fn test_cur_decomposition_full_rank() {
let a = array![[2.0, 0.5], [0.5, 2.0]];
match cur_decomposition(&a.view(), 2, Some(0), None) {
Ok((c, u, r, col_indices, row_indices)) => {
assert_eq!(c.shape(), [2, 2]);
assert_eq!(u.shape(), [2, 2]);
assert_eq!(r.shape(), [2, 2]);
}
Err(_) => {
}
}
}
#[test]
fn test_cur_decomposition_rectangular() {
let a = array![
[2.0, 1.0, 0.5],
[1.0, 2.0, 1.0],
[0.5, 1.0, 2.0],
[1.0, 0.5, 1.0]
];
match cur_decomposition(&a.view(), 2, Some(0), None) {
Ok((c, u, r, col_indices, row_indices)) => {
assert_eq!(c.shape(), [4, 2]);
assert_eq!(u.shape(), [2, 2]);
assert_eq!(r.shape(), [2, 3]);
assert_eq!(col_indices.len(), 2);
assert_eq!(row_indices.len(), 2);
}
Err(_) => {
}
}
}
#[test]
fn test_cur_decomposition_error_handling() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let result = cur_decomposition(&a.view(), 0, None, None);
assert!(result.is_err());
let result = cur_decomposition(&a.view(), 5, None, None);
assert!(result.is_err());
}
#[test]
fn test_cur_decomposition_interpretability() {
let a = array![[2.0, 0.0, 1.0], [0.0, 3.0, 0.0], [1.0, 0.0, 2.0]];
match cur_decomposition(&a.view(), 2, Some(0), None) {
Ok((c, u, r, col_indices, row_indices)) => {
assert_eq!(c.shape()[0], 3);
assert_eq!(r.shape()[1], 3);
assert_eq!(col_indices.len(), 2);
assert_eq!(row_indices.len(), 2);
}
Err(_) => {
}
}
}
}