use crate::error::{SparseError, SparseResult};
fn dense_lu_factor(a: &mut Vec<Vec<f64>>, n: usize) -> Option<Vec<usize>> {
let mut perm: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = a[k][k].abs();
let mut max_row = k;
for i in (k + 1)..n {
if a[i][k].abs() > max_val {
max_val = a[i][k].abs();
max_row = i;
}
}
if max_val < 1e-300 {
return None;
}
a.swap(k, max_row);
perm.swap(k, max_row);
let piv = a[k][k];
for i in (k + 1)..n {
a[i][k] /= piv;
for j in (k + 1)..n {
let l = a[i][k];
a[i][j] -= l * a[k][j];
}
}
}
Some(perm)
}
fn dense_lu_solve(lu: &[Vec<f64>], perm: &[usize], b: &[f64], n: usize) -> Vec<f64> {
let mut y: Vec<f64> = perm.iter().map(|&p| b[p]).collect();
for i in 0..n {
for j in 0..i {
y[i] -= lu[i][j] * y[j];
}
}
let mut x = y;
for ii in 0..n {
let i = n - 1 - ii;
for j in (i + 1)..n {
x[i] -= lu[i][j] * x[j];
}
x[i] /= lu[i][i];
}
x
}
fn csr_matvec(
row_ptr: &[usize],
col_ind: &[usize],
val: &[f64],
x: &[f64],
n: usize,
) -> Vec<f64> {
let mut y = vec![0.0f64; n];
for i in 0..n {
let mut acc = 0.0f64;
for pos in row_ptr[i]..row_ptr[i + 1] {
if col_ind[pos] < x.len() {
acc += val[pos] * x[col_ind[pos]];
}
}
y[i] = acc;
}
y
}
fn extract_submatrix_square(
global_row_ptr: &[usize],
global_col_ind: &[usize],
global_val: &[f64],
rows: &[usize],
) -> (Vec<f64>, Vec<usize>, Vec<usize>, usize) {
let m = rows.len();
let global_n = global_row_ptr.len().saturating_sub(1);
let mut g2l = vec![usize::MAX; global_n];
for (local, &global) in rows.iter().enumerate() {
g2l[global] = local;
}
let mut val = Vec::new();
let mut col_ind = Vec::new();
let mut row_ptr = vec![0usize; m + 1];
for (local_row, &global_row) in rows.iter().enumerate() {
for pos in global_row_ptr[global_row]..global_row_ptr[global_row + 1] {
let gc = global_col_ind[pos];
if gc < global_n && g2l[gc] != usize::MAX {
col_ind.push(g2l[gc]);
val.push(global_val[pos]);
}
}
row_ptr[local_row + 1] = val.len();
}
(val, row_ptr, col_ind, m)
}
#[allow(dead_code)]
fn extract_submatrix_dense(
global_row_ptr: &[usize],
global_col_ind: &[usize],
global_val: &[f64],
rows: &[usize],
cols: &[usize],
) -> Vec<Vec<f64>> {
let nr = rows.len();
let nc = cols.len();
let mut dense = vec![vec![0.0f64; nc]; nr];
let max_col = global_col_ind.iter().copied().max().unwrap_or(0) + 1;
let mut c2l = vec![usize::MAX; max_col.max(1)];
for (lc, &gc) in cols.iter().enumerate() {
if gc < c2l.len() {
c2l[gc] = lc;
}
}
for (lr, &gr) in rows.iter().enumerate() {
for pos in global_row_ptr[gr]..global_row_ptr[gr + 1] {
let gc = global_col_ind[pos];
if gc < c2l.len() && c2l[gc] != usize::MAX {
dense[lr][c2l[gc]] = global_val[pos];
}
}
}
dense
}
fn cg_internal<F>(
matvec: F,
b: &[f64],
n: usize,
max_iter: usize,
tol: f64,
) -> Vec<f64>
where
F: Fn(&[f64]) -> Vec<f64>,
{
let mut x = vec![0.0f64; n];
let mut r: Vec<f64> = b.to_vec(); let mut p = r.clone();
let mut rr: f64 = r.iter().map(|v| v * v).sum();
let tol_sq = tol * tol * b.iter().map(|v| v * v).sum::<f64>();
for _ in 0..max_iter {
if rr <= tol_sq {
break;
}
let ap = matvec(&p);
let pap: f64 = p.iter().zip(ap.iter()).map(|(a, b)| a * b).sum();
if pap.abs() < 1e-300 {
break;
}
let alpha = rr / pap;
for i in 0..n {
x[i] += alpha * p[i];
r[i] -= alpha * ap[i];
}
let rr_new: f64 = r.iter().map(|v| v * v).sum();
let beta = rr_new / rr;
for i in 0..n {
p[i] = r[i] + beta * p[i];
}
rr = rr_new;
}
x
}
struct SubdomainSolver {
global_indices: Vec<usize>,
lu_factor: Vec<Vec<f64>>,
lu_perm: Vec<usize>,
size: usize,
}
impl SubdomainSolver {
fn new(
global_row_ptr: &[usize],
global_col_ind: &[usize],
global_val: &[f64],
global_indices: Vec<usize>,
) -> SparseResult<Self> {
let m = global_indices.len();
if m == 0 {
return Err(SparseError::InvalidArgument(
"subdomain must have at least one DOF".to_string(),
));
}
let (lval, lrp, lci, ln) = extract_submatrix_square(
global_row_ptr, global_col_ind, global_val, &global_indices,
);
let mut dense = vec![vec![0.0f64; ln]; ln];
for i in 0..ln {
for pos in lrp[i]..lrp[i + 1] {
dense[i][lci[pos]] = lval[pos];
}
}
let perm = dense_lu_factor(&mut dense, ln).ok_or_else(|| {
SparseError::SingularMatrix("subdomain local matrix is singular".to_string())
})?;
Ok(Self {
global_indices,
lu_factor: dense,
lu_perm: perm,
size: ln,
})
}
fn apply_additive(
&self,
x_global: &[f64],
b_global: &[f64],
global_row_ptr: &[usize],
global_col_ind: &[usize],
global_val: &[f64],
global_n: usize,
result: &mut Vec<f64>,
) {
let m = self.size;
let idx = &self.global_indices;
let mut r_loc = vec![0.0f64; m];
for (li, &gi) in idx.iter().enumerate() {
let mut acc = b_global[gi];
if gi < global_n {
for pos in global_row_ptr[gi]..global_row_ptr[gi + 1] {
let gc = global_col_ind[pos];
if gc < x_global.len() {
acc -= global_val[pos] * x_global[gc];
}
}
}
r_loc[li] = acc;
}
let delta = dense_lu_solve(&self.lu_factor, &self.lu_perm, &r_loc, m);
for (li, &gi) in idx.iter().enumerate() {
result[gi] += delta[li];
}
}
}
pub struct SchwarzSolver {
subdomains: Vec<SubdomainSolver>,
n: usize,
row_ptr: Vec<usize>,
col_ind: Vec<usize>,
val: Vec<f64>,
}
impl SchwarzSolver {
pub fn new(
csr_val: &[f64],
csr_row_ptr: &[usize],
csr_col_ind: &[usize],
n: usize,
n_subdomains: usize,
overlap: usize,
) -> SparseResult<Self> {
if n_subdomains == 0 || n_subdomains > n {
return Err(SparseError::InvalidArgument(format!(
"n_subdomains={n_subdomains} must be in [1, {n}]"
)));
}
let base = n / n_subdomains;
let remainder = n % n_subdomains;
let mut partitions: Vec<Vec<usize>> = Vec::with_capacity(n_subdomains);
let mut start = 0;
for s in 0..n_subdomains {
let extra = if s < remainder { 1 } else { 0 };
let end = (start + base + extra).min(n);
let lo = start.saturating_sub(overlap);
let hi = (end + overlap).min(n);
partitions.push((lo..hi).collect());
start = end;
}
let mut subdomains = Vec::with_capacity(n_subdomains);
for indices in partitions {
let sd = SubdomainSolver::new(csr_row_ptr, csr_col_ind, csr_val, indices)?;
subdomains.push(sd);
}
Ok(Self {
subdomains,
n,
row_ptr: csr_row_ptr.to_vec(),
col_ind: csr_col_ind.to_vec(),
val: csr_val.to_vec(),
})
}
pub fn apply(&self, x: &[f64], b: &[f64]) -> Vec<f64> {
let mut result = x.to_vec();
let n = self.n;
let mut delta = vec![0.0f64; n];
for sd in &self.subdomains {
sd.apply_additive(
x,
b,
&self.row_ptr,
&self.col_ind,
&self.val,
n,
&mut delta,
);
}
for i in 0..n {
result[i] += delta[i];
}
result
}
pub fn solve(&self, b: &[f64], max_iter: usize, tol: f64) -> Vec<f64> {
let n = self.n;
let rp = &self.row_ptr;
let ci = &self.col_ind;
let vl = &self.val;
let mut x = vec![0.0f64; n];
let mut r: Vec<f64> = b.to_vec(); let mut z = self.precondition(&r); let mut p = z.clone();
let mut rz: f64 = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum();
let b_norm_sq: f64 = b.iter().map(|v| v * v).sum();
let tol_sq = tol * tol * b_norm_sq;
for _ in 0..max_iter {
if r.iter().map(|v| v * v).sum::<f64>() <= tol_sq {
break;
}
let ap = csr_matvec(rp, ci, vl, &p, n);
let pap: f64 = p.iter().zip(ap.iter()).map(|(a, b)| a * b).sum();
if pap.abs() < 1e-300 {
break;
}
let alpha = rz / pap;
for i in 0..n {
x[i] += alpha * p[i];
r[i] -= alpha * ap[i];
}
z = self.precondition(&r);
let rz_new: f64 = r.iter().zip(z.iter()).map(|(a, b)| a * b).sum();
let beta = rz_new / rz;
for i in 0..n {
p[i] = z[i] + beta * p[i];
}
rz = rz_new;
}
x
}
fn precondition(&self, r: &[f64]) -> Vec<f64> {
let n = self.n;
let mut result = vec![0.0f64; n];
for sd in &self.subdomains {
let m = sd.size;
let idx = &sd.global_indices;
let r_loc: Vec<f64> = idx.iter().map(|&gi| r[gi]).collect();
let z_loc = dense_lu_solve(&sd.lu_factor, &sd.lu_perm, &r_loc, m);
for (li, &gi) in idx.iter().enumerate() {
result[gi] += z_loc[li];
}
}
result
}
pub fn size(&self) -> usize {
self.n
}
pub fn num_subdomains(&self) -> usize {
self.subdomains.len()
}
}
pub struct SchurComplementSolver {
n1: usize,
n2: usize,
}
impl SchurComplementSolver {
#[allow(clippy::too_many_arguments)]
pub fn solve(
a_val: &[f64],
a_row_ptr: &[usize],
a_col_ind: &[usize],
n1: usize,
b_val: &[f64],
b_row_ptr: &[usize],
b_col_ind: &[usize],
c_val: &[f64],
c_row_ptr: &[usize],
c_col_ind: &[usize],
d_val: &[f64],
d_row_ptr: &[usize],
d_col_ind: &[usize],
n2: usize,
f: &[f64],
g: &[f64],
max_iter: usize,
tol: f64,
) -> SparseResult<Vec<f64>> {
let a_lu_opt: Option<(Vec<Vec<f64>>, Vec<usize>)> = if n1 <= 512 {
let mut a_dense = vec![vec![0.0f64; n1]; n1];
for i in 0..n1 {
for pos in a_row_ptr[i]..a_row_ptr[i + 1] {
if a_col_ind[pos] < n1 {
a_dense[i][a_col_ind[pos]] = a_val[pos];
}
}
}
match dense_lu_factor(&mut a_dense, n1) {
Some(perm) => Some((a_dense, perm)),
None => {
return Err(SparseError::SingularMatrix(
"block A is singular".to_string(),
))
}
}
} else {
None
};
let solve_a = |rhs: &[f64]| -> Vec<f64> {
match &a_lu_opt {
Some((lu, perm)) => dense_lu_solve(lu, perm, rhs, n1),
None => cg_internal(
|v| csr_matvec(a_row_ptr, a_col_ind, a_val, v, n1),
rhs,
n1,
max_iter,
tol,
),
}
};
let x_hat = solve_a(f);
let cx_hat = csr_matvec(c_row_ptr, c_col_ind, c_val, &x_hat, n2);
let g_tilde: Vec<f64> = (0..n2).map(|i| g[i] - cx_hat[i]).collect();
let schur_mv = |y: &[f64]| -> Vec<f64> {
let by = csr_matvec(b_row_ptr, b_col_ind, b_val, y, n1);
let a_inv_by = solve_a(&by);
let dy = csr_matvec(d_row_ptr, d_col_ind, d_val, y, n2);
let ca_inv_by = csr_matvec(c_row_ptr, c_col_ind, c_val, &a_inv_by, n2);
(0..n2).map(|i| dy[i] - ca_inv_by[i]).collect()
};
let y = cg_internal(schur_mv, &g_tilde, n2, max_iter, tol);
let by = csr_matvec(b_row_ptr, b_col_ind, b_val, &y, n1);
let f_minus_by: Vec<f64> = (0..n1).map(|i| f[i] - by[i]).collect();
let x = solve_a(&f_minus_by);
let mut result = x;
result.extend_from_slice(&y);
Ok(result)
}
pub fn n1(&self) -> usize {
self.n1
}
pub fn n2(&self) -> usize {
self.n2
}
pub fn new(n1: usize, n2: usize) -> Self {
Self { n1, n2 }
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tridiag_csr(n: usize, diag: f64, off: f64) -> (Vec<f64>, Vec<usize>, Vec<usize>) {
let mut val = Vec::new();
let mut col_ind = Vec::new();
let mut row_ptr = vec![0usize; n + 1];
for i in 0..n {
if i > 0 {
col_ind.push(i - 1);
val.push(off);
}
col_ind.push(i);
val.push(diag);
if i + 1 < n {
col_ind.push(i + 1);
val.push(off);
}
row_ptr[i + 1] = val.len();
}
(val, row_ptr, col_ind)
}
#[test]
fn test_schwarz_solver_constructs() {
let n = 6;
let (val, rp, ci) = tridiag_csr(n, 4.0, -1.0);
let solver = SchwarzSolver::new(&val, &rp, &ci, n, 3, 1)
.expect("SchwarzSolver::new");
assert_eq!(solver.size(), n);
assert_eq!(solver.num_subdomains(), 3);
}
#[test]
fn test_schwarz_solver_apply() {
let n = 4;
let (val, rp, ci) = tridiag_csr(n, 4.0, -1.0);
let solver = SchwarzSolver::new(&val, &rp, &ci, n, 2, 1)
.expect("SchwarzSolver::new");
let x = vec![0.0f64; n];
let b = vec![1.0f64; n];
let x_new = solver.apply(&x, &b);
assert_eq!(x_new.len(), n);
let norm: f64 = x_new.iter().map(|v| v * v).sum::<f64>().sqrt();
assert!(norm > 1e-12, "Schwarz apply should give non-zero update");
}
#[test]
fn test_schwarz_solver_solve_converges() {
let n = 8;
let (val, rp, ci) = tridiag_csr(n, 4.0, -1.0);
let solver = SchwarzSolver::new(&val, &rp, &ci, n, 4, 1)
.expect("SchwarzSolver::new");
let mut b = vec![0.0f64; n];
b[0] = 1.0;
let x = solver.solve(&b, 200, 1e-10);
let ax = csr_matvec(&rp, &ci, &val, &x, n);
let res: f64 = b.iter().zip(ax.iter()).map(|(bi, axi)| (bi - axi).powi(2)).sum::<f64>().sqrt();
assert!(res < 1e-8, "Schwarz PCG residual {res}");
}
#[test]
fn test_schwarz_invalid_args() {
let n = 4;
let (val, rp, ci) = tridiag_csr(n, 2.0, -0.5);
let r = SchwarzSolver::new(&val, &rp, &ci, n, 0, 0);
assert!(r.is_err());
let r2 = SchwarzSolver::new(&val, &rp, &ci, n, n + 1, 0);
assert!(r2.is_err());
}
#[test]
fn test_schur_complement_2x2() {
let n1 = 2;
let n2 = 2;
let a_rp = vec![0, 1, 2];
let a_ci = vec![0, 1];
let a_v = vec![3.0, 3.0];
let b_rp = vec![0, 1, 2];
let b_ci = vec![0, 1];
let b_v = vec![-1.0, -1.0];
let c_rp = vec![0, 1, 2];
let c_ci = vec![0, 1];
let c_v = vec![-1.0, -1.0];
let d_rp = vec![0, 1, 2];
let d_ci = vec![0, 1];
let d_v = vec![2.0, 2.0];
let f = vec![1.0, 0.0];
let g = vec![0.0, 1.0];
let sol = SchurComplementSolver::solve(
&a_v, &a_rp, &a_ci, n1,
&b_v, &b_rp, &b_ci,
&c_v, &c_rp, &c_ci,
&d_v, &d_rp, &d_ci, n2,
&f, &g,
100, 1e-12,
).expect("Schur solve");
assert_eq!(sol.len(), n1 + n2);
let r0 = 3.0 * sol[0] - sol[2] - 1.0;
let r1 = 3.0 * sol[1] - sol[3] - 0.0;
let r2 = -sol[0] + 2.0 * sol[2] - 0.0;
let r3 = -sol[1] + 2.0 * sol[3] - 1.0;
for (i, r) in [r0, r1, r2, r3].iter().enumerate() {
assert!(r.abs() < 1e-8, "residual[{i}] = {r}");
}
}
}