#[cfg(feature = "complex")]
use crate::algebra::bridge::BridgeScratch;
use crate::algebra::prelude::*;
use crate::error::KError;
use crate::matrix::convert::csr_from_linop;
use crate::matrix::op::{LinOp, StructureId, ValuesId};
use crate::matrix::sparse::CsrMatrix;
#[cfg(feature = "complex")]
use crate::ops::kpc::KPreconditioner;
#[cfg(feature = "complex")]
use crate::preconditioner::bridge::{
apply_pc_mut_s as bridge_apply_pc_mut_s, apply_pc_s as bridge_apply_pc_s,
};
use crate::preconditioner::{PcSide, Preconditioner};
use faer::Mat;
use faer::linalg::solvers::SolveLstsq;
#[derive(Clone, Copy, Debug)]
pub enum ApproxInvKind {
FSAI,
SPAI,
}
#[derive(Clone, Copy, Debug)]
pub struct ApproxInvParams {
pub kind: ApproxInvKind,
pub levels: usize, pub max_per_col: usize, pub drop_tol: f64, pub reg: f64, pub max_cond: f64, pub parallel: bool, }
impl Default for ApproxInvParams {
fn default() -> Self {
Self {
kind: ApproxInvKind::FSAI,
levels: 1,
max_per_col: 20,
drop_tol: 1e-3,
reg: 1e-12,
max_cond: 1e12,
parallel: cfg!(feature = "rayon"),
}
}
}
pub struct ApproxInvBuilder {
p: ApproxInvParams,
}
impl ApproxInvBuilder {
pub fn new(kind: ApproxInvKind) -> Self {
let mut p = ApproxInvParams::default();
p.kind = kind;
Self { p }
}
pub fn levels(mut self, l: usize) -> Self {
self.p.levels = l;
self
}
pub fn max_per_col(mut self, s: usize) -> Self {
self.p.max_per_col = s.max(1);
self
}
pub fn drop_tol(mut self, t: f64) -> Self {
self.p.drop_tol = t.max(S::zero().real());
self
}
pub fn reg(mut self, r: f64) -> Self {
self.p.reg = if r >= S::zero().real() {
r
} else {
S::zero().real()
};
self
}
pub fn max_cond(mut self, c: f64) -> Self {
self.p.max_cond = if c > S::zero().real() { c } else { 1e12 };
self
}
pub fn parallel(mut self, on: bool) -> Self {
self.p.parallel = on;
self
}
pub fn build_fsai(self, a: &CsrMatrix<f64>) -> Result<FsaiCsr, KError> {
FsaiCsr::build_from_csr(a.clone(), self.p)
}
pub fn build_spai(self, a: &CsrMatrix<f64>) -> Result<SpaiCsr, KError> {
SpaiCsr::build_from_csr(a.clone(), self.p)
}
}
pub struct FsaiCsr {
pub(crate) g: CsrMatrix<f64>, pub(crate) pat: Vec<Vec<usize>>, pub(crate) params: ApproxInvParams,
last_sid: Option<StructureId>,
last_vid: Option<ValuesId>,
}
pub struct SpaiCsr {
pub(crate) m: CsrMatrix<f64>,
pub(crate) pat: Vec<Vec<usize>>, pub(crate) params: ApproxInvParams,
last_sid: Option<StructureId>,
last_vid: Option<ValuesId>,
}
#[inline]
fn csr_find(a: &CsrMatrix<f64>, row: usize, col: usize) -> f64 {
let rp = a.row_ptr();
let ci = a.col_idx();
let vv = a.values();
let (rs, re) = (rp[row], rp[row + 1]);
let cols = &ci[rs..re];
match cols.binary_search(&col) {
Ok(k) => vv[rs + k],
Err(_) => R::default(),
}
}
#[inline]
fn spmv_csr(a: &CsrMatrix<f64>, x: &[S], y: &mut [S]) {
assert_eq!(x.len(), a.ncols());
assert_eq!(y.len(), a.nrows());
for yi in y.iter_mut() {
*yi = S::zero();
}
let rp = a.row_ptr();
let ci = a.col_idx();
let vv = a.values();
for i in 0..a.nrows() {
let (rs, re) = (rp[i], rp[i + 1]);
let mut sum = S::zero();
for p in rs..re {
sum += S::from_real(vv[p]) * x[ci[p]];
}
y[i] = sum;
}
}
#[inline]
fn spmv_csr_transpose(a: &CsrMatrix<f64>, x: &[S], y: &mut [S]) {
assert_eq!(x.len(), a.nrows());
assert_eq!(y.len(), a.ncols());
for yi in y.iter_mut() {
*yi = S::zero();
}
let rp = a.row_ptr();
let ci = a.col_idx();
let vv = a.values();
for i in 0..a.nrows() {
let (rs, re) = (rp[i], rp[i + 1]);
let xi = x[i];
for p in rs..re {
y[ci[p]] += S::from_real(vv[p]) * xi;
}
}
}
fn grow_pattern_row_graph(a: &CsrMatrix<f64>, i: usize, levels: usize, cap: usize) -> Vec<usize> {
let n = a.nrows();
assert!(i < n);
let mut cur: Vec<usize> = vec![i];
let mut acc: Vec<usize> = vec![i];
for _ in 0..levels {
let mut next: Vec<usize> = Vec::new();
for &u in &cur {
let rp = a.row_ptr();
let ci = a.col_idx();
let (rs, re) = (rp[u], rp[u + 1]);
for p in rs..re {
let v = ci[p];
next.push(v);
}
}
acc.extend(next);
acc.sort_unstable();
acc.dedup();
if acc.len() > cap {
acc.truncate(cap);
}
cur = acc.clone();
if cur.len() >= cap {
break;
}
}
if !acc.contains(&i) {
acc.push(i);
}
acc.sort_unstable();
acc.dedup();
if acc.len() > cap {
acc.truncate(cap);
}
acc
}
impl FsaiCsr {
fn build_from_csr(a: CsrMatrix<f64>, cfg: ApproxInvParams) -> Result<Self, KError> {
let n = a.nrows().min(a.ncols());
let mut pat: Vec<Vec<usize>> = Vec::with_capacity(n);
for i in 0..n {
let mut s = grow_pattern_row_graph(&a, i, cfg.levels, cfg.max_per_col);
s.retain(|&r| r >= i);
if !s.contains(&i) {
s.insert(0, i);
}
s.sort_unstable();
s.dedup();
if s.len() > cfg.max_per_col {
s.truncate(cfg.max_per_col);
}
pat.push(s);
}
let mut trips: Vec<(usize, usize, R)> = Vec::new();
let solve_column = |s: &[usize], i: usize| -> (Vec<(usize, usize, R)>, Vec<usize>) {
let m = s.len();
if m == 0 {
return (Vec::new(), Vec::new());
}
let mut a_ss = Mat::<R>::from_fn(m, m, |_, _| R::default());
let mut b = vec![R::default(); m];
for p in 0..m {
for q in 0..=p {
let v = csr_find(&a, s[p], s[q]);
a_ss[(p, q)] = v;
a_ss[(q, p)] = v;
}
}
for d in 0..m {
a_ss[(d, d)] += cfg.reg;
}
if let Ok(pos) = s.binary_search(&i) {
b[pos] = S::one().real();
} else {
return (Vec::new(), Vec::new());
}
let rhs = Mat::<R>::from_fn(m, 1, |r, _| b[r]);
let sol = faer::linalg::solvers::Qr::new(a_ss.as_ref()).solve_lstsq(rhs);
let mut norm2: R = R::default();
for r in 0..m {
norm2 += sol[(r, 0)] * sol[(r, 0)];
}
norm2 = norm2.sqrt();
let thr = cfg.drop_tol * norm2.max(1e-32);
let mut col_trips: Vec<(usize, usize, R)> = Vec::with_capacity(m);
let mut kept: Vec<usize> = Vec::with_capacity(m);
for (k, &row) in s.iter().enumerate() {
let val = sol[(k, 0)];
if val.abs() >= thr {
col_trips.push((row, i, val)); kept.push(row);
}
}
kept.sort_unstable();
(col_trips, kept)
};
if cfg.parallel {
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let pat_snapshot = pat.clone();
let mut results: Vec<(usize, Vec<(usize, usize, R)>, Vec<usize>)> = pat_snapshot
.par_iter()
.enumerate()
.map(|(i, s)| {
let (col_trips, kept) = solve_column(s, i);
(i, col_trips, kept)
})
.collect();
results.sort_by_key(|(i, _, _)| *i);
for (i, mut col_trips, kept) in results {
trips.append(&mut col_trips);
pat[i] = kept;
}
}
#[cfg(not(feature = "rayon"))]
{
for i in 0..n {
let (mut col_trips, kept) = solve_column(&pat[i], i);
trips.append(&mut col_trips);
pat[i] = kept;
}
}
} else {
for i in 0..n {
let (mut col_trips, kept) = solve_column(&pat[i], i);
trips.append(&mut col_trips);
pat[i] = kept;
}
}
let g = assemble_csr(n, n, &mut trips);
Ok(Self {
g,
pat,
params: cfg,
last_sid: None,
last_vid: None,
})
}
}
impl Preconditioner for FsaiCsr {
fn setup(&mut self, a: &dyn LinOp<S = S>) -> Result<(), KError> {
let csr = csr_from_linop(a, R::default())?; let sid = a.structure_id();
let vid = a.values_id();
if let (Some(ls), Some(lv)) = (self.last_sid, self.last_vid)
&& ls == sid
&& lv != vid
{
self.update_numeric(a)?;
self.last_vid = Some(vid);
return Ok(());
}
let rebuilt = FsaiCsr::build_from_csr((*csr).clone(), self.params)?;
*self = rebuilt;
self.last_sid = Some(sid);
self.last_vid = Some(vid);
Ok(())
}
fn apply(&self, _side: PcSide, x: &[S], y: &mut [S]) -> Result<(), KError> {
if x.len() != self.g.nrows() || y.len() != self.g.nrows() {
return Err(KError::InvalidInput(format!(
"FsaiCsr::apply dimension mismatch: n={}, x.len()={}, y.len()={}",
self.g.nrows(),
x.len(),
y.len()
)));
}
let n = x.len();
let mut t = vec![S::zero(); n];
spmv_csr_transpose(&self.g, x, &mut t);
spmv_csr(&self.g, &t, y);
Ok(())
}
fn supports_numeric_update(&self) -> bool {
true
}
fn update_numeric(&mut self, a: &dyn LinOp<S = S>) -> Result<(), KError> {
let csr = csr_from_linop(a, R::default())?;
let n = self.g.nrows().min(self.g.ncols());
let mut a_ss = Mat::<R>::from_fn(1, 1, |_, _| R::default());
let mut b = vec![R::default(); 1];
let mut trips: Vec<(usize, usize, R)> = Vec::new();
for i in 0..n {
let s = &self.pat[i];
let m = s.len();
if m == 0 {
continue;
}
if a_ss.nrows() != m || a_ss.ncols() != m {
a_ss = Mat::<R>::from_fn(m, m, |_, _| R::default());
b.resize(m, R::default());
}
for p in 0..m {
for q in 0..=p {
let v = csr_find(&csr, s[p], s[q]);
a_ss[(p, q)] = v;
a_ss[(q, p)] = v;
}
}
for d in 0..m {
a_ss[(d, d)] += self.params.reg;
}
for k in 0..m {
b[k] = R::default();
}
if let Ok(pos) = s.binary_search(&i) {
b[pos] = S::one().real();
} else {
continue;
}
let rhs = Mat::<R>::from_fn(m, 1, |r, _| b[r]);
let sol = faer::linalg::solvers::Qr::new(a_ss.as_ref()).solve_lstsq(rhs);
for (k, &row) in s.iter().enumerate() {
let val = sol[(k, 0)];
trips.push((row, i, val));
}
}
self.g = assemble_csr(n, n, &mut trips);
Ok(())
}
fn required_format(&self) -> crate::matrix::format::OpFormat {
crate::matrix::format::OpFormat::Csr
}
}
#[cfg(feature = "complex")]
impl KPreconditioner for FsaiCsr {
type Scalar = S;
#[inline]
fn dims(&self) -> (usize, usize) {
(self.g.nrows(), self.g.ncols())
}
fn apply_s(
&self,
side: PcSide,
x: &[S],
y: &mut [S],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
bridge_apply_pc_s(self, side, x, y, scratch)
}
fn apply_mut_s(
&mut self,
side: PcSide,
x: &[S],
y: &mut [S],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
bridge_apply_pc_mut_s(self, side, x, y, scratch)
}
}
impl SpaiCsr {
fn build_from_csr(a: CsrMatrix<f64>, cfg: ApproxInvParams) -> Result<Self, KError> {
let n = a.nrows().min(a.ncols());
let mut pat: Vec<Vec<usize>> = Vec::with_capacity(n);
for j in 0..n {
let mut s = grow_pattern_row_graph(&a, j, cfg.levels, cfg.max_per_col);
if !s.contains(&j) {
s.push(j);
}
s.sort_unstable();
s.dedup();
if s.len() > cfg.max_per_col {
s.truncate(cfg.max_per_col);
}
pat.push(s);
}
let mut trips: Vec<(usize, usize, R)> = Vec::new();
let rp = a.row_ptr();
let ci = a.col_idx();
let vv = a.values();
let solve_column = |s: &[usize], j: usize| -> (Vec<(usize, usize, R)>, Vec<usize>) {
let m = s.len();
if m == 0 {
return (Vec::new(), Vec::new());
}
let mut idx_in_s: Vec<i32> = vec![-1; n]; for (pos, &g) in s.iter().enumerate() {
idx_in_s[g] = pos as i32;
}
let mut nmat = Mat::<R>::from_fn(m, m, |_, _| R::default());
let mut cvec = Mat::<R>::from_fn(m, 1, |_, _| R::default());
let (rj, rj2) = (rp[j], rp[j + 1]);
for pidx in rj..rj2 {
let col = ci[pidx];
let val = vv[pidx];
let pos = idx_in_s[col];
if pos >= 0 {
cvec[(pos as usize, 0)] = val;
}
}
for i in 0..n {
let (rs, re) = (rp[i], rp[i + 1]);
let mut pos_tmp: smallvec::SmallVec<[(usize, R); 32]> = smallvec::SmallVec::new();
for p in rs..re {
let col = ci[p];
let pos = idx_in_s[col];
if pos >= 0 {
pos_tmp.push((pos as usize, vv[p]));
}
}
for ix in 0..pos_tmp.len() {
let (px, vx) = pos_tmp[ix];
for iy in 0..=ix {
let (py, vy) = pos_tmp[iy];
let v = vx * vy;
nmat[(px, py)] += v;
if px != py {
nmat[(py, px)] += v;
}
}
}
}
for d in 0..m {
nmat[(d, d)] += cfg.reg;
}
let sol = faer::linalg::solvers::Qr::new(nmat.as_ref()).solve_lstsq(cvec);
let mut norm2: R = R::default();
for r in 0..m {
norm2 += sol[(r, 0)] * sol[(r, 0)];
}
norm2 = norm2.sqrt();
let thr = cfg.drop_tol * norm2.max(1e-32);
let mut col_trips: Vec<(usize, usize, R)> = Vec::with_capacity(m);
let mut kept: Vec<usize> = Vec::with_capacity(m);
for (k, &row) in s.iter().enumerate() {
let val = sol[(k, 0)];
if val.abs() >= thr {
col_trips.push((row, j, val));
kept.push(row);
}
}
kept.sort_unstable();
(col_trips, kept)
};
if cfg.parallel {
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let pat_snapshot = pat.clone();
let mut results: Vec<(usize, Vec<(usize, usize, R)>, Vec<usize>)> = pat_snapshot
.par_iter()
.enumerate()
.map(|(j, s)| {
let (col_trips, kept) = solve_column(s, j);
(j, col_trips, kept)
})
.collect();
results.sort_by_key(|(j, _, _)| *j);
for (j, mut col_trips, kept) in results {
trips.append(&mut col_trips);
pat[j] = kept;
}
}
#[cfg(not(feature = "rayon"))]
{
for j in 0..n {
let (mut col_trips, kept) = solve_column(&pat[j], j);
trips.append(&mut col_trips);
pat[j] = kept;
}
}
} else {
for j in 0..n {
let (mut col_trips, kept) = solve_column(&pat[j], j);
trips.append(&mut col_trips);
pat[j] = kept;
}
}
let m = assemble_csr(n, n, &mut trips);
Ok(Self {
m,
pat,
params: cfg,
last_sid: None,
last_vid: None,
})
}
}
impl Preconditioner for SpaiCsr {
fn setup(&mut self, a: &dyn LinOp<S = S>) -> Result<(), KError> {
let csr = csr_from_linop(a, R::default())?;
let sid = a.structure_id();
let vid = a.values_id();
if let (Some(ls), Some(lv)) = (self.last_sid, self.last_vid)
&& ls == sid
&& lv != vid
{
self.update_numeric(a)?;
self.last_vid = Some(vid);
return Ok(());
}
let rebuilt = SpaiCsr::build_from_csr((*csr).clone(), self.params)?;
*self = rebuilt;
self.last_sid = Some(sid);
self.last_vid = Some(vid);
Ok(())
}
fn apply(&self, _side: PcSide, x: &[S], y: &mut [S]) -> Result<(), KError> {
if x.len() != self.m.ncols() || y.len() != self.m.nrows() {
return Err(KError::InvalidInput(format!(
"SpaiCsr::apply dimension mismatch: A={}x{}, x.len()={}, y.len()={}",
self.m.nrows(),
self.m.ncols(),
x.len(),
y.len()
)));
}
spmv_csr(&self.m, x, y);
Ok(())
}
fn supports_numeric_update(&self) -> bool {
true
}
fn update_numeric(&mut self, a: &dyn LinOp<S = S>) -> Result<(), KError> {
let csr = csr_from_linop(a, R::default())?;
let n = self.m.nrows().min(self.m.ncols());
let rp = csr.row_ptr();
let ci = csr.col_idx();
let vv = csr.values();
let mut idx_in_s: Vec<i32> = vec![-1; n];
let mut trips: Vec<(usize, usize, R)> = Vec::new();
for j in 0..n {
let s = &self.pat[j];
let m = s.len();
if m == 0 {
continue;
}
for (pos, &g) in s.iter().enumerate() {
idx_in_s[g] = pos as i32;
}
let mut nmat = Mat::<R>::from_fn(m, m, |_, _| R::default());
let mut cvec = Mat::<R>::from_fn(m, 1, |_, _| R::default());
let (rj, rj2) = (rp[j], rp[j + 1]);
for pidx in rj..rj2 {
let col = ci[pidx];
let val = vv[pidx];
let pos = idx_in_s[col];
if pos >= 0 {
cvec[(pos as usize, 0)] = val;
}
}
for i in 0..n {
let (rs, re) = (rp[i], rp[i + 1]);
let mut pos_tmp: smallvec::SmallVec<[(usize, f64); 32]> = smallvec::SmallVec::new();
for p in rs..re {
let col = ci[p];
let pos = idx_in_s[col];
if pos >= 0 {
pos_tmp.push((pos as usize, vv[p]));
}
}
for ix in 0..pos_tmp.len() {
let (px, vx) = pos_tmp[ix];
for iy in 0..=ix {
let (py, vy) = pos_tmp[iy];
let v = vx * vy;
nmat[(px, py)] += v;
if px != py {
nmat[(py, px)] += v;
}
}
}
}
for d in 0..m {
nmat[(d, d)] += self.params.reg;
}
let sol = faer::linalg::solvers::Qr::new(nmat.as_ref()).solve_lstsq(cvec);
for (k, &row) in s.iter().enumerate() {
trips.push((row, j, sol[(k, 0)]));
}
for &g in s.iter() {
idx_in_s[g] = -1;
}
}
self.m = assemble_csr(n, n, &mut trips);
Ok(())
}
fn required_format(&self) -> crate::matrix::format::OpFormat {
crate::matrix::format::OpFormat::Csr
}
}
#[cfg(feature = "complex")]
impl KPreconditioner for SpaiCsr {
type Scalar = S;
#[inline]
fn dims(&self) -> (usize, usize) {
(self.m.nrows(), self.m.ncols())
}
fn apply_s(
&self,
side: PcSide,
x: &[S],
y: &mut [S],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
bridge_apply_pc_s(self, side, x, y, scratch)
}
fn apply_mut_s(
&mut self,
side: PcSide,
x: &[S],
y: &mut [S],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
bridge_apply_pc_mut_s(self, side, x, y, scratch)
}
}
fn assemble_csr(nrows: usize, ncols: usize, trips: &mut Vec<(usize, usize, R)>) -> CsrMatrix<f64> {
trips.sort_unstable_by(|a, b| match a.0.cmp(&b.0) {
std::cmp::Ordering::Equal => a.1.cmp(&b.1),
o => o,
});
let mut row_ptr = vec![0usize; nrows + 1];
let mut col_idx: Vec<usize> = Vec::with_capacity(trips.len());
let mut vals: Vec<R> = Vec::with_capacity(trips.len());
let mut cur_row = 0usize;
let mut acc = 0usize;
let mut k = 0usize;
while k < trips.len() {
let (r, c, mut v) = trips[k];
while cur_row < r {
row_ptr[cur_row + 1] = acc;
cur_row += 1;
}
k += 1;
while k < trips.len() && trips[k].0 == r && trips[k].1 == c {
v += trips[k].2;
k += 1;
}
col_idx.push(c);
vals.push(v);
acc += 1;
}
while cur_row < nrows {
row_ptr[cur_row + 1] = acc;
cur_row += 1;
}
CsrMatrix::from_csr(nrows, ncols, row_ptr, col_idx, vals)
}
impl FsaiCsr {
pub fn new_with_params(params: ApproxInvParams) -> Self {
Self {
g: CsrMatrix::from_csr(0, 0, vec![0], vec![], vec![]),
pat: Vec::new(),
params,
last_sid: None,
last_vid: None,
}
}
}
impl SpaiCsr {
pub fn new_with_params(params: ApproxInvParams) -> Self {
Self {
m: CsrMatrix::from_csr(0, 0, vec![0], vec![], vec![]),
pat: Vec::new(),
params,
last_sid: None,
last_vid: None,
}
}
}
#[cfg(all(test, feature = "complex"))]
mod tests {
use super::*;
use crate::algebra::bridge::BridgeScratch;
use crate::ops::kpc::KPreconditioner;
fn poisson_1d_matrix() -> CsrMatrix<f64> {
let row_ptr = vec![0, 2, 5, 7];
let col_idx = vec![0, 1, 0, 1, 2, 1, 2];
let values = vec![2.0, -1.0, -1.0, 2.0, -1.0, -1.0, 2.0];
CsrMatrix::from_csr(3, 3, row_ptr, col_idx, values)
}
#[test]
fn fsai_apply_s_matches_real_path() {
let a = poisson_1d_matrix();
let pc = ApproxInvBuilder::new(ApproxInvKind::FSAI)
.levels(1)
.build_fsai(&a)
.expect("fsai build");
let rhs_s: Vec<S> = [1.0, 2.0, 3.0].iter().copied().map(S::from_real).collect();
let mut out_direct = vec![S::zero(); rhs_s.len()];
pc.apply(PcSide::Left, &rhs_s, &mut out_direct)
.expect("fsai apply direct");
let mut out_s = vec![S::zero(); rhs_s.len()];
let mut scratch = BridgeScratch::default();
pc.apply_s(PcSide::Left, &rhs_s, &mut out_s, &mut scratch)
.expect("fsai apply_s");
for (direct, bridged) in out_direct.iter().zip(out_s.iter()) {
assert!((direct.real() - bridged.real()).abs() < 1e-10);
assert!((direct.imag() - bridged.imag()).abs() < 1e-10);
}
}
#[test]
fn spai_apply_s_matches_real_path() {
let a = poisson_1d_matrix();
let pc = ApproxInvBuilder::new(ApproxInvKind::SPAI)
.levels(1)
.build_spai(&a)
.expect("spai build");
let rhs_s: Vec<S> = [1.5, -0.5, 0.25]
.iter()
.copied()
.map(S::from_real)
.collect();
let mut out_direct = vec![S::zero(); rhs_s.len()];
pc.apply(PcSide::Left, &rhs_s, &mut out_direct)
.expect("spai apply direct");
let mut out_s = vec![S::zero(); rhs_s.len()];
let mut scratch = BridgeScratch::default();
pc.apply_s(PcSide::Left, &rhs_s, &mut out_s, &mut scratch)
.expect("spai apply_s");
for (direct, bridged) in out_direct.iter().zip(out_s.iter()) {
assert!((direct.real() - bridged.real()).abs() < 1e-10);
assert!((direct.imag() - bridged.imag()).abs() < 1e-10);
}
}
fn assert_csr_close(a: &CsrMatrix<f64>, b: &CsrMatrix<f64>) {
assert_eq!(a.nrows(), b.nrows());
assert_eq!(a.ncols(), b.ncols());
assert_eq!(a.row_ptr(), b.row_ptr());
assert_eq!(a.col_idx(), b.col_idx());
assert_eq!(a.values().len(), b.values().len());
for (va, vb) in a.values().iter().zip(b.values().iter()) {
assert!((va - vb).abs() < 1e-12);
}
}
#[test]
fn approxinv_parallel_toggle_matches() {
let a = poisson_1d_matrix();
let fsai_seq = ApproxInvBuilder::new(ApproxInvKind::FSAI)
.levels(1)
.parallel(false)
.build_fsai(&a)
.expect("fsai seq build");
let fsai_par = ApproxInvBuilder::new(ApproxInvKind::FSAI)
.levels(1)
.parallel(true)
.build_fsai(&a)
.expect("fsai par build");
assert_csr_close(&fsai_seq.g, &fsai_par.g);
let spai_seq = ApproxInvBuilder::new(ApproxInvKind::SPAI)
.levels(1)
.parallel(false)
.build_spai(&a)
.expect("spai seq build");
let spai_par = ApproxInvBuilder::new(ApproxInvKind::SPAI)
.levels(1)
.parallel(true)
.build_spai(&a)
.expect("spai par build");
assert_csr_close(&spai_seq.m, &spai_par.m);
}
}