use crate::HisabError;
#[derive(Debug, Clone)]
#[must_use]
pub struct Svd {
pub u: Vec<Vec<f64>>,
pub sigma: Vec<f64>,
pub vt: Vec<Vec<f64>>,
}
#[must_use = "contains the SVD factors or an error"]
#[allow(clippy::needless_range_loop)]
pub fn svd(a: &[Vec<f64>]) -> Result<Svd, HisabError> {
let m = a.len();
if m == 0 {
return Err(HisabError::InvalidInput("empty matrix".into()));
}
let n = a[0].len();
if n == 0 {
return Err(HisabError::InvalidInput("empty matrix".into()));
}
for row in a {
if row.len() != n {
return Err(HisabError::InvalidInput("inconsistent row lengths".into()));
}
}
let transposed = m < n;
let (work_m, work_n, work): (usize, usize, Vec<Vec<f64>>) = if transposed {
let mut t = vec![vec![0.0; m]; n];
for i in 0..m {
for j in 0..n {
t[j][i] = a[i][j];
}
}
(n, m, t)
} else {
(m, n, a.to_vec())
};
let result = svd_tall(&work, work_m, work_n)?;
if transposed {
let mut vt = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
if i < result.u.len() {
vt[i][j] = result.u[i][j];
}
}
}
let mut u: Vec<Vec<f64>> = Vec::with_capacity(m);
for i in 0..result.vt.len() {
u.push(result.vt[i].clone());
}
extend_orthonormal_basis(&mut u, m);
Ok(Svd {
u,
sigma: result.sigma,
vt,
})
} else {
Ok(result)
}
}
#[allow(clippy::needless_range_loop)]
fn extend_orthonormal_basis(u: &mut Vec<Vec<f64>>, m: usize) {
for i in 0..m {
if u.len() >= m {
break;
}
let mut candidate = vec![0.0; m];
candidate[i] = 1.0;
for col in u.iter() {
let dot: f64 = (0..m).map(|k| col[k] * candidate[k]).sum();
for k in 0..m {
candidate[k] -= dot * col[k];
}
}
let norm: f64 = candidate.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > crate::EPSILON_F64 {
let inv = 1.0 / norm;
for x in &mut candidate {
*x *= inv;
}
u.push(candidate);
}
}
}
#[must_use = "contains the truncated SVD or an error"]
pub fn truncated_svd(a: &[Vec<f64>], k: usize) -> Result<Svd, HisabError> {
if k == 0 {
return Err(HisabError::InvalidInput("k must be positive".into()));
}
let result = svd(a)?;
if k > result.sigma.len() {
return Err(HisabError::InvalidInput(format!(
"k={k} > number of singular values {}",
result.sigma.len()
)));
}
Ok(Svd {
u: result.u[..k].to_vec(),
sigma: result.sigma[..k].to_vec(),
vt: result.vt[..k].to_vec(),
})
}
#[allow(clippy::needless_range_loop)]
fn svd_tall(a: &[Vec<f64>], m: usize, n: usize) -> Result<Svd, HisabError> {
let mut b: Vec<Vec<f64>> = vec![vec![0.0; m]; n];
for i in 0..m {
for j in 0..n {
b[j][i] = a[i][j];
}
}
let mut v: Vec<Vec<f64>> = vec![vec![0.0; n]; n];
for i in 0..n {
v[i][i] = 1.0;
}
let max_sweeps = 100 * n.max(m);
let tol = crate::EPSILON_F64 * crate::EPSILON_F64;
for sweep in 0..max_sweeps {
let mut converged = true;
for p in 0..n {
for q in (p + 1)..n {
let mut app = 0.0;
let mut aqq = 0.0;
let mut apq = 0.0;
for k in 0..m {
app += b[p][k] * b[p][k];
aqq += b[q][k] * b[q][k];
apq += b[p][k] * b[q][k];
}
if apq.abs() <= tol * (app * aqq).sqrt() {
continue;
}
converged = false;
let tau = (aqq - app) / (2.0 * apq);
let t = if tau >= 0.0 {
1.0 / (tau + (1.0 + tau * tau).sqrt())
} else {
-1.0 / (-tau + (1.0 + tau * tau).sqrt())
};
let cos = 1.0 / (1.0 + t * t).sqrt();
let sin = t * cos;
for k in 0..m {
let bp = b[p][k];
let bq = b[q][k];
b[p][k] = cos * bp - sin * bq;
b[q][k] = sin * bp + cos * bq;
}
for k in 0..n {
let vp = v[p][k];
let vq = v[q][k];
v[p][k] = cos * vp - sin * vq;
v[q][k] = sin * vp + cos * vq;
}
}
}
if converged {
break;
}
if sweep == max_sweeps - 1 {
return Err(HisabError::NoConvergence(max_sweeps));
}
}
let mut sigma = Vec::with_capacity(n);
let mut u: Vec<Vec<f64>> = Vec::with_capacity(m);
for j in 0..n {
let norm: f64 = b[j].iter().map(|x| x * x).sum::<f64>().sqrt();
sigma.push(norm);
if norm > crate::EPSILON_F64 {
let inv = 1.0 / norm;
u.push(b[j].iter().map(|x| x * inv).collect());
} else {
u.push(b[j].clone());
}
}
if m > n {
extend_orthonormal_basis(&mut u, m);
}
let mut order: Vec<usize> = (0..sigma.len()).collect();
order.sort_unstable_by(|&a, &b| {
sigma[b]
.partial_cmp(&sigma[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_sigma: Vec<f64> = order.iter().map(|&i| sigma[i]).collect();
let sorted_u: Vec<Vec<f64>> = if u.len() <= sigma.len() {
order.iter().map(|&i| u[i].clone()).collect()
} else {
let mut su: Vec<Vec<f64>> = order.iter().map(|&i| u[i].clone()).collect();
for i in sigma.len()..u.len() {
su.push(u[i].clone());
}
su
};
let mut vt = vec![vec![0.0; n]; n];
for row_idx in 0..n {
let src_col = if row_idx < order.len() {
order[row_idx]
} else {
row_idx
};
for col_idx in 0..n {
vt[row_idx][col_idx] = v[src_col][col_idx];
}
}
Ok(Svd {
u: sorted_u,
sigma: sorted_sigma,
vt,
})
}