use crate::csr::CsrMatrix;
use crate::error::{SparseError, SparseResult};
use crate::linalg::interface::LinearOperator;
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::fmt::Debug;
use std::iter::Sum;
pub struct SpaiPreconditioner<F> {
approx_inverse: CsrMatrix<F>,
}
pub struct SpaiOptions {
pub max_nnz_per_col: usize,
pub ls_tolerance: f64,
pub max_ls_iters: usize,
}
impl Default for SpaiOptions {
fn default() -> Self {
Self {
max_nnz_per_col: 10,
ls_tolerance: 1e-10,
max_ls_iters: 100,
}
}
}
impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> SpaiPreconditioner<F> {
pub fn new(matrix: &CsrMatrix<F>, options: SpaiOptions) -> SparseResult<Self> {
let n = matrix.rows();
if n != matrix.cols() {
return Err(SparseError::DimensionMismatch {
expected: n,
found: matrix.cols(),
});
}
let mut m_dense = vec![vec![F::sparse_zero(); n]; n];
for (i, row) in m_dense.iter_mut().enumerate().take(n) {
row[i] = F::sparse_one();
}
for j in 0..n {
let mut pattern = vec![j];
let start = j.saturating_sub(2);
let end = (j + 3).min(n);
for k in start..end {
if k != j && pattern.len() < options.max_nnz_per_col {
pattern.push(k);
}
}
let k = pattern.len();
let mut a_k = vec![vec![F::sparse_zero(); k]; n];
for (col_idx, &col) in pattern.iter().enumerate() {
for (row, a_k_row) in a_k.iter_mut().enumerate().take(n) {
let val = matrix.get(row, col);
a_k_row[col_idx] = val;
}
}
let mut e_j = vec![F::sparse_zero(); n];
e_j[j] = F::sparse_one();
let mut ata = vec![vec![F::sparse_zero(); k]; k];
for i in 0..k {
for j_inner in 0..k {
let mut sum = F::sparse_zero();
for a_k_row in a_k.iter().take(n) {
sum += a_k_row[i] * a_k_row[j_inner];
}
ata[i][j_inner] = sum;
}
}
let mut atb = vec![F::sparse_zero(); k];
atb[..k].copy_from_slice(&a_k[j][..k]);
let m_k = solve_dense_system(&ata, &atb)?;
for (idx, &row) in pattern.iter().enumerate() {
m_dense[row][j] = m_k[idx];
}
}
let n = m_dense.len();
let mut data = Vec::new();
let mut indices = Vec::new();
let mut indptr = vec![0];
for row in m_dense.iter().take(n) {
for (j, &val) in row.iter().enumerate().take(n) {
if val.abs() > F::epsilon() {
data.push(val);
indices.push(j);
}
}
indptr.push(data.len());
}
let approx_inverse = CsrMatrix::from_raw_csr(data, indptr, indices, (n, n))?;
Ok(Self { approx_inverse })
}
}
impl<F: Float + NumAssign + Sum + Debug + SparseElement + 'static> LinearOperator<F>
for SpaiPreconditioner<F>
{
fn shape(&self) -> (usize, usize) {
self.approx_inverse.shape()
}
fn matvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
if x.len() != self.approx_inverse.cols() {
return Err(SparseError::DimensionMismatch {
expected: self.approx_inverse.cols(),
found: x.len(),
});
}
let mut result = vec![F::sparse_zero(); self.approx_inverse.rows()];
for (row_idx, result_val) in result.iter_mut().enumerate() {
for j in self.approx_inverse.indptr[row_idx]..self.approx_inverse.indptr[row_idx + 1] {
let col_idx = self.approx_inverse.indices[j];
*result_val += self.approx_inverse.data[j] * x[col_idx];
}
}
Ok(result)
}
fn has_adjoint(&self) -> bool {
true
}
fn rmatvec(&self, x: &[F]) -> SparseResult<Vec<F>> {
if x.len() != self.approx_inverse.rows() {
return Err(SparseError::DimensionMismatch {
expected: self.approx_inverse.rows(),
found: x.len(),
});
}
let mut result = vec![F::sparse_zero(); self.approx_inverse.cols()];
for (row_idx, &x_val) in x.iter().enumerate() {
for j in self.approx_inverse.indptr[row_idx]..self.approx_inverse.indptr[row_idx + 1] {
let col_idx = self.approx_inverse.indices[j];
result[col_idx] += self.approx_inverse.data[j] * x_val;
}
}
Ok(result)
}
}
#[allow(dead_code)]
fn solve_dense_system<F: Float + NumAssign + SparseElement>(
a: &[Vec<F>],
b: &[F],
) -> SparseResult<Vec<F>> {
let n = a.len();
if n == 0 || n != a[0].len() || n != b.len() {
return Err(SparseError::DimensionMismatch {
expected: n,
found: b.len(),
});
}
let mut aug = vec![vec![F::sparse_zero(); n + 1]; n];
for i in 0..n {
for j in 0..n {
aug[i][j] = a[i][j];
}
aug[i][n] = b[i];
}
for k in 0..n {
let mut max_row = k;
let mut max_val = aug[k][k].abs();
for (i, aug_row) in aug.iter().enumerate().skip(k + 1).take(n - k - 1) {
let val_abs = aug_row[k].abs();
if val_abs > max_val {
max_val = val_abs;
max_row = i;
}
}
if max_val < F::from(1e-14).expect("Failed to convert constant to float") {
return Err(SparseError::ValueError(
"Matrix is singular or nearly singular".to_string(),
));
}
if max_row != k {
aug.swap(k, max_row);
}
for i in (k + 1)..n {
let factor = aug[i][k] / aug[k][k];
for j in k..=n {
aug[i][j] = aug[i][j] - factor * aug[k][j];
}
}
}
let mut x = vec![F::sparse_zero(); n];
for i in (0..n).rev() {
x[i] = aug[i][n];
for j in (i + 1)..n {
x[i] = x[i] - aug[i][j] * x[j];
}
x[i] /= aug[i][i];
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr::CsrMatrix;
#[test]
fn test_spai_simple() {
let data = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
let indptr = vec![0, 2, 5, 7];
let indices = vec![0, 1, 0, 1, 2, 1, 2];
let matrix =
CsrMatrix::from_raw_csr(data, indptr, indices, (3, 3)).expect("Operation failed");
let options = SpaiOptions::default();
let preconditioner = SpaiPreconditioner::new(&matrix, options).expect("Operation failed");
let b = vec![1.0, 2.0, 3.0];
let x = preconditioner.matvec(&b).expect("Operation failed");
assert!(x.iter().all(|&xi| xi.is_finite()));
}
#[test]
fn test_spai_diagonal() {
let data = vec![2.0, 3.0, 4.0];
let indptr = vec![0, 1, 2, 3];
let indices = vec![0, 1, 2];
let matrix =
CsrMatrix::from_raw_csr(data, indptr, indices, (3, 3)).expect("Operation failed");
let options = SpaiOptions::default();
let preconditioner = SpaiPreconditioner::new(&matrix, options).expect("Operation failed");
let e1 = vec![1.0, 0.0, 0.0];
let e2 = vec![0.0, 1.0, 0.0];
let e3 = vec![0.0, 0.0, 1.0];
let x1 = preconditioner.matvec(&e1).expect("Operation failed");
let x2 = preconditioner.matvec(&e2).expect("Operation failed");
let x3 = preconditioner.matvec(&e3).expect("Operation failed");
assert!((x1[0] - 0.5).abs() < 1e-10);
assert!((x2[1] - 1.0 / 3.0).abs() < 1e-10);
assert!((x3[2] - 0.25).abs() < 1e-10);
}
}