use crate::error::{SparseError, SparseResult};
use super::types::{estimate_condition, LUUpdateResult, LowRankUpdateConfig};
fn mat_vec_mul(a: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
a.iter()
.map(|row| row.iter().zip(x.iter()).map(|(&ai, &xi)| ai * xi).sum())
.collect()
}
fn mat_mat_mul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
let m = a.len();
if m == 0 {
return Vec::new();
}
let p = b.first().map_or(0, |r| r.len());
let k = b.len();
let mut c = vec![vec![0.0; p]; m];
for i in 0..m {
for j in 0..p {
let mut s = 0.0;
for t in 0..k {
s += a[i][t] * b[t][j];
}
c[i][j] = s;
}
}
c
}
fn permute_vector(x: &[f64], p: &[usize]) -> Vec<f64> {
p.iter().map(|&pi| x[pi]).collect()
}
fn forward_solve(l: &[Vec<f64>], b: &[f64]) -> SparseResult<Vec<f64>> {
let n = l.len();
let mut y = vec![0.0; n];
for i in 0..n {
let mut s = b[i];
for j in 0..i {
s -= l[i][j] * y[j];
}
if l[i][i].abs() < 1e-15 {
return Err(SparseError::SingularMatrix(format!(
"Zero diagonal at position {} during forward solve",
i
)));
}
y[i] = s / l[i][i];
}
Ok(y)
}
fn back_solve(u: &[Vec<f64>], b: &[f64]) -> SparseResult<Vec<f64>> {
let n = u.len();
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut s = b[i];
for j in (i + 1)..n {
s -= u[i][j] * x[j];
}
if u[i][i].abs() < 1e-15 {
return Err(SparseError::SingularMatrix(format!(
"Zero diagonal at position {} during back solve",
i
)));
}
x[i] = s / u[i][i];
}
Ok(x)
}
fn back_solve_transpose(u: &[Vec<f64>], b: &[f64]) -> SparseResult<Vec<f64>> {
let n = u.len();
let mut x = vec![0.0; n];
for i in 0..n {
let mut s = b[i];
for j in 0..i {
s -= u[j][i] * x[j];
}
if u[i][i].abs() < 1e-15 {
return Err(SparseError::SingularMatrix(format!(
"Zero diagonal at position {} during transpose back solve",
i
)));
}
x[i] = s / u[i][i];
}
Ok(x)
}
fn invert_dense(a: &[Vec<f64>]) -> SparseResult<Vec<Vec<f64>>> {
let n = a.len();
if n == 0 {
return Ok(Vec::new());
}
let mut aug = vec![vec![0.0; 2 * n]; n];
for i in 0..n {
for j in 0..n {
aug[i][j] = a[i][j];
}
aug[i][n + i] = 1.0;
}
for col in 0..n {
let mut max_val = aug[col][col].abs();
let mut max_row = col;
for row in (col + 1)..n {
if aug[row][col].abs() > max_val {
max_val = aug[row][col].abs();
max_row = row;
}
}
if max_val < 1e-15 {
return Err(SparseError::SingularMatrix(
"Matrix is singular, cannot invert".into(),
));
}
if max_row != col {
aug.swap(col, max_row);
}
let pivot = aug[col][col];
for j in 0..(2 * n) {
aug[col][j] /= pivot;
}
for row in 0..n {
if row == col {
continue;
}
let factor = aug[row][col];
for j in 0..(2 * n) {
aug[row][j] -= factor * aug[col][j];
}
}
}
let mut inv = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
inv[i][j] = aug[i][n + j];
}
}
Ok(inv)
}
pub fn sherman_morrison_woodbury(
a_inv: &[Vec<f64>],
u_mat: &[Vec<f64>],
c_mat: &[Vec<f64>],
v_mat: &[Vec<f64>],
) -> SparseResult<Vec<Vec<f64>>> {
let n = a_inv.len();
for (i, row) in a_inv.iter().enumerate() {
if row.len() != n {
return Err(SparseError::ComputationError(format!(
"A_inv row {} has length {} but expected {}",
i,
row.len(),
n
)));
}
}
if u_mat.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: u_mat.len(),
});
}
let k = u_mat.first().map_or(0, |r| r.len());
if c_mat.len() != k {
return Err(SparseError::DimensionMismatch {
expected: k,
found: c_mat.len(),
});
}
for (i, row) in c_mat.iter().enumerate() {
if row.len() != k {
return Err(SparseError::ComputationError(format!(
"C row {} has length {} but expected {}",
i,
row.len(),
k
)));
}
}
if v_mat.len() != k {
return Err(SparseError::DimensionMismatch {
expected: k,
found: v_mat.len(),
});
}
for (i, row) in v_mat.iter().enumerate() {
if row.len() != n {
return Err(SparseError::ComputationError(format!(
"V row {} has length {} but expected {}",
i,
row.len(),
n
)));
}
}
if n == 0 || k == 0 {
return Ok(a_inv.to_vec());
}
let a_inv_u = mat_mat_mul(a_inv, u_mat);
let v_a_inv = mat_mat_mul(v_mat, a_inv);
let v_a_inv_u = mat_mat_mul(v_mat, &a_inv_u);
let c_inv = invert_dense(c_mat)?;
let mut inner = vec![vec![0.0; k]; k];
for i in 0..k {
for j in 0..k {
inner[i][j] = c_inv[i][j] + v_a_inv_u[i][j];
}
}
let inner_inv = invert_dense(&inner)?;
let temp = mat_mat_mul(&a_inv_u, &inner_inv); let correction = mat_mat_mul(&temp, &v_a_inv);
let mut result = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
result[i][j] = a_inv[i][j] - correction[i][j];
}
}
Ok(result)
}
pub fn lu_rank1_update(
l: &[Vec<f64>],
u: &[Vec<f64>],
p: &[usize],
u_vec: &[f64],
v_vec: &[f64],
) -> SparseResult<LUUpdateResult> {
lu_rank1_update_with_config(l, u, p, u_vec, v_vec, &LowRankUpdateConfig::default())
}
pub fn lu_rank1_update_with_config(
l: &[Vec<f64>],
u: &[Vec<f64>],
p: &[usize],
u_vec: &[f64],
v_vec: &[f64],
config: &LowRankUpdateConfig,
) -> SparseResult<LUUpdateResult> {
let n = l.len();
if u.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: u.len(),
});
}
if p.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: p.len(),
});
}
if u_vec.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: u_vec.len(),
});
}
if v_vec.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: v_vec.len(),
});
}
if n == 0 {
return Ok(LUUpdateResult {
l: Vec::new(),
u: Vec::new(),
p: Vec::new(),
success: true,
condition_estimate: 1.0,
});
}
let lu = mat_mat_mul(l, u);
let mut p_inv = vec![0usize; n];
for (i, &pi) in p.iter().enumerate() {
if pi < n {
p_inv[pi] = i;
}
}
let mut a_prime = vec![vec![0.0; n]; n];
for j in 0..n {
for col in 0..n {
a_prime[j][col] = lu[p_inv[j]][col] + u_vec[j] * v_vec[col];
}
}
let mut u_new = a_prime;
let mut l_new = vec![vec![0.0; n]; n];
let mut p_new: Vec<usize> = (0..n).collect();
for i in 0..n {
l_new[i][i] = 1.0;
}
for col in 0..n {
let mut max_val = u_new[col][col].abs();
let mut max_row = col;
for row in (col + 1)..n {
if u_new[row][col].abs() > max_val {
max_val = u_new[row][col].abs();
max_row = row;
}
}
if max_row != col {
u_new.swap(col, max_row);
p_new.swap(col, max_row);
for j in 0..col {
let tmp = l_new[col][j];
l_new[col][j] = l_new[max_row][j];
l_new[max_row][j] = tmp;
}
}
if u_new[col][col].abs() < config.tolerance {
continue;
}
for row in (col + 1)..n {
let factor = u_new[row][col] / u_new[col][col];
l_new[row][col] = factor;
u_new[row][col] = 0.0;
for j in (col + 1)..n {
u_new[row][j] -= factor * u_new[col][j];
}
}
}
let cond = estimate_condition(&u_new, config.tolerance);
Ok(LUUpdateResult {
l: l_new,
u: u_new,
p: p_new,
success: true,
condition_estimate: cond,
})
}
pub fn lu_column_replace(
l: &[Vec<f64>],
u: &[Vec<f64>],
p: &[usize],
col_idx: usize,
new_col: &[f64],
) -> SparseResult<LUUpdateResult> {
let n = l.len();
if col_idx >= n {
return Err(SparseError::IndexOutOfBounds {
index: (0, col_idx),
shape: (n, n),
});
}
if new_col.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: new_col.len(),
});
}
let mut lu_col = vec![0.0; n];
for i in 0..n {
let mut s = 0.0;
for t in 0..n {
s += l[i][t] * u[t][col_idx];
}
lu_col[i] = s;
}
let mut p_inv = vec![0usize; n];
for (i, &pi) in p.iter().enumerate() {
if pi < n {
p_inv[pi] = i;
}
}
let mut old_col = vec![0.0; n];
for j in 0..n {
old_col[j] = lu_col[p_inv[j]];
}
let d: Vec<f64> = new_col
.iter()
.zip(old_col.iter())
.map(|(a, b)| a - b)
.collect();
let mut e_k = vec![0.0; n];
e_k[col_idx] = 1.0;
lu_rank1_update(l, u, p, &d, &e_k)
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::type_complexity)]
fn simple_lu(a: &[Vec<f64>]) -> SparseResult<(Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<usize>)> {
let n = a.len();
let mut l = vec![vec![0.0; n]; n];
let mut u_mat = vec![vec![0.0; n]; n];
let p: Vec<usize> = (0..n).collect();
for i in 0..n {
for j in 0..n {
u_mat[i][j] = a[i][j];
}
}
for i in 0..n {
l[i][i] = 1.0;
}
for col in 0..n {
if u_mat[col][col].abs() < 1e-15 {
return Err(SparseError::SingularMatrix("Zero pivot".into()));
}
for row in (col + 1)..n {
let factor = u_mat[row][col] / u_mat[col][col];
l[row][col] = factor;
for j in col..n {
u_mat[row][j] -= factor * u_mat[col][j];
}
}
}
Ok((l, u_mat, p))
}
fn reconstruct(l: &[Vec<f64>], u: &[Vec<f64>], p: &[usize]) -> Vec<Vec<f64>> {
let n = l.len();
let lu = mat_mat_mul(l, u);
let mut p_inv = vec![0usize; n];
for (i, &pi) in p.iter().enumerate() {
p_inv[pi] = i;
}
let mut a = vec![vec![0.0; n]; n];
for j in 0..n {
for col in 0..n {
a[j][col] = lu[p_inv[j]][col];
}
}
a
}
#[test]
fn test_sherman_morrison_woodbury_identity() {
let a_inv = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let u_mat = vec![vec![1.0], vec![0.0]];
let c_mat = vec![vec![1.0]];
let v_mat = vec![vec![0.0, 1.0]];
let result = sherman_morrison_woodbury(&a_inv, &u_mat, &c_mat, &v_mat)
.expect("woodbury should succeed");
assert!((result[0][0] - 1.0).abs() < 1e-10);
assert!((result[0][1] - (-1.0)).abs() < 1e-10);
assert!((result[1][0]).abs() < 1e-10);
assert!((result[1][1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_sherman_morrison_woodbury_3x3() {
let a_inv = vec![
vec![0.5, 0.0, 0.0],
vec![0.0, 1.0 / 3.0, 0.0],
vec![0.0, 0.0, 0.25],
];
let u_mat = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]];
let c_mat = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let v_mat = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let result = sherman_morrison_woodbury(&a_inv, &u_mat, &c_mat, &v_mat)
.expect("woodbury 3x3 should succeed");
assert!((result[0][0] - 1.0 / 3.0).abs() < 1e-10);
assert!((result[1][1] - 0.25).abs() < 1e-10);
assert!((result[2][2] - 0.25).abs() < 1e-10);
assert!(result[0][1].abs() < 1e-10);
}
#[test]
fn test_lu_rank1_update_identity() {
let l = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let u_mat = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let p = vec![0, 1];
let u_vec = vec![1.0, 0.0];
let v_vec = vec![0.0, 1.0];
let result =
lu_rank1_update(&l, &u_mat, &p, &u_vec, &v_vec).expect("lu update on identity");
let a_prime = reconstruct(&result.l, &result.u, &result.p);
assert!((a_prime[0][0] - 1.0).abs() < 1e-10);
assert!((a_prime[0][1] - 1.0).abs() < 1e-10);
assert!((a_prime[1][0]).abs() < 1e-10);
assert!((a_prime[1][1] - 1.0).abs() < 1e-10);
assert!(result.success);
}
#[test]
fn test_lu_rank1_update_factors_correctly() {
let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
let (l, u_mat, p) = simple_lu(&a).expect("initial LU");
let u_vec = vec![1.0, 0.5];
let v_vec = vec![0.5, 1.0];
let result = lu_rank1_update(&l, &u_mat, &p, &u_vec, &v_vec).expect("rank-1 update");
let a_prime = reconstruct(&result.l, &result.u, &result.p);
for i in 0..2 {
for j in 0..2 {
let expected = a[i][j] + u_vec[i] * v_vec[j];
assert!(
(a_prime[i][j] - expected).abs() < 1e-9,
"Mismatch at ({},{}): {} vs {}",
i,
j,
a_prime[i][j],
expected
);
}
}
}
#[test]
fn test_lu_column_replace() {
let a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
let (l, u_mat, p) = simple_lu(&a).expect("initial LU");
let new_col = vec![5.0, 2.0];
let result = lu_column_replace(&l, &u_mat, &p, 0, &new_col).expect("column replace");
let a_prime = reconstruct(&result.l, &result.u, &result.p);
assert!((a_prime[0][0] - 5.0).abs() < 1e-9);
assert!((a_prime[1][0] - 2.0).abs() < 1e-9);
assert!((a_prime[0][1] - 1.0).abs() < 1e-9);
assert!((a_prime[1][1] - 3.0).abs() < 1e-9);
}
#[test]
fn test_lu_dimension_mismatch() {
let l = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let u_mat = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let p = vec![0, 1];
let u_vec = vec![1.0]; let v_vec = vec![0.0, 1.0];
let result = lu_rank1_update(&l, &u_mat, &p, &u_vec, &v_vec);
assert!(result.is_err());
}
#[test]
fn test_smw_dimension_mismatch() {
let a_inv = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let u_mat = vec![vec![1.0]]; let c_mat = vec![vec![1.0]];
let v_mat = vec![vec![1.0, 0.0]];
let result = sherman_morrison_woodbury(&a_inv, &u_mat, &c_mat, &v_mat);
assert!(result.is_err());
}
}