use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_linalg::svd;
use crate::error::{Result, TransformError};
#[derive(Debug, Clone)]
pub struct IncrementalSVD {
u: Option<Array2<f64>>,
s: Option<Array1<f64>>,
vt: Option<Array2<f64>>,
max_rank: usize,
n_rows: usize,
n_cols: usize,
}
impl IncrementalSVD {
pub fn new(max_rank: usize) -> Result<Self> {
if max_rank == 0 {
return Err(TransformError::InvalidInput(
"max_rank must be positive".to_string(),
));
}
Ok(Self {
u: None,
s: None,
vt: None,
max_rank,
n_rows: 0,
n_cols: 0,
})
}
pub fn initialize(&mut self, data: &Array2<f64>) -> Result<()> {
let (m, n) = (data.nrows(), data.ncols());
if m == 0 || n == 0 {
return Err(TransformError::InvalidInput(
"Data matrix must be non-empty".to_string(),
));
}
let (u_full, s_full, vt_full) = compute_svd(data)?;
let k = self.max_rank.min(s_full.len());
self.u = Some(u_full.slice(scirs2_core::ndarray::s![.., ..k]).to_owned());
self.s = Some(s_full.slice(scirs2_core::ndarray::s![..k]).to_owned());
self.vt = Some(vt_full.slice(scirs2_core::ndarray::s![..k, ..]).to_owned());
self.n_rows = m;
self.n_cols = n;
Ok(())
}
pub fn add_row(&mut self, row: &Array1<f64>) -> Result<()> {
if let (Some(u), Some(s), Some(vt)) = (&self.u, &self.s, &self.vt) {
if row.len() != self.n_cols {
return Err(TransformError::InvalidInput(format!(
"Expected {} columns, got {}",
self.n_cols,
row.len()
)));
}
let k = s.len();
let mut p = Array1::zeros(k);
for j in 0..k {
let mut dot = 0.0;
for d in 0..self.n_cols {
dot += vt[[j, d]] * row[d];
}
p[j] = dot;
}
let mut residual = row.clone();
for j in 0..k {
for d in 0..self.n_cols {
residual[d] -= vt[[j, d]] * p[j];
}
}
let r_norm = residual.dot(&residual).sqrt();
let new_k = k + 1;
let mut k_mat = Array2::zeros((new_k, new_k));
for j in 0..k {
k_mat[[j, j]] = s[j];
k_mat[[j, k]] = p[j];
}
k_mat[[k, k]] = r_norm;
let (u_k, s_k, vt_k) = compute_svd(&k_mat)?;
let new_rank = self.max_rank.min(s_k.len());
let old_m = self.n_rows;
let mut new_u = Array2::zeros((old_m + 1, new_rank));
for i in 0..old_m {
for j in 0..new_rank {
let mut val = 0.0;
for l in 0..k {
val += u[[i, l]] * u_k[[l, j]];
}
new_u[[i, j]] = val;
}
}
for j in 0..new_rank {
new_u[[old_m, j]] = u_k[[k, j]];
}
let mut new_vt = Array2::zeros((new_rank, self.n_cols));
for i in 0..new_rank {
for d in 0..self.n_cols {
let mut val = 0.0;
for l in 0..k {
val += vt_k[[i, l]] * vt[[l, d]];
}
if r_norm > 1e-15 {
val += vt_k[[i, k]] * (residual[d] / r_norm);
}
new_vt[[i, d]] = val;
}
}
let new_s = s_k.slice(scirs2_core::ndarray::s![..new_rank]).to_owned();
self.u = Some(new_u);
self.s = Some(new_s);
self.vt = Some(new_vt);
self.n_rows += 1;
Ok(())
} else {
let data = row.clone().insert_axis(Axis(0));
self.initialize(&data)
}
}
pub fn add_column(&mut self, col: &Array1<f64>) -> Result<()> {
if let (Some(u), Some(s), Some(vt)) = (&self.u, &self.s, &self.vt) {
if col.len() != self.n_rows {
return Err(TransformError::InvalidInput(format!(
"Expected {} rows in column, got {}",
self.n_rows,
col.len()
)));
}
let k = s.len();
let mut p = Array1::zeros(k);
for j in 0..k {
let mut dot = 0.0;
for i in 0..self.n_rows {
dot += u[[i, j]] * col[i];
}
p[j] = dot;
}
let mut residual = col.clone();
for j in 0..k {
for i in 0..self.n_rows {
residual[i] -= u[[i, j]] * p[j];
}
}
let r_norm = residual.dot(&residual).sqrt();
let new_k = k + 1;
let mut k_mat = Array2::zeros((new_k, new_k));
for j in 0..k {
k_mat[[j, j]] = s[j];
k_mat[[j, k]] = p[j]; }
k_mat[[k, k]] = r_norm;
let (u_k, s_k, vt_k) = compute_svd(&k_mat)?;
let new_rank = self.max_rank.min(s_k.len());
let mut new_u = Array2::zeros((self.n_rows, new_rank));
for i in 0..self.n_rows {
for j in 0..new_rank {
let mut val = 0.0;
for l in 0..k {
val += u[[i, l]] * u_k[[l, j]];
}
if r_norm > 1e-15 {
val += (residual[i] / r_norm) * u_k[[k, j]];
}
new_u[[i, j]] = val;
}
}
let old_n = self.n_cols;
let mut new_vt = Array2::zeros((new_rank, old_n + 1));
for i in 0..new_rank {
for d in 0..old_n {
let mut val = 0.0;
for l in 0..k {
val += vt_k[[i, l]] * vt[[l, d]];
}
new_vt[[i, d]] = val;
}
new_vt[[i, old_n]] = vt_k[[i, k]];
}
let new_s = s_k.slice(scirs2_core::ndarray::s![..new_rank]).to_owned();
self.u = Some(new_u);
self.s = Some(new_s);
self.vt = Some(new_vt);
self.n_cols += 1;
Ok(())
} else {
let data = col.clone().insert_axis(Axis(1));
self.initialize(&data)
}
}
pub fn add_rows(&mut self, rows: &Array2<f64>) -> Result<()> {
for i in 0..rows.nrows() {
let row = rows.row(i).to_owned();
self.add_row(&row)?;
}
Ok(())
}
pub fn downdate_row(&mut self, row: &Array1<f64>) -> Result<()> {
if self.u.is_none() {
return Err(TransformError::InvalidInput(
"Cannot downdate: SVD not initialised".to_string(),
));
}
if row.len() != self.n_cols {
return Err(TransformError::InvalidInput(format!(
"Expected {} columns, got {}",
self.n_cols,
row.len()
)));
}
if self.n_rows <= 1 {
return Err(TransformError::InvalidInput(
"Cannot downdate: only one row remaining".to_string(),
));
}
let neg_row = -row.clone();
let (u, s, vt) = match (&self.u, &self.s, &self.vt) {
(Some(u), Some(s), Some(vt)) => (u.clone(), s.clone(), vt.clone()),
_ => {
return Err(TransformError::InvalidInput(
"SVD not initialised".to_string(),
))
}
};
let k = s.len();
let mut p = Array1::zeros(k);
for j in 0..k {
let mut dot = 0.0;
for d in 0..self.n_cols {
dot += vt[[j, d]] * neg_row[d];
}
p[j] = dot;
}
let mut new_s = s.clone();
for j in 0..k {
let reduction = p[j].powi(2) / (2.0 * s[j].max(1e-15));
new_s[j] = (s[j] - reduction).max(0.0);
}
if self.n_rows > 1 {
let new_u = u
.slice(scirs2_core::ndarray::s![..self.n_rows - 1, ..])
.to_owned();
self.u = Some(new_u);
}
self.s = Some(new_s);
self.n_rows -= 1;
Ok(())
}
pub fn singular_values(&self) -> Option<&Array1<f64>> {
self.s.as_ref()
}
pub fn left_singular_vectors(&self) -> Option<&Array2<f64>> {
self.u.as_ref()
}
pub fn right_singular_vectors(&self) -> Option<&Array2<f64>> {
self.vt.as_ref()
}
pub fn reconstruct(&self) -> Result<Array2<f64>> {
let u = self
.u
.as_ref()
.ok_or_else(|| TransformError::NotFitted("SVD not initialised".to_string()))?;
let s = self
.s
.as_ref()
.ok_or_else(|| TransformError::NotFitted("SVD not initialised".to_string()))?;
let vt = self
.vt
.as_ref()
.ok_or_else(|| TransformError::NotFitted("SVD not initialised".to_string()))?;
let m = u.nrows();
let n = vt.ncols();
let k = s.len();
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut val = 0.0;
for l in 0..k {
val += u[[i, l]] * s[l] * vt[[l, j]];
}
result[[i, j]] = val;
}
}
Ok(result)
}
pub fn rank(&self) -> usize {
self.s.as_ref().map_or(0, |s| s.len())
}
pub fn n_rows(&self) -> usize {
self.n_rows
}
pub fn n_cols(&self) -> usize {
self.n_cols
}
pub fn max_rank(&self) -> usize {
self.max_rank
}
}
fn compute_svd(mat: &Array2<f64>) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>)> {
let (m, n) = (mat.nrows(), mat.ncols());
let min_dim = m.min(n);
let (u_full, s_full, vt_full) = svd(&mat.view(), false, None)
.map_err(|e| TransformError::ComputationError(format!("SVD computation failed: {}", e)))?;
let k = min_dim.min(s_full.len());
let u = u_full.slice(scirs2_core::ndarray::s![.., ..k]).to_owned();
let s = s_full.slice(scirs2_core::ndarray::s![..k]).to_owned();
let vt = vt_full.slice(scirs2_core::ndarray::s![..k, ..]).to_owned();
Ok((u, s, vt))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_test_matrix() -> Array2<f64> {
Array2::from_shape_vec(
(4, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("valid shape")
}
#[test]
fn test_incremental_svd_init_and_reconstruct() {
let mat = make_test_matrix();
let mut isvd = IncrementalSVD::new(3).expect("create");
isvd.initialize(&mat).expect("init");
assert_eq!(isvd.n_rows(), 4);
assert_eq!(isvd.n_cols(), 3);
assert!(isvd.rank() <= 3);
let recon = isvd.reconstruct().expect("reconstruct");
assert_eq!(recon.shape(), &[4, 3]);
for i in 0..4 {
for j in 0..3 {
assert!(
(recon[[i, j]] - mat[[i, j]]).abs() < 1e-8,
"Mismatch at [{},{}]: {} vs {}",
i,
j,
recon[[i, j]],
mat[[i, j]]
);
}
}
}
#[test]
fn test_incremental_svd_add_rows_matches_batch() {
let mat = Array2::from_shape_vec(
(6, 4),
vec![
3.0, 1.0, 0.5, 2.0, 1.0, 4.0, 1.5, 0.5, 0.5, 1.5, 5.0, 1.0, 2.0, 0.5, 1.0, 6.0,
1.5, 3.5, 2.0, 1.5, 0.8, 2.2, 3.5, 2.8,
],
)
.expect("valid shape");
let mut batch_svd = IncrementalSVD::new(4).expect("create");
batch_svd.initialize(&mat).expect("init");
let mut inc_svd = IncrementalSVD::new(4).expect("create");
for i in 0..mat.nrows() {
let row = mat.row(i).to_owned();
inc_svd.add_row(&row).expect("add row");
}
assert_eq!(inc_svd.n_rows(), 6);
assert_eq!(inc_svd.n_cols(), 4);
let inc_recon = inc_svd.reconstruct().expect("inc recon");
let frob_err: f64 = mat
.iter()
.zip(inc_recon.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let frob_orig: f64 = mat.iter().map(|x| x.powi(2)).sum::<f64>().sqrt();
let relative_err = frob_err / frob_orig.max(1e-15);
assert!(
relative_err < 1.0,
"Relative reconstruction error too large: {}",
relative_err
);
let sv = inc_svd.singular_values().expect("sv");
for &s in sv.iter() {
assert!(s >= 0.0, "Singular value should be non-negative: {}", s);
}
}
#[test]
fn test_incremental_svd_rank_truncation() {
let mut mat = Array2::zeros((10, 5));
for i in 0..10 {
for j in 0..5 {
mat[[i, j]] = ((i + 1) * (j + 1)) as f64 + (i as f64 * 0.1).sin();
}
}
let mut isvd = IncrementalSVD::new(2).expect("create"); isvd.initialize(&mat).expect("init");
assert!(isvd.rank() <= 2);
let sv = isvd.singular_values().expect("singular values");
assert!(sv.len() <= 2);
if sv.len() == 2 {
assert!(sv[0] >= sv[1], "Singular values should be decreasing");
}
}
#[test]
fn test_incremental_svd_add_column() {
let mat = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("valid shape");
let mut isvd = IncrementalSVD::new(3).expect("create");
isvd.initialize(&mat).expect("init");
assert_eq!(isvd.n_cols(), 2);
let new_col = Array1::from_vec(vec![7.0, 8.0, 9.0]);
isvd.add_column(&new_col).expect("add column");
assert_eq!(isvd.n_cols(), 3);
assert_eq!(isvd.n_rows(), 3);
}
#[test]
fn test_incremental_svd_downdate() {
let mat = make_test_matrix();
let mut isvd = IncrementalSVD::new(3).expect("create");
isvd.initialize(&mat).expect("init");
let last_row = mat.row(3).to_owned();
isvd.downdate_row(&last_row).expect("downdate");
assert_eq!(isvd.n_rows(), 3);
}
#[test]
fn test_incremental_svd_error_cases() {
let mut isvd = IncrementalSVD::new(2).expect("create");
let row1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
isvd.add_row(&row1).expect("first row");
let bad_row = Array1::from_vec(vec![1.0, 2.0]);
assert!(isvd.add_row(&bad_row).is_err());
let bad_col = Array1::from_vec(vec![1.0, 2.0, 3.0]);
assert!(isvd.add_column(&bad_col).is_err());
assert!(isvd.downdate_row(&row1).is_err());
}
#[test]
fn test_incremental_svd_batch_update() {
let mut isvd = IncrementalSVD::new(3).expect("create");
let mat = make_test_matrix();
isvd.add_rows(&mat).expect("batch add");
assert_eq!(isvd.n_rows(), 4);
assert_eq!(isvd.n_cols(), 3);
}
#[test]
fn test_zero_max_rank() {
let result = IncrementalSVD::new(0);
assert!(result.is_err());
}
}