use std::sync::Arc;
use crate::algebra::scalar::KrystScalar;
use crate::error::KError;
use crate::matrix::convert::csr_from_linop;
use crate::matrix::dist::LocalSquareCsr;
use crate::matrix::format::OpFormat;
use crate::matrix::op::{LinOp, StructureId, ValuesId};
use crate::matrix::sparse::CsrMatrix;
use crate::preconditioner::{
LocalPreconditioner, Op, PcCaps, PcDistributedSupport, PcSide, Preconditioner,
};
use crate::utils::conditioning::ConditioningOptions;
use crate::utils::permutation::Permutation;
#[cfg(feature = "complex")]
use crate::utils::permutation::{
amd_csr, permute_csr_nonsymmetric, permute_csr_symmetric, rcm_csr,
};
use crate::utils::preconditioning_pipeline::{
PreconditioningMetadata, apply_preconditioning_pipeline,
};
#[cfg(feature = "complex")]
use crate::algebra::bridge::BridgeScratch;
#[cfg(feature = "complex")]
use crate::algebra::scalar::S;
#[cfg(feature = "complex")]
use crate::ops::kpc::KPreconditioner;
use once_cell::sync::OnceCell;
type Real = f64;
mod csr_builder;
mod ilut_params;
mod pivot;
mod pos_map;
mod row_work;
mod tri_solve;
pub use ilut_params::{IlutParams, PivotPolicy, Pivoting};
pub use pivot::PivotStrategy;
use csr_builder::CsrBuilder;
use row_work::RowWork;
#[derive(Clone, Debug)]
struct URowMap {
epoch: usize,
mark: Vec<usize>,
pos: Vec<usize>,
}
impl URowMap {
fn new() -> Self {
Self {
epoch: 0,
mark: Vec::new(),
pos: Vec::new(),
}
}
fn ensure_size(&mut self, n: usize) {
if self.mark.len() < n {
self.mark.resize(n, 0);
self.pos.resize(n, 0);
}
}
fn prime(&mut self, u_row: &[usize], u_col: &[usize], i: usize) {
self.epoch = self.epoch.wrapping_add(1);
let rs = u_row[i];
let re = u_row[i + 1];
for (offset, &col) in u_col[rs..re].iter().enumerate() {
self.mark[col] = self.epoch;
self.pos[col] = rs + offset;
}
}
#[inline]
fn get(&self, j: usize) -> Option<usize> {
if self.mark.get(j).copied().unwrap_or(0) == self.epoch {
Some(self.pos[j])
} else {
None
}
}
}
mod symbolic;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum IluKind {
Ilu0,
Milu0,
Iluk { k: usize },
Ilut { params: IlutParams },
}
#[derive(Clone, Debug)]
pub struct IluCsrConfig {
pub kind: IluKind,
pub pivot: PivotStrategy,
pub pivot_threshold: f64,
pub diag_perturb_factor: f64,
pub level_sched: bool,
pub numeric_update_fixed: bool,
pub logging: usize,
pub reordering: ReorderingOptions,
pub conditioning: ConditioningOptions,
}
impl Default for IluCsrConfig {
fn default() -> Self {
Self {
kind: IluKind::Ilu0,
pivot: PivotStrategy::DiagonalPerturbation,
pivot_threshold: 1e-12,
diag_perturb_factor: 1e-10,
level_sched: cfg!(feature = "rayon"),
numeric_update_fixed: true,
logging: 0,
reordering: ReorderingOptions::default(),
conditioning: ConditioningOptions::default(),
}
}
}
#[cfg(feature = "complex")]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum IluComplexKernelMode {
Native,
DegradedRealProjection,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ReorderingKind {
None,
Rcm,
Amd,
}
#[derive(Clone, Debug)]
pub struct ReorderingOptions {
pub kind: ReorderingKind,
pub symmetric: bool,
pub deterministic: bool,
}
impl Default for ReorderingOptions {
fn default() -> Self {
Self {
kind: ReorderingKind::None,
symmetric: true,
deterministic: true,
}
}
}
pub struct IluCsr {
pub(crate) cfg: IluCsrConfig,
last_sid: Option<StructureId>,
last_vid: Option<ValuesId>,
n: usize,
l_row: Vec<usize>,
l_col: Vec<usize>,
l_val: Vec<Real>,
u_row: Vec<usize>,
u_col: Vec<usize>,
u_val: Vec<Real>,
u_diag_ix: Vec<usize>,
l_lev: Vec<usize>,
u_lev: Vec<usize>,
lt: OnceCell<(Vec<usize>, Vec<usize>, Vec<Real>)>,
ut: OnceCell<(Vec<usize>, Vec<usize>, Vec<Real>)>,
levels_fwd: Vec<usize>,
levels_bwd: Vec<usize>,
buckets_fwd: Vec<Vec<usize>>,
buckets_bwd: Vec<Vec<usize>>,
tmp: Vec<Real>,
tmp2: Vec<Real>,
tmp3: Vec<Real>,
perm: Permutation,
pipeline_meta: PreconditioningMetadata,
#[cfg(feature = "complex")]
c_l_val: Vec<S>,
#[cfg(feature = "complex")]
c_u_val: Vec<S>,
#[cfg(feature = "complex")]
c_tmp: Vec<S>,
#[cfg(feature = "complex")]
c_y_tmp: Vec<S>,
#[cfg(feature = "complex")]
c_xr: Vec<Real>,
#[cfg(feature = "complex")]
c_xi: Vec<Real>,
#[cfg(feature = "complex")]
c_yr: Vec<Real>,
#[cfg(feature = "complex")]
c_yi: Vec<Real>,
#[cfg(feature = "complex")]
native_complex_active: bool,
#[cfg(feature = "complex")]
complex_kernel_mode: IluComplexKernelMode,
#[cfg(feature = "complex")]
complex_force_degraded: bool,
}
impl IluCsr {
pub(crate) fn empty() -> Self {
Self {
cfg: IluCsrConfig::default(),
last_sid: None,
last_vid: None,
n: 0,
l_row: Vec::new(),
l_col: Vec::new(),
l_val: Vec::new(),
u_row: Vec::new(),
u_col: Vec::new(),
u_val: Vec::new(),
u_diag_ix: Vec::new(),
l_lev: Vec::new(),
u_lev: Vec::new(),
lt: OnceCell::new(),
ut: OnceCell::new(),
levels_fwd: Vec::new(),
levels_bwd: Vec::new(),
buckets_fwd: Vec::new(),
buckets_bwd: Vec::new(),
tmp: Vec::new(),
tmp2: Vec::new(),
tmp3: Vec::new(),
perm: Permutation::identity(0),
pipeline_meta: PreconditioningMetadata::identity(0),
#[cfg(feature = "complex")]
c_l_val: Vec::new(),
#[cfg(feature = "complex")]
c_u_val: Vec::new(),
#[cfg(feature = "complex")]
c_tmp: Vec::new(),
#[cfg(feature = "complex")]
c_y_tmp: Vec::new(),
#[cfg(feature = "complex")]
c_xr: Vec::new(),
#[cfg(feature = "complex")]
c_xi: Vec::new(),
#[cfg(feature = "complex")]
c_yr: Vec::new(),
#[cfg(feature = "complex")]
c_yi: Vec::new(),
#[cfg(feature = "complex")]
native_complex_active: false,
#[cfg(feature = "complex")]
complex_kernel_mode: IluComplexKernelMode::DegradedRealProjection,
#[cfg(feature = "complex")]
complex_force_degraded: false,
}
}
pub fn new_with_config(cfg: IluCsrConfig) -> Self {
let mut me = Self::empty();
me.cfg = cfg;
me
}
fn clear_levels(&mut self) {
self.levels_fwd.clear();
self.levels_bwd.clear();
self.buckets_fwd.clear();
self.buckets_bwd.clear();
}
fn build_levels_if_enabled(&mut self) {
if !self.cfg.level_sched {
self.clear_levels();
return;
}
let n = self.n;
self.levels_fwd.resize(n, 0);
for i in 0..n {
let mut lv = 0usize;
let rs = self.l_row[i];
let re = self.l_row[i + 1];
for p in rs..re {
let j = self.l_col[p];
lv = lv.max(self.levels_fwd[j] + 1);
}
self.levels_fwd[i] = lv;
}
let max_lv_fwd = self.levels_fwd.iter().copied().max().unwrap_or(0);
self.buckets_fwd.clear();
self.buckets_fwd.resize(max_lv_fwd + 1, Vec::new());
for i in 0..n {
let lv = self.levels_fwd[i];
self.buckets_fwd[lv].push(i);
}
self.levels_bwd.resize(n, 0);
for i in (0..n).rev() {
let mut lv = 0usize;
let rs = self.u_row[i];
let re = self.u_row[i + 1];
for p in rs..re {
let j = self.u_col[p];
if j > i {
lv = lv.max(self.levels_bwd[j] + 1);
}
}
self.levels_bwd[i] = lv;
}
let max_lv_bwd = self.levels_bwd.iter().copied().max().unwrap_or(0);
self.buckets_bwd.clear();
self.buckets_bwd.resize(max_lv_bwd + 1, Vec::new());
for i in (0..n).rev() {
let lv = self.levels_bwd[i];
self.buckets_bwd[lv].push(i);
}
}
fn factor_symbolic_and_numeric(&mut self, a: &CsrMatrix<f64>) -> Result<(), KError> {
match self.cfg.kind {
IluKind::Ilu0 | IluKind::Milu0 => self.factor_ilu0(a),
IluKind::Iluk { k } => self.factor_iluk(a, k),
IluKind::Ilut { params } => self.factor_ilut(a, ¶ms),
}
}
fn factor_numeric_only(&mut self, a: &CsrMatrix<f64>) -> Result<(), KError> {
match self.cfg.kind {
IluKind::Ilu0 | IluKind::Milu0 => self.factor_ilu0_numeric_only(a),
IluKind::Iluk { k } => self.factor_iluk_numeric_only(a, k),
IluKind::Ilut { .. } => self.factor_ilut_numeric_only(a),
}
}
#[cfg(feature = "complex")]
fn factor_ilu0_complex(&mut self, a: &CsrMatrix<S>) -> Result<(), KError> {
let a_zero = CsrMatrix::from_csr(
a.nrows(),
a.ncols(),
a.row_ptr().to_vec(),
a.col_idx().to_vec(),
vec![0.0; a.values().len()],
);
self.factor_ilu0(&a_zero)?;
self.c_l_val.resize(self.l_val.len(), S::zero());
self.c_u_val.resize(self.u_val.len(), S::zero());
let n = a.nrows();
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut map = URowMap::new();
map.ensure_size(n);
let milu = matches!(self.cfg.kind, IluKind::Milu0);
let mut max_diag_abs = 0.0f64;
for i in 0..n {
let mut di = S::zero();
for p in rp[i]..rp[i + 1] {
if cj[p] == i {
di = vv[p];
break;
}
}
max_diag_abs = max_diag_abs.max(di.abs());
}
for i in 0..n {
map.prime(&self.u_row, &self.u_col, i);
for p in rp[i]..rp[i + 1] {
let j = cj[p];
let val = vv[p];
if j < i {
let ls = self.l_row[i];
if let Ok(off) = self.l_col[ls..self.l_row[i + 1]].binary_search(&j) {
self.c_l_val[ls + off] = val;
}
} else if let Some(pos) = map.get(j) {
self.c_u_val[pos] = val;
}
}
for p in self.l_row[i]..self.l_row[i + 1] {
let j = self.l_col[p];
let wij = self.c_l_val[p];
if wij == S::zero() {
continue;
}
let djj = self.c_u_val[self.u_diag_ix[j]];
let mult = wij / djj;
self.c_l_val[p] = mult;
for q in self.u_row[j]..self.u_row[j + 1] {
let k = self.u_col[q];
if k <= j {
continue;
}
let ujk = self.c_u_val[q];
if let Some(pos_ik) = map.get(k) {
self.c_u_val[pos_ik] -= mult * ujk;
} else if milu {
self.c_u_val[self.u_diag_ix[i]] -= mult * ujk;
}
}
}
let di_pos = self.u_diag_ix[i];
let fixed = pivot::handle_pivot_scalar(
self.c_u_val[di_pos],
self.cfg.pivot,
self.cfg.pivot_threshold,
self.cfg.diag_perturb_factor,
max_diag_abs,
)
.map_err(|_| KError::ZeroPivot(i))?;
self.c_u_val[di_pos] = fixed;
}
self.native_complex_active = true;
self.complex_kernel_mode = IluComplexKernelMode::Native;
Ok(())
}
#[cfg(feature = "complex")]
fn factor_iluk_complex(&mut self, a: &CsrMatrix<S>, k: usize) -> Result<(), KError> {
let a_zero = CsrMatrix::from_csr(
a.nrows(),
a.ncols(),
a.row_ptr().to_vec(),
a.col_idx().to_vec(),
vec![0.0; a.values().len()],
);
self.factor_iluk(&a_zero, k)?;
self.c_l_val.resize(self.l_val.len(), S::zero());
self.c_u_val.resize(self.u_val.len(), S::zero());
let n = a.nrows();
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut map = URowMap::new();
map.ensure_size(n);
let mut max_diag_abs = 0.0f64;
for i in 0..n {
let mut di = S::zero();
for p in rp[i]..rp[i + 1] {
if cj[p] == i {
di = vv[p];
break;
}
}
max_diag_abs = max_diag_abs.max(di.abs());
}
for i in 0..n {
map.prime(&self.u_row, &self.u_col, i);
for p in rp[i]..rp[i + 1] {
let j = cj[p];
let val = vv[p];
if j < i {
let ls = self.l_row[i];
if let Ok(off) = self.l_col[ls..self.l_row[i + 1]].binary_search(&j) {
self.c_l_val[ls + off] = val;
}
} else if let Some(pos) = map.get(j) {
self.c_u_val[pos] = val;
}
}
for p in self.l_row[i]..self.l_row[i + 1] {
let j = self.l_col[p];
let wij = self.c_l_val[p];
if wij == S::zero() {
continue;
}
let djj = self.c_u_val[self.u_diag_ix[j]];
let mult = wij / djj;
self.c_l_val[p] = mult;
for q in self.u_row[j]..self.u_row[j + 1] {
let kcol = self.u_col[q];
if kcol <= j {
continue;
}
let ujk = self.c_u_val[q];
if let Some(pos_ik) = map.get(kcol) {
self.c_u_val[pos_ik] -= mult * ujk;
}
}
}
let di_pos = self.u_diag_ix[i];
let fixed = pivot::handle_pivot_scalar(
self.c_u_val[di_pos],
self.cfg.pivot,
self.cfg.pivot_threshold,
self.cfg.diag_perturb_factor,
max_diag_abs,
)
.map_err(|_| KError::ZeroPivot(i))?;
self.c_u_val[di_pos] = fixed;
}
self.native_complex_active = true;
self.complex_kernel_mode = IluComplexKernelMode::Native;
Ok(())
}
#[cfg(feature = "complex")]
fn factor_ilut_complex(&mut self, a: &CsrMatrix<S>, params: &IlutParams) -> Result<(), KError> {
let a_real = CsrMatrix::from_csr(
a.nrows(),
a.ncols(),
a.row_ptr().to_vec(),
a.col_idx().to_vec(),
a.values().iter().map(|v| v.real()).collect(),
);
self.factor_ilut(&a_real, params)?;
self.native_complex_active = false;
self.complex_kernel_mode = if self.complex_force_degraded {
IluComplexKernelMode::DegradedRealProjection
} else {
IluComplexKernelMode::Native
};
Ok(())
}
fn factor_ilu0(&mut self, a: &CsrMatrix<f64>) -> Result<(), KError> {
let n = a.nrows();
if n != a.ncols() {
return Err(KError::InvalidInput("ILU requires square matrix".into()));
}
self.n = n;
self.l_row.clear();
self.l_col.clear();
self.u_row.clear();
self.u_col.clear();
self.u_diag_ix.clear();
self.l_row.resize(n + 1, 0);
self.u_row.resize(n + 1, 0);
self.u_diag_ix.resize(n, 0);
let rp = a.row_ptr();
let cj = a.col_idx();
let mut lcols_row: Vec<usize> = Vec::new();
let mut ucols_row: Vec<usize> = Vec::new();
for i in 0..n {
lcols_row.clear();
ucols_row.clear();
let mut have_diag = false;
for p in rp[i]..rp[i + 1] {
let j = cj[p];
if j < i {
lcols_row.push(j);
} else if j == i {
ucols_row.push(j);
have_diag = true;
} else {
ucols_row.push(j);
}
}
if !have_diag {
ucols_row.push(i);
}
lcols_row.sort_unstable();
ucols_row.sort_unstable();
self.l_row[i + 1] = self.l_row[i] + lcols_row.len();
self.u_row[i + 1] = self.u_row[i] + ucols_row.len();
self.l_col.extend_from_slice(&lcols_row);
let u_start = self.u_col.len();
self.u_col.extend_from_slice(&ucols_row);
let d_rel = ucols_row
.iter()
.position(|&c| c == i)
.expect("diagonal present");
self.u_diag_ix[i] = u_start + d_rel;
}
self.l_val.clear();
self.u_val.clear();
self.l_lev.clear();
self.u_lev.clear();
self.l_val.resize(self.l_col.len(), Real::zero());
self.u_val.resize(self.u_col.len(), Real::zero());
self.l_lev.resize(self.l_col.len(), 0);
self.u_lev.resize(self.u_col.len(), 0);
let milu = matches!(self.cfg.kind, IluKind::Milu0);
self.ilu0_numeric(a, milu)
}
fn factor_ilu0_numeric_only(&mut self, a: &CsrMatrix<f64>) -> Result<(), KError> {
if self.n == 0 {
return self.factor_ilu0(a);
}
if self.n != a.nrows() || a.nrows() != a.ncols() {
return Err(KError::InvalidInput(
"ILU0 numeric update: size/shape mismatch".into(),
));
}
let milu = matches!(self.cfg.kind, IluKind::Milu0);
self.ilu0_numeric(a, milu)
}
fn ilu0_numeric(&mut self, a: &CsrMatrix<f64>, milu: bool) -> Result<(), KError> {
let n = self.n;
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut map = URowMap::new();
map.ensure_size(n);
let mut max_diag_abs = 0.0f64;
for i in 0..n {
let mut di = 0.0;
for p in rp[i]..rp[i + 1] {
if cj[p] == i {
di = vv[p];
break;
}
}
max_diag_abs = max_diag_abs.max(di.abs());
}
for i in 0..n {
map.prime(&self.u_row, &self.u_col, i);
let mut p = rp[i];
while p < rp[i + 1] {
let j = cj[p];
let val = vv[p];
if j < i {
let ls = self.l_row[i];
if let Ok(off) = self.l_col[ls..self.l_row[i + 1]].binary_search(&j) {
self.l_val[ls + off] = Real::from_real(val);
}
} else {
if let Some(pos) = map.get(j) {
self.u_val[pos] = Real::from_real(val);
}
}
p += 1;
}
let ls = self.l_row[i];
let le = self.l_row[i + 1];
for pos in ls..le {
let k = self.l_col[pos];
let ukk = self.u_val[self.u_diag_ix[k]];
if ukk == Real::zero() {
return Err(KError::FactorError(format!(
"zero U(j,j) encountered at row {k}"
)));
}
let mult = self.l_val[pos] / ukk;
self.l_val[pos] = mult;
let urs = self.u_row[k];
let ure = self.u_row[k + 1];
for q in urs..ure {
let j = self.u_col[q];
if j <= k {
continue;
}
let val_q = self.u_val[q];
if let Some(pos_ij) = map.get(j) {
self.u_val[pos_ij] -= mult * val_q;
} else if milu {
let di_pos = self.u_diag_ix[i];
self.u_val[di_pos] -= mult * val_q;
}
}
}
let di_pos = self.u_diag_ix[i];
let fixed = pivot::handle_pivot(
self.u_val[di_pos],
self.cfg.pivot,
self.cfg.pivot_threshold,
self.cfg.diag_perturb_factor,
max_diag_abs,
)
.map_err(|_| KError::ZeroPivot(i))?;
self.u_val[di_pos] = fixed;
}
Ok(())
}
fn factor_iluk(&mut self, a: &CsrMatrix<f64>, k_limit: usize) -> Result<(), KError> {
let n = a.nrows();
if n != a.ncols() {
return Err(KError::InvalidInput("ILUK requires square matrix".into()));
}
self.n = n;
self.l_row.clear();
self.u_row.clear();
self.l_col.clear();
self.u_col.clear();
self.l_val.clear();
self.u_val.clear();
self.l_lev.clear();
self.u_lev.clear();
self.u_diag_ix.clear();
self.l_row.resize(n + 1, 0);
self.u_row.resize(n + 1, 0);
self.u_diag_ix.resize(n, 0);
use symbolic::RowWork;
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut w = RowWork {
mark: Vec::new(),
idx: Vec::new(),
val: Vec::new(),
};
let mut wlev: Vec<usize> = Vec::new();
symbolic::ensure_rowwork(&mut w, n);
let mut max_diag_abs = 0.0f64;
for i in 0..n {
let mut di = 0.0;
for p in rp[i]..rp[i + 1] {
if cj[p] == i {
di = vv[p];
break;
}
}
max_diag_abs = max_diag_abs.max(di.abs());
}
for i in 0..n {
symbolic::ensure_rowwork(&mut w, n);
wlev.clear();
for p in rp[i]..rp[i + 1] {
let j = cj[p];
let pos = symbolic::find_or_insert(&mut w, j);
if pos == wlev.len() {
wlev.push(0);
} else {
wlev[pos] = 0;
}
w.val[pos] = Real::from_real(vv[p]);
}
let mut lowers: Vec<(usize, usize)> = w
.idx
.iter()
.enumerate()
.filter_map(|(pos, &col)| if col < i { Some((col, pos)) } else { None })
.collect();
lowers.sort_by_key(|x| x.0);
for &(j, pos) in &lowers {
let lij_level = wlev[pos];
if lij_level > k_limit {
continue;
}
let wij = w.val[pos];
if wij == Real::zero() {
continue;
}
let djj = {
let dix = self.u_diag_ix.get(j).copied().unwrap_or(0);
if j < i && self.u_val.get(dix).copied().unwrap_or(Real::zero()) == Real::zero()
{
}
if j < i {
self.u_val[self.u_diag_ix[j]]
} else {
Real::one()
}
};
let lij = wij / djj;
let urs = self.u_row.get(j).copied().unwrap_or(0);
let ure = self.u_row.get(j + 1).copied().unwrap_or(0);
for q in urs..ure {
let kcol = self.u_col[q];
if kcol <= j {
continue;
}
let new_level = lij_level + self.u_lev[q] + 1;
if new_level > k_limit {
continue;
}
let kpos = symbolic::find_or_insert(&mut w, kcol);
if kpos == wlev.len() {
wlev.push(new_level);
} else if new_level < wlev[kpos] {
wlev[kpos] = new_level;
}
w.val[kpos] -= lij * self.u_val[q];
}
}
let mut l_pairs: Vec<(usize, Real, usize)> = w
.idx
.iter()
.enumerate()
.filter_map(|(pos, &col)| {
if col < i && wlev[pos] <= k_limit {
Some((col, w.val[pos], wlev[pos]))
} else {
None
}
})
.collect();
l_pairs.sort_by_key(|x| x.0);
let mut u_pairs: Vec<(usize, Real, usize)> = w
.idx
.iter()
.enumerate()
.filter_map(|(pos, &col)| {
if col >= i && wlev[pos] <= k_limit {
Some((col, w.val[pos], wlev[pos]))
} else {
None
}
})
.collect();
if !u_pairs.iter().any(|(c, _, _)| *c == i) {
u_pairs.push((i, Real::zero(), 0));
}
u_pairs.sort_by_key(|x| x.0);
self.l_row[i + 1] = self.l_row[i] + l_pairs.len();
for (c, v, lev) in l_pairs {
self.l_col.push(c);
self.l_val.push(v);
self.l_lev.push(lev);
}
let u_start = self.u_col.len();
self.u_row[i + 1] = self.u_row[i] + u_pairs.len();
for (c, v, lev) in &u_pairs {
self.u_col.push(*c);
self.u_val.push(*v);
self.u_lev.push(*lev);
}
let d_rel = u_pairs.iter().position(|(c, _, _)| *c == i).unwrap();
self.u_diag_ix[i] = u_start + d_rel;
symbolic::clear_rowwork(&mut w);
}
self.iluk_numeric_only(a, k_limit, max_diag_abs)
}
fn iluk_numeric_only(
&mut self,
a: &CsrMatrix<f64>,
_k_limit: usize,
max_diag_abs: f64,
) -> Result<(), KError> {
use symbolic::RowWork;
let n = self.n;
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut w = RowWork {
mark: Vec::new(),
idx: Vec::new(),
val: Vec::new(),
};
symbolic::ensure_rowwork(&mut w, n);
for i in 0..n {
symbolic::ensure_rowwork(&mut w, n);
for p in rp[i]..rp[i + 1] {
let j = cj[p];
let pos = symbolic::find_or_insert(&mut w, j);
w.val[pos] = Real::from_real(vv[p]);
}
let ls = self.l_row[i];
let le = self.l_row[i + 1];
for pos in ls..le {
let j = self.l_col[pos];
let wij = if w.mark[j] >= 0 {
w.val[w.mark[j] as usize]
} else {
Real::zero()
};
let djj = self.u_val[self.u_diag_ix[j]];
let lij = if djj == Real::zero() {
Real::zero()
} else {
wij / djj
};
self.l_val[pos] = lij;
let urs = self.u_row[j];
let ure = self.u_row[j + 1];
for q in urs..ure {
let kcol = self.u_col[q];
if kcol <= j {
continue;
}
let mk = w.mark.get(kcol).copied().unwrap_or(-1);
if mk >= 0 {
w.val[mk as usize] -= lij * self.u_val[q];
}
}
}
let us = self.u_row[i];
let ue = self.u_row[i + 1];
let mut diag = Real::zero();
for q in us..ue {
let k = self.u_col[q];
let v = if w.mark.get(k).copied().unwrap_or(-1) >= 0 {
w.val[w.mark[k] as usize]
} else {
Real::zero()
};
if k == i {
diag = v;
}
self.u_val[q] = v;
}
let fixed = pivot::handle_pivot(
diag,
self.cfg.pivot,
self.cfg.pivot_threshold,
self.cfg.diag_perturb_factor,
max_diag_abs,
)
.map_err(|_| KError::ZeroPivot(i))?;
let dix = self.u_diag_ix[i];
self.u_val[dix] = fixed;
symbolic::clear_rowwork(&mut w);
}
Ok(())
}
fn factor_iluk_numeric_only(
&mut self,
a: &CsrMatrix<f64>,
k_limit: usize,
) -> Result<(), KError> {
let mut max_diag_abs = 0.0f64;
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
for i in 0..self.n {
let mut di = 0.0;
for p in rp[i]..rp[i + 1] {
if cj[p] == i {
di = vv[p];
break;
}
}
max_diag_abs = max_diag_abs.max(di.abs());
}
self.iluk_numeric_only(a, k_limit, max_diag_abs)
}
fn factor_ilut(&mut self, a: &CsrMatrix<f64>, params: &IlutParams) -> Result<(), KError> {
let n = a.nrows();
if n != a.ncols() {
return Err(KError::InvalidInput("ILUT requires square matrix".into()));
}
self.n = n;
let mut l_build = CsrBuilder::new(n);
let mut u_build = CsrBuilder::new(n);
let mut inv_diag_u = vec![Real::zero(); n];
let mut w = RowWork::new();
w.ensure_size(n);
let mut l_tmp: Vec<(usize, Real)> = Vec::new();
let mut u_tmp: Vec<(usize, Real)> = Vec::new();
let mut max_diag_abs = 0.0f64;
for i in 0..n {
w.clear_row();
l_tmp.clear();
u_tmp.clear();
let (a_cols, a_vals) = a.row(i);
let mut row_inf: f64 = 0.0;
for (&j, &v) in a_cols.iter().zip(a_vals.iter()) {
if v != 0.0 {
w.set(j, Real::from_real(v));
row_inf = row_inf.max(v.abs());
}
}
let tau = params.droptol_abs + params.droptol_rel * row_inf;
let mut lowers: Vec<usize> = w.iter().filter(|&(j, _)| j < i).map(|(j, _)| j).collect();
lowers.sort_unstable();
for &k in &lowers {
let wk = w.get(k);
if wk == Real::zero() {
continue;
}
let lik = wk * inv_diag_u[k];
if params.early_drop && lik.abs() < tau {
w.set(k, Real::zero());
continue;
}
l_tmp.push((k, lik));
w.set(k, Real::zero());
let (u_cols_k, u_vals_k) = u_build.row(k);
for (&j, &ukj) in u_cols_k.iter().zip(u_vals_k.iter()) {
if j <= k {
continue;
}
let newv: Real = w.get(j) - lik * ukj;
if params.early_drop && newv.abs() < tau {
w.set(j, Real::zero());
} else {
w.set(j, newv);
}
}
}
for (j, v) in w.iter() {
if j >= i && (j == i || v.abs() >= tau) {
u_tmp.push((j, v));
}
}
if !u_tmp.iter().any(|(j, _)| *j == i) {
u_tmp.push((i, Real::zero()));
}
if params.p_l > 0 && l_tmp.len() > params.p_l {
l_tmp.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
l_tmp.truncate(params.p_l);
}
let mut diag = Real::zero();
if let Some(pos) = u_tmp.iter().position(|(j, _)| *j == i) {
diag = u_tmp[pos].1;
u_tmp.remove(pos);
}
if params.p_u > 0 && u_tmp.len() > params.p_u {
u_tmp.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
u_tmp.truncate(params.p_u);
}
u_tmp.push((i, diag));
if params.reproducible_order {
l_tmp.sort_by(|a, b| a.0.cmp(&b.0));
u_tmp.sort_by(|a, b| a.0.cmp(&b.0));
} else {
l_tmp.sort_unstable_by(|a, b| a.0.cmp(&b.0));
u_tmp.sort_unstable_by(|a, b| a.0.cmp(&b.0));
}
let diag_pos = u_tmp.iter().position(|(j, _)| *j == i).unwrap();
let mut uii = u_tmp[diag_pos].1;
max_diag_abs = max_diag_abs.max(uii.abs());
let tau = params.pivot_tau;
match params.pivot {
PivotPolicy::Strict => {
if uii.abs() < tau {
return Err(KError::ZeroPivot(i));
}
}
PivotPolicy::Threshold => {
if uii.abs() < tau {
if uii == Real::zero() {
uii = Real::from_real(tau);
} else {
uii = uii * Real::from_real(tau / uii.abs());
}
}
}
PivotPolicy::DiagonalPerturbation => {
if uii.abs() < tau {
let direction = if uii == Real::zero() {
Real::one()
} else {
uii / Real::from_real(uii.abs())
};
uii += direction * Real::from_real(tau);
}
}
}
u_tmp[diag_pos].1 = uii;
inv_diag_u[i] = uii.inv();
for &(k, v) in &l_tmp {
l_build.push(i, k, v);
}
l_build.push(i, i, Real::one());
for &(j, v) in &u_tmp {
u_build.push(i, j, v);
}
}
let (l_row, l_col, l_val) = l_build.finalize_sorted_unique(params.reproducible_order);
let (u_row, u_col, u_val) = u_build.finalize_sorted_unique(params.reproducible_order);
self.l_row = l_row;
self.l_col = l_col;
self.l_val = l_val;
self.u_row = u_row;
self.u_col = u_col;
self.u_val = u_val;
self.u_diag_ix.clear();
self.u_diag_ix.resize(n, 0);
for i in 0..n {
let rs = self.u_row[i];
let re = self.u_row[i + 1];
if let Some(pos) = self.u_col[rs..re].iter().position(|&c| c == i) {
self.u_diag_ix[i] = rs + pos;
} else {
return Err(KError::InvalidInput("missing diagonal".into()));
}
}
self.resize_apply_workspace(n);
self.ilut_numeric_only(a, max_diag_abs)
}
fn ilut_numeric_only(&mut self, a: &CsrMatrix<f64>, max_diag_abs: f64) -> Result<(), KError> {
use symbolic::RowWork;
let n = self.n;
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut w = RowWork {
mark: Vec::new(),
idx: Vec::new(),
val: Vec::new(),
};
symbolic::ensure_rowwork(&mut w, n);
for i in 0..n {
symbolic::ensure_rowwork(&mut w, n);
for p in rp[i]..rp[i + 1] {
let j = cj[p];
let pos = symbolic::find_or_insert(&mut w, j);
w.val[pos] = Real::from_real(vv[p]);
}
let ls = self.l_row[i];
let le = self.l_row[i + 1];
for pos in ls..le {
let j = self.l_col[pos];
let wij = if w.mark[j] >= 0 {
w.val[w.mark[j] as usize]
} else {
Real::zero()
};
let djj = self.u_val[self.u_diag_ix[j]];
let lij = if djj == Real::zero() {
Real::zero()
} else {
wij / djj
};
self.l_val[pos] = lij;
let urs = self.u_row[j];
let ure = self.u_row[j + 1];
for q in urs..ure {
let kcol = self.u_col[q];
if kcol <= j {
continue;
}
let mk = w.mark.get(kcol).copied().unwrap_or(-1);
if mk >= 0 {
w.val[mk as usize] -= lij * self.u_val[q];
}
}
}
let us = self.u_row[i];
let ue = self.u_row[i + 1];
let mut diag = Real::zero();
for q in us..ue {
let k = self.u_col[q];
let v = if w.mark.get(k).copied().unwrap_or(-1) >= 0 {
w.val[w.mark[k] as usize]
} else {
Real::zero()
};
if k == i {
diag = v;
}
self.u_val[q] = v;
}
let fixed = pivot::handle_pivot(
diag,
self.cfg.pivot,
self.cfg.pivot_threshold,
self.cfg.diag_perturb_factor,
max_diag_abs,
)
.map_err(|_| KError::ZeroPivot(i))?;
self.u_val[self.u_diag_ix[i]] = fixed;
symbolic::clear_rowwork(&mut w);
}
Ok(())
}
fn factor_ilut_numeric_only(&mut self, a: &CsrMatrix<f64>) -> Result<(), KError> {
let rp = a.row_ptr();
let cj = a.col_idx();
let vv = a.values();
let mut max_diag_abs = 0.0f64;
for i in 0..self.n {
let mut di = 0.0;
for p in rp[i]..rp[i + 1] {
if cj[p] == i {
di = vv[p];
break;
}
}
max_diag_abs = max_diag_abs.max(di.abs());
}
self.ilut_numeric_only(a, max_diag_abs)
}
fn setup_from_local_square_ids(
&mut self,
a: &CsrMatrix<f64>,
sid: StructureId,
vid: ValuesId,
) -> Result<(), KError> {
let pipeline =
apply_preconditioning_pipeline(a, &self.cfg.conditioning, &self.cfg.reordering)?;
let a = &pipeline.matrix;
let structure_changed = self.last_sid != Some(sid);
let values_changed = self.last_vid != Some(vid);
if structure_changed || !self.cfg.numeric_update_fixed {
self.perm = pipeline.metadata.left_perm.clone();
self.pipeline_meta = pipeline.metadata.clone();
self.factor_symbolic_and_numeric(a)?;
self.build_levels_if_enabled();
self.last_sid = Some(sid);
self.last_vid = Some(vid);
self.resize_apply_workspace(a.nrows());
Ok(())
} else if values_changed {
self.pipeline_meta = pipeline.metadata.clone();
self.factor_numeric_only(a)?;
self.last_vid = Some(vid);
Ok(())
} else {
Ok(())
}
}
pub fn setup_local_square(&mut self, local: &LocalSquareCsr<f64>) -> Result<(), KError> {
let op = local.as_csr();
self.setup_from_local_square_ids(op, op.structure_id(), op.values_id())
}
fn resize_apply_workspace(&mut self, n: usize) {
self.tmp.resize(n, Real::zero());
self.tmp2.resize(n, Real::zero());
self.tmp3.resize(n, Real::zero());
#[cfg(feature = "complex")]
{
self.c_tmp.resize(n, S::zero());
self.c_y_tmp.resize(n, S::zero());
self.c_xr.resize(n, Real::zero());
self.c_xi.resize(n, Real::zero());
self.c_yr.resize(n, Real::zero());
self.c_yi.resize(n, Real::zero());
}
}
}
#[cfg(not(feature = "complex"))]
impl Preconditioner for IluCsr {
fn dims(&self) -> (usize, usize) {
(self.n, self.n)
}
fn setup(&mut self, op: &dyn LinOp<S = f64>) -> Result<(), KError> {
let drop = 0.0; let a: Arc<CsrMatrix<f64>> = csr_from_linop(op, drop)?;
let local = LocalSquareCsr::try_from(a.as_ref().clone())?;
self.setup_from_local_square_ids(local.as_csr(), op.structure_id(), op.values_id())
}
fn apply(&self, _side: PcSide, x: &[f64], y: &mut [f64]) -> Result<(), KError> {
self.apply_op_scalar(Op::NoTrans, x, y)
}
fn distributed_support(&self) -> PcDistributedSupport {
PcDistributedSupport::LocalOnly
}
fn apply_op(&self, op: Op, x: &[f64], y: &mut [f64]) -> Result<(), KError> {
if x.len() != self.n || y.len() != self.n {
return Err(KError::InvalidInput(format!(
"IluCsr::apply dimension mismatch: n={}, x.len()={}, y.len()={}",
self.n,
x.len(),
y.len()
)));
}
self.apply_op_scalar(op, x, y)
}
fn apply_mut(&mut self, _side: PcSide, x: &[f64], y: &mut [f64]) -> Result<(), KError> {
self.apply_op_scalar_mut(Op::NoTrans, x, y)
}
fn supports_numeric_update(&self) -> bool {
self.cfg.numeric_update_fixed
}
fn update_numeric(&mut self, op: &dyn LinOp<S = f64>) -> Result<(), KError> {
if !self.cfg.numeric_update_fixed {
return Err(KError::Unsupported("numeric update requires fixed pattern"));
}
if Some(op.structure_id()) != self.last_sid {
return Err(KError::Unsupported("pattern changed; call update_symbolic"));
}
let a = csr_from_linop(op, 0.0)?;
self.factor_numeric_only(&a)?;
self.last_vid = Some(op.values_id());
Ok(())
}
fn update_symbolic(&mut self, op: &dyn LinOp<S = f64>) -> Result<(), KError> {
let a = csr_from_linop(op, 0.0)?;
self.factor_symbolic_and_numeric(&a)?;
self.build_levels_if_enabled();
self.last_sid = Some(op.structure_id());
self.last_vid = Some(op.values_id());
Ok(())
}
fn required_format(&self) -> OpFormat {
OpFormat::Csr
}
fn capabilities(&self) -> PcCaps {
PcCaps {
supports_transpose: true,
supports_conj_trans: false,
is_spd: false,
side_restriction: Some(PcSide::Left),
}
}
}
#[cfg(feature = "complex")]
impl Preconditioner for IluCsr {
fn dims(&self) -> (usize, usize) {
(self.n, self.n)
}
fn setup(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
let csr = op
.as_any()
.downcast_ref::<CsrMatrix<S>>()
.ok_or_else(|| {
KError::Unsupported(
"IluCsr complex setup currently requires a CSR operator; non-CSR LinOp paths have no lossless complex ILU fallback".into(),
)
})?;
let local = LocalSquareCsr::try_from(csr.clone())?;
let csr = local.as_csr();
let sid = op.structure_id();
let vid = op.values_id();
let structure_changed = self.last_sid != Some(sid);
let values_changed = self.last_vid != Some(vid);
if structure_changed || !self.cfg.numeric_update_fixed {
let perm = match self.cfg.reordering.kind {
ReorderingKind::None => Permutation::identity(csr.nrows()),
ReorderingKind::Rcm => {
let a_real = CsrMatrix::from_csr(
csr.nrows(),
csr.ncols(),
csr.row_ptr().to_vec(),
csr.col_idx().to_vec(),
csr.values().iter().map(|v| v.real()).collect(),
);
rcm_csr(&a_real)
}
ReorderingKind::Amd => {
let a_real = CsrMatrix::from_csr(
csr.nrows(),
csr.ncols(),
csr.row_ptr().to_vec(),
csr.col_idx().to_vec(),
csr.values().iter().map(|v| v.real()).collect(),
);
amd_csr(&a_real)
}
};
let a_perm = if self.cfg.reordering.symmetric {
self.perm = perm.clone();
self.pipeline_meta = PreconditioningMetadata::identity(csr.nrows());
self.pipeline_meta.left_perm = perm.clone();
self.pipeline_meta.right_perm = perm;
permute_csr_symmetric(csr, &self.pipeline_meta.left_perm)
} else {
let a_real = CsrMatrix::from_csr(
csr.nrows(),
csr.ncols(),
csr.row_ptr().to_vec(),
csr.col_idx().to_vec(),
csr.values().iter().map(|v| v.real()).collect(),
);
let pipeline = apply_preconditioning_pipeline(
&a_real,
&ConditioningOptions::default(),
&self.cfg.reordering,
)?;
self.perm = pipeline.metadata.left_perm.clone();
self.pipeline_meta = PreconditioningMetadata::identity(csr.nrows());
self.pipeline_meta.left_perm = pipeline.metadata.left_perm.clone();
self.pipeline_meta.right_perm = pipeline.metadata.right_perm;
permute_csr_nonsymmetric(
csr,
&self.pipeline_meta.left_perm,
&self.pipeline_meta.right_perm,
)
};
match self.cfg.kind {
IluKind::Ilu0 | IluKind::Milu0 => {
if self.complex_force_degraded {
let a_real = CsrMatrix::from_csr(
a_perm.nrows(),
a_perm.ncols(),
a_perm.row_ptr().to_vec(),
a_perm.col_idx().to_vec(),
a_perm.values().iter().map(|v| v.real()).collect(),
);
self.factor_ilu0(&a_real)?;
self.native_complex_active = false;
self.complex_kernel_mode = IluComplexKernelMode::DegradedRealProjection;
} else {
self.factor_ilu0_complex(&a_perm)?;
}
}
IluKind::Iluk { k } => {
if self.complex_force_degraded {
let a_real = CsrMatrix::from_csr(
a_perm.nrows(),
a_perm.ncols(),
a_perm.row_ptr().to_vec(),
a_perm.col_idx().to_vec(),
a_perm.values().iter().map(|v| v.real()).collect(),
);
self.factor_iluk(&a_real, k)?;
self.native_complex_active = false;
self.complex_kernel_mode = IluComplexKernelMode::DegradedRealProjection;
} else {
self.factor_iluk_complex(&a_perm, k)?;
}
}
IluKind::Ilut { params } => self.factor_ilut_complex(&a_perm, ¶ms)?,
}
self.build_levels_if_enabled();
self.last_sid = Some(sid);
self.last_vid = Some(vid);
self.resize_apply_workspace(csr.nrows());
Ok(())
} else if values_changed {
self.update_numeric(op)
} else {
Ok(())
}
}
fn apply(&self, _side: PcSide, x: &[S], y: &mut [S]) -> Result<(), KError> {
let n = self.n;
if x.len() != n || y.len() != n {
return Err(KError::InvalidInput(format!(
"IluCsr::apply dimension mismatch: n={}, x.len()={}, y.len()={}",
self.n,
x.len(),
y.len()
)));
}
if self.native_complex_active {
let mut w = vec![S::zero(); n];
let mut y_perm = vec![S::zero(); n];
self.pipeline_meta.left_perm.apply_vec(x, &mut w);
for i in 0..n {
let mut s = w[i];
for p in self.l_row[i]..self.l_row[i + 1] {
s -= self.c_l_val[p] * w[self.l_col[p]];
}
w[i] = s;
}
for i in (0..n).rev() {
let mut s = w[i];
for p in self.u_row[i]..self.u_row[i + 1] {
let j = self.u_col[p];
if j > i {
s -= self.c_u_val[p] * w[j];
}
}
w[i] = s / self.c_u_val[self.u_diag_ix[i]];
}
self.pipeline_meta.right_perm.apply_vec_t(&w, &mut y_perm);
y.copy_from_slice(&y_perm);
Ok(())
} else {
let mut xr = vec![Real::zero(); n];
let mut xi = vec![Real::zero(); n];
let mut yr = vec![Real::zero(); n];
let mut yi = vec![Real::zero(); n];
for i in 0..n {
xr[i] = x[i].real();
xi[i] = x[i].imag();
}
self.apply_op_scalar(Op::NoTrans, &xr, &mut yr)?;
self.apply_op_scalar(Op::NoTrans, &xi, &mut yi)?;
for i in 0..n {
y[i] = S::from_parts(yr[i], yi[i]);
}
Ok(())
}
}
fn apply_mut(&mut self, _side: PcSide, x: &[S], y: &mut [S]) -> Result<(), KError> {
self.apply_complex_mut(x, y)
}
fn supports_numeric_update(&self) -> bool {
self.cfg.numeric_update_fixed
}
fn update_numeric(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
if !self.cfg.numeric_update_fixed {
return Err(KError::Unsupported("numeric update requires fixed pattern"));
}
if Some(op.structure_id()) != self.last_sid {
return Err(KError::Unsupported("pattern changed; call update_symbolic"));
}
let csr = op.as_any().downcast_ref::<CsrMatrix<S>>().ok_or_else(|| {
KError::Unsupported("IluCsr complex numeric update requires CSR".into())
})?;
let a_perm = if self.cfg.reordering.symmetric {
permute_csr_symmetric(csr, &self.pipeline_meta.left_perm)
} else {
permute_csr_nonsymmetric(
csr,
&self.pipeline_meta.left_perm,
&self.pipeline_meta.right_perm,
)
};
match self.cfg.kind {
IluKind::Ilu0 | IluKind::Milu0 => {
if self.complex_force_degraded {
let a_real = CsrMatrix::from_csr(
a_perm.nrows(),
a_perm.ncols(),
a_perm.row_ptr().to_vec(),
a_perm.col_idx().to_vec(),
a_perm.values().iter().map(|v| v.real()).collect(),
);
self.factor_ilu0(&a_real)?;
self.native_complex_active = false;
self.complex_kernel_mode = IluComplexKernelMode::DegradedRealProjection;
} else {
self.factor_ilu0_complex(&a_perm)?;
}
}
IluKind::Iluk { k } => {
if self.complex_force_degraded {
let a_real = CsrMatrix::from_csr(
a_perm.nrows(),
a_perm.ncols(),
a_perm.row_ptr().to_vec(),
a_perm.col_idx().to_vec(),
a_perm.values().iter().map(|v| v.real()).collect(),
);
self.factor_iluk(&a_real, k)?;
self.native_complex_active = false;
self.complex_kernel_mode = IluComplexKernelMode::DegradedRealProjection;
} else {
self.factor_iluk_complex(&a_perm, k)?;
}
}
IluKind::Ilut { params } => self.factor_ilut_complex(&a_perm, ¶ms)?,
}
self.last_vid = Some(op.values_id());
Ok(())
}
fn update_symbolic(&mut self, op: &dyn LinOp<S = S>) -> Result<(), KError> {
self.last_sid = None;
self.last_vid = None;
self.setup(op)
}
fn required_format(&self) -> OpFormat {
OpFormat::Csr
}
fn capabilities(&self) -> PcCaps {
PcCaps {
supports_transpose: false,
supports_conj_trans: false,
is_spd: false,
side_restriction: Some(PcSide::Left),
}
}
fn apply_op(&self, op: Op, x: &[S], y: &mut [S]) -> Result<(), KError> {
if op == Op::NoTrans {
return self.apply(PcSide::Left, x, y);
}
Err(KError::Unsupported(
"IluCsr complex transpose kernels are not yet implemented".into(),
))
}
}
impl LocalPreconditioner<f64> for IluCsr {
fn dims(&self) -> (usize, usize) {
(self.n, self.n)
}
fn apply_local(&self, x: &[f64], y: &mut [f64]) -> Result<(), KError> {
if x.len() != self.n || y.len() != self.n {
return Err(KError::InvalidInput(format!(
"IluCsr::apply_local dimension mismatch: n={}, x.len()={}, y.len()={}",
self.n,
x.len(),
y.len()
)));
}
self.apply_op_scalar(Op::NoTrans, x, y)
}
}
#[cfg(feature = "complex")]
impl KPreconditioner for IluCsr {
type Scalar = crate::algebra::prelude::S;
#[inline]
fn dims(&self) -> (usize, usize) {
crate::preconditioner::Preconditioner::dims(self)
}
fn apply_s(
&self,
side: PcSide,
x: &[Self::Scalar],
y: &mut [Self::Scalar],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
let _ = scratch;
self.apply(side, x, y)
}
fn apply_mut_s(
&mut self,
side: PcSide,
x: &[Self::Scalar],
y: &mut [Self::Scalar],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
let _ = scratch;
crate::preconditioner::Preconditioner::apply_mut(self, side, x, y)
}
}
impl IluCsr {
#[cfg(feature = "complex")]
pub fn complex_kernel_mode(&self) -> IluComplexKernelMode {
self.complex_kernel_mode
}
#[cfg(feature = "complex")]
pub fn set_complex_force_degraded(&mut self, on: bool) {
self.complex_force_degraded = on;
}
#[cfg(feature = "complex")]
fn apply_complex_mut(&mut self, x: &[S], y: &mut [S]) -> Result<(), KError> {
let n = self.n;
if x.len() != n || y.len() != n {
return Err(KError::InvalidInput(format!(
"IluCsr::apply dimension mismatch: n={}, x.len()={}, y.len()={}",
self.n,
x.len(),
y.len()
)));
}
if self.c_tmp.len() != n
|| self.c_y_tmp.len() != n
|| self.c_xr.len() != n
|| self.c_xi.len() != n
|| self.c_yr.len() != n
|| self.c_yi.len() != n
|| self.tmp.len() != n
|| self.tmp2.len() != n
|| self.tmp3.len() != n
{
self.resize_apply_workspace(n);
}
if self.native_complex_active {
let mut w = std::mem::take(&mut self.c_tmp);
let mut y_perm = std::mem::take(&mut self.c_y_tmp);
let result = (|| {
let w = &mut w[..n];
let y_perm = &mut y_perm[..n];
self.pipeline_meta.left_perm.apply_vec(x, w);
for i in 0..n {
let mut s = w[i];
for p in self.l_row[i]..self.l_row[i + 1] {
s -= self.c_l_val[p] * w[self.l_col[p]];
}
w[i] = s;
}
for i in (0..n).rev() {
let mut s = w[i];
for p in self.u_row[i]..self.u_row[i + 1] {
let j = self.u_col[p];
if j > i {
s -= self.c_u_val[p] * w[j];
}
}
w[i] = s / self.c_u_val[self.u_diag_ix[i]];
}
self.pipeline_meta.right_perm.apply_vec_t(w, y_perm);
y.copy_from_slice(y_perm);
Ok(())
})();
self.c_tmp = w;
self.c_y_tmp = y_perm;
result
} else {
let mut xr = std::mem::take(&mut self.c_xr);
let mut xi = std::mem::take(&mut self.c_xi);
let mut yr = std::mem::take(&mut self.c_yr);
let mut yi = std::mem::take(&mut self.c_yi);
let result = (|| {
for i in 0..n {
xr[i] = x[i].real();
xi[i] = x[i].imag();
}
self.apply_op_scalar_mut(Op::NoTrans, &xr[..n], &mut yr[..n])?;
self.apply_op_scalar_mut(Op::NoTrans, &xi[..n], &mut yi[..n])?;
for i in 0..n {
y[i] = S::from_parts(yr[i], yi[i]);
}
Ok(())
})();
self.c_xr = xr;
self.c_xi = xi;
self.c_yr = yr;
self.c_yi = yi;
result
}
}
#[inline]
pub(crate) fn n(&self) -> usize {
self.n
}
#[inline]
pub(crate) fn l_row(&self) -> &[usize] {
&self.l_row
}
#[inline]
pub(crate) fn l_col(&self) -> &[usize] {
&self.l_col
}
#[inline]
pub(crate) fn l_val(&self) -> &[Real] {
&self.l_val
}
#[inline]
pub(crate) fn u_row(&self) -> &[usize] {
&self.u_row
}
#[inline]
pub(crate) fn u_col(&self) -> &[usize] {
&self.u_col
}
#[inline]
pub(crate) fn u_val(&self) -> &[Real] {
&self.u_val
}
#[inline]
pub(crate) fn u_diag_ix(&self) -> &[usize] {
&self.u_diag_ix
}
#[allow(dead_code)]
#[inline]
pub(crate) fn tmp(&self) -> &[Real] {
&self.tmp
}
#[allow(dead_code)]
#[inline]
pub(crate) fn tmp_mut(&mut self) -> &mut [Real] {
&mut self.tmp
}
#[inline]
pub(crate) fn buckets_fwd(&self) -> &[Vec<usize>] {
&self.buckets_fwd
}
#[inline]
pub(crate) fn buckets_bwd(&self) -> &[Vec<usize>] {
&self.buckets_bwd
}
fn pipeline_apply_left(&self, x: &[Real], y: &mut [Real]) {
self.pipeline_meta.left_perm.apply_vec(x, y);
if let Some(scale) = &self.pipeline_meta.row_scaling {
for i in 0..y.len() {
y[i] /= scale[i];
}
}
}
fn pipeline_apply_right_inverse_with_tmp(&self, x: &[Real], tmp: &mut [Real], y: &mut [Real]) {
self.pipeline_meta.right_perm.apply_vec_t(x, tmp);
if let Some(scale) = &self.pipeline_meta.col_scaling {
for i in 0..y.len() {
y[i] = tmp[i] / scale[i];
}
} else {
y.copy_from_slice(tmp);
}
}
fn apply_op_scalar(&self, op: Op, x: &[Real], y: &mut [Real]) -> Result<(), KError> {
if x.len() != self.n || y.len() != self.n {
return Err(KError::InvalidInput(format!(
"IluCsr::apply dimension mismatch: n={}, x.len()={}, y.len()={}",
self.n,
x.len(),
y.len()
)));
}
let mut x_perm = vec![Real::zero(); self.n];
let mut y_perm = vec![Real::zero(); self.n];
let mut right_tmp = vec![Real::zero(); self.n];
self.apply_op_scalar_with_workspace(op, x, y, &mut x_perm, &mut y_perm, &mut right_tmp)
}
fn apply_op_scalar_mut(&mut self, op: Op, x: &[Real], y: &mut [Real]) -> Result<(), KError> {
if self.tmp.len() != self.n || self.tmp2.len() != self.n || self.tmp3.len() != self.n {
self.resize_apply_workspace(self.n);
}
let mut x_perm = std::mem::take(&mut self.tmp);
let mut y_perm = std::mem::take(&mut self.tmp2);
let mut right_tmp = std::mem::take(&mut self.tmp3);
let result = self.apply_op_scalar_with_workspace(
op,
x,
y,
&mut x_perm[..self.n],
&mut y_perm[..self.n],
&mut right_tmp[..self.n],
);
self.tmp = x_perm;
self.tmp2 = y_perm;
self.tmp3 = right_tmp;
result
}
fn apply_op_scalar_with_workspace(
&self,
op: Op,
x: &[Real],
y: &mut [Real],
x_perm: &mut [Real],
y_perm: &mut [Real],
right_tmp: &mut [Real],
) -> Result<(), KError> {
if x.len() != self.n || y.len() != self.n {
return Err(KError::InvalidInput(format!(
"IluCsr::apply dimension mismatch: n={}, x.len()={}, y.len()={}",
self.n,
x.len(),
y.len()
)));
}
self.pipeline_apply_left(x, x_perm);
match op {
Op::NoTrans => {
if self.cfg.level_sched {
tri_solve::tri_solve_level_scheduled(self, x_perm, y_perm)
} else {
tri_solve::tri_solve_serial(self, x_perm, y_perm)
}
}
Op::Trans | Op::ConjTrans => {
let ut = self
.ut
.get_or_init(|| transpose_csr(self.n, &self.u_row, &self.u_col, &self.u_val));
let lt = self
.lt
.get_or_init(|| transpose_csr(self.n, &self.l_row, &self.l_col, &self.l_val));
tri_solve::tri_solve_transpose_serial(
self, &ut.0, &ut.1, &ut.2, <.0, <.1, <.2, x_perm, y_perm,
)
}
}?;
self.pipeline_apply_right_inverse_with_tmp(y_perm, right_tmp, y);
Ok(())
}
}
fn transpose_csr(
n: usize,
row: &[usize],
col: &[usize],
val: &[Real],
) -> (Vec<usize>, Vec<usize>, Vec<Real>) {
let nnz = col.len();
let mut t_row = vec![0usize; n + 1];
for &j in col {
t_row[j + 1] += 1;
}
for i in 0..n {
t_row[i + 1] += t_row[i];
}
let mut t_col = vec![0usize; nnz];
let mut t_val = vec![Real::zero(); nnz];
let mut offset = t_row.clone();
for i in 0..n {
for p in row[i]..row[i + 1] {
let j = col[p];
let dest = offset[j];
t_col[dest] = i;
t_val[dest] = val[p];
offset[j] += 1;
}
}
(t_row, t_col, t_val)
}
#[cfg(all(test, feature = "complex"))]
mod complex_pivot_tests {
use super::*;
fn checkerboard_zero_diag() -> CsrMatrix<S> {
CsrMatrix::from_csr(
2,
2,
vec![0, 1, 2],
vec![1, 0],
vec![S::from_real(1.0), S::from_real(1.0)],
)
}
#[test]
fn ilu0_complex_zero_pivot_obeys_strategy() {
let a = checkerboard_zero_diag();
let mut strict = IluCsr::new_with_config(IluCsrConfig {
kind: IluKind::Ilu0,
pivot: PivotStrategy::Strict,
pivot_threshold: 1e-12,
diag_perturb_factor: 1e-10,
level_sched: false,
numeric_update_fixed: true,
logging: 0,
reordering: ReorderingOptions::default(),
conditioning: ConditioningOptions::default(),
});
assert!(strict.factor_ilu0_complex(&a).is_err());
let mut threshold = IluCsr::new_with_config(IluCsrConfig {
kind: IluKind::Ilu0,
pivot: PivotStrategy::Threshold,
pivot_threshold: 1e-12,
diag_perturb_factor: 1e-10,
level_sched: false,
numeric_update_fixed: true,
logging: 0,
reordering: ReorderingOptions::default(),
conditioning: ConditioningOptions::default(),
});
threshold
.factor_ilu0_complex(&a)
.expect("threshold pivot policy should floor tiny complex pivots");
for i in 0..threshold.n {
let d = threshold.c_u_val[threshold.u_diag_ix[i]];
assert!(d.abs() >= 1e-12, "row {i} pivot should be floored");
}
let mut perturb = IluCsr::new_with_config(IluCsrConfig {
kind: IluKind::Ilu0,
pivot: PivotStrategy::DiagonalPerturbation,
pivot_threshold: 1e-12,
diag_perturb_factor: 1e-10,
level_sched: false,
numeric_update_fixed: true,
logging: 0,
reordering: ReorderingOptions::default(),
conditioning: ConditioningOptions::default(),
});
perturb
.factor_ilu0_complex(&a)
.expect("diag perturbation should repair tiny complex pivots");
for i in 0..perturb.n {
let d = perturb.c_u_val[perturb.u_diag_ix[i]];
assert!(d.abs() > 0.0, "row {i} pivot should be nonzero");
}
}
}