#![allow(clippy::needless_range_loop)]
use std::error::Error;
const N: usize = 3;
pub(crate) fn decompose(
a: &[[f64; 3]; 3],
tol: f64,
) -> Result<([[f64; 3]; 3], [usize; 4]), DegenerateMatrixError> {
let mut a = *a;
let mut p = [0, 1, 2, 3];
for i in 0..N {
let mut max_a = 0.0;
let mut i_max = i;
for k in i..N {
let abs_a = a[k][i].abs();
if abs_a > max_a {
max_a = abs_a;
i_max = k;
}
}
if max_a < tol {
return Err(DegenerateMatrixError {});
}
if i_max != i {
p.swap(i, i_max);
a.swap(i, i_max);
p[N] += 1;
}
for j in (i + 1)..N {
a[j][i] /= a[i][i];
for k in (i + 1)..N {
a[j][k] -= a[j][i] * a[i][k];
}
}
}
Ok((a, p))
}
#[derive(Debug)]
pub struct DegenerateMatrixError {}
impl std::fmt::Display for DegenerateMatrixError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Decomposition failed, matrix is degenerate.")
}
}
impl Error for DegenerateMatrixError {}
pub(crate) fn solve(a: [[f64; 3]; 3], p: [usize; 4], b: [f64; 3]) -> [f64; 3] {
let mut x = [0.0; 3];
for i in 0..N {
x[i] = b[p[i]];
for k in 0..i {
x[i] -= a[i][k] * x[k];
}
}
for i in (0..=(N - 1)).rev() {
for k in (i + 1)..N {
x[i] -= a[i][k] * x[k];
}
x[i] /= a[i][i];
}
x
}