use super::dense_factor::DenseLu;
use crate::error::FeralError;
impl DenseLu {
pub fn update(&mut self, leaving_slot: usize, entering_col: &[f64]) -> Result<(), FeralError> {
let m = self.m;
if entering_col.len() != m {
return Err(FeralError::DimensionMismatch {
expected: m,
got: entering_col.len(),
});
}
if leaving_slot >= m {
return Err(FeralError::InvalidInput(format!(
"leaving_slot {} out of range for basis dimension {}",
leaving_slot, m
)));
}
if self.updates_since_refactor + 1 > self.params.max_updates {
return Err(FeralError::NeedsRefactor);
}
let mut spike = vec![0.0; m];
for (i, si) in spike.iter_mut().enumerate() {
*si = self.scale.d_row[i]
* entering_col[self.scale.rperm[i]]
* self.scale.d_col[leaving_slot];
}
self.ftran_partial(&mut spike)?;
let q = self.qcol_inv[leaving_slot];
let ztol = self.params.zero_pivot_tol * self.u_max0;
let max_growth = self.params.max_growth;
let mut u = self.u.clone();
let mut l = self.l.clone();
let mut qcol = self.qcol.clone();
for i in 0..m {
u[i + q * m] = spike[i];
}
cyclic_shift_columns(&mut u, m, q);
let leaving = qcol[q];
for j in q..m - 1 {
qcol[j] = qcol[j + 1];
}
qcol[m - 1] = leaving;
for k in q..m.saturating_sub(1) {
let piv = u[k + k * m];
if piv.abs() <= ztol {
return Err(FeralError::NeedsRefactor);
}
let sub = u[k + 1 + k * m];
if sub == 0.0 {
continue;
}
let mult = sub / piv;
for j in k..m {
u[k + 1 + j * m] -= mult * u[k + j * m];
}
u[k + 1 + k * m] = 0.0; for i in 0..m {
l[i + k * m] += mult * l[i + (k + 1) * m];
}
}
let umax = u.iter().fold(0.0_f64, |a, &x| a.max(x.abs()));
let growth = self.growth.max(umax / self.u_max0);
if growth > max_growth {
return Err(FeralError::NeedsRefactor);
}
let last = m - 1;
if u[last + last * m].abs() <= ztol {
return Err(FeralError::NeedsRefactor);
}
self.u = u;
self.l = l;
self.qcol = qcol;
for (k, &slot) in self.qcol.iter().enumerate() {
self.qcol_inv[slot] = k;
}
self.growth = growth;
self.updates_since_refactor += 1;
Ok(())
}
}
fn cyclic_shift_columns(buf: &mut [f64], m: usize, q: usize) {
if q + 1 >= m {
return;
}
let mut saved = vec![0.0; m];
saved.copy_from_slice(&buf[q * m..q * m + m]);
for j in q..m - 1 {
let (dst, src) = (j * m, (j + 1) * m);
buf.copy_within(src..src + m, dst);
}
let last = (m - 1) * m;
buf[last..last + m].copy_from_slice(&saved);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lu::dense_factor::DenseLu;
use crate::lu::LuParams;
#[test]
fn update_singular_last_pivot_does_not_commit() {
let cols = vec![vec![1.0, 0.0], vec![0.0, 1.0]]; let mut lu = DenseLu::factor(&cols, 2, LuParams::default()).expect("factor");
let err = lu.update(1, &[1.0, 0.0]);
assert!(
matches!(err, Err(FeralError::NeedsRefactor)),
"singular replacement basis must be rejected, got {err:?}"
);
let mut rhs = vec![1.0, 1.0];
lu.ftran(&mut rhs).expect("ftran after rejected update");
assert!(rhs.iter().all(|x| x.is_finite()));
}
#[test]
fn growth_monitor_tracks_compounded_element_growth() {
let cols = vec![
vec![4.0, 1.0, 0.0, 0.0],
vec![1.0, 3.0, 1.0, 0.0],
vec![0.0, 1.0, 2.0, 1.0],
vec![0.0, 0.0, 1.0, 5.0],
];
let m = 4;
let params = LuParams {
max_updates: 20,
max_growth: 1e12, ..LuParams::default()
};
let mut lu = DenseLu::factor(&cols, m, params).expect("factor");
let umax = |lu: &DenseLu| {
let mut mx = 0.0_f64;
for j in 0..m {
for i in 0..m {
mx = mx.max(lu.u(i, j).abs());
}
}
mx
};
let u_max0 = umax(&lu);
let mut hw = 1.0_f64;
let updates = [
(3usize, vec![0.0, 0.0, 1.0, 20.0]),
(3usize, vec![0.0, 0.0, 1.0, 60.0]),
(3usize, vec![0.0, 0.0, 1.0, 180.0]),
];
for (i, (slot, col)) in updates.iter().enumerate() {
lu.update(*slot, col)
.unwrap_or_else(|e| panic!("update {i} should commit: {e:?}"));
hw = hw.max(umax(&lu) / u_max0);
assert!(
(lu.growth - hw).abs() <= 1e-9 * hw,
"growth monitor {} must equal element-growth high-water {}",
lu.growth,
hw
);
}
assert!(hw > 1.0, "test must exercise genuine element growth");
}
}