use super::sparse_factor::{FtEta, FtOp, SparseLu};
use super::sparse_symbolic::SparseLuSymbolic;
use crate::error::FeralError;
use crate::lu::sparse_matrix::SparseColMatrix;
impl SparseLu {
#[inline]
pub fn updates_since_refactor(&self) -> usize {
self.etas.len()
}
pub fn update(&mut self, leaving_slot: usize, entering_col: &[f64]) -> Result<(), FeralError> {
if entering_col.len() != self.m {
return Err(FeralError::DimensionMismatch {
expected: self.m,
got: entering_col.len(),
});
}
let sparse: Vec<(usize, f64)> = entering_col
.iter()
.enumerate()
.filter(|&(_, &v)| v != 0.0)
.map(|(i, &v)| (i, v))
.collect();
self.update_sparse(leaving_slot, &sparse)
}
pub fn update_sparse(
&mut self,
leaving_slot: usize,
entering: &[(usize, f64)],
) -> Result<(), FeralError> {
let m = self.m;
if leaving_slot >= m {
return Err(FeralError::InvalidInput(format!(
"leaving_slot {} out of range for basis dimension {}",
leaving_slot, m
)));
}
for &(row, _) in entering.iter() {
if row >= m {
return Err(FeralError::InvalidInput(format!(
"entering-column row {} out of range for dimension {}",
row, m
)));
}
}
if self.updates_since_refactor() + 1 > self.params.max_updates {
return Err(FeralError::NeedsRefactor);
}
let mut w = std::mem::take(&mut self.ft_work); let mut touched: Vec<usize> = Vec::new(); self.compute_spike(entering, leaving_slot, &mut w, &mut touched);
let r = self.qcol_inv[leaving_slot];
let mut supp: Vec<usize> = touched.iter().copied().filter(|&k| w[k] != 0.0).collect();
supp.sort_unstable();
supp.dedup();
let h = match supp.last().copied() {
Some(h) if h >= r => h,
_ => {
clear(&mut w, &touched);
self.ft_work = w;
return Err(FeralError::NeedsRefactor);
}
};
let mut changed: Vec<usize> = (r..=h).collect();
changed.extend(supp.iter().copied());
changed.extend(self.u_above[r].iter().copied());
changed.push(r);
changed.sort_unstable();
changed.dedup();
let saved: Vec<(usize, Vec<(usize, f64)>)> = changed
.iter()
.map(|&i| (i, self.u_rows[i].clone()))
.collect();
self.set_column_r(r, &w, &supp);
let result = self.eliminate_bump(r, h);
clear(&mut w, &touched);
self.ft_work = w;
match result {
Ok(ops) => {
let mut changed_max = 0.0_f64;
for &i in changed.iter() {
for &(_, v) in self.u_rows[i].iter() {
changed_max = changed_max.max(v.abs());
}
}
let growth = self.growth.max(changed_max / self.u_max0);
if growth > self.params.max_growth {
for (i, row) in saved {
self.u_rows[i] = row;
}
return Err(FeralError::NeedsRefactor);
}
for (i, old_row) in saved.iter() {
self.unindex_above(*i, old_row);
let new_row = std::mem::take(&mut self.u_rows[*i]);
self.index_above(*i, &new_row);
self.u_rows[*i] = new_row;
}
self.etas.push(FtEta { ops });
self.growth = growth;
Ok(())
}
Err(e) => {
for (i, row) in saved {
self.u_rows[i] = row;
}
Err(e)
}
}
}
pub fn refactor(
&mut self,
a: &SparseColMatrix,
symbolic: &SparseLuSymbolic,
) -> Result<(), FeralError> {
*self = SparseLu::factor(a, symbolic, self.params.clone())?;
Ok(())
}
fn compute_spike(
&mut self,
entering: &[(usize, f64)],
leaving_slot: usize,
w: &mut [f64],
touched: &mut Vec<usize>,
) {
let dcol = self.scale.d_col[leaving_slot];
let mut mark = std::mem::take(&mut self.scratch_mark);
let mut stack: Vec<usize> = Vec::new();
for &(o, val) in entering.iter() {
let i = self.scale_rperm_inv[o];
let v = self.scale.d_row[i] * val * dcol;
if v == 0.0 {
continue;
}
let k = self.perm_inv[i];
w[k] += v;
if !mark[k] {
mark[k] = true;
touched.push(k);
stack.push(k);
}
}
let mut reach: Vec<usize> = touched.clone();
while let Some(k) = stack.pop() {
let (lo, hi) = (self.l_col_ptr[k], self.l_col_ptr[k + 1]);
for idx in lo..hi {
let i = self.l_row_idx[idx];
if !mark[i] {
mark[i] = true;
touched.push(i);
reach.push(i);
stack.push(i);
}
}
}
reach.sort_unstable();
for &k in reach.iter() {
let yk = w[k];
if yk == 0.0 {
continue;
}
let (lo, hi) = (self.l_col_ptr[k], self.l_col_ptr[k + 1]);
for idx in lo..hi {
w[self.l_row_idx[idx]] -= self.l_val[idx] * yk;
}
}
for eta in self.etas.iter() {
for op in eta.ops.iter() {
match *op {
FtOp::Swap(a, b) => {
w.swap(a, b);
for x in [a, b] {
if w[x] != 0.0 && !mark[x] {
mark[x] = true;
touched.push(x);
}
}
}
FtOp::Axpy { target, src, mult } => {
w[target] -= mult * w[src];
if w[target] != 0.0 && !mark[target] {
mark[target] = true;
touched.push(target);
}
}
}
}
}
for &k in touched.iter() {
mark[k] = false;
}
self.scratch_mark = mark;
}
fn set_column_r(&mut self, r: usize, w: &[f64], supp: &[usize]) {
let old_above = self.u_above[r].clone();
for &i in old_above.iter() {
remove_col(&mut self.u_rows[i], r);
}
remove_col(&mut self.u_rows[r], r); for &i in supp.iter() {
insert_or_set(&mut self.u_rows[i], r, w[i]);
}
}
fn eliminate_bump(&mut self, r: usize, h: usize) -> Result<Vec<FtOp>, FeralError> {
let ztol = self.params.zero_pivot_tol * self.u_max0;
let mut ops: Vec<FtOp> = Vec::new();
let width = h - r + 1;
let mut col_rows: Vec<Vec<usize>> = vec![Vec::new(); width];
for i in r..=h {
for &(c, _) in self.u_rows[i].iter() {
if c < r {
continue;
}
if c > h {
break;
}
col_rows[c - r].push(i);
}
}
for k in r..=h {
let kc = k - r;
let mut pivot_row = k;
let mut pivot_abs = 0.0_f64;
for &i in col_rows[kc].iter() {
if i < k {
continue;
}
if let Some(v) = get_col(&self.u_rows[i], k) {
if v.abs() > pivot_abs {
pivot_abs = v.abs();
pivot_row = i;
}
}
}
if pivot_abs <= ztol {
return Err(FeralError::NeedsRefactor);
}
if pivot_row != k {
self.u_rows.swap(k, pivot_row);
ops.push(FtOp::Swap(k, pivot_row));
swap_membership(&mut col_rows, r, h, k, pivot_row, &self.u_rows);
}
let pivot_data = self.u_rows[k].clone();
let pivot = get_col(&pivot_data, k).unwrap_or(0.0);
let targets: Vec<usize> = col_rows[kc].iter().copied().filter(|&i| i > k).collect();
for i in targets {
if let Some(vik) = get_col(&self.u_rows[i], k) {
let mult = vik / pivot;
let old_row = std::mem::take(&mut self.u_rows[i]);
let new_row = row_sub(&old_row, &pivot_data, mult, k);
reindex_after_rowsub(&mut col_rows, r, h, i, &old_row, &new_row);
self.u_rows[i] = new_row;
ops.push(FtOp::Axpy {
target: i,
src: k,
mult,
});
}
}
}
Ok(ops)
}
fn unindex_above(&mut self, i: usize, old_row: &[(usize, f64)]) {
for &(c, _) in old_row.iter() {
if c > i {
if let Ok(pos) = self.u_above[c].binary_search(&i) {
self.u_above[c].remove(pos);
}
}
}
}
fn index_above(&mut self, i: usize, new_row: &[(usize, f64)]) {
for &(c, _) in new_row.iter() {
if c > i {
if let Err(pos) = self.u_above[c].binary_search(&i) {
self.u_above[c].insert(pos, i);
}
}
}
}
}
fn clear(w: &mut [f64], touched: &[usize]) {
for &k in touched.iter() {
w[k] = 0.0;
}
}
fn set_member(list: &mut Vec<usize>, i: usize, present: bool) {
match list.binary_search(&i) {
Ok(pos) => {
if !present {
list.remove(pos);
}
}
Err(pos) => {
if present {
list.insert(pos, i);
}
}
}
}
fn swap_membership(
col_rows: &mut [Vec<usize>],
r: usize,
h: usize,
k: usize,
p: usize,
u_rows: &[Vec<(usize, f64)>],
) {
for &src in &[k, p] {
for &(c, _) in u_rows[src].iter() {
if c < r {
continue;
}
if c > h {
break;
}
let cc = c - r;
set_member(&mut col_rows[cc], k, get_col(&u_rows[k], c).is_some());
set_member(&mut col_rows[cc], p, get_col(&u_rows[p], c).is_some());
}
}
}
fn reindex_after_rowsub(
col_rows: &mut [Vec<usize>],
r: usize,
h: usize,
i: usize,
old_row: &[(usize, f64)],
new_row: &[(usize, f64)],
) {
for &(c, _) in old_row {
if c < r {
continue;
}
if c > h {
break;
}
if get_col(new_row, c).is_none() {
set_member(&mut col_rows[c - r], i, false);
}
}
for &(c, _) in new_row {
if c < r {
continue;
}
if c > h {
break;
}
set_member(&mut col_rows[c - r], i, true);
}
}
fn get_col(row: &[(usize, f64)], c: usize) -> Option<f64> {
row.binary_search_by_key(&c, |&(col, _)| col)
.ok()
.map(|pos| row[pos].1)
}
fn remove_col(row: &mut Vec<(usize, f64)>, c: usize) {
if let Ok(pos) = row.binary_search_by_key(&c, |&(col, _)| col) {
row.remove(pos);
}
}
fn insert_or_set(row: &mut Vec<(usize, f64)>, c: usize, v: f64) {
match row.binary_search_by_key(&c, |&(col, _)| col) {
Ok(pos) => {
if v != 0.0 {
row[pos].1 = v;
} else {
row.remove(pos);
}
}
Err(pos) => {
if v != 0.0 {
row.insert(pos, (c, v));
}
}
}
}
fn row_sub(
dst: &[(usize, f64)],
src: &[(usize, f64)],
mult: f64,
drop_col: usize,
) -> Vec<(usize, f64)> {
let mut out = Vec::with_capacity(dst.len() + src.len());
let (mut i, mut j) = (0usize, 0usize);
while i < dst.len() && j < src.len() {
let (dc, dv) = dst[i];
let (sc, sv) = src[j];
if dc < sc {
if dc != drop_col {
out.push((dc, dv));
}
i += 1;
} else if dc > sc {
let v = -mult * sv;
if sc != drop_col && v != 0.0 {
out.push((sc, v));
}
j += 1;
} else {
let v = dv - mult * sv;
if dc != drop_col && v != 0.0 {
out.push((dc, v));
}
i += 1;
j += 1;
}
}
while i < dst.len() {
let (dc, dv) = dst[i];
if dc != drop_col {
out.push((dc, dv));
}
i += 1;
}
while j < src.len() {
let (sc, sv) = src[j];
let v = -mult * sv;
if sc != drop_col && v != 0.0 {
out.push((sc, v));
}
j += 1;
}
out
}
#[cfg(test)]
mod tests {
use crate::lu::sparse_factor::SparseLu;
use crate::lu::LuParams;
#[test]
fn growth_monitor_tracks_compounded_element_growth() {
let cols = vec![
vec![4.0, 1.0, 0.0, 0.0],
vec![1.0, 3.0, 1.0, 0.0],
vec![0.0, 1.0, 2.0, 1.0],
vec![0.0, 0.0, 1.0, 5.0],
];
let m = 4;
let params = LuParams {
max_updates: 20,
max_growth: 1e12,
..LuParams::default()
};
let mut lu = SparseLu::factor_dense_columns(m, &cols, params).expect("factor");
let umax = |lu: &SparseLu| {
let mut mx = 0.0_f64;
for i in 0..m {
for j in 0..m {
mx = mx.max(lu.u_dense(i, j).abs());
}
}
mx
};
let u_max0 = umax(&lu);
let mut hw = 1.0_f64;
let updates = [
(3usize, vec![0.0, 0.0, 1.0, 20.0]),
(3usize, vec![0.0, 0.0, 1.0, 60.0]),
(3usize, vec![0.0, 0.0, 1.0, 180.0]),
];
for (i, (slot, col)) in updates.iter().enumerate() {
lu.update(*slot, col)
.unwrap_or_else(|e| panic!("update {i} should commit: {e:?}"));
hw = hw.max(umax(&lu) / u_max0);
assert!(
(lu.growth - hw).abs() <= 1e-9 * hw,
"growth monitor {} must equal element-growth high-water {}",
lu.growth,
hw
);
}
assert!(hw > 1.0, "test must exercise genuine element growth");
}
}