use super::scaling::{compute_lu_scale, LuScale};
use super::sparse_symbolic::SparseLuSymbolic;
use super::{LuParams, LuScaling, LuSingularAction};
use crate::error::FeralError;
use crate::lu::sparse_matrix::SparseColMatrix;
#[derive(Debug, Clone)]
pub(super) enum FtOp {
Swap(usize, usize),
Axpy {
target: usize,
src: usize,
mult: f64,
},
}
#[derive(Debug, Clone)]
pub(super) struct FtEta {
pub ops: Vec<FtOp>,
}
impl FtEta {
pub(super) fn apply_forward(&self, y: &mut [f64]) {
for op in self.ops.iter() {
match *op {
FtOp::Swap(a, b) => y.swap(a, b),
FtOp::Axpy { target, src, mult } => y[target] -= mult * y[src],
}
}
}
pub(super) fn apply_transpose(&self, y: &mut [f64]) {
for op in self.ops.iter().rev() {
match *op {
FtOp::Swap(a, b) => y.swap(a, b),
FtOp::Axpy { target, src, mult } => y[src] -= mult * y[target],
}
}
}
}
#[derive(Debug, Clone)]
pub struct SparseLu {
pub(super) m: usize,
pub(super) l_col_ptr: Vec<usize>,
pub(super) l_row_idx: Vec<usize>,
pub(super) l_val: Vec<f64>,
pub(super) u_rows: Vec<Vec<(usize, f64)>>,
pub(super) perm: Vec<usize>,
pub(super) perm_inv: Vec<usize>,
pub(super) qcol: Vec<usize>,
pub(super) qcol_inv: Vec<usize>,
pub(super) u_above: Vec<Vec<usize>>,
pub(super) etas: Vec<FtEta>,
pub(super) growth: f64,
pub(super) u_max0: f64,
pub(super) reach_visits: usize,
pub(super) params: LuParams,
pub(super) scale: LuScale,
pub(super) scale_rperm_inv: Vec<usize>,
pub(super) scratch: Vec<f64>,
pub(super) scratch_mark: Vec<bool>,
pub(super) ft_work: Vec<f64>,
pub(super) scratch_b: Vec<f64>,
pub(super) scratch_c: Vec<f64>,
pub(super) scratch_d: Vec<f64>,
pub(super) pivot_scratch: Vec<(usize, f64)>,
pub(super) targets_scratch: Vec<usize>,
pub(super) row_pool: Vec<Vec<(usize, f64)>>,
pub(super) col_rows_pool: Vec<Vec<usize>>,
pub(super) saved_scratch: Vec<(usize, Vec<(usize, f64)>)>,
pub(super) saved_pool: Vec<Vec<(usize, f64)>>,
}
impl SparseLu {
pub fn factor(
a: &SparseColMatrix,
symbolic: &SparseLuSymbolic,
params: LuParams,
) -> Result<Self, FeralError> {
params.validate()?;
let m = a.m;
if symbolic.m != m {
return Err(FeralError::DimensionMismatch {
expected: m,
got: symbolic.m,
});
}
let (scale, scaled) = if params.scaling == LuScaling::None {
(LuScale::identity(m), None)
} else {
let scale = compute_lu_scale(a, params.scaling)?;
let mat = scale.apply_sparse(a)?;
(scale, Some(mat))
};
let a: &SparseColMatrix = scaled.as_ref().unwrap_or(a);
let qcol = symbolic.qcol.clone();
let qcol_inv = symbolic.qcol_inv.clone();
let mut w = vec![0.0_f64; m];
let mut mark = vec![false; m];
let mut touched: Vec<usize> = Vec::new();
let mut pinv: Vec<isize> = vec![-1; m];
let mut perm = vec![0usize; m];
let mut l_col_ptr = Vec::with_capacity(m + 1);
let mut l_row_idx: Vec<usize> = Vec::new(); let mut l_val: Vec<f64> = Vec::new();
l_col_ptr.push(0);
let mut u_col_ptr = Vec::with_capacity(m + 1);
let mut u_row_idx: Vec<usize> = Vec::new();
let mut u_val: Vec<f64> = Vec::new();
u_col_ptr.push(0);
let mut u_diag = vec![0.0_f64; m];
let mut reach_mark = vec![false; m];
let mut reach: Vec<usize> = Vec::new();
let mut dfs_stack: Vec<usize> = Vec::new();
let utol = params.pivot_threshold;
let a_max = a.values.iter().fold(0.0_f64, |acc, &x| acc.max(x.abs()));
let ztol = params.zero_pivot_tol * a_max;
let mut reach_visits = 0usize;
for k in 0..m {
let (rows, vals) = a.column(qcol[k]);
for (&i, &v) in rows.iter().zip(vals.iter()) {
w[i] = v;
if !mark[i] {
mark[i] = true;
touched.push(i);
}
}
reach.clear();
for &i in rows.iter() {
let pp = pinv[i];
if pp >= 0 && !reach_mark[pp as usize] {
reach_mark[pp as usize] = true;
reach.push(pp as usize);
dfs_stack.push(pp as usize);
}
}
while let Some(p) = dfs_stack.pop() {
let (ls, le) = (l_col_ptr[p], l_col_ptr[p + 1]);
for idx in ls..le {
let pp = pinv[l_row_idx[idx]];
if pp >= 0 && !reach_mark[pp as usize] {
reach_mark[pp as usize] = true;
reach.push(pp as usize);
dfs_stack.push(pp as usize);
}
}
}
reach.sort_unstable();
reach_visits += reach.len();
let mut u_entries: Vec<(usize, f64)> = Vec::with_capacity(reach.len());
for &p in reach.iter() {
reach_mark[p] = false; let r = perm[p];
let xp = w[r];
if xp == 0.0 {
continue;
}
u_entries.push((p, xp));
let (ls, le) = (l_col_ptr[p], l_col_ptr[p + 1]);
for idx in ls..le {
let i = l_row_idx[idx];
w[i] -= xp * l_val[idx];
if !mark[i] {
mark[i] = true;
touched.push(i);
}
}
}
let mut amax = 0.0_f64;
let mut ipiv: isize = -1;
for &i in touched.iter() {
if pinv[i] < 0 {
let av = w[i].abs();
if av > amax {
amax = av;
ipiv = i as isize;
}
}
}
let mut piv;
let pivot_row: usize;
if amax <= ztol {
match params.on_singular {
LuSingularAction::Fail => {
return Err(FeralError::SingularBasis { column: qcol[k] });
}
LuSingularAction::PerturbToEps { abs_floor } => {
let r = if ipiv >= 0 {
ipiv as usize
} else {
(0..m)
.find(|&i| pinv[i] < 0)
.ok_or(FeralError::SingularBasis { column: qcol[k] })?
};
pivot_row = r;
let s = if w[r] < 0.0 { -1.0 } else { 1.0 };
piv = s * abs_floor.max(w[r].abs());
}
}
} else {
let diag = qcol[k];
pivot_row =
if pinv[diag] < 0 && w[diag].abs() >= utol * amax && w[diag].abs() > ztol {
diag
} else {
ipiv as usize
};
piv = w[pivot_row];
if piv.abs() <= ztol {
match params.on_singular {
LuSingularAction::Fail => {
return Err(FeralError::SingularBasis { column: qcol[k] });
}
LuSingularAction::PerturbToEps { abs_floor } => {
let s = if piv < 0.0 { -1.0 } else { 1.0 };
piv = s * abs_floor.max(piv.abs());
}
}
}
}
pinv[pivot_row] = k as isize;
perm[k] = pivot_row;
u_diag[k] = piv;
for (p, v) in u_entries.into_iter() {
u_row_idx.push(p);
u_val.push(v);
}
u_col_ptr.push(u_row_idx.len());
let inv = 1.0 / piv;
for &i in touched.iter() {
if pinv[i] < 0 && w[i] != 0.0 {
l_row_idx.push(i); l_val.push(w[i] * inv);
}
}
l_col_ptr.push(l_row_idx.len());
for &i in touched.iter() {
w[i] = 0.0;
mark[i] = false;
}
touched.clear();
}
let perm_inv: Vec<usize> = pinv.iter().map(|&p| p as usize).collect();
remap_and_sort_l(&l_col_ptr, &mut l_row_idx, &mut l_val, &perm_inv, m);
let (u_row_ptr, u_col_idx, u_val) = transpose_to_csr(&u_col_ptr, &u_row_idx, &u_val, m);
let mut u_rows: Vec<Vec<(usize, f64)>> = Vec::with_capacity(m);
for i in 0..m {
let mut row = Vec::with_capacity(1 + (u_row_ptr[i + 1] - u_row_ptr[i]));
row.push((i, u_diag[i])); for idx in u_row_ptr[i]..u_row_ptr[i + 1] {
row.push((u_col_idx[idx], u_val[idx]));
}
u_rows.push(row);
}
let mut u_above: Vec<Vec<usize>> = vec![Vec::new(); m];
for (i, row) in u_rows.iter().enumerate() {
for &(c, _) in row.iter() {
if c > i {
u_above[c].push(i);
}
}
}
let mut scale_rperm_inv = vec![0usize; m];
for (i, &o) in scale.rperm.iter().enumerate() {
scale_rperm_inv[o] = i;
}
let mut u_max0 = 0.0_f64;
for row in u_rows.iter() {
for &(_, v) in row.iter() {
u_max0 = u_max0.max(v.abs());
}
}
let u_max0 = u_max0.max(f64::MIN_POSITIVE);
Ok(SparseLu {
m,
l_col_ptr,
l_row_idx,
l_val,
u_rows,
perm,
perm_inv,
qcol,
qcol_inv,
u_above,
etas: Vec::new(),
growth: 1.0,
u_max0,
reach_visits,
params,
scale,
scale_rperm_inv,
scratch: vec![0.0; m],
scratch_mark: vec![false; m],
ft_work: vec![0.0; m],
scratch_b: vec![0.0; m],
scratch_c: vec![0.0; m],
scratch_d: vec![0.0; m],
pivot_scratch: Vec::new(),
targets_scratch: Vec::new(),
row_pool: Vec::new(),
col_rows_pool: Vec::new(),
saved_scratch: Vec::new(),
saved_pool: Vec::new(),
})
}
pub fn factor_dense_columns(
m: usize,
cols: &[Vec<f64>],
params: LuParams,
) -> Result<Self, FeralError> {
let a = SparseColMatrix::from_dense_columns(m, cols)?;
let symbolic = SparseLuSymbolic::analyze(&a)?;
SparseLu::factor(&a, &symbolic, params)
}
#[inline]
pub fn dim(&self) -> usize {
self.m
}
#[inline]
pub fn perm(&self) -> &[usize] {
&self.perm
}
#[inline]
pub fn qcol(&self) -> &[usize] {
&self.qcol
}
pub fn factor_nnz(&self) -> usize {
self.l_val.len() + self.u_rows.iter().map(|r| r.len()).sum::<usize>()
}
pub fn eta_ops(&self) -> usize {
self.etas.iter().map(|e| e.ops.len()).sum()
}
pub fn last_eta_ops(&self) -> usize {
self.etas.last().map(|e| e.ops.len()).unwrap_or(0)
}
pub fn reach_visits(&self) -> usize {
self.reach_visits
}
pub fn l_dense(&self, i: usize, j: usize) -> f64 {
if i == j {
return 1.0;
}
let (s, e) = (self.l_col_ptr[j], self.l_col_ptr[j + 1]);
for idx in s..e {
if self.l_row_idx[idx] == i {
return self.l_val[idx];
}
}
0.0
}
pub fn u_dense(&self, i: usize, j: usize) -> f64 {
if i > j {
return 0.0;
}
for &(c, v) in self.u_rows[i].iter() {
if c == j {
return v;
}
}
0.0
}
}
fn transpose_to_csr(
col_ptr: &[usize],
row_idx: &[usize],
val: &[f64],
m: usize,
) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
let nnz = val.len();
let mut row_cnt = vec![0usize; m];
for &r in row_idx.iter() {
row_cnt[r] += 1;
}
let mut row_ptr = vec![0usize; m + 1];
for i in 0..m {
row_ptr[i + 1] = row_ptr[i] + row_cnt[i];
}
let mut col_idx = vec![0usize; nnz];
let mut out_val = vec![0.0; nnz];
let mut next: Vec<usize> = row_ptr[..m].to_vec();
for k in 0..m {
for idx in col_ptr[k]..col_ptr[k + 1] {
let r = row_idx[idx];
let dst = next[r];
next[r] += 1;
col_idx[dst] = k;
out_val[dst] = val[idx];
}
}
(row_ptr, col_idx, out_val)
}
fn remap_and_sort_l(
col_ptr: &[usize],
row_idx: &mut [usize],
val: &mut [f64],
perm_inv: &[usize],
m: usize,
) {
for r in row_idx.iter_mut() {
*r = perm_inv[*r];
}
let mut order: Vec<usize> = Vec::new();
for j in 0..m {
let (s, e) = (col_ptr[j], col_ptr[j + 1]);
order.clear();
order.extend(s..e);
order.sort_by_key(|&idx| row_idx[idx]);
let rows: Vec<usize> = order.iter().map(|&idx| row_idx[idx]).collect();
let vals: Vec<f64> = order.iter().map(|&idx| val[idx]).collect();
row_idx[s..e].copy_from_slice(&rows);
val[s..e].copy_from_slice(&vals);
}
}