use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct RRQRResult<F> {
pub q: Array2<F>,
pub r: Array2<F>,
pub perm: Vec<usize>,
pub rank: usize,
}
pub fn rrqr<F>(a: &ArrayView2<F>, tol: F) -> LinalgResult<RRQRResult<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
let (m, n) = a.dim();
if m == 0 || n == 0 {
return Err(LinalgError::ShapeError(
"RRQR: input matrix must be non-empty".to_string(),
));
}
for &v in a.iter() {
if !v.is_finite() {
return Err(LinalgError::InvalidInputError(
"RRQR: matrix contains non-finite values".to_string(),
));
}
}
let min_dim = m.min(n);
let mut r = a.to_owned();
let mut q = Array2::<F>::eye(m);
let mut perm: Vec<usize> = (0..n).collect();
let mut col_norms_sq: Vec<F> = (0..n)
.map(|j| {
let mut s = F::zero();
for i in 0..m {
s += r[[i, j]] * r[[i, j]];
}
s
})
.collect();
let two = F::from(2.0).unwrap_or_else(|| F::one() + F::one());
let mut rank = 0usize;
for k in 0..min_dim {
let mut best_col = k;
let mut best_norm = col_norms_sq[k];
for (j, &cnj) in col_norms_sq.iter().enumerate().take(n).skip(k + 1) {
if cnj > best_norm {
best_norm = cnj;
best_col = j;
}
}
if best_norm.sqrt() <= F::epsilon() {
break;
}
if best_col != k {
for i in 0..m {
let tmp = r[[i, k]];
r[[i, k]] = r[[i, best_col]];
r[[i, best_col]] = tmp;
}
perm.swap(k, best_col);
col_norms_sq.swap(k, best_col);
}
let mut x = Array1::<F>::zeros(m - k);
for i in k..m {
x[i - k] = r[[i, k]];
}
let x_norm = x.iter().fold(F::zero(), |acc, &v| acc + v * v).sqrt();
if x_norm <= F::epsilon() {
continue;
}
let alpha = if x[0] >= F::zero() { -x_norm } else { x_norm };
let mut v = x;
v[0] -= alpha;
let v_norm_sq = v.iter().fold(F::zero(), |acc, &val| acc + val * val);
if v_norm_sq <= F::epsilon() {
continue;
}
let beta = two / v_norm_sq;
for j in k..n {
let mut dot = F::zero();
for i in 0..(m - k) {
dot += v[i] * r[[i + k, j]];
}
for i in 0..(m - k) {
r[[i + k, j]] -= beta * v[i] * dot;
}
}
for i in 0..m {
let mut dot = F::zero();
for jj in 0..(m - k) {
dot += q[[i, jj + k]] * v[jj];
}
for jj in 0..(m - k) {
q[[i, jj + k]] -= beta * dot * v[jj];
}
}
for j in (k + 1)..n {
let rk = r[[k, j]];
col_norms_sq[j] -= rk * rk;
if col_norms_sq[j] < F::zero() {
col_norms_sq[j] = F::zero();
for i in (k + 1)..m {
col_norms_sq[j] += r[[i, j]] * r[[i, j]];
}
}
}
let r00_abs = r[[0, 0]].abs();
let rkk_abs = r[[k, k]].abs();
if r00_abs > F::epsilon() && rkk_abs / r00_abs >= tol {
rank = k + 1;
} else if r00_abs <= F::epsilon() {
if rkk_abs > tol {
rank = k + 1;
}
}
}
if rank == 0 && min_dim > 0 && r[[0, 0]].abs() > tol {
rank = 1;
}
Ok(RRQRResult { q, r, perm, rank })
}
pub fn rrqr_rank<F>(a: &ArrayView2<F>, tol: F) -> LinalgResult<usize>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
Ok(rrqr(a, tol)?.rank)
}
pub fn strong_rrqr<F>(a: &ArrayView2<F>, tol: F, max_swaps: usize) -> LinalgResult<RRQRResult<F>>
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + Send + Sync + 'static,
{
let mut result = rrqr(a, tol)?;
if max_swaps == 0 || result.rank == 0 {
return Ok(result);
}
let (m, n) = a.dim();
let k = result.rank;
let two = F::from(2.0).unwrap_or_else(|| F::one() + F::one());
for _swap_iter in 0..max_swaps {
let mut min_diag = result.r[[0, 0]].abs();
let mut min_diag_idx = 0usize;
for i in 1..k {
let d = result.r[[i, i]].abs();
if d < min_diag {
min_diag = d;
min_diag_idx = i;
}
}
let mut max_col_norm_sq = F::zero();
let mut max_col_idx = k;
for j in k..n {
let mut norm_sq = F::zero();
for i in k..m {
norm_sq += result.r[[i, j]] * result.r[[i, j]];
}
let mut r12_norm_sq = F::zero();
for i in 0..k {
r12_norm_sq += result.r[[i, j]] * result.r[[i, j]];
}
let total = norm_sq + r12_norm_sq;
if total > max_col_norm_sq {
max_col_norm_sq = total;
max_col_idx = j;
}
}
let max_col_norm = max_col_norm_sq.sqrt();
if max_col_norm <= min_diag * two {
break;
}
if min_diag_idx != max_col_idx {
for i in 0..m {
let tmp = result.r[[i, min_diag_idx]];
result.r[[i, min_diag_idx]] = result.r[[i, max_col_idx]];
result.r[[i, max_col_idx]] = tmp;
}
result.perm.swap(min_diag_idx, max_col_idx);
retriangularize_givens(&mut result.q, &mut result.r, min_diag_idx, m);
}
let r00 = result.r[[0, 0]].abs();
let mut new_rank = 0;
for i in 0..k.min(m.min(n)) {
if r00 > F::epsilon() && result.r[[i, i]].abs() / r00 >= tol {
new_rank = i + 1;
}
}
result.rank = new_rank;
}
Ok(result)
}
fn retriangularize_givens<F>(q: &mut Array2<F>, r: &mut Array2<F>, col: usize, m: usize)
where
F: Float + NumAssign + Sum + Debug + ScalarOperand + 'static,
{
let n = r.ncols();
let two = F::from(2.0).unwrap_or_else(|| F::one() + F::one());
let _ = two;
for i in (col + 1)..m {
let a_val = r[[col, col]];
let b_val = r[[i, col]];
if b_val.abs() <= F::epsilon() {
continue;
}
let rr = (a_val * a_val + b_val * b_val).sqrt();
let c = a_val / rr;
let s = b_val / rr;
for j in 0..n {
let r_col_j = r[[col, j]];
let r_i_j = r[[i, j]];
r[[col, j]] = c * r_col_j + s * r_i_j;
r[[i, j]] = -s * r_col_j + c * r_i_j;
}
let q_rows = q.nrows();
for row in 0..q_rows {
let q_r_col = q[[row, col]];
let q_r_i = q[[row, i]];
q[[row, col]] = c * q_r_col + s * q_r_i;
q[[row, i]] = -s * q_r_col + c * q_r_i;
}
}
}
pub fn perm_to_matrix<F>(perm: &[usize]) -> Array2<F>
where
F: Float,
{
let n = perm.len();
let mut p = Array2::<F>::zeros((n, n));
for (j, &orig) in perm.iter().enumerate() {
p[[orig, j]] = F::one();
}
p
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn frob_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
#[test]
fn test_rrqr_full_rank_square() {
let a = array![
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 10.0] ];
let res = rrqr(&a.view(), 1e-12).expect("rrqr failed");
assert_eq!(res.rank, 3, "should detect full rank");
assert_eq!(res.q.shape(), &[3, 3]);
assert_eq!(res.r.shape(), &[3, 3]);
assert_eq!(res.perm.len(), 3);
let qtq = res.q.t().dot(&res.q);
let eye3 = Array2::<f64>::eye(3);
assert!(frob_diff(&qtq, &eye3) < 1e-10, "Q must be orthogonal");
let p = perm_to_matrix::<f64>(&res.perm);
let qr = res.q.dot(&res.r);
let qr_pt = qr.dot(&p.t());
assert!(
frob_diff(&qr_pt, &a.to_owned()) < 1e-10,
"A = Q R P^T reconstruction"
);
}
#[test]
fn test_rrqr_rank_deficient() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [5.0, 7.0, 9.0]];
let res = rrqr(&a.view(), 1e-10).expect("rrqr failed");
assert_eq!(res.rank, 2, "rank should be 2");
let ratio = res.r[[2, 2]].abs() / res.r[[0, 0]].abs();
assert!(ratio < 1e-10, "R[2,2]/R[0,0] = {ratio} should be tiny");
}
#[test]
fn test_rrqr_column_pivoting_improves_conditioning() {
let a = array![[1e-15, 1.0, 0.0], [1e-15, 0.0, 1.0], [1e-15, 0.0, 0.0]];
let res = rrqr(&a.view(), 1e-12).expect("rrqr failed");
assert_ne!(
res.perm[0], 0,
"pivot should not choose the tiny column first"
);
}
#[test]
fn test_rrqr_reconstruction_rectangular() {
let a = array![
[1.0, 0.0, 1.0],
[0.0, 1.0, 1.0],
[1.0, 1.0, 2.0],
[2.0, 1.0, 3.0]
];
let res = rrqr(&a.view(), 1e-10).expect("rrqr failed");
let p = perm_to_matrix::<f64>(&res.perm);
let reconstructed = res.q.dot(&res.r).dot(&p.t());
let diff = frob_diff(&reconstructed, &a.to_owned());
assert!(diff < 1e-10, "reconstruction error = {diff}");
assert_eq!(res.rank, 2, "rectangular matrix has rank 2");
}
#[test]
fn test_rrqr_rank_convenience() {
let a = array![
[1.0, 2.0],
[2.0, 4.0] ];
let r = rrqr_rank(&a.view(), 1e-10).expect("rrqr_rank failed");
assert_eq!(r, 1);
}
#[test]
fn test_rrqr_identity() {
let eye = Array2::<f64>::eye(4);
let res = rrqr(&eye.view(), 1e-12).expect("rrqr failed");
assert_eq!(res.rank, 4);
}
#[test]
fn test_rrqr_zero_matrix() {
let z = Array2::<f64>::zeros((3, 3));
let res = rrqr(&z.view(), 1e-12).expect("rrqr failed");
assert_eq!(res.rank, 0, "zero matrix has rank 0");
}
#[test]
fn test_rrqr_empty_matrix_error() {
let e = Array2::<f64>::zeros((0, 0));
assert!(rrqr(&e.view(), 1e-12).is_err());
}
#[test]
fn test_strong_rrqr_basic() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [5.0, 7.0, 9.0]];
let res = strong_rrqr(&a.view(), 1e-10, 5).expect("strong_rrqr failed");
assert_eq!(res.rank, 2);
let p = perm_to_matrix::<f64>(&res.perm);
let reconstructed = res.q.dot(&res.r).dot(&p.t());
let diff = frob_diff(&reconstructed, &a.to_owned());
assert!(diff < 1e-8, "strong RRQR reconstruction error = {diff}");
}
#[test]
fn test_strong_rrqr_improves_over_basic() {
let a = array![[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 2.0]];
let basic = rrqr(&a.view(), 1e-10).expect("rrqr failed");
let strong = strong_rrqr(&a.view(), 1e-10, 10).expect("strong rrqr failed");
assert_eq!(basic.rank, 2);
assert_eq!(strong.rank, 2);
}
#[test]
fn test_perm_to_matrix() {
let perm = vec![2, 0, 1];
let p = perm_to_matrix::<f64>(&perm);
assert!((p[[2, 0]] - 1.0).abs() < 1e-15);
assert!((p[[0, 1]] - 1.0).abs() < 1e-15);
assert!((p[[1, 2]] - 1.0).abs() < 1e-15);
}
#[test]
fn test_rrqr_wide_matrix() {
let a = array![[1.0, 0.0, 1.0, 2.0], [0.0, 1.0, 1.0, 3.0]];
let res = rrqr(&a.view(), 1e-10).expect("rrqr failed");
assert_eq!(res.rank, 2);
let p = perm_to_matrix::<f64>(&res.perm);
let reconstructed = res.q.dot(&res.r).dot(&p.t());
let diff = frob_diff(&reconstructed, &a.to_owned());
assert!(diff < 1e-10, "wide matrix reconstruction error = {diff}");
}
}