use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
pub struct CodResult {
pub u: Array2<f64>,
pub t: Array2<f64>,
pub z: Array2<f64>,
pub rank: usize,
}
pub struct RrqrResult {
pub q: Array2<f64>,
pub r: Array2<f64>,
pub perm: Vec<usize>,
pub rank: usize,
}
pub struct UlvResult {
pub u: Array2<f64>,
pub l: Array2<f64>,
pub v: Array2<f64>,
pub rank: usize,
}
fn norm2(v: &[f64]) -> f64 {
v.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
fn apply_householder_left(a: &mut Array2<f64>, v: &[f64], row_start: usize) {
let m = a.nrows();
let n = a.ncols();
let vn2: f64 = v.iter().map(|&x| x * x).sum();
if vn2 < f64::EPSILON {
return;
}
let scale = 2.0 / vn2;
for j in 0..n {
let dot: f64 = (row_start..m).map(|i| v[i - row_start] * a[[i, j]]).sum();
let coeff = scale * dot;
for i in row_start..m {
a[[i, j]] -= coeff * v[i - row_start];
}
}
}
fn apply_householder_right(a: &mut Array2<f64>, v: &[f64], col_start: usize) {
let m = a.nrows();
let n = a.ncols();
let vn2: f64 = v.iter().map(|&x| x * x).sum();
if vn2 < f64::EPSILON {
return;
}
let scale = 2.0 / vn2;
for i in 0..m {
let dot: f64 = (col_start..n).map(|j| v[j - col_start] * a[[i, j]]).sum();
let coeff = scale * dot;
for j in col_start..n {
a[[i, j]] -= coeff * v[j - col_start];
}
}
}
fn build_q_from_reflectors(m: usize, reflectors: &[(usize, Vec<f64>)]) -> Array2<f64> {
let mut q = Array2::<f64>::eye(m);
for &(start, ref v) in reflectors.iter().rev() {
apply_householder_left(&mut q, v, start);
}
q
}
fn build_z_from_reflectors(n: usize, reflectors: &[(usize, Vec<f64>)]) -> Array2<f64> {
let mut z = Array2::<f64>::eye(n);
for &(start, ref v) in reflectors.iter().rev() {
apply_householder_right(&mut z, v, start);
}
z
}
pub fn rrqr(a: &ArrayView2<f64>, tol: Option<f64>) -> LinalgResult<RrqrResult> {
let m = a.nrows();
let n = a.ncols();
let mut work = a.to_owned();
let mut perm: Vec<usize> = (0..n).collect();
let k = m.min(n);
let mut col_norms: Vec<f64> = (0..n)
.map(|j| norm2(work.column(j).as_slice().unwrap_or(&[])))
.collect();
let mut left_reflectors: Vec<(usize, Vec<f64>)> = Vec::with_capacity(k);
for step in 0..k {
let pivot_rel = (step..n)
.max_by(|&a, &b| col_norms[a].partial_cmp(&col_norms[b]).expect("nan norm"))
.unwrap_or(step);
if col_norms[pivot_rel] < f64::EPSILON {
break;
}
if pivot_rel != step {
for i in 0..m {
let tmp = work[[i, step]];
work[[i, step]] = work[[i, pivot_rel]];
work[[i, pivot_rel]] = tmp;
}
perm.swap(step, pivot_rel);
col_norms.swap(step, pivot_rel);
}
let col_len = m - step;
let mut v: Vec<f64> = (step..m).map(|i| work[[i, step]]).collect();
let sigma = norm2(&v);
if sigma < f64::EPSILON {
continue;
}
let sign = if v[0] >= 0.0 { 1.0 } else { -1.0 };
v[0] += sign * sigma;
apply_householder_left(&mut work, &v, step);
for j in (step + 1)..n {
let r_jj = work[[step, j]];
let old_sq = col_norms[j] * col_norms[j];
let new_sq = (old_sq - r_jj * r_jj).max(0.0);
col_norms[j] = new_sq.sqrt();
}
left_reflectors.push((step, v));
let _ = col_len; }
let q = build_q_from_reflectors(m, &left_reflectors);
let r = work;
let r00 = r[[0, 0]].abs();
let threshold = tol.unwrap_or_else(|| {
let eps = f64::EPSILON;
eps * (m.max(n) as f64) * r00
});
let rank = (0..k).take_while(|&i| r[[i, i]].abs() > threshold).count();
Ok(RrqrResult { q, r, perm, rank })
}
pub fn cod(a: &ArrayView2<f64>, tol: Option<f64>) -> LinalgResult<CodResult> {
let m = a.nrows();
let n = a.ncols();
let rrqr_res = rrqr(a, tol)?;
let rank = rrqr_res.rank;
let u = rrqr_res.q.clone();
let mut t = rrqr_res.r;
let mut z = Array2::<f64>::eye(n);
let perm = rrqr_res.perm;
let z_perm: Vec<_> = perm.iter().map(|&c| z.column(c).to_owned()).collect();
for (j, col) in z_perm.iter().enumerate() {
z.column_mut(j).assign(col);
}
let mut right_reflectors: Vec<(usize, Vec<f64>)> = Vec::new();
for i in 0..rank {
if rank >= n {
break; }
let row_len = n - i;
if row_len <= 1 {
break;
}
let v_raw: Vec<f64> = (i..n).map(|j| t[[i, j]]).collect();
let sigma = norm2(&v_raw);
if sigma < f64::EPSILON {
continue;
}
let sign = if v_raw[0] >= 0.0 { 1.0 } else { -1.0 };
let mut v = v_raw;
v[0] += sign * sigma;
apply_householder_right(&mut t, &v, i);
apply_householder_right(&mut z, &v, i);
right_reflectors.push((i, v));
}
Ok(CodResult { u, t, z, rank })
}
pub fn ulv(a: &ArrayView2<f64>) -> LinalgResult<UlvResult> {
let m = a.nrows();
let n = a.ncols();
let (u_s, s_vec, vt_s) = crate::decomposition::svd(a, true, None)?;
let k = m.min(n);
let mut l_mat = Array2::<f64>::zeros((m, n));
for i in 0..k {
l_mat[[i, i]] = s_vec[i];
}
let sigma_mat = l_mat; let l_full = sigma_mat.dot(&vt_s);
let lt = l_full.t().to_owned(); let (q_prime, r_prime) = crate::decomposition::qr(<.view(), None)?;
let l = r_prime.t().to_owned(); let u = u_s;
let v = q_prime;
let s_max = s_vec.iter().cloned().fold(0.0_f64, f64::max);
let eps = f64::EPSILON;
let thresh = eps * (m.max(n) as f64) * s_max;
let rank = s_vec.iter().filter(|&&sv| sv > thresh).count();
Ok(UlvResult { u, l, v, rank })
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
fn max_abs_error(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
(a - b).iter().map(|v| v.abs()).fold(0.0_f64, f64::max)
}
#[test]
fn test_rrqr_full_rank() {
let a = array![
[1.0_f64, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 10.0], ];
let res = rrqr(&a.view(), None).expect("rrqr failed");
let mut a_perm = Array2::<f64>::zeros((3, 3));
for (j, &c) in res.perm.iter().enumerate() {
a_perm.column_mut(j).assign(&a.column(c).to_owned());
}
let recon = res.q.dot(&res.r);
assert!(
max_abs_error(&recon, &a_perm) < 1e-10,
"RRQR reconstruction error"
);
}
#[test]
fn test_rrqr_rank_deficient() {
let a = array![[1.0_f64, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0],];
let res = rrqr(&a.view(), Some(1e-8)).expect("rrqr failed");
assert_eq!(res.rank, 1, "Rank-1 matrix should have rank 1");
}
#[test]
fn test_rrqr_tall_matrix() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0],];
let res = rrqr(&a.view(), None).expect("rrqr failed");
let mut a_perm = Array2::<f64>::zeros((4, 2));
for (j, &c) in res.perm.iter().enumerate() {
a_perm.column_mut(j).assign(&a.column(c).to_owned());
}
let recon = res.q.dot(&res.r);
assert!(
max_abs_error(&recon.slice(s![.., ..2]).to_owned(), &a_perm) < 1e-10,
"RRQR tall reconstruction"
);
}
#[test]
fn test_rrqr_permutation_valid() {
let a = array![[1.0_f64, 5.0, 2.0], [0.0, 0.0, 3.0],];
let res = rrqr(&a.view(), None).expect("rrqr failed");
let mut sorted = res.perm.clone();
sorted.sort_unstable();
assert_eq!(sorted, vec![0, 1, 2]);
}
#[test]
fn test_cod_full_rank_reconstruction() {
let a = array![[1.0_f64, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0],];
let res = cod(&a.view(), None).expect("COD failed");
let recon = res.u.dot(&res.t).dot(&res.z.t());
assert!(max_abs_error(&recon, &a) < 1e-8, "COD reconstruction error");
}
#[test]
fn test_cod_rank_deficient() {
let a = array![[1.0_f64, 2.0, 3.0], [2.0, 4.0, 6.0],];
let res = cod(&a.view(), Some(1e-8)).expect("COD failed");
assert_eq!(res.rank, 1, "rank-deficient COD rank");
}
#[test]
fn test_cod_reconstruction_rank_deficient() {
let a = array![[1.0_f64, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0],];
let res = cod(&a.view(), Some(1e-8)).expect("COD failed");
let recon = res.u.dot(&res.t).dot(&res.z.t());
let err = max_abs_error(&recon, &a);
assert!(err < 1e-7, "COD rank-deficient reconstruction error: {err}");
}
#[test]
fn test_ulv_reconstruction() {
let a = array![[3.0_f64, 2.0, 1.0], [1.0, 2.0, 3.0],];
let res = ulv(&a.view()).expect("ULV failed");
let recon = res.u.dot(&res.l).dot(&res.v.t());
assert!(
max_abs_error(&recon, &a) < 1e-10,
"ULV reconstruction error"
);
}
#[test]
fn test_ulv_square() {
let a = array![[4.0_f64, 3.0], [6.0, 3.0],];
let res = ulv(&a.view()).expect("ULV square failed");
let recon = res.u.dot(&res.l).dot(&res.v.t());
assert!(
max_abs_error(&recon, &a) < 1e-10,
"ULV square reconstruction"
);
}
#[test]
fn test_ulv_rank() {
let a = array![[1.0_f64, 2.0, 3.0], [2.0, 4.0, 6.0],];
let res = ulv(&a.view()).expect("ULV rank failed");
assert_eq!(res.rank, 1, "Rank-1 ULV rank estimate");
}
#[test]
fn test_ulv_orthogonality_u() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0],];
let res = ulv(&a.view()).expect("ULV orth U");
let utu = res.u.t().dot(&res.u);
let eye3 = Array2::<f64>::eye(3);
assert!(max_abs_error(&utu, &eye3) < 1e-10, "U not orthogonal");
}
#[test]
fn test_ulv_orthogonality_v() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0],];
let res = ulv(&a.view()).expect("ULV orth V");
let vtv = res.v.t().dot(&res.v);
let eye3 = Array2::<f64>::eye(3);
assert!(max_abs_error(&vtv, &eye3) < 1e-10, "V not orthogonal");
}
}