#[cfg(feature = "complex")]
use crate::algebra::bridge::BridgeScratch;
use crate::algebra::prelude::*;
use crate::core::traits::MatVec;
use crate::error::KError;
use crate::matrix::convert::csr_from_linop;
use crate::matrix::op::{StructureId, ValuesId};
use crate::matrix::sparse::CsrMatrix;
#[cfg(feature = "complex")]
use crate::ops::kpc::KPreconditioner;
use crate::preconditioner::SparsityPattern;
use crate::preconditioner::legacy::Preconditioner;
#[cfg(feature = "complex")]
use crate::preconditioner::pc_bridge::apply_pc_s;
use faer::linalg::solvers::{SolveCore, SolveLstsq};
use std::any::TypeId;
use std::marker::PhantomData;
use std::sync::Arc;
pub struct ApproxInv<M, V, T> {
pub pattern: SparsityPattern,
pub tol: R,
pub max_iter: usize,
pub nbsteps: usize,
pub max_size: usize,
pub max_new: usize,
pub block_size: usize,
pub cache_size: usize,
pub verbose: bool,
pub sp: bool,
pub inv_rows: Vec<Vec<(usize, T)>>,
pub a: Option<M>,
pub csr: Option<Arc<CsrMatrix<T>>>,
pub last_sid: Option<StructureId>,
pub last_vid: Option<ValuesId>,
pub drop_tol: R,
_phantom: PhantomData<V>,
}
impl<M, V, T> ApproxInv<M, V, T>
where
T: KrystScalar<Real = R>,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
pattern: SparsityPattern,
tol: R,
max_iter: usize,
nbsteps: usize,
max_size: usize,
max_new: usize,
block_size: usize,
cache_size: usize,
verbose: bool,
sp: bool,
) -> Self {
Self {
pattern,
tol,
max_iter,
nbsteps,
max_size,
max_new,
block_size,
cache_size,
verbose,
sp,
inv_rows: Vec::new(),
a: None,
csr: None,
last_sid: None,
last_vid: None,
drop_tol: 1e-12,
_phantom: PhantomData,
}
}
}
impl<M: 'static + Sync, V: Sync, T> Preconditioner<M, V> for ApproxInv<M, V, T>
where
M: MatVec<V>,
V: From<Vec<T>> + AsRef<[T]> + AsMut<[T]> + Clone,
T: KrystScalar<Real = R> + 'static + Send + Sync,
{
fn setup(&mut self, a: &M) -> Result<(), KError> {
let n = match &self.pattern {
SparsityPattern::Manual(pat) => pat.len(),
SparsityPattern::Auto => {
if let Some(nrows) = get_nrows(a) {
nrows
} else {
return Err(KError::Unsupported(
"SparsityPattern::Auto requires nrows() or row_indices() support",
));
}
}
};
let mut cols = vec![vec![T::zero(); n]; n];
for j in 0..n {
let pattern: Vec<usize> = match &self.pattern {
SparsityPattern::Auto => {
if let Some(rowpat) = get_row_pattern(a) {
rowpat.row_indices(j).to_vec()
} else {
(0..n).collect()
}
}
SparsityPattern::Manual(pat) => pat.get(j).cloned().unwrap_or_else(Vec::new),
};
let m = pattern.len();
let mut b = vec![vec![T::zero(); n]; m];
for (row_idx, &i) in pattern.iter().enumerate() {
let mut ei = V::from(vec![T::zero(); n]);
ei.as_mut()[i] = T::one();
let mut col = V::from(vec![T::zero(); n]);
a.matvec(&ei, &mut col);
for k in 0..n {
b[row_idx][k] = col.as_ref()[k];
}
}
let mut e_j = vec![T::zero(); n];
e_j[j] = T::one();
let m_vec: Vec<T> = if TypeId::of::<T>() == TypeId::of::<f64>() {
use faer::linalg::solvers::{FullPivLu, Qr};
use faer::{Mat, MatMut};
let b_f64 = Mat::from_fn(n, m, |j, i| b[i][j].real());
let rhs = Mat::from_fn(n, 1, |i, _| e_j[i].real());
let sol: Vec<f64> = if m == n {
let lu = FullPivLu::new(b_f64.as_ref());
let mut x = rhs.col_as_slice(0).to_vec();
let x_mat = MatMut::from_column_major_slice_mut(&mut x, n, 1);
lu.solve_in_place_with_conj(faer::Conj::No, x_mat);
x
} else {
let sol_mat = Qr::new(b_f64.as_ref()).solve_lstsq(rhs);
(0..m).map(|i| sol_mat[(i, 0)]).collect()
};
let mut full = vec![T::zero(); n];
for (k, &row_i) in pattern.iter().enumerate() {
full[row_i] = T::from_real(sol[k]);
}
full
} else {
let mut bt_b = vec![vec![T::zero(); m]; m];
let mut bt_e = vec![T::zero(); m];
for r in 0..m {
for c in 0..m {
for k in 0..n {
bt_b[r][c] = bt_b[r][c] + b[r][k] * b[c][k];
}
}
for k in 0..n {
bt_e[r] = bt_e[r] + b[r][k] * e_j[k];
}
}
let mut m_vec_pattern = bt_e.clone();
for k in 0..m {
let mut max_row = k;
for r in (k + 1)..m {
if bt_b[r][k].abs() > bt_b[max_row][k].abs() {
max_row = r;
}
}
if max_row != k {
bt_b.swap(k, max_row);
m_vec_pattern.swap(k, max_row);
}
let pivot = bt_b[k][k];
if pivot.abs() < self.tol {
continue;
}
for r in (k + 1)..m {
let f = bt_b[r][k] / pivot;
for c in k..m {
bt_b[r][c] = bt_b[r][c] - f * bt_b[k][c];
}
m_vec_pattern[r] = m_vec_pattern[r] - f * m_vec_pattern[k];
}
}
for k in (0..m).rev() {
let mut sum = m_vec_pattern[k];
for c in (k + 1)..m {
sum = sum - bt_b[k][c] * m_vec_pattern[c];
}
let pivot = bt_b[k][k];
if pivot.abs() < self.tol {
m_vec_pattern[k] = T::zero();
} else {
m_vec_pattern[k] = sum / pivot;
}
}
let mut full = vec![T::zero(); n];
for (k, &row_i) in pattern.iter().enumerate() {
full[row_i] = m_vec_pattern[k];
}
full
};
for i in 0..n {
cols[j][i] = m_vec[i];
}
}
self.inv_rows = vec![vec![]; n];
for i in 0..n {
for j in 0..n {
if cols[j][i].abs() > self.tol {
self.inv_rows[i].push((j, cols[j][i]));
}
}
}
Ok(())
}
fn apply(&self, _side: crate::preconditioner::PcSide, x: &V, y: &mut V) -> Result<(), KError> {
for yi in y.as_mut().iter_mut() {
*yi = T::zero();
}
let x_ref = x.as_ref();
let y_mut = y.as_mut();
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
y_mut.par_iter_mut().enumerate().for_each(|(i, yi)| {
let mut sum = T::zero();
for &(j, mij) in &self.inv_rows[i] {
sum = sum + mij * x_ref[j];
}
*yi = sum;
});
}
#[cfg(not(feature = "rayon"))]
{
for (i, row) in self.inv_rows.iter().enumerate() {
let mut sum = T::zero();
for &(j, mij) in row {
sum = sum + mij * x_ref[j];
}
y_mut[i] = sum;
}
}
Ok(())
}
}
#[cfg(not(feature = "complex"))]
impl<M: 'static + Send + Sync> crate::preconditioner::Preconditioner for ApproxInv<M, Vec<f64>, f64>
where
M: MatVec<Vec<f64>>,
{
fn setup(&mut self, op: &dyn crate::matrix::op::LinOp<S = f64>) -> Result<(), KError> {
let csr = csr_from_linop(op, self.drop_tol)?;
let sid = op.structure_id();
let vid = op.values_id();
if self.last_sid.is_none() || self.last_sid != Some(sid) || self.last_vid != Some(vid) {
let a_dense = csr.to_dense()?;
let n = a_dense.nrows();
let n_expected = match &self.pattern {
SparsityPattern::Manual(pat) => pat.len(),
SparsityPattern::Auto => n,
};
if n != n_expected {
if let SparsityPattern::Manual(_) = &self.pattern {
return Err(KError::InvalidInput(
"ApproxInv: operator size mismatch with manual pattern".into(),
));
}
}
let mut cols = vec![vec![0.0f64; n]; n];
for j in 0..n {
let pattern_idx: Vec<usize> = match &self.pattern {
SparsityPattern::Auto => {
if let Some(rowpat) = get_row_pattern(&csr) {
rowpat.row_indices(j).to_vec()
} else {
(0..n).collect()
}
}
SparsityPattern::Manual(pat) => pat.get(j).cloned().unwrap_or_else(Vec::new),
};
let m = pattern_idx.len();
let mut b = vec![vec![0.0f64; n]; m];
for (row_idx, &col_i) in pattern_idx.iter().enumerate() {
for k in 0..n {
b[row_idx][k] = a_dense[(k, col_i)];
}
}
let rhs = faer::Mat::from_fn(n, 1, |i, _| if i == j { 1.0 } else { 0.0 });
use faer::linalg::solvers::{FullPivLu, Qr};
use faer::{Mat, MatMut};
let b_f64 = Mat::from_fn(n, m, |row, col| b[col][row]);
let sol: Vec<f64> = if m == n {
let lu = FullPivLu::new(b_f64.as_ref());
let mut x = rhs.col_as_slice(0).to_vec();
let x_mat = MatMut::from_column_major_slice_mut(&mut x, n, 1);
lu.solve_in_place_with_conj(faer::Conj::No, x_mat);
x
} else {
let sol_mat = Qr::new(b_f64.as_ref()).solve_lstsq(rhs);
(0..m).map(|i| sol_mat[(i, 0)]).collect()
};
for (k, &row_i) in pattern_idx.iter().enumerate() {
cols[j][row_i] = sol.get(k).cloned().unwrap_or(0.0);
}
}
self.inv_rows = vec![vec![]; n];
for i in 0..n {
for j in 0..n {
if cols[j][i].abs() > self.tol {
self.inv_rows[i].push((j, cols[j][i]));
}
}
}
self.csr = Some(csr);
self.last_sid = Some(sid);
self.last_vid = Some(vid);
}
Ok(())
}
fn apply(
&self,
_side: crate::preconditioner::PcSide,
x: &[f64],
y: &mut [f64],
) -> Result<(), KError> {
if x.len() != y.len() {
return Err(KError::InvalidInput(format!(
"ApproxInv.apply: x/y length mismatch: {} vs {}",
x.len(),
y.len()
)));
}
for v in y.iter_mut() {
*v = 0.0;
}
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
y.par_iter_mut().enumerate().for_each(|(i, yi)| {
let mut sum = 0.0f64;
for &(j, val) in &self.inv_rows[i] {
sum += val * x[j];
}
*yi = sum;
});
}
#[cfg(not(feature = "rayon"))]
{
let n = x.len();
for i in 0..n {
let mut sum = 0.0f64;
for &(j, val) in &self.inv_rows[i] {
sum += val * x[j];
}
y[i] = sum;
}
}
Ok(())
}
}
#[cfg(feature = "complex")]
impl<M: 'static + Send + Sync> crate::preconditioner::Preconditioner for ApproxInv<M, Vec<f64>, f64>
where
M: MatVec<Vec<f64>>,
{
fn setup(&mut self, _op: &dyn crate::matrix::op::LinOp<S = S>) -> Result<(), KError> {
Err(KError::Unsupported(
"ApproxInv does not support complex scalars yet".into(),
))
}
fn apply(&self, _side: crate::preconditioner::PcSide, _x: &[S], _y: &mut [S]) -> Result<(), KError> {
Err(KError::Unsupported(
"ApproxInv does not support complex scalars yet".into(),
))
}
}
#[cfg(feature = "complex")]
impl<M> KPreconditioner for ApproxInv<M, Vec<f64>, f64>
where
M: MatVec<Vec<f64>> + Send + Sync + 'static,
{
type Scalar = S;
#[inline]
fn dims(&self) -> (usize, usize) {
let n = self.inv_rows.len();
(n, n)
}
fn apply_s(
&self,
side: crate::preconditioner::PcSide,
x: &[S],
y: &mut [S],
scratch: &mut BridgeScratch,
) -> Result<(), KError> {
apply_pc_s(self, side, x, y, scratch)
}
}
fn get_nrows<M: 'static>(a: &M) -> Option<usize> {
use crate::core::traits::Indexing;
let any = a as &dyn std::any::Any;
if let Some(indexed) = any.downcast_ref::<&dyn Indexing>() {
Some(indexed.nrows())
} else {
any.downcast_ref::<&dyn crate::core::traits::MatShape>()
.map(|indexed| indexed.nrows())
}
}
fn get_row_pattern<M: 'static>(a: &M) -> Option<&dyn crate::core::traits::RowPattern> {
let any = a as &dyn std::any::Any;
if let Some(rowpat) = any.downcast_ref::<&dyn crate::core::traits::RowPattern>() {
Some(*rowpat)
} else {
None
}
}
#[cfg(all(test, not(feature = "complex")))]
mod tests {
use super::*;
#[cfg(feature = "complex")]
use crate::algebra::bridge::BridgeScratch;
use crate::core::traits::MatVec;
use crate::preconditioner::legacy::Preconditioner;
use approx::assert_relative_eq;
#[cfg(feature = "complex")]
use std::sync::Arc;
#[derive(Clone)]
struct DenseMat<T> {
data: Vec<Vec<T>>,
}
impl<T: KrystScalar<Real = R>> MatVec<Vec<T>> for DenseMat<T> {
fn matvec(&self, x: &Vec<T>, y: &mut Vec<T>) {
for (i, row) in self.data.iter().enumerate() {
y[i] = row
.iter()
.zip(x.iter())
.map(|(a, b)| *a * *b)
.fold(T::zero(), |acc, v| acc + v);
}
}
}
impl<T: KrystScalar<Real = R>> crate::core::traits::RowPattern for DenseMat<T> {
fn row_indices(&self, i: usize) -> &[usize] {
thread_local! {
static IDX: std::cell::RefCell<Vec<usize>> = std::cell::RefCell::new(Vec::new());
}
let row = &self.data[i];
IDX.with(|idx| {
let mut idx = idx.borrow_mut();
idx.clear();
for (j, &val) in row.iter().enumerate() {
if val != T::zero() {
idx.push(j);
}
}
unsafe { std::mem::transmute::<&[usize], &[usize]>(&*idx) }
})
}
}
fn eye<T: KrystScalar<Real = R>>(n: usize) -> DenseMat<T> {
DenseMat {
data: (0..n)
.map(|i| {
(0..n)
.map(|j| if i == j { T::one() } else { T::zero() })
.collect()
})
.collect(),
}
}
#[test]
fn approxinv_exact_inverse() {
let a = DenseMat {
data: vec![
vec![2.0, 0.0, 0.0],
vec![0.0, 3.0, 0.0],
vec![0.0, 0.0, 4.0],
],
};
let pattern = SparsityPattern::Manual(vec![vec![0], vec![1], vec![2]]);
let mut spai = ApproxInv::<DenseMat<f64>, Vec<f64>, f64>::new(
pattern, 1e-12, 10, 1, 100, 8, 1, 0, false, false,
);
spai.setup(&a).unwrap();
let inv = &spai.inv_rows;
assert_relative_eq!(inv[0][0].1, 0.5, epsilon = 1e-12);
assert_relative_eq!(inv[1][0].1, 1.0 / 3.0, epsilon = 1e-12);
assert_relative_eq!(inv[2][0].1, 0.25, epsilon = 1e-12);
}
#[test]
fn approxinv_apply_vector() {
let a = DenseMat {
data: vec![vec![4.0, 1.0], vec![2.0, 3.0]],
};
let pattern = SparsityPattern::Manual(vec![vec![0, 1], vec![0, 1]]);
let mut spai = ApproxInv::<DenseMat<f64>, Vec<f64>, f64>::new(
pattern, 1e-12, 10, 1, 100, 8, 1, 0, false, false,
);
spai.setup(&a).unwrap();
let x = vec![1.0, 2.0];
let mut y = vec![0.0, 0.0];
spai.apply(crate::preconditioner::PcSide::Left, &x, &mut y)
.unwrap();
let a_inv = faer::Mat::<f64>::from_fn(2, 2, |i, j| match (i, j) {
(0, 0) => 0.375,
(0, 1) => -0.125,
(1, 0) => -0.25,
(1, 1) => 0.5,
_ => 0.0,
});
let x_vec = faer::Mat::<f64>::from_fn(2, 1, |i, _| x[i]);
let y_expected = &a_inv * &x_vec;
assert_relative_eq!(y[0], y_expected[(0, 0)], epsilon = 2.5e-1);
assert_relative_eq!(y[1], y_expected[(1, 0)], epsilon = 2.5e-1);
}
#[test]
fn approxinv_identity() {
let a = eye::<f64>(4);
let pattern = SparsityPattern::Manual(vec![vec![0], vec![1], vec![2], vec![3]]);
let mut spai = ApproxInv::<DenseMat<f64>, Vec<f64>, f64>::new(
pattern, 1e-12, 10, 1, 100, 8, 1, 0, false, false,
);
spai.setup(&a).unwrap();
let x = vec![1.0, 2.0, 3.0, 4.0];
let mut y = vec![0.0; 4];
spai.apply(crate::preconditioner::PcSide::Left, &x, &mut y)
.unwrap();
assert_relative_eq!(x[0], y[0], epsilon = 1e-12);
assert_relative_eq!(x[1], y[1], epsilon = 1e-12);
assert_relative_eq!(x[2], y[2], epsilon = 1e-12);
assert_relative_eq!(x[3], y[3], epsilon = 1e-12);
}
#[cfg(feature = "complex")]
#[test]
fn approxinv_apply_s_matches_real_path() {
use crate::matrix::Csr;
use crate::matrix::op::CsrOp;
use crate::matrix::sparse::CsrMatrix;
let n = 3;
let pattern = SparsityPattern::Manual((0..n).map(|i| vec![i]).collect());
let mut spai = ApproxInv::<CsrOp<f64>, Vec<f64>, f64>::new(
pattern, 1e-12, 10, 1, 100, 8, 1, 0, false, false,
);
let csr = Arc::new(CsrMatrix::<f64>::identity(n));
let op: CsrOp<f64> = CsrOp::new(csr);
spai.setup(&op)
.expect("setup should succeed for identity operator");
let rhs_real = vec![1.0, 2.0, 3.0];
let mut out_real = vec![0.0; n];
crate::preconditioner::Preconditioner::apply(
&spai,
crate::preconditioner::PcSide::Left,
&rhs_real,
&mut out_real,
)
.expect("real apply should succeed");
let rhs_s: Vec<S> = rhs_real.iter().copied().map(S::from_real).collect();
let mut out_s = vec![S::zero(); n];
let mut scratch = BridgeScratch::default();
crate::ops::kpc::KPreconditioner::apply_s(
&spai,
crate::preconditioner::PcSide::Left,
&rhs_s,
&mut out_s,
&mut scratch,
)
.expect("apply_s should bridge to real implementation");
for (expected, actual) in out_real.iter().zip(out_s.iter()) {
assert_relative_eq!(*expected, actual.real(), epsilon = 1e-12);
assert_relative_eq!(0.0, actual.imag(), epsilon = 1e-12);
}
}
#[test]
fn debug_faer_lu_inverse_rows() {
use faer::linalg::solvers::FullPivLu;
use faer::{Mat, MatMut};
let a = Mat::from_fn(2, 2, |j, i| match (i, j) {
(0, 0) => 4.0,
(0, 1) => 1.0,
(1, 0) => 2.0,
(1, 1) => 3.0,
_ => 0.0,
});
let lu = FullPivLu::new(a.as_ref());
let mut inv = vec![vec![0.0; 2]; 2];
for i in 0..2 {
let mut e = vec![0.0; 2];
e[i] = 1.0;
let mut x = e.clone();
let x_mat = MatMut::from_column_major_slice_mut(&mut x, 2, 1);
lu.solve_in_place_with_conj(faer::Conj::No, x_mat);
for j in 0..2 {
inv[i][j] = x[j];
}
}
println!("faer LU inverse rows: {:?}", inv);
}
}