use std::time::Instant;
use tracing::{debug, info, instrument, warn};
use crate::error::{SolverError, ValidationError};
use crate::traits::SolverEngine;
use crate::types::{
Algorithm, ComplexityClass, ComplexityEstimate, ComputeBudget, ConvergenceInfo, CsrMatrix,
SolverResult, SparsityProfile,
};
const POWER_ITERATION_STEPS: usize = 10;
const INSTABILITY_GROWTH_FACTOR: f64 = 2.0;
#[derive(Debug, Clone)]
pub struct NeumannSolver {
pub tolerance: f64,
pub max_iterations: usize,
}
impl NeumannSolver {
pub fn new(tolerance: f64, max_iterations: usize) -> Self {
Self {
tolerance,
max_iterations,
}
}
#[instrument(skip(matrix), fields(n = matrix.rows))]
pub fn estimate_spectral_radius(matrix: &CsrMatrix<f32>) -> f64 {
let n = matrix.rows;
if n == 0 {
return 0.0;
}
let d_inv = extract_diag_inv_f32(matrix);
Self::estimate_spectral_radius_with_diag(matrix, &d_inv)
}
fn estimate_spectral_radius_with_diag(matrix: &CsrMatrix<f32>, d_inv: &[f32]) -> f64 {
let n = matrix.rows;
if n == 0 {
return 0.0;
}
let mut v: Vec<f32> = (0..n)
.map(|i| ((i.wrapping_mul(7).wrapping_add(13)) % 100) as f32 / 100.0)
.collect();
let norm = l2_norm_f32(&v);
if norm > 1e-12 {
scale_vec_f32(&mut v, 1.0 / norm);
}
let mut av = vec![0.0f32; n]; let mut w = vec![0.0f32; n]; let mut eigenvalue_estimate = 0.0f64;
for _ in 0..POWER_ITERATION_STEPS {
matrix.spmv(&v, &mut av);
for j in 0..n {
w[j] = v[j] - d_inv[j] * av[j];
}
let dot: f64 = v
.iter()
.zip(w.iter())
.map(|(&a, &b)| a as f64 * b as f64)
.sum();
eigenvalue_estimate = dot;
let w_norm = l2_norm_f32(&w);
if w_norm < 1e-12 {
break;
}
for j in 0..n {
v[j] = w[j] / w_norm as f32;
}
}
let rho = eigenvalue_estimate.abs();
debug!(rho, "estimated spectral radius of (I - D^-1 A)");
rho
}
#[instrument(skip(self, matrix, rhs), fields(n = matrix.rows, nnz = matrix.nnz()))]
pub fn solve(&self, matrix: &CsrMatrix<f32>, rhs: &[f32]) -> Result<SolverResult, SolverError> {
let start = Instant::now();
let n = matrix.rows;
if matrix.rows != matrix.cols {
return Err(SolverError::InvalidInput(
ValidationError::DimensionMismatch(format!(
"matrix must be square: got {}x{}",
matrix.rows, matrix.cols,
)),
));
}
if rhs.len() != n {
return Err(SolverError::InvalidInput(
ValidationError::DimensionMismatch(format!(
"rhs length {} does not match matrix dimension {}",
rhs.len(),
n,
)),
));
}
if n == 0 {
return Ok(SolverResult {
solution: Vec::new(),
iterations: 0,
residual_norm: 0.0,
wall_time: start.elapsed(),
convergence_history: Vec::new(),
algorithm: Algorithm::Neumann,
});
}
let d_inv = extract_diag_inv_f32(matrix);
let rho = Self::estimate_spectral_radius_with_diag(matrix, &d_inv);
if rho >= 1.0 {
warn!(rho, "spectral radius >= 1.0, Neumann series will diverge");
return Err(SolverError::SpectralRadiusExceeded {
spectral_radius: rho,
limit: 1.0,
algorithm: Algorithm::Neumann,
});
}
info!(rho, "spectral radius check passed");
let mut x: Vec<f32> = (0..n).map(|i| d_inv[i] * rhs[i]).collect();
let mut r = vec![0.0f32; n];
let mut convergence_history = Vec::with_capacity(self.max_iterations.min(256));
let mut prev_residual_norm = f64::MAX;
let final_residual_norm: f64;
let mut iterations_done: usize = 0;
for k in 0..self.max_iterations {
let residual_norm_sq = matrix.fused_residual_norm_sq(&x, rhs, &mut r);
let residual_norm = residual_norm_sq.sqrt();
iterations_done = k + 1;
convergence_history.push(ConvergenceInfo {
iteration: k,
residual_norm,
});
debug!(iteration = k, residual_norm, "neumann iteration");
if residual_norm < self.tolerance {
final_residual_norm = residual_norm;
info!(iterations = iterations_done, residual_norm, "converged");
return Ok(SolverResult {
solution: x,
iterations: iterations_done,
residual_norm: final_residual_norm,
wall_time: start.elapsed(),
convergence_history,
algorithm: Algorithm::Neumann,
});
}
if residual_norm.is_nan() || residual_norm.is_infinite() {
return Err(SolverError::NumericalInstability {
iteration: k,
detail: format!("residual became {residual_norm}"),
});
}
if k > 0
&& prev_residual_norm < f64::MAX
&& prev_residual_norm > 0.0
&& residual_norm > INSTABILITY_GROWTH_FACTOR * prev_residual_norm
{
warn!(
iteration = k,
prev = prev_residual_norm,
current = residual_norm,
"residual diverging",
);
return Err(SolverError::NumericalInstability {
iteration: k,
detail: format!(
"residual grew from {prev_residual_norm:.6e} to \
{residual_norm:.6e} (>{INSTABILITY_GROWTH_FACTOR:.0}x)",
),
});
}
let chunks = n / 4;
for c in 0..chunks {
let j = c * 4;
x[j] += d_inv[j] * r[j];
x[j + 1] += d_inv[j + 1] * r[j + 1];
x[j + 2] += d_inv[j + 2] * r[j + 2];
x[j + 3] += d_inv[j + 3] * r[j + 3];
}
for j in (chunks * 4)..n {
x[j] += d_inv[j] * r[j];
}
prev_residual_norm = residual_norm;
}
final_residual_norm = prev_residual_norm;
Err(SolverError::NonConvergence {
iterations: iterations_done,
residual: final_residual_norm,
tolerance: self.tolerance,
})
}
}
impl SolverEngine for NeumannSolver {
fn solve(
&self,
matrix: &CsrMatrix<f64>,
rhs: &[f64],
budget: &ComputeBudget,
) -> Result<SolverResult, SolverError> {
let start = Instant::now();
for (i, &v) in matrix.values.iter().enumerate() {
if v.is_finite() && v.abs() > f32::MAX as f64 {
return Err(SolverError::InvalidInput(ValidationError::NonFiniteValue(
format!("matrix value at index {i} ({v:.6e}) overflows f32"),
)));
}
}
for (i, &v) in rhs.iter().enumerate() {
if v.is_finite() && v.abs() > f32::MAX as f64 {
return Err(SolverError::InvalidInput(ValidationError::NonFiniteValue(
format!("rhs value at index {i} ({v:.6e}) overflows f32"),
)));
}
}
let f32_matrix = CsrMatrix {
row_ptr: matrix.row_ptr.clone(),
col_indices: matrix.col_indices.clone(),
values: matrix.values.iter().map(|&v| v as f32).collect(),
rows: matrix.rows,
cols: matrix.cols,
};
let f32_rhs: Vec<f32> = rhs.iter().map(|&v| v as f32).collect();
let max_iters = self.max_iterations.min(budget.max_iterations);
let tol = self
.tolerance
.min(budget.tolerance)
.max(f32::EPSILON as f64 * 4.0);
let inner_solver = NeumannSolver::new(tol, max_iters);
let mut result = inner_solver.solve(&f32_matrix, &f32_rhs)?;
if start.elapsed() > budget.max_time {
return Err(SolverError::BudgetExhausted {
reason: "wall-clock time limit exceeded".to_string(),
elapsed: start.elapsed(),
});
}
result.wall_time = start.elapsed();
Ok(result)
}
fn estimate_complexity(&self, profile: &SparsityProfile, n: usize) -> ComplexityEstimate {
let rho = profile.estimated_spectral_radius.max(0.01).min(0.999);
let est_iters = ((1.0 / self.tolerance).ln() / (1.0 - rho).ln().abs()).ceil() as usize;
let est_iters = est_iters.min(self.max_iterations).max(1);
ComplexityEstimate {
algorithm: Algorithm::Neumann,
estimated_flops: (est_iters as u64) * (profile.nnz as u64) * 2,
estimated_iterations: est_iters,
estimated_memory_bytes: n * 4 * 3,
complexity_class: ComplexityClass::SublinearNnz,
}
}
fn algorithm(&self) -> Algorithm {
Algorithm::Neumann
}
}
fn extract_diag_inv_f32(matrix: &CsrMatrix<f32>) -> Vec<f32> {
let n = matrix.rows;
let mut d_inv = vec![1.0f32; n];
for i in 0..n {
let start = matrix.row_ptr[i];
let end = matrix.row_ptr[i + 1];
for idx in start..end {
if matrix.col_indices[idx] == i {
let diag = matrix.values[idx];
if diag.abs() > 1e-15 {
d_inv[i] = 1.0 / diag;
} else {
warn!(
row = i,
diag_value = %diag,
"zero or near-zero diagonal entry; substituting 1.0 — matrix may be singular"
);
}
break;
}
}
}
d_inv
}
#[inline]
fn l2_norm_f32(v: &[f32]) -> f32 {
let sum: f64 = v.iter().map(|&x| (x as f64) * (x as f64)).sum();
sum.sqrt() as f32
}
#[inline]
fn scale_vec_f32(v: &mut [f32], s: f32) {
for x in v.iter_mut() {
*x *= s;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::CsrMatrix;
fn tridiag_f32(n: usize, diag_val: f32, off_val: f32) -> CsrMatrix<f32> {
let mut entries = Vec::new();
for i in 0..n {
entries.push((i, i, diag_val));
if i > 0 {
entries.push((i, i - 1, off_val));
}
if i + 1 < n {
entries.push((i, i + 1, off_val));
}
}
CsrMatrix::<f32>::from_coo(n, n, entries)
}
fn test_matrix_f64() -> CsrMatrix<f64> {
CsrMatrix::<f64>::from_coo(
3,
3,
vec![
(0, 0, 1.0),
(0, 1, -0.1),
(1, 0, -0.1),
(1, 1, 1.0),
(1, 2, -0.1),
(2, 1, -0.1),
(2, 2, 1.0),
],
)
}
#[test]
fn test_new() {
let solver = NeumannSolver::new(1e-8, 100);
assert_eq!(solver.tolerance, 1e-8);
assert_eq!(solver.max_iterations, 100);
}
#[test]
fn test_spectral_radius_identity() {
let identity = CsrMatrix::<f32>::identity(4);
let rho = NeumannSolver::estimate_spectral_radius(&identity);
assert!(rho < 0.1, "expected rho ~ 0 for identity, got {rho}");
}
#[test]
fn test_spectral_radius_pure_diagonal() {
let a = CsrMatrix::<f32>::from_coo(3, 3, vec![(0, 0, 0.5_f32), (1, 1, 0.5), (2, 2, 0.5)]);
let rho = NeumannSolver::estimate_spectral_radius(&a);
assert!(rho < 0.1, "expected rho ~ 0 for diagonal matrix, got {rho}");
}
#[test]
fn test_spectral_radius_empty() {
let empty = CsrMatrix::<f32> {
row_ptr: vec![0],
col_indices: vec![],
values: vec![],
rows: 0,
cols: 0,
};
assert_eq!(NeumannSolver::estimate_spectral_radius(&empty), 0.0);
}
#[test]
fn test_spectral_radius_non_diag_dominant() {
let a = CsrMatrix::<f32>::from_coo(
2,
2,
vec![(0, 0, 1.0_f32), (0, 1, 2.0), (1, 0, 2.0), (1, 1, 1.0)],
);
let rho = NeumannSolver::estimate_spectral_radius(&a);
assert!(
rho > 1.0,
"expected rho > 1 for non-diag-dominant matrix, got {rho}"
);
}
#[test]
fn test_solve_identity() {
let identity = CsrMatrix::<f32>::identity(3);
let rhs = vec![1.0_f32, 2.0, 3.0];
let solver = NeumannSolver::new(1e-6, 100);
let result = solver.solve(&identity, &rhs).unwrap();
for (i, (&e, &a)) in rhs.iter().zip(result.solution.iter()).enumerate() {
assert!((e - a).abs() < 1e-4, "index {i}: expected {e}, got {a}");
}
assert!(result.residual_norm < 1e-6);
}
#[test]
fn test_solve_diagonal() {
let a = CsrMatrix::<f32>::from_coo(3, 3, vec![(0, 0, 0.5_f32), (1, 1, 0.5), (2, 2, 0.5)]);
let rhs = vec![1.0_f32, 1.0, 1.0];
let solver = NeumannSolver::new(1e-6, 200);
let result = solver.solve(&a, &rhs).unwrap();
for (i, &val) in result.solution.iter().enumerate() {
assert!(
(val - 2.0).abs() < 0.01,
"index {i}: expected ~2.0, got {val}"
);
}
}
#[test]
fn test_solve_tridiagonal() {
let a = tridiag_f32(5, 1.0, -0.1);
let rhs = vec![1.0_f32, 0.0, 1.0, 0.0, 1.0];
let solver = NeumannSolver::new(1e-6, 1000);
let result = solver.solve(&a, &rhs).unwrap();
assert!(result.residual_norm < 1e-4);
assert!(result.iterations > 0);
assert!(!result.convergence_history.is_empty());
}
#[test]
fn test_solve_empty_system() {
let a = CsrMatrix::<f32> {
row_ptr: vec![0],
col_indices: vec![],
values: vec![],
rows: 0,
cols: 0,
};
let result = NeumannSolver::new(1e-6, 10).solve(&a, &[]).unwrap();
assert_eq!(result.iterations, 0);
assert!(result.solution.is_empty());
}
#[test]
fn test_solve_dimension_mismatch() {
let a = CsrMatrix::<f32>::identity(3);
let rhs = vec![1.0_f32, 2.0];
let err = NeumannSolver::new(1e-6, 100).solve(&a, &rhs).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("dimension") || msg.contains("mismatch"),
"got: {msg}"
);
}
#[test]
fn test_solve_non_square() {
let a = CsrMatrix::<f32>::from_coo(2, 3, vec![(0, 0, 1.0_f32), (1, 1, 1.0)]);
let rhs = vec![1.0_f32, 1.0];
let err = NeumannSolver::new(1e-6, 100).solve(&a, &rhs).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("square") || msg.contains("dimension"),
"got: {msg}"
);
}
#[test]
fn test_solve_divergent_matrix() {
let a = CsrMatrix::<f32>::from_coo(
2,
2,
vec![(0, 0, 1.0_f32), (0, 1, 2.0), (1, 0, 2.0), (1, 1, 1.0)],
);
let rhs = vec![1.0_f32, 1.0];
let err = NeumannSolver::new(1e-6, 100).solve(&a, &rhs).unwrap_err();
assert!(err.to_string().contains("spectral radius"), "got: {}", err);
}
#[test]
fn test_convergence_history_monotone() {
let a = CsrMatrix::<f32>::identity(4);
let rhs = vec![1.0_f32; 4];
let result = NeumannSolver::new(1e-10, 50).solve(&a, &rhs).unwrap();
assert!(!result.convergence_history.is_empty());
for window in result.convergence_history.windows(2) {
assert!(
window[1].residual_norm <= window[0].residual_norm + 1e-12,
"residual not decreasing: {} -> {}",
window[0].residual_norm,
window[1].residual_norm,
);
}
}
#[test]
fn test_algorithm_tag() {
let a = CsrMatrix::<f32>::identity(2);
let rhs = vec![1.0_f32; 2];
let result = NeumannSolver::new(1e-6, 100).solve(&a, &rhs).unwrap();
assert_eq!(result.algorithm, Algorithm::Neumann);
}
#[test]
fn test_solver_engine_trait_f64() {
let solver = NeumannSolver::new(1e-6, 200);
let engine: &dyn SolverEngine = &solver;
let a = test_matrix_f64();
let rhs = vec![1.0_f64, 0.0, 1.0];
let budget = ComputeBudget::default();
let result = engine.solve(&a, &rhs, &budget).unwrap();
assert!(result.residual_norm < 1e-4);
assert_eq!(result.algorithm, Algorithm::Neumann);
}
#[test]
fn test_larger_system_accuracy() {
let n = 50;
let a = tridiag_f32(n, 1.0, -0.1);
let rhs: Vec<f32> = (0..n).map(|i| (i as f32 + 1.0) / n as f32).collect();
let result = NeumannSolver::new(1e-6, 2000).solve(&a, &rhs).unwrap();
assert!(
result.residual_norm < 1e-6,
"residual too large: {}",
result.residual_norm
);
let mut ax = vec![0.0f32; n];
a.spmv(&result.solution, &mut ax);
for i in 0..n {
assert!(
(ax[i] - rhs[i]).abs() < 1e-4,
"A*x[{i}]={} but b[{i}]={}",
ax[i],
rhs[i]
);
}
}
#[test]
fn test_scalar_system() {
let a = CsrMatrix::<f32>::from_coo(1, 1, vec![(0, 0, 0.5_f32)]);
let rhs = vec![4.0_f32];
let result = NeumannSolver::new(1e-8, 200).solve(&a, &rhs).unwrap();
assert!(
(result.solution[0] - 8.0).abs() < 0.01,
"expected ~8.0, got {}",
result.solution[0]
);
}
#[test]
fn test_estimate_complexity() {
let solver = NeumannSolver::new(1e-6, 1000);
let profile = SparsityProfile {
rows: 100,
cols: 100,
nnz: 500,
density: 0.05,
is_diag_dominant: true,
estimated_spectral_radius: 0.5,
estimated_condition: 3.0,
is_symmetric_structure: true,
avg_nnz_per_row: 5.0,
max_nnz_per_row: 8,
};
let estimate = solver.estimate_complexity(&profile, 100);
assert_eq!(estimate.algorithm, Algorithm::Neumann);
assert!(estimate.estimated_flops > 0);
assert!(estimate.estimated_iterations > 0);
assert!(estimate.estimated_memory_bytes > 0);
assert_eq!(estimate.complexity_class, ComplexityClass::SublinearNnz);
}
}