use crate::csr::CsrMatrix;
use crate::error::{SparseError, SparseResult};
use crate::iterative_solvers::Preconditioner;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::fmt::Debug;
use std::iter::Sum;
use std::ops::AddAssign;
#[derive(Clone, Debug)]
pub struct SaddlePointSystem {
pub a: CsrMatrix<f64>,
pub b: CsrMatrix<f64>,
pub c: CsrMatrix<f64>,
pub n: usize,
pub m: usize,
}
impl SaddlePointSystem {
pub fn new(
a: CsrMatrix<f64>,
b: CsrMatrix<f64>,
c: CsrMatrix<f64>,
) -> SparseResult<Self> {
let n = a.rows();
let m = b.rows();
if a.cols() != n {
return Err(SparseError::ValueError(
"Block A must be square (n×n)".to_string(),
));
}
if b.cols() != n {
return Err(SparseError::ShapeMismatch {
expected: (m, n),
found: (b.rows(), b.cols()),
});
}
if c.rows() != m || c.cols() != m {
return Err(SparseError::ShapeMismatch {
expected: (m, m),
found: (c.rows(), c.cols()),
});
}
Ok(SaddlePointSystem { a, b, c, n, m })
}
pub fn total_dim(&self) -> usize {
self.n + self.m
}
pub fn apply(&self, x: &Array1<f64>) -> SparseResult<Array1<f64>> {
let total = self.n + self.m;
if x.len() != total {
return Err(SparseError::DimensionMismatch {
expected: total,
found: x.len(),
});
}
let u = x.slice(scirs2_core::ndarray::s![..self.n]).to_owned();
let p = x.slice(scirs2_core::ndarray::s![self.n..]).to_owned();
let bt = self.b.transpose();
let mut top = csr_matvec(&self.a, &u)?;
let bt_p = csr_matvec(&bt, &p)?;
for i in 0..self.n {
top[i] += bt_p[i];
}
let mut bot = csr_matvec(&self.b, &u)?;
let cp = csr_matvec(&self.c, &p)?;
for i in 0..self.m {
bot[i] -= cp[i];
}
let mut result = Array1::zeros(total);
for i in 0..self.n {
result[i] = top[i];
}
for i in 0..self.m {
result[self.n + i] = bot[i];
}
Ok(result)
}
}
fn csr_matvec(a: &CsrMatrix<f64>, x: &Array1<f64>) -> SparseResult<Array1<f64>> {
let (m, n) = a.shape();
if x.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: x.len(),
});
}
let mut y = Array1::zeros(m);
for i in 0..m {
let range = a.row_range(i);
let mut acc = 0.0_f64;
for pos in range {
acc += a.data[pos] * x[a.indices[pos]];
}
y[i] = acc;
}
Ok(y)
}
pub struct SchurComplementPrecond {
a_diag_inv: Array1<f64>,
s_approx_diag_inv: Array1<f64>,
n: usize,
m: usize,
}
impl SchurComplementPrecond {
pub fn new(
a: &CsrMatrix<f64>,
_b: &CsrMatrix<f64>,
s_approx: &CsrMatrix<f64>,
) -> SparseResult<Self> {
let n = a.rows();
let m = s_approx.rows();
if a.cols() != n {
return Err(SparseError::ValueError(
"Block A must be square".to_string(),
));
}
if s_approx.cols() != m {
return Err(SparseError::ValueError(
"S_approx must be square".to_string(),
));
}
let mut a_diag_inv = Array1::zeros(n);
for i in 0..n {
let d = a.get(i, i);
if d.abs() < 1e-14 {
return Err(SparseError::SingularMatrix(format!(
"Near-zero diagonal in A at index {i}"
)));
}
a_diag_inv[i] = 1.0 / d;
}
let mut s_diag_inv = Array1::zeros(m);
for i in 0..m {
let d = s_approx.get(i, i);
if d.abs() < 1e-14 {
return Err(SparseError::SingularMatrix(format!(
"Near-zero diagonal in S_approx at index {i}"
)));
}
s_diag_inv[i] = 1.0 / d;
}
Ok(Self {
a_diag_inv,
s_approx_diag_inv: s_diag_inv,
n,
m,
})
}
}
impl Preconditioner<f64> for SchurComplementPrecond {
fn apply(&self, r: &Array1<f64>) -> SparseResult<Array1<f64>> {
let total = self.n + self.m;
if r.len() != total {
return Err(SparseError::DimensionMismatch {
expected: total,
found: r.len(),
});
}
let mut z = Array1::zeros(total);
for i in 0..self.n {
z[i] = r[i] * self.a_diag_inv[i];
}
for i in 0..self.m {
z[self.n + i] = r[self.n + i] * self.s_approx_diag_inv[i];
}
Ok(z)
}
}
pub fn schur_complement_precond(
a: &CsrMatrix<f64>,
b: &CsrMatrix<f64>,
s_approx: &CsrMatrix<f64>,
) -> SparseResult<SchurComplementPrecond> {
SchurComplementPrecond::new(a, b, s_approx)
}
pub struct BlockDiagonalPrecond {
a_diag_inv: Array1<f64>,
c_diag_inv: Array1<f64>,
n: usize,
m: usize,
}
impl BlockDiagonalPrecond {
pub fn new(a: &CsrMatrix<f64>, c: &CsrMatrix<f64>) -> SparseResult<Self> {
let n = a.rows();
let m = c.rows();
if a.cols() != n {
return Err(SparseError::ValueError("A must be square".to_string()));
}
if c.cols() != m {
return Err(SparseError::ValueError("C must be square".to_string()));
}
let mut a_diag_inv = Array1::zeros(n);
for i in 0..n {
let d = a.get(i, i);
if d.abs() < 1e-14 {
return Err(SparseError::SingularMatrix(format!(
"Near-zero diagonal in A at {i}"
)));
}
a_diag_inv[i] = 1.0 / d;
}
let mut c_diag_inv = Array1::ones(m);
for i in 0..m {
let d = c.get(i, i);
if d.abs() > 1e-14 {
c_diag_inv[i] = 1.0 / d;
}
}
Ok(Self {
a_diag_inv,
c_diag_inv,
n,
m,
})
}
}
impl Preconditioner<f64> for BlockDiagonalPrecond {
fn apply(&self, r: &Array1<f64>) -> SparseResult<Array1<f64>> {
let total = self.n + self.m;
if r.len() != total {
return Err(SparseError::DimensionMismatch {
expected: total,
found: r.len(),
});
}
let mut z = Array1::zeros(total);
for i in 0..self.n {
z[i] = r[i] * self.a_diag_inv[i];
}
for i in 0..self.m {
z[self.n + i] = r[self.n + i] * self.c_diag_inv[i];
}
Ok(z)
}
}
pub fn block_diagonal_precond(
a: &CsrMatrix<f64>,
c: &CsrMatrix<f64>,
) -> SparseResult<BlockDiagonalPrecond> {
BlockDiagonalPrecond::new(a, c)
}
pub struct BlockTriangularPrecond {
a_diag_inv: Array1<f64>,
s_diag_inv: Array1<f64>,
b: CsrMatrix<f64>,
n: usize,
m: usize,
}
impl BlockTriangularPrecond {
pub fn new(
a: &CsrMatrix<f64>,
b: &CsrMatrix<f64>,
s: &CsrMatrix<f64>,
) -> SparseResult<Self> {
let n = a.rows();
let m = b.rows();
if a.cols() != n {
return Err(SparseError::ValueError("A must be square".to_string()));
}
if b.cols() != n {
return Err(SparseError::ShapeMismatch {
expected: (m, n),
found: (b.rows(), b.cols()),
});
}
if s.rows() != m || s.cols() != m {
return Err(SparseError::ShapeMismatch {
expected: (m, m),
found: (s.rows(), s.cols()),
});
}
let mut a_diag_inv = Array1::zeros(n);
for i in 0..n {
let d = a.get(i, i);
if d.abs() < 1e-14 {
return Err(SparseError::SingularMatrix(format!(
"Near-zero A diagonal at {i}"
)));
}
a_diag_inv[i] = 1.0 / d;
}
let mut s_diag_inv = Array1::zeros(m);
for i in 0..m {
let d = s.get(i, i);
if d.abs() < 1e-14 {
return Err(SparseError::SingularMatrix(format!(
"Near-zero S diagonal at {i}"
)));
}
s_diag_inv[i] = 1.0 / d;
}
Ok(Self {
a_diag_inv,
s_diag_inv,
b: b.clone(),
n,
m,
})
}
}
impl Preconditioner<f64> for BlockTriangularPrecond {
fn apply(&self, r: &Array1<f64>) -> SparseResult<Array1<f64>> {
let total = self.n + self.m;
if r.len() != total {
return Err(SparseError::DimensionMismatch {
expected: total,
found: r.len(),
});
}
let mut p = Array1::zeros(self.m);
for i in 0..self.m {
p[i] = r[self.n + i] * self.s_diag_inv[i];
}
let bt = self.b.transpose();
let bt_p = csr_matvec(&bt, &p)?;
let mut u = Array1::zeros(self.n);
for i in 0..self.n {
u[i] = (r[i] - bt_p[i]) * self.a_diag_inv[i];
}
let mut z = Array1::zeros(total);
for i in 0..self.n {
z[i] = u[i];
}
for i in 0..self.m {
z[self.n + i] = p[i];
}
Ok(z)
}
}
pub fn block_triangular_precond(
a: &CsrMatrix<f64>,
b: &CsrMatrix<f64>,
s: &CsrMatrix<f64>,
) -> SparseResult<BlockTriangularPrecond> {
BlockTriangularPrecond::new(a, b, s)
}
#[derive(Clone, Debug)]
pub struct MinresConfig {
pub max_iter: usize,
pub tol: f64,
pub verbose: bool,
}
impl Default for MinresConfig {
fn default() -> Self {
Self {
max_iter: 1000,
tol: 1e-10,
verbose: false,
}
}
}
#[derive(Clone, Debug)]
pub struct MinresResult {
pub solution: Array1<f64>,
pub n_iter: usize,
pub residual_norm: f64,
pub converged: bool,
}
pub fn minres_saddle(
system: &SaddlePointSystem,
rhs: &Array1<f64>,
precond: Option<&dyn Preconditioner<f64>>,
config: &MinresConfig,
) -> SparseResult<MinresResult> {
let n = system.total_dim();
if rhs.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: rhs.len(),
});
}
let matvec = |x: &Array1<f64>| system.apply(x);
let tol = config.tol;
let mut x = Array1::zeros(n);
let mut r = rhs.clone();
let bnorm = norm2(&r);
if bnorm < 1e-30 {
return Ok(MinresResult {
solution: x,
n_iter: 0,
residual_norm: 0.0,
converged: true,
});
}
let tolerance = tol * bnorm;
let mut z = match precond {
Some(pc) => pc.apply(&r)?,
None => r.clone(),
};
let mut beta1 = dot(&r, &z).sqrt();
if beta1 < 1e-30 {
return Ok(MinresResult {
solution: x,
n_iter: 0,
residual_norm: 0.0,
converged: true,
});
}
let mut beta_prev = beta1;
let mut v_prev = r.mapv(|v| v / beta_prev);
let mut z_prev = z.mapv(|v| v / beta_prev);
let mut alpha;
let mut beta_cur;
let mut c_bar = 1.0_f64;
let mut s_bar = 0.0_f64;
let mut phi_bar = beta1;
let mut d: Array1<f64> = Array1::zeros(n);
let mut d_bar: Array1<f64> = Array1::zeros(n);
let mut rnorm = bnorm;
for iter in 0..config.max_iter {
let az = matvec(&z_prev)?;
alpha = dot(&az, &v_prev);
let mut r_next = az.clone();
axpy_mut(&mut r_next, -alpha, &v_prev);
if iter > 0 {
}
let z_next = match precond {
Some(pc) => pc.apply(&r_next)?,
None => r_next.clone(),
};
beta_cur = dot(&r_next, &z_next).sqrt();
if beta_cur < 1e-30 {
rnorm = phi_bar.abs();
return Ok(MinresResult {
solution: x,
n_iter: iter + 1,
residual_norm: rnorm,
converged: rnorm <= tolerance,
});
}
let v_next = r_next.mapv(|v| v / beta_cur);
let z_next_norm = z_next.mapv(|v| v / beta_cur);
let alpha_bar = c_bar * alpha + s_bar * beta_cur;
let beta_bar_from_alpha = -s_bar * alpha + c_bar * beta_cur;
let rho = (alpha_bar * alpha_bar + beta_bar_from_alpha * beta_bar_from_alpha).sqrt();
if rho < 1e-30 {
rnorm = phi_bar.abs();
return Ok(MinresResult {
solution: x,
n_iter: iter + 1,
residual_norm: rnorm,
converged: rnorm <= tolerance,
});
}
let c_new = alpha_bar / rho;
let s_new = beta_bar_from_alpha / rho;
let phi = c_new * phi_bar;
let phi_bar_new = s_new * phi_bar;
let mut d_new = z_prev.clone();
axpy_mut(&mut d_new, -beta_prev, &d_bar);
axpy_mut(&mut d_new, -alpha_bar, &d);
let d_new = d_new.mapv(|v| v / rho);
axpy_mut(&mut x, phi, &d_new);
rnorm = phi_bar_new.abs();
if rnorm <= tolerance {
return Ok(MinresResult {
solution: x,
n_iter: iter + 1,
residual_norm: rnorm,
converged: true,
});
}
c_bar = c_new;
s_bar = s_new;
phi_bar = phi_bar_new;
d_bar = d;
d = d_new;
beta_prev = beta_cur;
v_prev = v_next;
z_prev = z_next_norm;
}
Ok(MinresResult {
solution: x,
n_iter: config.max_iter,
residual_norm: rnorm,
converged: rnorm <= tolerance,
})
}
#[derive(Clone, Debug, Copy)]
pub struct MeshNode {
pub x: f64,
pub y: f64,
}
#[derive(Clone, Debug, Copy)]
pub struct MeshElement {
pub n0: usize,
pub n1: usize,
pub n2: usize,
}
pub fn assemble_stokes(
mesh_nodes: &[MeshNode],
mesh_elements: &[MeshElement],
) -> SparseResult<SaddlePointSystem> {
let num_nodes = mesh_nodes.len();
let num_elements = mesh_elements.len();
if num_nodes == 0 {
return Err(SparseError::ValueError("No mesh nodes provided".to_string()));
}
if num_elements == 0 {
return Err(SparseError::ValueError(
"No mesh elements provided".to_string(),
));
}
let n = 2 * num_nodes;
let m = num_elements;
let mut a_rows = Vec::new();
let mut a_cols = Vec::new();
let mut a_vals = Vec::new();
let mut b_rows = Vec::new();
let mut b_cols = Vec::new();
let mut b_vals = Vec::new();
let stab_coeff = 1e-6_f64;
let mut c_diag = vec![0.0_f64; m];
for (elem_idx, elem) in mesh_elements.iter().enumerate() {
let idx = [elem.n0, elem.n1, elem.n2];
for &ni in &idx {
if ni >= num_nodes {
return Err(SparseError::ValueError(format!(
"Element node index {ni} out of bounds (num_nodes={num_nodes})"
)));
}
}
let x0 = mesh_nodes[idx[0]].x;
let y0 = mesh_nodes[idx[0]].y;
let x1 = mesh_nodes[idx[1]].x;
let y1 = mesh_nodes[idx[1]].y;
let x2 = mesh_nodes[idx[2]].x;
let y2 = mesh_nodes[idx[2]].y;
let area = 0.5 * ((x1 - x0) * (y2 - y0) - (x2 - x0) * (y1 - y0));
if area.abs() < 1e-15 {
return Err(SparseError::ValueError(format!(
"Degenerate element {elem_idx}: area is near-zero"
)));
}
let area_abs = area.abs();
let dxphi = [
(y1 - y2) / (2.0 * area),
(y2 - y0) / (2.0 * area),
(y0 - y1) / (2.0 * area),
];
let dyphi = [
(x2 - x1) / (2.0 * area),
(x0 - x2) / (2.0 * area),
(x1 - x0) / (2.0 * area),
];
for i_local in 0..3usize {
for j_local in 0..3usize {
let k_ij = area_abs
* (dxphi[i_local] * dxphi[j_local] + dyphi[i_local] * dyphi[j_local]);
let node_i = idx[i_local];
let node_j = idx[j_local];
a_rows.push(node_i);
a_cols.push(node_j);
a_vals.push(k_ij);
a_rows.push(num_nodes + node_i);
a_cols.push(num_nodes + node_j);
a_vals.push(k_ij);
}
}
for i_local in 0..3usize {
let node_i = idx[i_local];
b_rows.push(elem_idx);
b_cols.push(node_i);
b_vals.push(area_abs * dxphi[i_local]);
b_rows.push(elem_idx);
b_cols.push(num_nodes + node_i);
b_vals.push(area_abs * dyphi[i_local]);
}
c_diag[elem_idx] += stab_coeff * area_abs;
}
let a = CsrMatrix::from_triplets(n, n, a_rows, a_cols, a_vals)?;
let b = CsrMatrix::from_triplets(m, n, b_rows, b_cols, b_vals)?;
let c_rows: Vec<usize> = (0..m).collect();
let c_cols: Vec<usize> = (0..m).collect();
let c = CsrMatrix::from_triplets(m, m, c_rows, c_cols, c_diag)?;
SaddlePointSystem::new(a, b, c)
}
#[inline]
fn dot(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
#[inline]
fn norm2(v: &Array1<f64>) -> f64 {
dot(v, v).sqrt()
}
#[inline]
fn axpy_mut(y: &mut Array1<f64>, alpha: f64, x: &Array1<f64>) {
for (yi, &xi) in y.iter_mut().zip(x.iter()) {
*yi += alpha * xi;
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn make_simple_saddle() -> SaddlePointSystem {
let a = CsrMatrix::try_from_triplets(2, 2, &[(0, 0, 2.0), (1, 1, 2.0)]).expect("valid test setup");
let b = CsrMatrix::try_from_triplets(1, 2, &[(0, 0, 1.0), (0, 1, 1.0)]).expect("valid test setup");
let c = CsrMatrix::try_from_triplets(1, 1, &[(0, 0, 0.0)]).expect("valid test setup");
SaddlePointSystem::new(a, b, c).expect("valid test setup")
}
#[test]
fn test_saddle_point_system_apply() {
let sys = make_simple_saddle();
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = sys.apply(&x).expect("valid test setup");
assert_relative_eq!(result[0], 5.0, epsilon = 1e-12);
assert_relative_eq!(result[1], 7.0, epsilon = 1e-12);
assert_relative_eq!(result[2], 3.0, epsilon = 1e-12);
}
#[test]
fn test_block_diagonal_precond() {
let a = CsrMatrix::try_from_triplets(2, 2, &[(0, 0, 4.0), (1, 1, 2.0)]).expect("valid test setup");
let c = CsrMatrix::try_from_triplets(1, 1, &[(0, 0, 1.0)]).expect("valid test setup");
let pc = block_diagonal_precond(&a, &c).expect("valid test setup");
let r = Array1::from_vec(vec![4.0, 2.0, 1.0]);
let z = pc.apply(&r).expect("valid test setup");
assert_relative_eq!(z[0], 1.0, epsilon = 1e-12);
assert_relative_eq!(z[1], 1.0, epsilon = 1e-12);
assert_relative_eq!(z[2], 1.0, epsilon = 1e-12);
}
#[test]
fn test_block_triangular_precond() {
let a = CsrMatrix::try_from_triplets(2, 2, &[(0, 0, 2.0), (1, 1, 2.0)]).expect("valid test setup");
let b = CsrMatrix::try_from_triplets(1, 2, &[(0, 0, 1.0), (0, 1, 1.0)]).expect("valid test setup");
let s = CsrMatrix::try_from_triplets(1, 1, &[(0, 0, 1.0)]).expect("valid test setup");
let pc = block_triangular_precond(&a, &b, &s).expect("valid test setup");
let r = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let z = pc.apply(&r).expect("valid test setup");
assert!(z.len() == 3);
assert_relative_eq!(z[2], 1.0, epsilon = 1e-12);
assert_relative_eq!(z[0], 0.0, epsilon = 1e-12);
}
#[test]
fn test_minres_saddle_trivial() {
let sys = make_simple_saddle();
let rhs = Array1::from_vec(vec![2.0, 2.0, 2.0]);
let config = MinresConfig {
max_iter: 500,
tol: 1e-8,
verbose: false,
};
let result = minres_saddle(&sys, &rhs, None, &config).expect("valid test setup");
assert!(
result.converged || result.residual_norm < 1e-6,
"MINRES did not converge: residual={}",
result.residual_norm
);
}
#[test]
fn test_assemble_stokes_small() {
let nodes = vec![
MeshNode { x: 0.0, y: 0.0 },
MeshNode { x: 1.0, y: 0.0 },
MeshNode { x: 1.0, y: 1.0 },
MeshNode { x: 0.0, y: 1.0 },
];
let elements = vec![
MeshElement { n0: 0, n1: 1, n2: 2 },
MeshElement { n0: 0, n1: 2, n2: 3 },
];
let sys = assemble_stokes(&nodes, &elements).expect("valid test setup");
assert_eq!(sys.n, 8); assert_eq!(sys.m, 2); assert!(sys.a.nnz() > 0);
assert!(sys.b.nnz() > 0);
}
}