use std::collections::BinaryHeap;
use crate::parallel_solver::CsrMatrix;
pub fn strong_connections(a: &CsrMatrix, theta: f64) -> Vec<Vec<usize>> {
let n = a.nrows;
let mut strong = vec![Vec::new(); n];
for (i, strong_row) in strong.iter_mut().enumerate() {
let rs = a.row_offsets[i];
let re = a.row_offsets[i + 1];
let mut max_neg = 0.0f64;
for k in rs..re {
let j = a.col_indices[k];
if j != i {
let neg_val = -a.values[k];
if neg_val > max_neg {
max_neg = neg_val;
}
}
}
if max_neg < 1e-300 {
continue;
}
let threshold = theta * max_neg;
for k in rs..re {
let j = a.col_indices[k];
if j != i && -a.values[k] >= threshold {
strong_row.push(j);
}
}
}
strong
}
pub fn cf_splitting(strong: &[Vec<usize>]) -> Vec<bool> {
let n = strong.len();
let mut strong_of: Vec<Vec<usize>> = vec![Vec::new(); n];
for (i, row) in strong.iter().enumerate() {
for &j in row {
strong_of[j].push(i);
}
}
let mut lambda: Vec<usize> = (0..n)
.map(|i| strong[i].len() + strong_of[i].len())
.collect();
let mut state: Vec<u8> = vec![0; n];
let mut heap: BinaryHeap<(usize, usize)> = (0..n).map(|i| (lambda[i], i)).collect();
while let Some((lam_stored, i)) = heap.pop() {
if state[i] != 0 {
continue;
}
if lam_stored != lambda[i] {
heap.push((lambda[i], i));
continue;
}
state[i] = 1;
for &j in &strong[i] {
if state[j] == 0 {
state[j] = 2;
for &k in &strong_of[j] {
if state[k] == 0 {
lambda[k] += 1;
heap.push((lambda[k], k));
}
}
}
}
}
for s in state.iter_mut() {
if *s == 0 {
*s = 1;
}
}
for i in 0..n {
if state[i] == 2 {
let has_c_neighbor = strong[i].iter().any(|&j| state[j] == 1);
if !has_c_neighbor {
state[i] = 1; }
}
}
state.iter().map(|&s| s == 1).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_1d_poisson(n: usize) -> CsrMatrix {
let mut row_offsets = vec![0usize; n + 1];
let mut col_indices = Vec::new();
let mut values = Vec::new();
for i in 0..n {
if i > 0 {
col_indices.push(i - 1);
values.push(-1.0);
}
col_indices.push(i);
values.push(2.0);
if i + 1 < n {
col_indices.push(i + 1);
values.push(-1.0);
}
row_offsets[i + 1] = col_indices.len();
}
CsrMatrix {
nrows: n,
ncols: n,
row_offsets,
col_indices,
values,
}
}
#[test]
fn test_strong_connection_1d_poisson() {
let n = 5;
let a = make_1d_poisson(n);
let theta = 0.25;
let strong = strong_connections(&a, theta);
for (i, nbrs) in strong.iter().enumerate().take(n) {
assert!(!nbrs.contains(&i), "node {i} has self-connection");
if i > 0 {
assert!(
nbrs.contains(&(i - 1)),
"node {i} missing strong connection to {}",
i - 1
);
}
if i + 1 < n {
assert!(
nbrs.contains(&(i + 1)),
"node {i} missing strong connection to {}",
i + 1
);
}
}
}
#[test]
fn test_cf_splitting_has_c_neighbor() {
let n = 10;
let a = make_1d_poisson(n);
let strong = strong_connections(&a, 0.25);
let is_c = cf_splitting(&strong);
for i in 0..n {
if !is_c[i] {
let has_c = strong[i].iter().any(|&j| is_c[j]);
assert!(
has_c,
"F-point {i} has no C-point among its strong neighbors {:?}",
strong[i]
);
}
}
assert!(is_c.iter().any(|&c| c), "No C-points found");
assert!(is_c.iter().any(|&c| !c), "No F-points found");
}
#[test]
fn test_cf_splitting_heap_ordering() {
let strong = vec![vec![1usize, 2], vec![0usize], vec![0usize]];
let is_c = cf_splitting(&strong);
assert!(is_c[0], "Center node should be C-point");
for i in [1usize, 2] {
if !is_c[i] {
let has_c = strong[i].iter().any(|&j| is_c[j]);
assert!(has_c, "F-point {i} has no C-point");
}
}
}
}