use crate::csr::CsrMatrix;
use crate::error::SparseResult;
use crate::parallel_amg::strength::StrengthGraph;
use std::sync::Arc;
fn spmv_internal(a: &CsrMatrix<f64>, x: &[f64]) -> Vec<f64> {
let (n_rows, _) = a.shape();
let mut y = vec![0.0f64; n_rows];
for i in 0..n_rows {
let mut acc = 0.0f64;
for pos in a.row_range(i) {
acc += a.data[pos] * x[a.indices[pos]];
}
y[i] = acc;
}
y
}
fn build_csr_from_triplets(
rows_t: Vec<usize>,
cols_t: Vec<usize>,
vals_t: Vec<f64>,
shape: (usize, usize),
) -> SparseResult<CsrMatrix<f64>> {
CsrMatrix::new(vals_t, rows_t, cols_t, shape)
}
fn direct_interp_row_range(
indptr: &[usize],
indices: &[usize],
data: &[f64],
splitting: &[u8],
c_map: &[usize], row_start: usize,
row_end: usize,
) -> Vec<(usize, usize, f64)> {
let mut triplets = Vec::new();
let n = splitting.len();
for i in row_start..row_end {
if i >= n {
break;
}
if splitting[i] == 1 {
let ci = c_map[i];
triplets.push((i, ci, 1.0f64));
} else {
let mut diag = 0.0f64;
let row_start_ptr = indptr[i];
let row_end_ptr = indptr[i + 1];
for pos in row_start_ptr..row_end_ptr {
if indices[pos] == i {
diag = data[pos];
break;
}
}
if diag.abs() < f64::EPSILON {
continue;
}
let mut sum_neg_a_ij_over_diag = 0.0f64;
for pos in row_start_ptr..row_end_ptr {
let j = indices[pos];
if j != i && j < splitting.len() && splitting[j] == 1 {
sum_neg_a_ij_over_diag += -data[pos];
}
}
if sum_neg_a_ij_over_diag.abs() < f64::EPSILON {
continue;
}
for pos in row_start_ptr..row_end_ptr {
let j = indices[pos];
if j != i && j < splitting.len() && splitting[j] == 1 {
let cj = c_map[j];
let weight = -data[pos] / sum_neg_a_ij_over_diag;
triplets.push((i, cj, weight));
}
}
}
}
triplets
}
pub fn parallel_direct_interpolation(
a: &CsrMatrix<f64>,
splitting: &[u8],
n_threads: usize,
) -> SparseResult<CsrMatrix<f64>> {
let n = a.shape().0;
let n_threads = n_threads.max(1);
let mut c_map = vec![usize::MAX; n];
let mut n_coarse = 0usize;
for i in 0..n {
if i < splitting.len() && splitting[i] == 1 {
c_map[i] = n_coarse;
n_coarse += 1;
}
}
if n_coarse == 0 {
return CsrMatrix::new(Vec::new(), Vec::new(), Vec::new(), (n, 0));
}
let indptr = Arc::new(a.indptr.clone());
let indices_arc = Arc::new(a.indices.clone());
let data_arc = Arc::new(a.data.clone());
let splitting_arc = Arc::new(splitting.to_vec());
let c_map_arc = Arc::new(c_map);
let chunk_size = (n + n_threads - 1) / n_threads;
let mut all_triplets: Vec<Vec<(usize, usize, f64)>> = Vec::new();
std::thread::scope(|s| {
let mut handles = Vec::new();
for t in 0..n_threads {
let row_start = t * chunk_size;
let row_end = ((t + 1) * chunk_size).min(n);
if row_start >= row_end {
continue;
}
let indptr_ref = Arc::clone(&indptr);
let indices_ref = Arc::clone(&indices_arc);
let data_ref = Arc::clone(&data_arc);
let splitting_ref = Arc::clone(&splitting_arc);
let c_map_ref = Arc::clone(&c_map_arc);
let handle = s.spawn(move || {
direct_interp_row_range(
&indptr_ref,
&indices_ref,
&data_ref,
&splitting_ref,
&c_map_ref,
row_start,
row_end,
)
});
handles.push(handle);
}
for h in handles {
if let Ok(triplets) = h.join() {
all_triplets.push(triplets);
}
}
});
let mut rows_t = Vec::new();
let mut cols_t = Vec::new();
let mut vals_t = Vec::new();
for chunk in all_triplets {
for (r, c, v) in chunk {
rows_t.push(r);
cols_t.push(c);
vals_t.push(v);
}
}
build_csr_from_triplets(rows_t, cols_t, vals_t, (n, n_coarse))
}
fn build_tentative_prolongator(
strength: &StrengthGraph,
splitting: &[u8],
c_map: &[usize],
n_coarse: usize,
) -> SparseResult<CsrMatrix<f64>> {
let n = strength.n;
let mut rows_t = Vec::new();
let mut cols_t = Vec::new();
let mut vals_t = Vec::new();
for i in 0..n {
if i >= splitting.len() {
break;
}
if splitting[i] == 1 {
rows_t.push(i);
cols_t.push(c_map[i]);
vals_t.push(1.0f64);
} else {
let nearest_c = strength.strong_influencers[i]
.iter()
.find(|&&j| j < splitting.len() && splitting[j] == 1)
.or_else(|| {
strength.strong_neighbors[i]
.iter()
.find(|&&j| j < splitting.len() && splitting[j] == 1)
})
.copied();
if let Some(c) = nearest_c {
rows_t.push(i);
cols_t.push(c_map[c]);
vals_t.push(1.0f64);
}
}
}
build_csr_from_triplets(rows_t, cols_t, vals_t, (n, n_coarse))
}
pub fn parallel_sa_interpolation(
a: &CsrMatrix<f64>,
strength: &StrengthGraph,
splitting: &[u8],
n_threads: usize,
omega: f64,
) -> SparseResult<CsrMatrix<f64>> {
let n = a.shape().0;
let n_threads = n_threads.max(1);
let mut c_map = vec![usize::MAX; n];
let mut n_coarse = 0usize;
for i in 0..n {
if i < splitting.len() && splitting[i] == 1 {
c_map[i] = n_coarse;
n_coarse += 1;
}
}
if n_coarse == 0 {
return CsrMatrix::new(Vec::new(), Vec::new(), Vec::new(), (n, 0));
}
let p0 = build_tentative_prolongator(strength, splitting, &c_map, n_coarse)?;
let mut diag_inv = vec![0.0f64; n];
for i in 0..n {
for pos in a.row_range(i) {
if a.indices[pos] == i {
let d = a.data[pos];
if d.abs() > f64::EPSILON {
diag_inv[i] = 1.0 / d;
}
break;
}
}
}
let p0_arc = Arc::new(p0);
let a_arc = Arc::new(a.clone());
let diag_inv_arc = Arc::new(diag_inv.clone());
let chunk_size = (n_coarse + n_threads - 1) / n_threads;
let mut all_cols: Vec<Vec<(usize, usize, f64)>> = Vec::new();
std::thread::scope(|s| {
let mut handles = Vec::new();
for t in 0..n_threads {
let col_start = t * chunk_size;
let col_end = ((t + 1) * chunk_size).min(n_coarse);
if col_start >= col_end {
continue;
}
let p0_ref = Arc::clone(&p0_arc);
let a_ref = Arc::clone(&a_arc);
let diag_inv_ref = Arc::clone(&diag_inv_arc);
let handle = s.spawn(move || {
let mut col_triplets = Vec::new();
let n_fine = p0_ref.shape().0;
for c in col_start..col_end {
let mut p0_col = vec![0.0f64; n_fine];
for i in 0..n_fine {
for pos in p0_ref.row_range(i) {
if p0_ref.indices[pos] == c {
p0_col[i] = p0_ref.data[pos];
break;
}
}
}
let ap0_col = spmv_internal(&a_ref, &p0_col);
for i in 0..n_fine {
let val = p0_col[i] - omega * diag_inv_ref[i] * ap0_col[i];
if val.abs() > f64::EPSILON * 10.0 {
col_triplets.push((i, c, val));
}
}
}
col_triplets
});
handles.push(handle);
}
for h in handles {
if let Ok(triplets) = h.join() {
all_cols.push(triplets);
}
}
});
let mut rows_t = Vec::new();
let mut cols_t = Vec::new();
let mut vals_t = Vec::new();
for chunk in all_cols {
for (r, c, v) in chunk {
rows_t.push(r);
cols_t.push(c);
vals_t.push(v);
}
}
build_csr_from_triplets(rows_t, cols_t, vals_t, (n, n_coarse))
}
pub fn galerkin_coarse_operator(
a: &CsrMatrix<f64>,
p: &CsrMatrix<f64>,
) -> SparseResult<CsrMatrix<f64>> {
let b = a.matmul(p)?;
let r = p.transpose();
r.matmul(&b)
}
pub fn parallel_galerkin_coarse_operator(
a: &CsrMatrix<f64>,
p: &CsrMatrix<f64>,
n_threads: usize,
) -> SparseResult<CsrMatrix<f64>> {
let n_fine = a.shape().0;
let n_coarse = p.shape().1;
let n_threads = n_threads.max(1);
if n_coarse == 0 {
return CsrMatrix::new(Vec::new(), Vec::new(), Vec::new(), (0, 0));
}
let r = p.transpose();
let a_arc = Arc::new(a.clone());
let p_arc = Arc::new(p.clone());
let r_arc = Arc::new(r);
let chunk_size = (n_coarse + n_threads - 1) / n_threads;
let mut all_triplets: Vec<Vec<(usize, usize, f64)>> = Vec::new();
std::thread::scope(|s| {
let mut handles = Vec::new();
for t in 0..n_threads {
let row_start = t * chunk_size;
let row_end = ((t + 1) * chunk_size).min(n_coarse);
if row_start >= row_end {
continue;
}
let a_ref = Arc::clone(&a_arc);
let p_ref = Arc::clone(&p_arc);
let r_ref = Arc::clone(&r_arc);
let handle = s.spawn(move || {
let mut triplets = Vec::new();
for ci in row_start..row_end {
let mut r_row: Vec<(usize, f64)> = Vec::new();
for pos in r_ref.row_range(ci) {
r_row.push((r_ref.indices[pos], r_ref.data[pos]));
}
let mut ra_row = vec![0.0f64; n_fine];
for &(fi, rval) in &r_row {
for pos in a_ref.row_range(fi) {
ra_row[a_ref.indices[pos]] += rval * a_ref.data[pos];
}
}
let mut rac_row = vec![0.0f64; n_coarse];
for (fi, &raval) in ra_row.iter().enumerate() {
if raval.abs() < f64::EPSILON * 1e-6 {
continue;
}
for pos in p_ref.row_range(fi) {
rac_row[p_ref.indices[pos]] += raval * p_ref.data[pos];
}
}
for (cj, &val) in rac_row.iter().enumerate() {
if val.abs() > f64::EPSILON * 1e-8 {
triplets.push((ci, cj, val));
}
}
}
triplets
});
handles.push(handle);
}
for h in handles {
if let Ok(triplets) = h.join() {
all_triplets.push(triplets);
}
}
});
let mut rows_t = Vec::new();
let mut cols_t = Vec::new();
let mut vals_t = Vec::new();
for chunk in all_triplets {
for (r, c, v) in chunk {
rows_t.push(r);
cols_t.push(c);
vals_t.push(v);
}
}
build_csr_from_triplets(rows_t, cols_t, vals_t, (n_coarse, n_coarse))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parallel_amg::parallel_rs::pmis_coarsening;
use crate::parallel_amg::strength::serial_strength_of_connection;
fn laplacian_1d(n: usize) -> CsrMatrix<f64> {
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut vals = Vec::new();
for i in 0..n {
rows.push(i);
cols.push(i);
vals.push(2.0f64);
}
for i in 0..n - 1 {
rows.push(i);
cols.push(i + 1);
vals.push(-1.0f64);
rows.push(i + 1);
cols.push(i);
vals.push(-1.0f64);
}
CsrMatrix::new(vals, rows, cols, (n, n)).expect("valid Laplacian")
}
#[test]
fn test_direct_interp_c_node_identity() {
let n = 8;
let a = laplacian_1d(n);
let g = serial_strength_of_connection(&a, 0.25);
let result = pmis_coarsening(&g);
let p = parallel_direct_interpolation(&a, &result.cf_splitting, 1)
.expect("direct interpolation");
let mut c_col = 0usize;
for i in 0..n {
if result.cf_splitting[i] == 1 {
let mut found = false;
for pos in p.row_range(i) {
if p.indices[pos] == c_col {
found = true;
assert!(
(p.data[pos] - 1.0).abs() < 1e-10,
"C-node {i} should map to coarse col {c_col} with weight 1.0"
);
}
}
assert!(found, "C-node {i} should have identity entry in P");
c_col += 1;
}
}
}
#[test]
fn test_direct_interp_f_node_has_c_parents() {
let n = 12;
let a = laplacian_1d(n);
let g = serial_strength_of_connection(&a, 0.25);
let result = pmis_coarsening(&g);
let p = parallel_direct_interpolation(&a, &result.cf_splitting, 1)
.expect("direct interpolation");
let n_coarse = result.c_nodes.len();
for pos in 0..p.nnz() {
assert!(
p.indices[pos] < n_coarse,
"P column index {} out of range [0, {})",
p.indices[pos],
n_coarse
);
}
}
#[test]
fn test_direct_interp_parallel() {
let n = 16;
let a = laplacian_1d(n);
let g = serial_strength_of_connection(&a, 0.25);
let result = pmis_coarsening(&g);
let p_serial = parallel_direct_interpolation(&a, &result.cf_splitting, 1)
.expect("serial direct interpolation");
let p_parallel = parallel_direct_interpolation(&a, &result.cf_splitting, 4)
.expect("parallel direct interpolation");
assert_eq!(p_serial.shape(), p_parallel.shape());
assert_eq!(p_serial.nnz(), p_parallel.nnz(), "NNZ should match");
}
#[test]
fn test_sa_interp_shape() {
let n = 16;
let a = laplacian_1d(n);
let g = serial_strength_of_connection(&a, 0.25);
let result = pmis_coarsening(&g);
let n_coarse = result.c_nodes.len();
let p = parallel_sa_interpolation(&a, &g, &result.cf_splitting, 2, 4.0 / 3.0)
.expect("SA interpolation");
let (rows, cols) = p.shape();
assert_eq!(rows, n, "P should have n_fine rows");
assert_eq!(cols, n_coarse, "P should have n_coarse columns");
}
#[test]
fn test_galerkin_operator_size() {
let n = 12;
let a = laplacian_1d(n);
let g = serial_strength_of_connection(&a, 0.25);
let result = pmis_coarsening(&g);
let n_coarse = result.c_nodes.len();
let p = parallel_direct_interpolation(&a, &result.cf_splitting, 1)
.expect("direct interpolation");
let ac = galerkin_coarse_operator(&a, &p).expect("galerkin operator");
let (rows_c, cols_c) = ac.shape();
assert_eq!(rows_c, n_coarse, "A_c should have n_coarse rows");
assert_eq!(cols_c, n_coarse, "A_c should have n_coarse columns");
}
#[test]
fn test_galerkin_spd_preserved() {
let n = 10;
let a = laplacian_1d(n);
let g = serial_strength_of_connection(&a, 0.25);
let result = pmis_coarsening(&g);
let p = parallel_direct_interpolation(&a, &result.cf_splitting, 1)
.expect("direct interpolation");
let ac = galerkin_coarse_operator(&a, &p).expect("galerkin operator");
let (nc, _) = ac.shape();
for i in 0..nc {
let mut diag = 0.0f64;
let mut off_diag_sum = 0.0f64;
for pos in ac.row_range(i) {
if ac.indices[pos] == i {
diag = ac.data[pos];
} else {
off_diag_sum += ac.data[pos].abs();
}
}
assert!(
diag > 0.0,
"Coarse diagonal should be positive (got {diag} at row {i})"
);
assert!(
diag >= off_diag_sum - 1e-10,
"Coarse matrix not diagonally dominant at row {i}: diag={diag}, off={off_diag_sum}"
);
}
}
}