use super::scaling::{compute_lu_scale, LuScale};
use super::sparse_matrix::SparseColMatrix;
use super::{LuParams, LuScaling, LuSingularAction};
use crate::error::FeralError;
#[derive(Debug, Clone)]
pub struct DenseLu {
pub(super) m: usize,
pub(super) l: Vec<f64>,
pub(super) u: Vec<f64>,
pub(super) perm: Vec<usize>,
pub(super) perm_inv: Vec<usize>,
pub(super) qcol: Vec<usize>,
pub(super) qcol_inv: Vec<usize>,
pub(super) updates_since_refactor: usize,
pub(super) growth: f64,
pub(super) u_max0: f64,
pub(super) params: LuParams,
pub(super) scale: LuScale,
pub(super) scratch_a: Vec<f64>,
pub(super) scratch_b: Vec<f64>,
pub(super) scratch_c: Vec<f64>,
pub(super) scratch_d: Vec<f64>,
}
impl DenseLu {
pub fn factor(cols: &[Vec<f64>], m: usize, params: LuParams) -> Result<Self, FeralError> {
params.validate()?;
let (scale, scaled) = compute_scale(cols, m, params.scaling)?;
let factor_cols: &[Vec<f64>] = scaled.as_deref().unwrap_or(cols);
let mut packed = vec![0.0; m * m];
copy_columns_into(&mut packed, factor_cols, m)?;
let mut perm: Vec<usize> = (0..m).collect();
factorize_packed(&mut packed, &mut perm, m, ¶ms)?;
let (l, u) = split_packed(&packed, m);
let u_max0 = umax(&u);
let mut perm_inv = vec![0usize; m];
for (k, &p) in perm.iter().enumerate() {
perm_inv[p] = k;
}
Ok(DenseLu {
m,
l,
u,
perm,
perm_inv,
qcol: (0..m).collect(),
qcol_inv: (0..m).collect(),
updates_since_refactor: 0,
growth: 1.0,
u_max0,
params,
scale,
scratch_a: vec![0.0; m],
scratch_b: vec![0.0; m],
scratch_c: vec![0.0; m],
scratch_d: vec![0.0; m],
})
}
pub fn refactor(&mut self, cols: &[Vec<f64>]) -> Result<(), FeralError> {
self.params.validate()?;
let m = self.m;
let (scale, scaled) = compute_scale(cols, m, self.params.scaling)?;
self.scale = scale;
let factor_cols: &[Vec<f64>] = scaled.as_deref().unwrap_or(cols);
let mut packed = vec![0.0; m * m];
copy_columns_into(&mut packed, factor_cols, m)?;
for (k, p) in self.perm.iter_mut().enumerate() {
*p = k;
}
factorize_packed(&mut packed, &mut self.perm, m, &self.params)?;
let (l, u) = split_packed(&packed, m);
self.u_max0 = umax(&u);
self.l = l;
self.u = u;
for (k, &p) in self.perm.iter().enumerate() {
self.perm_inv[p] = k;
}
for (k, (q, qi)) in self
.qcol
.iter_mut()
.zip(self.qcol_inv.iter_mut())
.enumerate()
{
*q = k;
*qi = k;
}
self.updates_since_refactor = 0;
self.growth = 1.0;
Ok(())
}
#[inline]
pub fn dim(&self) -> usize {
self.m
}
#[inline]
pub fn updates_since_refactor(&self) -> usize {
self.updates_since_refactor
}
#[inline]
pub fn perm(&self) -> &[usize] {
&self.perm
}
#[inline]
pub fn qcol(&self) -> &[usize] {
&self.qcol
}
#[inline]
pub fn l(&self, i: usize, j: usize) -> f64 {
self.l[i + j * self.m]
}
#[inline]
pub fn u(&self, i: usize, j: usize) -> f64 {
self.u[i + j * self.m]
}
}
type ScaleResult = Result<(LuScale, Option<Vec<Vec<f64>>>), FeralError>;
fn compute_scale(cols: &[Vec<f64>], m: usize, strategy: LuScaling) -> ScaleResult {
if strategy == LuScaling::None {
return Ok((LuScale::identity(m), None));
}
let b = SparseColMatrix::from_dense_columns(m, cols)?;
let scale = compute_lu_scale(&b, strategy)?;
let scaled = scale.scaled_dense_columns(&b);
Ok((scale, Some(scaled)))
}
fn copy_columns_into(buf: &mut [f64], cols: &[Vec<f64>], m: usize) -> Result<(), FeralError> {
if cols.len() != m {
return Err(FeralError::DimensionMismatch {
expected: m,
got: cols.len(),
});
}
for (j, col) in cols.iter().enumerate() {
if col.len() != m {
return Err(FeralError::DimensionMismatch {
expected: m,
got: col.len(),
});
}
buf[j * m..j * m + m].copy_from_slice(col);
if col.iter().any(|x| !x.is_finite()) {
return Err(FeralError::InvalidInput(
"LU basis column contains non-finite entries".to_string(),
));
}
}
Ok(())
}
fn split_packed(packed: &[f64], m: usize) -> (Vec<f64>, Vec<f64>) {
let mut l = vec![0.0; m * m];
let mut u = vec![0.0; m * m];
for j in 0..m {
for i in 0..m {
let v = packed[i + j * m];
if i > j {
l[i + j * m] = v;
} else {
u[i + j * m] = v;
}
}
l[j + j * m] = 1.0;
}
(l, u)
}
#[inline]
fn umax(u: &[f64]) -> f64 {
u.iter()
.fold(0.0_f64, |a, &x| a.max(x.abs()))
.max(f64::MIN_POSITIVE)
}
fn factorize_packed(
packed: &mut [f64],
perm: &mut [usize],
m: usize,
params: &LuParams,
) -> Result<(), FeralError> {
let u = params.pivot_threshold;
let a_max = packed.iter().fold(0.0_f64, |a, &x| a.max(x.abs()));
let ztol = params.zero_pivot_tol * a_max;
for k in 0..m {
let mut amax = 0.0_f64;
let mut argmax = k;
for i in k..m {
let v = packed[i + k * m].abs();
if v > amax {
amax = v;
argmax = i;
}
}
let diag = packed[k + k * m].abs();
let pivot_row = if diag >= u * amax && diag > ztol {
k
} else {
argmax
};
if pivot_row != k {
swap_rows(packed, k, pivot_row, m);
perm.swap(k, pivot_row);
}
let mut piv = packed[k + k * m];
if piv.abs() <= ztol {
match params.on_singular {
LuSingularAction::Fail => {
return Err(FeralError::SingularBasis { column: k });
}
LuSingularAction::PerturbToEps { abs_floor } => {
let s = if piv < 0.0 { -1.0 } else { 1.0 };
piv = s * abs_floor.max(piv.abs());
packed[k + k * m] = piv;
}
}
}
let inv = 1.0 / piv;
for i in k + 1..m {
packed[i + k * m] *= inv;
}
for j in k + 1..m {
let ukj = packed[k + j * m];
if ukj != 0.0 {
for i in k + 1..m {
packed[i + j * m] -= packed[i + k * m] * ukj;
}
}
}
}
Ok(())
}
fn swap_rows(buf: &mut [f64], a: usize, b: usize, m: usize) {
for c in 0..m {
buf.swap(a + c * m, b + c * m);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cols_from_rows(rows: &[&[f64]]) -> (Vec<Vec<f64>>, usize) {
let m = rows.len();
let mut cols = vec![vec![0.0; m]; m];
for (i, row) in rows.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
cols[j][i] = v;
}
}
(cols, m)
}
fn reconstruction_residual(lu: &DenseLu, rows: &[&[f64]]) -> f64 {
let m = lu.m;
let mut worst = 0.0_f64;
for i in 0..m {
for j in 0..m {
let pb = rows[lu.perm[i]][j];
let mut prod = 0.0;
for k in 0..m {
prod += lu.l(i, k) * lu.u(k, j);
}
worst = worst.max((pb - prod).abs());
}
}
worst
}
#[test]
fn factor_2x2_swap_exact() {
let (cols, m) = cols_from_rows(&[&[0.0, 2.0], &[3.0, 4.0]]);
let lu = DenseLu::factor(&cols, m, LuParams::default()).expect("factor");
assert_eq!(lu.perm, vec![1, 0]);
assert!((lu.u(0, 0) - 3.0).abs() < 1e-14);
assert!((lu.u(0, 1) - 4.0).abs() < 1e-14);
assert!((lu.u(1, 1) - 2.0).abs() < 1e-14);
assert!((lu.l(1, 0) - 0.0).abs() < 1e-14);
}
#[test]
fn factor_3x3_reconstruction() {
let rows: [&[f64]; 3] = [&[2.0, 1.0, 1.0], &[4.0, 3.0, 3.0], &[8.0, 7.0, 9.0]];
let (cols, m) = cols_from_rows(&rows);
let lu = DenseLu::factor(&cols, m, LuParams::default()).expect("factor");
assert!(reconstruction_residual(&lu, &rows) < 1e-12);
}
#[test]
fn factor_random_reconstruction() {
let m = 7;
let mut cols = vec![vec![0.0; m]; m];
let mut state = 0x1234_5678_u64;
for j in 0..m {
for i in 0..m {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
let r = ((state >> 33) as f64) / (1u64 << 31) as f64 - 1.0;
cols[j][i] = r;
}
cols[j][j] += 5.0;
}
let rows_owned: Vec<Vec<f64>> = (0..m)
.map(|i| (0..m).map(|j| cols[j][i]).collect())
.collect();
let rows: Vec<&[f64]> = rows_owned.iter().map(|r| r.as_slice()).collect();
let lu = DenseLu::factor(&cols, m, LuParams::default()).expect("factor");
assert!(reconstruction_residual(&lu, &rows) < 1e-10);
}
#[test]
fn factor_singular_repeated_columns_fails() {
let (cols, m) = cols_from_rows(&[&[1.0, 1.0], &[2.0, 2.0]]);
let err = DenseLu::factor(&cols, m, LuParams::default());
assert!(matches!(err, Err(FeralError::SingularBasis { .. })));
}
#[test]
fn factor_singular_perturb_succeeds() {
let (cols, m) = cols_from_rows(&[&[1.0, 1.0], &[2.0, 2.0]]);
let params = LuParams {
on_singular: LuSingularAction::PerturbToEps { abs_floor: 1e-10 },
..LuParams::default()
};
let lu = DenseLu::factor(&cols, m, params).expect("perturbed factor");
assert!(lu.u(1, 1).abs() >= 1e-10);
}
#[test]
fn factor_tiny_well_conditioned_basis_not_singular() {
let s = 1e-14;
let (cols, m) = cols_from_rows(&[&[s, 0.0], &[0.0, s]]);
let mut lu = DenseLu::factor(&cols, m, LuParams::default())
.expect("tiny but well-conditioned basis must factor");
let mut rhs = vec![s, 2.0 * s];
lu.ftran(&mut rhs).expect("ftran");
assert!((rhs[0] - 1.0).abs() < 1e-6, "x0 = {}", rhs[0]);
assert!((rhs[1] - 2.0).abs() < 1e-6, "x1 = {}", rhs[1]);
}
#[test]
fn refactor_resets_state() {
let (cols, m) = cols_from_rows(&[&[2.0, 1.0], &[1.0, 3.0]]);
let mut lu = DenseLu::factor(&cols, m, LuParams::default()).expect("factor");
let (cols2, _) = cols_from_rows(&[&[4.0, 0.0], &[1.0, 5.0]]);
lu.refactor(&cols2).expect("refactor");
assert_eq!(lu.updates_since_refactor(), 0);
let rows2: [&[f64]; 2] = [&[4.0, 0.0], &[1.0, 5.0]];
assert!(reconstruction_residual(&lu, &rows2) < 1e-12);
}
}