use crate::error::{SparseError, SparseResult};
use super::types::QRUpdateResult;
type DenseMatrixPair = (Vec<Vec<f64>>, Vec<Vec<f64>>);
fn qr_factorize_dense(a: &[Vec<f64>], m: usize, n: usize) -> SparseResult<DenseMatrixPair> {
let k = m.min(n);
let mut cols: Vec<Vec<f64>> = (0..n).map(|j| (0..m).map(|i| a[i][j]).collect()).collect();
let mut q_cols: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut r = vec![vec![0.0; n]; m];
for j in 0..k {
let mut v = cols[j].clone();
for (qi, q_col) in q_cols.iter().enumerate() {
let dot: f64 = v.iter().zip(q_col.iter()).map(|(&a, &b)| a * b).sum();
r[qi][j] = dot;
for (vi, &qi_val) in v.iter_mut().zip(q_col.iter()) {
*vi -= dot * qi_val;
}
}
let norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm < 1e-14 {
r[j][j] = 0.0;
q_cols.push(vec![0.0; m]);
} else {
r[j][j] = norm;
for vi in &mut v {
*vi /= norm;
}
q_cols.push(v);
}
}
for j in k..n {
for (qi, q_col) in q_cols.iter().enumerate() {
let dot: f64 = cols[j].iter().zip(q_col.iter()).map(|(&a, &b)| a * b).sum();
r[qi][j] = dot;
for (ci, &qi_val) in cols[j].iter_mut().zip(q_col.iter()) {
*ci -= dot * qi_val;
}
}
}
for extra in 0..m {
if q_cols.len() >= m {
break;
}
let mut e = vec![0.0; m];
e[extra] = 1.0;
for q_col in &q_cols {
let dot: f64 = e.iter().zip(q_col.iter()).map(|(&a, &b)| a * b).sum();
for (ei, &qi) in e.iter_mut().zip(q_col.iter()) {
*ei -= dot * qi;
}
}
let norm: f64 = e.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > 1e-10 {
for ei in &mut e {
*ei /= norm;
}
q_cols.push(e);
}
}
let mut q_out = vec![vec![0.0; m]; m];
for i in 0..m {
for j in 0..q_cols.len().min(m) {
q_out[i][j] = q_cols[j][i];
}
}
Ok((q_out, r))
}
pub fn givens_rotation(a: f64, b: f64) -> (f64, f64, f64) {
if b.abs() < 1e-15 && a.abs() < 1e-15 {
return (1.0, 0.0, 0.0);
}
let r = a.hypot(b);
let c = a / r;
let s = b / r;
(c, s, r)
}
pub fn apply_givens_left(matrix: &mut [Vec<f64>], i: usize, k: usize, c: f64, s: f64) {
let ncols = matrix[i].len().min(matrix[k].len());
for j in 0..ncols {
let a = matrix[i][j];
let b = matrix[k][j];
matrix[i][j] = c * a + s * b;
matrix[k][j] = -s * a + c * b;
}
}
pub fn apply_givens_right(matrix: &mut [Vec<f64>], i: usize, k: usize, c: f64, s: f64) {
let nrows = matrix.len();
for j in 0..nrows {
if i < matrix[j].len() && k < matrix[j].len() {
let a = matrix[j][i];
let b = matrix[j][k];
matrix[j][i] = c * a + s * b;
matrix[j][k] = -s * a + c * b;
}
}
}
pub fn qr_rank1_update(
q: &[Vec<f64>],
r: &[Vec<f64>],
u_vec: &[f64],
v_vec: &[f64],
) -> SparseResult<QRUpdateResult> {
let m = q.len(); if m == 0 {
return Ok(QRUpdateResult {
q: Vec::new(),
r: Vec::new(),
success: true,
});
}
let n = r.first().map_or(0, |row| row.len());
if u_vec.len() != m {
return Err(SparseError::DimensionMismatch {
expected: m,
found: u_vec.len(),
});
}
if v_vec.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: v_vec.len(),
});
}
for (i, row) in q.iter().enumerate() {
if row.len() != m {
return Err(SparseError::ComputationError(format!(
"Q row {} has length {} but expected {}",
i,
row.len(),
m
)));
}
}
if r.len() != m {
return Err(SparseError::DimensionMismatch {
expected: m,
found: r.len(),
});
}
let mut a_prime = vec![vec![0.0; n]; m];
for i in 0..m {
for j in 0..n {
let mut s = 0.0;
for t in 0..m {
s += q[i][t] * r[t][j];
}
a_prime[i][j] = s + u_vec[i] * v_vec[j];
}
}
let (q_new, r_new) = qr_factorize_dense(&a_prime, m, n)?;
Ok(QRUpdateResult {
q: q_new,
r: r_new,
success: true,
})
}
pub fn qr_column_insert(
q: &[Vec<f64>],
r: &[Vec<f64>],
col_idx: usize,
new_col: &[f64],
) -> SparseResult<QRUpdateResult> {
let m = q.len();
if m == 0 {
return Ok(QRUpdateResult {
q: Vec::new(),
r: Vec::new(),
success: true,
});
}
let n = r.first().map_or(0, |row| row.len());
if col_idx > n {
return Err(SparseError::IndexOutOfBounds {
index: (0, col_idx),
shape: (m, n + 1),
});
}
if new_col.len() != m {
return Err(SparseError::DimensionMismatch {
expected: m,
found: new_col.len(),
});
}
let n_new = n + 1;
let mut a_prime = vec![vec![0.0; n_new]; m];
for i in 0..m {
for j in 0..col_idx {
let mut s = 0.0;
for t in 0..m {
s += q[i][t] * r[t][j];
}
a_prime[i][j] = s;
}
a_prime[i][col_idx] = new_col[i];
for j in col_idx..n {
let mut s = 0.0;
for t in 0..m {
s += q[i][t] * r[t][j];
}
a_prime[i][j + 1] = s;
}
}
let (q_new, r_new) = qr_factorize_dense(&a_prime, m, n_new)?;
Ok(QRUpdateResult {
q: q_new,
r: r_new,
success: true,
})
}
pub fn qr_column_delete(
q: &[Vec<f64>],
r: &[Vec<f64>],
col_idx: usize,
) -> SparseResult<QRUpdateResult> {
let m = q.len();
if m == 0 {
return Ok(QRUpdateResult {
q: Vec::new(),
r: Vec::new(),
success: true,
});
}
let n = r.first().map_or(0, |row| row.len());
if col_idx >= n {
return Err(SparseError::IndexOutOfBounds {
index: (0, col_idx),
shape: (m, n),
});
}
if n == 0 {
return Err(SparseError::ComputationError(
"Cannot delete column from empty matrix".into(),
));
}
let n_new = n - 1;
let mut a_prime = vec![vec![0.0; n_new]; m];
for i in 0..m {
for j in 0..col_idx {
let mut s = 0.0;
for t in 0..m {
s += q[i][t] * r[t][j];
}
a_prime[i][j] = s;
}
for j in (col_idx + 1)..n {
let mut s = 0.0;
for t in 0..m {
s += q[i][t] * r[t][j];
}
a_prime[i][j - 1] = s;
}
}
let (q_new, r_new) = qr_factorize_dense(&a_prime, m, n_new)?;
Ok(QRUpdateResult {
q: q_new,
r: r_new,
success: true,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn check_orthogonal(q: &[Vec<f64>], tol: f64) -> bool {
let m = q.len();
for i in 0..m {
for j in 0..m {
let mut dot = 0.0;
for k in 0..m {
dot += q[k][i] * q[k][j];
}
let expected = if i == j { 1.0 } else { 0.0 };
if (dot - expected).abs() > tol {
return false;
}
}
}
true
}
fn check_upper_triangular(r: &[Vec<f64>], tol: f64) -> bool {
let n = r.first().map_or(0, |row| row.len());
for i in 0..r.len() {
for j in 0..i.min(n) {
if r[i][j].abs() > tol {
return false;
}
}
}
true
}
fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
let m = a.len();
if m == 0 {
return Vec::new();
}
let p = b.first().map_or(0, |r| r.len());
let k = b.len();
let mut c = vec![vec![0.0; p]; m];
for i in 0..m {
for j in 0..p {
for t in 0..k {
c[i][j] += a[i][t] * b[t][j];
}
}
}
c
}
#[test]
fn test_givens_rotation_basic() {
let (c, s, r) = givens_rotation(3.0, 4.0);
assert!((r - 5.0).abs() < 1e-10);
assert!((c * 3.0 + s * 4.0 - r).abs() < 1e-10);
assert!((-s * 3.0 + c * 4.0).abs() < 1e-10);
}
#[test]
fn test_givens_rotation_zeros() {
let (c, s, r) = givens_rotation(0.0, 0.0);
assert!((c - 1.0).abs() < 1e-10);
assert!(s.abs() < 1e-10);
assert!(r.abs() < 1e-10);
}
#[test]
fn test_qr_rank1_update_product() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![3.0, 1.0], vec![0.0, 2.0]];
let u_vec = vec![1.0, 0.5];
let v_vec = vec![0.5, 1.0];
let result = qr_rank1_update(&q, &r, &u_vec, &v_vec).expect("qr rank1 update");
let qr_product = mat_mul(&result.q, &result.r);
let a_prime = [
vec![3.0 + 1.0 * 0.5, 1.0 + 1.0 * 1.0],
vec![0.0 + 0.5 * 0.5, 2.0 + 0.5 * 1.0],
];
for i in 0..2 {
for j in 0..2 {
assert!(
(qr_product[i][j] - a_prime[i][j]).abs() < 1e-9,
"QR product mismatch at ({},{}): {} vs {}",
i,
j,
qr_product[i][j],
a_prime[i][j]
);
}
}
}
#[test]
fn test_qr_update_orthogonality() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![3.0, 1.0], vec![0.0, 2.0]];
let u_vec = vec![1.0, 0.5];
let v_vec = vec![0.5, 1.0];
let result = qr_rank1_update(&q, &r, &u_vec, &v_vec).expect("qr update");
assert!(
check_orthogonal(&result.q, 1e-10),
"Q should be orthogonal after update"
);
}
#[test]
fn test_qr_update_upper_triangular() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![3.0, 1.0], vec![0.0, 2.0]];
let u_vec = vec![1.0, 0.5];
let v_vec = vec![0.5, 1.0];
let result = qr_rank1_update(&q, &r, &u_vec, &v_vec).expect("qr update");
assert!(
check_upper_triangular(&result.r, 1e-10),
"R should be upper triangular after update"
);
}
#[test]
fn test_qr_column_insert_dimensions() {
let q = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let r = vec![vec![2.0, 1.0], vec![0.0, 3.0], vec![0.0, 0.0]];
let new_col = vec![1.0, 2.0, 3.0];
let result = qr_column_insert(&q, &r, 1, &new_col).expect("column insert");
assert_eq!(result.r.len(), 3);
assert_eq!(result.r[0].len(), 3);
assert!(result.success);
}
#[test]
fn test_qr_column_insert_factorization() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![3.0], vec![0.0]];
let new_col = vec![1.0, 2.0];
let result = qr_column_insert(&q, &r, 0, &new_col).expect("column insert");
let qr_product = mat_mul(&result.q, &result.r);
assert!((qr_product[0][0] - 1.0).abs() < 1e-9);
assert!((qr_product[0][1] - 3.0).abs() < 1e-9);
assert!((qr_product[1][0] - 2.0).abs() < 1e-9);
assert!((qr_product[1][1]).abs() < 1e-9);
}
#[test]
fn test_qr_column_delete_dimensions() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![2.0, 1.0], vec![0.0, 3.0]];
let result = qr_column_delete(&q, &r, 0).expect("column delete");
assert_eq!(result.r.len(), 2);
assert_eq!(result.r[0].len(), 1);
assert!(result.success);
}
#[test]
fn test_qr_column_delete_factorization() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![2.0, 1.0], vec![0.0, 3.0]];
let result = qr_column_delete(&q, &r, 0).expect("column delete");
let qr_product = mat_mul(&result.q, &result.r);
assert!((qr_product[0][0] - 1.0).abs() < 1e-9);
assert!((qr_product[1][0] - 3.0).abs() < 1e-9);
}
#[test]
fn test_sequential_qr_updates() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r1 = qr_rank1_update(&q, &r, &[1.0, 0.0], &[1.0, 0.0]).expect("update 1");
let r2 = qr_rank1_update(&r1.q, &r1.r, &[0.0, 1.0], &[0.0, 1.0]).expect("update 2");
let product = mat_mul(&r2.q, &r2.r);
assert!((product[0][0] - 2.0).abs() < 1e-9);
assert!((product[1][1] - 2.0).abs() < 1e-9);
assert!(product[0][1].abs() < 1e-9);
assert!(product[1][0].abs() < 1e-9);
}
#[test]
fn test_qr_rank1_dimension_mismatch() {
let q = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let r = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let u_vec = vec![1.0]; let v_vec = vec![1.0, 0.0];
let result = qr_rank1_update(&q, &r, &u_vec, &v_vec);
assert!(result.is_err());
}
}