use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
use crate::factorization::rrqr::{perm_to_matrix, rrqr};
#[derive(Debug, Clone)]
pub struct URVResult<F> {
pub u: Array2<F>,
pub r: Array2<F>,
pub v: Array2<F>,
pub rank: usize,
}
pub fn urv<F>(a: &ArrayView2<F>, tol: F) -> LinalgResult<URVResult<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if m == 0 || n == 0 {
return Err(LinalgError::ShapeError(
"URV: input matrix must be non-empty".to_string(),
));
}
let rrqr_res = rrqr(a, tol)?;
let q1 = rrqr_res.q; let r1 = rrqr_res.r; let rank = rrqr_res.rank;
let p_mat = perm_to_matrix::<F>(&rrqr_res.perm);
let r1t = r1.t().to_owned(); let (q2, r2) = householder_qr(&r1t)?;
let r_urv = r2.t().to_owned(); let v = p_mat.dot(&q2);
Ok(URVResult {
u: q1,
r: r_urv,
v,
rank,
})
}
pub fn urv_lstsq<F>(a: &ArrayView2<F>, b: &Array1<F>, tol: F) -> LinalgResult<Array1<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
let (m, _n) = a.dim();
if b.len() != m {
return Err(LinalgError::DimensionError(format!(
"URV lstsq: b length ({}) must match matrix rows ({m})",
b.len()
)));
}
let res = urv(a, tol)?;
let k = res.rank;
let n = res.v.nrows();
let c = res.u.t().dot(b);
let mut y = Array1::<F>::zeros(n);
if k > 0 {
for i in (0..k).rev() {
let mut s = c[i];
for j in (i + 1)..k {
s -= res.r[[i, j]] * y[j];
}
let diag = res.r[[i, i]];
if diag.abs() <= F::epsilon() {
return Err(LinalgError::SingularMatrixError(
"URV lstsq: zero diagonal in R at detected rank boundary".to_string(),
));
}
y[i] = s / diag;
}
}
let x = res.v.dot(&y);
Ok(x)
}
fn householder_qr<F>(a: &Array2<F>) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + 'static,
{
let (m, n) = a.dim();
let min_dim = m.min(n);
let mut r = a.clone();
let mut q = Array2::<F>::eye(m);
let two = F::from(2.0).unwrap_or_else(|| F::one() + F::one());
for k in 0..min_dim {
let mut x = Array1::<F>::zeros(m - k);
for i in k..m {
x[i - k] = r[[i, k]];
}
let x_norm = x.iter().fold(F::zero(), |acc, &v| acc + v * v).sqrt();
if x_norm <= F::epsilon() {
continue;
}
let alpha = if x[0] >= F::zero() { -x_norm } else { x_norm };
let mut v = x;
v[0] -= alpha;
let v_norm_sq = v.iter().fold(F::zero(), |acc, &val| acc + val * val);
if v_norm_sq <= F::epsilon() {
continue;
}
let beta = two / v_norm_sq;
for j in k..n {
let mut dot = F::zero();
for i in 0..(m - k) {
dot += v[i] * r[[i + k, j]];
}
for i in 0..(m - k) {
r[[i + k, j]] -= beta * v[i] * dot;
}
}
for row in 0..m {
let mut dot = F::zero();
for jj in 0..(m - k) {
dot += q[[row, jj + k]] * v[jj];
}
for jj in 0..(m - k) {
q[[row, jj + k]] -= beta * dot * v[jj];
}
}
}
Ok((q, r))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn frob_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
#[test]
fn test_urv_full_rank() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]];
let res = urv(&a.view(), 1e-12).expect("urv failed");
assert_eq!(res.rank, 3);
let recon = res.u.dot(&res.r).dot(&res.v.t());
let err = frob_diff(&recon, &a.to_owned());
assert!(err < 1e-10, "reconstruction error = {err}");
let utu = res.u.t().dot(&res.u);
let eye = Array2::<f64>::eye(3);
assert!(frob_diff(&utu, &eye) < 1e-10, "U not orthogonal");
let vtv = res.v.t().dot(&res.v);
assert!(frob_diff(&vtv, &eye) < 1e-10, "V not orthogonal");
}
#[test]
fn test_urv_rank_deficient() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [5.0, 7.0, 9.0]];
let res = urv(&a.view(), 1e-10).expect("urv failed");
assert_eq!(res.rank, 2, "rank should be 2");
let recon = res.u.dot(&res.r).dot(&res.v.t());
let err = frob_diff(&recon, &a.to_owned());
assert!(err < 1e-10, "reconstruction error = {err}");
}
#[test]
fn test_urv_rectangular_tall() {
let a = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 1.0]];
let res = urv(&a.view(), 1e-10).expect("urv failed");
assert_eq!(res.rank, 2);
assert_eq!(res.u.shape(), &[4, 4]);
assert_eq!(res.r.shape(), &[4, 2]);
assert_eq!(res.v.shape(), &[2, 2]);
let recon = res.u.dot(&res.r).dot(&res.v.t());
let err = frob_diff(&recon, &a.to_owned());
assert!(err < 1e-10, "tall reconstruction error = {err}");
}
#[test]
fn test_urv_rectangular_wide() {
let a = array![[1.0, 0.0, 1.0, 2.0], [0.0, 1.0, 1.0, 3.0]];
let res = urv(&a.view(), 1e-10).expect("urv failed");
assert_eq!(res.rank, 2);
assert_eq!(res.u.shape(), &[2, 2]);
assert_eq!(res.r.shape(), &[2, 4]);
assert_eq!(res.v.shape(), &[4, 4]);
let recon = res.u.dot(&res.r).dot(&res.v.t());
let err = frob_diff(&recon, &a.to_owned());
assert!(err < 1e-10, "wide reconstruction error = {err}");
}
#[test]
fn test_urv_zero_matrix() {
let a = Array2::<f64>::zeros((3, 3));
let res = urv(&a.view(), 1e-12).expect("urv failed");
assert_eq!(res.rank, 0);
}
#[test]
fn test_urv_identity() {
let eye = Array2::<f64>::eye(4);
let res = urv(&eye.view(), 1e-12).expect("urv failed");
assert_eq!(res.rank, 4);
let recon = res.u.dot(&res.r).dot(&res.v.t());
let err = frob_diff(&recon, &eye);
assert!(err < 1e-10);
}
#[test]
fn test_urv_lstsq_overdetermined() {
let a = array![[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]];
let b = array![1.0, 2.0, 3.0];
let x = urv_lstsq(&a.view(), &b, 1e-10).expect("lstsq failed");
let residual = &a.dot(&x) - &b;
let res_norm: f64 = residual.iter().map(|&v| v * v).sum::<f64>().sqrt();
assert!(res_norm < 0.5, "residual norm = {res_norm}");
}
#[test]
fn test_urv_lstsq_rank_deficient() {
let a = array![[1.0, 2.0], [2.0, 4.0], [3.0, 6.0]];
let b = array![1.0, 2.0, 3.0];
let x = urv_lstsq(&a.view(), &b, 1e-10).expect("lstsq failed");
let residual = &a.dot(&x) - &b;
let res_norm: f64 = residual.iter().map(|&v| v * v).sum::<f64>().sqrt();
assert!(
res_norm < 1e-8,
"rank-deficient lstsq residual = {res_norm}"
);
}
#[test]
fn test_urv_lstsq_dimension_error() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![1.0, 2.0, 3.0]; assert!(urv_lstsq(&a.view(), &b, 1e-10).is_err());
}
#[test]
fn test_urv_single_element() {
let a = array![[5.0]];
let res = urv(&a.view(), 1e-12).expect("urv failed");
assert_eq!(res.rank, 1);
let recon = res.u.dot(&res.r).dot(&res.v.t());
assert!((recon[[0, 0]] - 5.0).abs() < 1e-10);
}
}