use crate::error::SolverError;
#[derive(Debug)]
pub struct LuFactorization {
pub n: usize,
pub lu: Vec<f32>,
pub pivot: Vec<usize>,
}
#[allow(clippy::cast_precision_loss)]
pub fn lu_factorize(a: &[f32], n: usize) -> Result<LuFactorization, SolverError> {
if a.len() != n * n {
return Err(SolverError::NotSquare {
rows: n,
cols: a.len() / n.max(1),
});
}
let mut lu = a.to_vec();
let mut pivot: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = lu[k * n + k].abs();
let mut max_row = k;
for i in (k + 1)..n {
let val = lu[i * n + k].abs();
if val > max_val {
max_val = val;
max_row = i;
}
}
if max_val < f32::EPSILON * 1e3 {
return Err(SolverError::SingularMatrix(k));
}
if max_row != k {
pivot.swap(k, max_row);
for j in 0..n {
lu.swap(k * n + j, max_row * n + j);
}
}
let pivot_val = lu[k * n + k];
for i in (k + 1)..n {
let factor = lu[i * n + k] / pivot_val;
lu[i * n + k] = factor;
for j in (k + 1)..n {
lu[i * n + j] -= factor * lu[k * n + j];
}
}
}
Ok(LuFactorization { n, lu, pivot })
}
impl LuFactorization {
pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
if b.len() != self.n {
return Err(SolverError::DimensionMismatch {
matrix_n: self.n,
rhs_len: b.len(),
});
}
let n = self.n;
let mut x = vec![0.0f32; n];
for i in 0..n {
x[i] = b[self.pivot[i]];
}
for i in 1..n {
let mut sum = x[i];
for j in 0..i {
sum -= self.lu[i * n + j] * x[j];
}
x[i] = sum;
}
for i in (0..n).rev() {
let mut sum = x[i];
for j in (i + 1)..n {
sum -= self.lu[i * n + j] * x[j];
}
x[i] = sum / self.lu[i * n + i];
}
Ok(x)
}
pub fn extract_l(&self) -> Vec<f32> {
let n = self.n;
let mut l = vec![0.0f32; n * n];
for i in 0..n {
l[i * n + i] = 1.0; for j in 0..i {
l[i * n + j] = self.lu[i * n + j];
}
}
l
}
pub fn extract_u(&self) -> Vec<f32> {
let n = self.n;
let mut u = vec![0.0f32; n * n];
for i in 0..n {
for j in i..n {
u[i * n + j] = self.lu[i * n + j];
}
}
u
}
pub fn extract_p(&self) -> Vec<f32> {
let n = self.n;
let mut p = vec![0.0f32; n * n];
for (i, &pi) in self.pivot.iter().enumerate() {
p[i * n + pi] = 1.0;
}
p
}
}