use pounce_common::types::Number;
pub trait SensBacksolver {
fn dim(&self) -> usize;
fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool;
}
#[derive(Debug, Clone)]
pub struct DenseLuBacksolver {
n: usize,
lu: Vec<Number>,
piv: Vec<usize>,
}
impl DenseLuBacksolver {
pub fn from_dense(n: usize, a_row_major: &[Number]) -> Result<Self, ()> {
if a_row_major.len() != n * n {
return Err(());
}
let mut lu = a_row_major.to_vec();
let mut piv: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut best = k;
let mut best_mag = lu[k * n + k].abs();
for r in (k + 1)..n {
let mag = lu[r * n + k].abs();
if mag > best_mag {
best = r;
best_mag = mag;
}
}
if best_mag == 0.0 {
return Err(());
}
if best != k {
piv.swap(k, best);
for j in 0..n {
let tmp = lu[k * n + j];
lu[k * n + j] = lu[best * n + j];
lu[best * n + j] = tmp;
}
}
let inv_p = 1.0 / lu[k * n + k];
for r in (k + 1)..n {
let m = lu[r * n + k] * inv_p;
lu[r * n + k] = m;
for j in (k + 1)..n {
let upd = lu[k * n + j];
lu[r * n + j] -= m * upd;
}
}
}
Ok(Self { n, lu, piv })
}
}
impl SensBacksolver for DenseLuBacksolver {
fn dim(&self) -> usize {
self.n
}
fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
if rhs.len() != self.n || lhs.len() != self.n {
return false;
}
for i in 0..self.n {
lhs[i] = rhs[self.piv[i]];
}
for i in 0..self.n {
let mut s = lhs[i];
for j in 0..i {
s -= self.lu[i * self.n + j] * lhs[j];
}
lhs[i] = s;
}
for i in (0..self.n).rev() {
let mut s = lhs[i];
for j in (i + 1)..self.n {
s -= self.lu[i * self.n + j] * lhs[j];
}
let p = self.lu[i * self.n + i];
if p == 0.0 {
return false;
}
lhs[i] = s / p;
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dense_lu_solves_3x3_symmetric() {
#[rustfmt::skip]
let a = vec![
2.0, -1.0, 0.0,
-1.0, 2.0, -1.0,
0.0, -1.0, 2.0,
];
let solver = DenseLuBacksolver::from_dense(3, &a).expect("factor");
let b = [1.0, 0.0, 0.0];
let mut x = [0.0; 3];
assert!(solver.solve(&b, &mut x));
assert!((x[0] - 0.75).abs() < 1e-12, "x[0] = {}", x[0]);
assert!((x[1] - 0.50).abs() < 1e-12, "x[1] = {}", x[1]);
assert!((x[2] - 0.25).abs() < 1e-12, "x[2] = {}", x[2]);
}
#[test]
fn dense_lu_handles_zero_first_pivot() {
let a = vec![0.0, 1.0, 1.0, 0.0];
let solver = DenseLuBacksolver::from_dense(2, &a).expect("factor");
let b = [2.0, 1.0];
let mut x = [0.0; 2];
assert!(solver.solve(&b, &mut x));
assert!((x[0] - 1.0).abs() < 1e-12, "x[0] = {}", x[0]);
assert!((x[1] - 2.0).abs() < 1e-12, "x[1] = {}", x[1]);
}
#[test]
fn dense_lu_rejects_singular_matrix() {
let a = vec![1.0, 2.0, 2.0, 4.0];
assert!(DenseLuBacksolver::from_dense(2, &a).is_err());
}
#[test]
fn solve_rejects_wrong_dim() {
let a = vec![1.0, 0.0, 0.0, 1.0];
let s = DenseLuBacksolver::from_dense(2, &a).expect("ok");
let b = [1.0];
let mut x = [0.0; 2];
assert!(!s.solve(&b, &mut x));
}
}