use std::time::Instant;
use tracing::{debug, trace};
use crate::error::{SolverError, ValidationError};
use crate::traits::SolverEngine;
use crate::types::{
Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
SolverResult, SparsityProfile,
};
const STRONG_THRESHOLD: f64 = 0.25;
const SMOOTH_STEPS: usize = 3;
const COARSEST_DIRECT_LIMIT: usize = 100;
const TARGET_AGGREGATE_SIZE: usize = 4;
const NEAR_ZERO_F64: f64 = 1e-15;
#[derive(Debug, Clone)]
pub struct BmsspSolver {
tolerance: f64,
max_iterations: usize,
max_levels: usize,
coarsening_ratio: f64,
}
impl BmsspSolver {
pub fn new(tolerance: f64, max_iterations: usize) -> Self {
Self {
tolerance,
max_iterations,
max_levels: 20,
coarsening_ratio: 0.5,
}
}
pub fn with_params(
tolerance: f64,
max_iterations: usize,
max_levels: usize,
coarsening_ratio: f64,
) -> Self {
Self {
tolerance,
max_iterations,
max_levels,
coarsening_ratio,
}
}
}
struct MultigridLevel {
matrix: CsrMatrix<f64>,
prolongation: CsrMatrix<f64>,
restriction: CsrMatrix<f64>,
}
struct MultigridHierarchy {
levels: Vec<MultigridLevel>,
coarsest_size: usize,
}
fn build_hierarchy(
matrix: &CsrMatrix<f64>,
max_levels: usize,
coarsening_ratio: f64,
) -> Result<MultigridHierarchy, SolverError> {
let mut levels: Vec<MultigridLevel> = Vec::new();
let mut current = matrix.clone();
for lvl in 0..max_levels {
let n = current.rows;
if n <= COARSEST_DIRECT_LIMIT {
debug!(level = lvl, size = n, "coarsest level reached");
break;
}
let target_coarse = ((n as f64) * coarsening_ratio).max(1.0) as usize;
let target_coarse = target_coarse.max(1);
let aggregates = build_aggregates(¤t, target_coarse);
let num_aggregates = aggregates.iter().copied().max().map_or(0, |m| m + 1);
if num_aggregates == 0 || num_aggregates >= n {
debug!(
level = lvl,
n, num_aggregates, "coarsening stalled, stopping hierarchy build"
);
break;
}
let prolongation = build_prolongation(n, num_aggregates, &aggregates);
let restriction = transpose_csr(&prolongation);
let ap = sparse_matmul(¤t, &prolongation);
let coarse_matrix = sparse_matmul(&restriction, &ap);
trace!(
level = lvl,
fine = n,
coarse = coarse_matrix.rows,
nnz = coarse_matrix.nnz(),
"built multigrid level"
);
levels.push(MultigridLevel {
matrix: current,
prolongation,
restriction,
});
current = coarse_matrix;
}
let coarsest_size = current.rows;
levels.push(MultigridLevel {
matrix: current,
prolongation: empty_csr(),
restriction: empty_csr(),
});
debug!(
total_levels = levels.len(),
coarsest_size, "multigrid hierarchy built"
);
Ok(MultigridHierarchy {
levels,
coarsest_size,
})
}
#[inline]
fn empty_csr() -> CsrMatrix<f64> {
CsrMatrix {
row_ptr: vec![0],
col_indices: Vec::new(),
values: Vec::new(),
rows: 0,
cols: 0,
}
}
fn build_aggregates(matrix: &CsrMatrix<f64>, target_coarse: usize) -> Vec<usize> {
let n = matrix.rows;
let mut aggregate_id = vec![usize::MAX; n];
let mut current_agg: usize = 0;
let max_off_diag: Vec<f64> = (0..n)
.map(|i| {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let mut max_val: f64 = 0.0;
for idx in start..end {
if matrix.col_indices[idx] != i {
max_val = max_val.max(matrix.values[idx].abs());
}
}
max_val
})
.collect();
for seed in 0..n {
if aggregate_id[seed] != usize::MAX {
continue;
}
aggregate_id[seed] = current_agg;
let mut agg_size = 1usize;
let threshold = STRONG_THRESHOLD * max_off_diag[seed];
let start = matrix.row_ptr[seed];
let end = matrix.row_ptr[seed + 1];
for idx in start..end {
let j = matrix.col_indices[idx];
if j == seed || aggregate_id[j] != usize::MAX {
continue;
}
if matrix.values[idx].abs() >= threshold {
aggregate_id[j] = current_agg;
agg_size += 1;
if agg_size >= TARGET_AGGREGATE_SIZE {
break;
}
}
}
current_agg += 1;
if current_agg >= target_coarse && seed > n / 2 {
break;
}
}
for i in 0..n {
if aggregate_id[i] != usize::MAX {
continue;
}
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let mut best_agg = if current_agg > 0 { current_agg - 1 } else { 0 };
let mut best_strength: f64 = -1.0;
for idx in start..end {
let j = matrix.col_indices[idx];
if j != i && aggregate_id[j] != usize::MAX {
let strength = matrix.values[idx].abs();
if strength > best_strength {
best_strength = strength;
best_agg = aggregate_id[j];
}
}
}
if best_strength < 0.0 {
aggregate_id[i] = current_agg;
current_agg += 1;
} else {
aggregate_id[i] = best_agg;
}
}
aggregate_id
}
fn build_prolongation(
fine_rows: usize,
coarse_cols: usize,
aggregates: &[usize],
) -> CsrMatrix<f64> {
let mut row_ptr = Vec::with_capacity(fine_rows + 1);
let mut col_indices = Vec::with_capacity(fine_rows);
let mut values = Vec::with_capacity(fine_rows);
row_ptr.push(0);
for &agg in aggregates.iter().take(fine_rows) {
col_indices.push(agg);
values.push(1.0f64);
row_ptr.push(col_indices.len());
}
CsrMatrix {
row_ptr,
col_indices,
values,
rows: fine_rows,
cols: coarse_cols,
}
}
fn transpose_csr(a: &CsrMatrix<f64>) -> CsrMatrix<f64> {
let (m, n, nnz) = (a.rows, a.cols, a.nnz());
let mut row_counts = vec![0usize; n];
for &c in &a.col_indices {
row_counts[c] += 1;
}
let mut row_ptr = vec![0usize; n + 1];
for i in 0..n {
row_ptr[i + 1] = row_ptr[i] + row_counts[i];
}
let mut col_indices = vec![0usize; nnz];
let mut values = vec![0.0f64; nnz];
let mut offset = vec![0usize; n];
for i in 0..m {
let start = a.row_ptr[i];
let end = a.row_ptr[i + 1];
for idx in start..end {
let j = a.col_indices[idx];
let pos = row_ptr[j] + offset[j];
col_indices[pos] = i;
values[pos] = a.values[idx];
offset[j] += 1;
}
}
CsrMatrix {
row_ptr,
col_indices,
values,
rows: n,
cols: m,
}
}
fn sparse_matmul(a: &CsrMatrix<f64>, b: &CsrMatrix<f64>) -> CsrMatrix<f64> {
assert_eq!(
a.cols, b.rows,
"sparse_matmul: dimension mismatch {}x{} * {}x{}",
a.rows, a.cols, b.rows, b.cols
);
let m = a.rows;
let n = b.cols;
let mut row_ptr = Vec::with_capacity(m + 1);
let mut col_indices = Vec::new();
let mut values = Vec::new();
let mut acc = vec![0.0f64; n];
let mut nz_cols: Vec<usize> = Vec::new();
row_ptr.push(0);
for i in 0..m {
let a_start = a.row_ptr[i];
let a_end = a.row_ptr[i + 1];
for a_idx in a_start..a_end {
let k = a.col_indices[a_idx];
let a_val = a.values[a_idx];
let b_start = b.row_ptr[k];
let b_end = b.row_ptr[k + 1];
for b_idx in b_start..b_end {
let j = b.col_indices[b_idx];
if acc[j] == 0.0 {
nz_cols.push(j);
}
acc[j] += a_val * b.values[b_idx];
}
}
nz_cols.sort_unstable();
for &j in &nz_cols {
let v = acc[j];
if v.abs() > f64::EPSILON {
col_indices.push(j);
values.push(v);
}
acc[j] = 0.0;
}
nz_cols.clear();
row_ptr.push(col_indices.len());
}
CsrMatrix {
row_ptr,
col_indices,
values,
rows: m,
cols: n,
}
}
#[inline]
fn gauss_seidel_sweep(matrix: &CsrMatrix<f64>, x: &mut [f64], b: &[f64]) {
let n = matrix.rows;
for i in 0..n {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
let mut sigma = 0.0f64;
let mut diag = 0.0f64;
for idx in start..end {
let j = matrix.col_indices[idx];
let v = matrix.values[idx];
if j == i {
diag = v;
} else {
sigma += v * x[j];
}
}
if diag.abs() > NEAR_ZERO_F64 {
x[i] = (b[i] - sigma) / diag;
}
}
}
fn dense_direct_solve(matrix: &CsrMatrix<f64>, b: &[f64]) -> Vec<f64> {
let n = matrix.rows;
if n == 0 {
return Vec::new();
}
let stride = n + 1;
let mut aug = vec![0.0f64; n * stride];
for i in 0..n {
aug[i * stride + n] = b[i];
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
for idx in start..end {
let j = matrix.col_indices[idx];
aug[i * stride + j] = matrix.values[idx];
}
}
for col in 0..n {
let mut max_val = aug[col * stride + col].abs();
let mut max_row = col;
for row in (col + 1)..n {
let v = aug[row * stride + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_row != col {
let (first, second) = if col < max_row {
let (left, right) = aug.split_at_mut(max_row * stride);
(
&mut left[col * stride..col * stride + stride],
&mut right[..stride],
)
} else {
let (left, right) = aug.split_at_mut(col * stride);
(
&mut right[..stride],
&mut left[max_row * stride..max_row * stride + stride],
)
};
first.swap_with_slice(second);
}
let pivot = aug[col * stride + col];
if pivot.abs() < NEAR_ZERO_F64 {
continue;
}
for row in (col + 1)..n {
let factor = aug[row * stride + col] / pivot;
aug[row * stride + col] = 0.0;
for k in (col + 1)..stride {
aug[row * stride + k] -= factor * aug[col * stride + k];
}
}
}
let mut x = vec![0.0f64; n];
for i in (0..n).rev() {
let mut sum = aug[i * stride + n];
for j in (i + 1)..n {
sum -= aug[i * stride + j] * x[j];
}
let diag = aug[i * stride + i];
if diag.abs() > NEAR_ZERO_F64 {
x[i] = sum / diag;
}
}
x
}
fn v_cycle(hierarchy: &MultigridHierarchy, x: &mut [f64], b: &[f64], level: usize) {
let num_levels = hierarchy.levels.len();
let mat = &hierarchy.levels[level].matrix;
let n = mat.rows;
if level == num_levels - 1 || n <= COARSEST_DIRECT_LIMIT {
let sol = dense_direct_solve(mat, b);
x[..n].copy_from_slice(&sol);
return;
}
let prol = &hierarchy.levels[level].prolongation;
let rest = &hierarchy.levels[level].restriction;
for _ in 0..SMOOTH_STEPS {
gauss_seidel_sweep(mat, x, b);
}
let mut ax = vec![0.0f64; n];
mat.spmv(x, &mut ax);
let residual: Vec<f64> = (0..n).map(|i| b[i] - ax[i]).collect();
let coarse_n = rest.rows;
let mut r_coarse = vec![0.0f64; coarse_n];
rest.spmv(&residual, &mut r_coarse);
let mut e_coarse = vec![0.0f64; coarse_n];
v_cycle(hierarchy, &mut e_coarse, &r_coarse, level + 1);
let mut correction = vec![0.0f64; n];
prol.spmv(&e_coarse, &mut correction);
for i in 0..n {
x[i] += correction[i];
}
for _ in 0..SMOOTH_STEPS {
gauss_seidel_sweep(mat, x, b);
}
}
fn validate_inputs(matrix: &CsrMatrix<f64>, rhs: &[f64]) -> Result<(), SolverError> {
if matrix.rows == 0 || matrix.cols == 0 {
return Err(SolverError::InvalidInput(
ValidationError::DimensionMismatch("matrix has zero dimension".into()),
));
}
if matrix.rows != matrix.cols {
return Err(SolverError::InvalidInput(
ValidationError::DimensionMismatch(format!(
"BMSSP requires square matrix, got {}x{}",
matrix.rows, matrix.cols
)),
));
}
if rhs.len() != matrix.rows {
return Err(SolverError::InvalidInput(
ValidationError::DimensionMismatch(format!(
"RHS length {} does not match matrix dimension {}",
rhs.len(),
matrix.rows
)),
));
}
if matrix.row_ptr.len() != matrix.rows + 1 {
return Err(SolverError::InvalidInput(
ValidationError::DimensionMismatch(format!(
"row_ptr length {} != rows + 1 = {}",
matrix.row_ptr.len(),
matrix.rows + 1
)),
));
}
for (idx, &v) in matrix.values.iter().enumerate() {
if !v.is_finite() {
return Err(SolverError::InvalidInput(ValidationError::NonFiniteValue(
format!("matrix value at index {idx}"),
)));
}
}
for (idx, &v) in rhs.iter().enumerate() {
if !v.is_finite() {
return Err(SolverError::InvalidInput(ValidationError::NonFiniteValue(
format!("RHS value at index {idx}"),
)));
}
}
Ok(())
}
fn residual_l2(matrix: &CsrMatrix<f64>, x: &[f64], b: &[f64]) -> f64 {
let n = matrix.rows;
let mut ax = vec![0.0f64; n];
matrix.spmv(x, &mut ax);
let mut sum_sq = 0.0f64;
for i in 0..n {
let r = b[i] - ax[i];
sum_sq += r * r;
}
sum_sq.sqrt()
}
impl SolverEngine for BmsspSolver {
fn solve(
&self,
matrix: &CsrMatrix<f64>,
rhs: &[f64],
budget: &ComputeBudget,
) -> Result<SolverResult, SolverError> {
validate_inputs(matrix, rhs)?;
let start = Instant::now();
let n = matrix.rows;
let tol = self.tolerance.min(budget.tolerance);
let max_iter = self.max_iterations.min(budget.max_iterations);
if n == 1 {
let diag = if matrix.nnz() > 0 {
matrix.values[0]
} else {
0.0
};
if diag.abs() <= f64::EPSILON {
return Err(SolverError::NumericalInstability {
iteration: 0,
detail: "1x1 system with zero diagonal".into(),
});
}
return Ok(SolverResult {
solution: vec![(rhs[0] / diag) as f32],
iterations: 0,
residual_norm: 0.0,
wall_time: start.elapsed(),
convergence_history: Vec::new(),
algorithm: Algorithm::BMSSP,
});
}
if n <= COARSEST_DIRECT_LIMIT {
debug!(n, "small system, using direct solve");
let sol = dense_direct_solve(matrix, rhs);
let res = residual_l2(matrix, &sol, rhs);
return Ok(SolverResult {
solution: sol.iter().map(|&v| v as f32).collect(),
iterations: 0,
residual_norm: res,
wall_time: start.elapsed(),
convergence_history: Vec::new(),
algorithm: Algorithm::BMSSP,
});
}
let hierarchy = build_hierarchy(matrix, self.max_levels, self.coarsening_ratio)?;
trace!(
levels = hierarchy.levels.len(),
coarsest = hierarchy.coarsest_size,
"AMG hierarchy ready, starting V-cycle iterations"
);
let mut x = vec![0.0f64; n];
let mut ax_buf = vec![0.0f64; n];
let mut convergence_history = Vec::with_capacity(max_iter);
let b_norm = {
let mut s = 0.0f64;
for &v in rhs {
s += v * v;
}
s.sqrt()
};
if b_norm < tol {
return Ok(SolverResult {
solution: vec![0.0f32; n],
iterations: 0,
residual_norm: b_norm,
wall_time: start.elapsed(),
convergence_history: Vec::new(),
algorithm: Algorithm::BMSSP,
});
}
for iter in 0..max_iter {
if start.elapsed() > budget.max_time {
return Err(SolverError::BudgetExhausted {
reason: format!(
"wall-time limit {}ms exceeded at iteration {iter}",
budget.max_time.as_millis()
),
elapsed: start.elapsed(),
});
}
v_cycle(&hierarchy, &mut x, rhs, 0);
matrix.spmv(&x, &mut ax_buf);
let res = (0..n)
.map(|i| {
let r = rhs[i] - ax_buf[i];
r * r
})
.sum::<f64>()
.sqrt();
convergence_history.push(ConvergenceInfo {
iteration: iter,
residual_norm: res,
});
trace!(iteration = iter, residual = res, "V-cycle completed");
if !res.is_finite() {
return Err(SolverError::NumericalInstability {
iteration: iter,
detail: "residual became NaN or Inf during V-cycle".into(),
});
}
if res < tol {
debug!(iterations = iter + 1, residual = res, "BMSSP converged");
return Ok(SolverResult {
solution: x.iter().map(|&v| v as f32).collect(),
iterations: iter + 1,
residual_norm: res,
wall_time: start.elapsed(),
convergence_history,
algorithm: Algorithm::BMSSP,
});
}
}
let final_residual = convergence_history
.last()
.map_or(b_norm, |c| c.residual_norm);
Err(SolverError::NonConvergence {
iterations: max_iter,
residual: final_residual,
tolerance: tol,
})
}
fn estimate_complexity(&self, profile: &SparsityProfile, n: usize) -> ComplexityEstimate {
let log_n = ((n as f64).ln().max(1.0)) as u64;
let nnz = profile.nnz as u64;
let estimated_iters = log_n as usize;
let flops_per_iter = nnz * 6; let total_flops = flops_per_iter * log_n;
let mem = profile.nnz * 16 + n * 8;
ComplexityEstimate {
algorithm: Algorithm::BMSSP,
estimated_flops: total_flops,
estimated_iterations: estimated_iters,
estimated_memory_bytes: mem,
complexity_class: ComplexityClass::SublinearNnz,
}
}
fn algorithm(&self) -> Algorithm {
Algorithm::BMSSP
}
}
#[cfg(test)]
mod tests {
use super::*;
fn laplacian_1d(n: usize) -> CsrMatrix<f64> {
let mut entries = Vec::new();
for i in 0..n {
entries.push((i, i, 2.0f64));
if i > 0 {
entries.push((i, i - 1, -1.0));
}
if i + 1 < n {
entries.push((i, i + 1, -1.0));
}
}
CsrMatrix::<f64>::from_coo(n, n, entries)
}
fn diag_dominant_3x3() -> CsrMatrix<f64> {
CsrMatrix::<f64>::from_coo(
3,
3,
vec![
(0, 0, 4.0),
(0, 1, -1.0),
(0, 2, -1.0),
(1, 0, -1.0),
(1, 1, 4.0),
(1, 2, -1.0),
(2, 0, -1.0),
(2, 1, -1.0),
(2, 2, 4.0),
],
)
}
fn budget() -> ComputeBudget {
ComputeBudget::default()
}
#[test]
fn solve_small_direct() {
let matrix = diag_dominant_3x3();
let rhs = vec![1.0, 2.0, 3.0];
let solver = BmsspSolver::new(1e-8, 100);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
assert!(
result.residual_norm < 1e-5,
"residual too high: {}",
result.residual_norm
);
assert_eq!(result.algorithm, Algorithm::BMSSP);
assert_eq!(result.solution.len(), 3);
}
#[test]
fn solve_1d_laplacian_small() {
let n = 50;
let matrix = laplacian_1d(n);
let rhs = vec![1.0f64; n];
let solver = BmsspSolver::new(1e-6, 200);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
assert!(
result.residual_norm < 1e-5,
"residual: {}",
result.residual_norm
);
assert_eq!(result.solution.len(), n);
}
#[test]
fn solve_1d_laplacian_medium() {
let n = 500;
let matrix = laplacian_1d(n);
let rhs: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64)).collect();
let solver = BmsspSolver::new(1e-4, 500);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
assert!(
result.residual_norm < 1e-3,
"residual: {}",
result.residual_norm
);
assert!(result.iterations > 0, "should have iterated");
}
#[test]
fn solve_identity() {
let n = 10;
let matrix = CsrMatrix::<f64>::identity(n);
let rhs: Vec<f64> = (1..=n as i32).map(|i| i as f64).collect();
let solver = BmsspSolver::new(1e-10, 100);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
for i in 0..n {
assert!(
(result.solution[i] as f64 - rhs[i]).abs() < 1e-3,
"mismatch at {}: {} vs {}",
i,
result.solution[i],
rhs[i]
);
}
}
#[test]
fn solve_zero_rhs() {
let matrix = laplacian_1d(10);
let rhs = vec![0.0f64; 10];
let solver = BmsspSolver::new(1e-8, 100);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
for &v in &result.solution {
assert!(v.abs() < 1e-6, "expected zero solution, got {v}");
}
}
#[test]
fn reject_dimension_mismatch_rhs() {
let matrix = laplacian_1d(5);
let rhs = vec![1.0f64; 3];
let solver = BmsspSolver::new(1e-8, 100);
assert!(solver.solve(&matrix, &rhs, &budget()).is_err());
}
#[test]
fn reject_nonsquare_matrix() {
let matrix = CsrMatrix {
row_ptr: vec![0, 1, 2],
col_indices: vec![0, 1],
values: vec![1.0f64, 1.0],
rows: 2,
cols: 3,
};
let rhs = vec![1.0, 1.0];
let solver = BmsspSolver::new(1e-8, 100);
assert!(solver.solve(&matrix, &rhs, &budget()).is_err());
}
#[test]
fn solve_1x1_system() {
let matrix = CsrMatrix::<f64>::from_coo(1, 1, vec![(0, 0, 5.0)]);
let rhs = vec![10.0];
let solver = BmsspSolver::new(1e-10, 100);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
assert!((result.solution[0] as f64 - 2.0).abs() < 1e-5);
}
#[test]
fn with_params_stores_values() {
let solver = BmsspSolver::with_params(1e-6, 300, 10, 0.3);
assert!((solver.tolerance - 1e-6).abs() < f64::EPSILON);
assert_eq!(solver.max_iterations, 300);
assert_eq!(solver.max_levels, 10);
assert!((solver.coarsening_ratio - 0.3).abs() < f64::EPSILON);
}
#[test]
fn convergence_history_populated() {
let n = 200;
let matrix = laplacian_1d(n);
let rhs = vec![1.0f64; n];
let solver = BmsspSolver::new(1e-6, 500);
let result = solver.solve(&matrix, &rhs, &budget()).unwrap();
assert!(!result.convergence_history.is_empty());
let first = result.convergence_history.first().unwrap().residual_norm;
let last = result.convergence_history.last().unwrap().residual_norm;
assert!(
last < first || first < 1e-6,
"residual did not decrease: first={first}, last={last}"
);
}
#[test]
fn transpose_csr_identity() {
let id = CsrMatrix::<f64>::identity(5);
let id_t = transpose_csr(&id);
assert_eq!(id_t.rows, 5);
assert_eq!(id_t.cols, 5);
assert_eq!(id_t.nnz(), 5);
for i in 0..5 {
assert_eq!(id_t.col_indices[i], i);
assert!((id_t.values[i] - 1.0).abs() < f64::EPSILON);
}
}
#[test]
fn sparse_matmul_identity() {
let id = CsrMatrix::<f64>::identity(4);
let a = CsrMatrix::<f64>::from_coo(
4,
4,
vec![
(0, 0, 2.0),
(0, 1, 1.0),
(1, 1, 3.0),
(2, 2, 4.0),
(3, 3, 5.0),
],
);
let result = sparse_matmul(&id, &a);
assert_eq!(result.rows, 4);
assert_eq!(result.cols, 4);
assert_eq!(result.nnz(), a.nnz());
}
#[test]
fn gauss_seidel_diagonal_system() {
let matrix = CsrMatrix::<f64>::from_coo(2, 2, vec![(0, 0, 4.0), (1, 1, 4.0)]);
let b = [8.0f64, 12.0];
let mut x = [0.0f64; 2];
gauss_seidel_sweep(&matrix, &mut x, &b);
assert!((x[0] - 2.0).abs() < 1e-10);
assert!((x[1] - 3.0).abs() < 1e-10);
}
#[test]
fn dense_direct_solve_3x3() {
let matrix = diag_dominant_3x3();
let rhs = [2.0f64, 2.0, 2.0];
let x = dense_direct_solve(&matrix, &rhs);
let mut ax = vec![0.0f64; 3];
matrix.spmv(&x, &mut ax);
for i in 0..3 {
assert!(
(ax[i] - rhs[i]).abs() < 1e-10,
"dense solve mismatch at {i}: {} vs {}",
ax[i],
rhs[i],
);
}
}
#[test]
fn algorithm_returns_bmssp() {
let solver = BmsspSolver::new(1e-6, 100);
assert_eq!(solver.algorithm(), Algorithm::BMSSP);
}
#[test]
fn estimate_complexity_reasonable() {
let solver = BmsspSolver::new(1e-6, 100);
let profile = SparsityProfile {
rows: 1000,
cols: 1000,
nnz: 5000,
density: 0.005,
is_diag_dominant: true,
estimated_spectral_radius: 0.5,
estimated_condition: 10.0,
is_symmetric_structure: true,
avg_nnz_per_row: 5.0,
max_nnz_per_row: 7,
};
let est = solver.estimate_complexity(&profile, 1000);
assert_eq!(est.algorithm, Algorithm::BMSSP);
assert_eq!(est.complexity_class, ComplexityClass::SublinearNnz);
assert!(est.estimated_flops > 0);
assert!(est.estimated_iterations > 0);
assert!(est.estimated_memory_bytes > 0);
}
}