use crate::closure::closure_indices;
use crate::complexity::{Complexity, ComplexityClass};
use crate::error::{Result, SolverError};
use crate::matrix::Matrix;
use crate::types::Precision;
use alloc::collections::BTreeMap;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, Default)]
pub struct SolveSingleEntryNeumannOp;
impl Complexity for SolveSingleEntryNeumannOp {
const CLASS: ComplexityClass = ComplexityClass::SubLinear;
const DETAIL: &'static str =
"Single-entry Neumann: O(max_terms · |closure| · branching). Independent of n for \
sparse DD matrices with bounded degree + bounded max_terms. Widens to Linear when \
the closure spans the whole graph.";
}
pub fn solve_single_entry_neumann(
matrix: &dyn Matrix,
b: &[Precision],
target: usize,
max_terms: usize,
tolerance: Precision,
) -> Result<Precision> {
let n = matrix.rows();
if target >= n {
return Err(SolverError::IndexOutOfBounds {
index: target,
max_index: n.saturating_sub(1),
context: alloc::string::String::from("solve_single_entry_neumann::target"),
});
}
if b.len() != n {
return Err(SolverError::DimensionMismatch {
expected: n,
actual: b.len(),
operation: alloc::string::String::from("solve_single_entry_neumann::b.len()"),
});
}
let closure_set = closure_indices(matrix, &[target], max_terms);
if closure_set.is_empty() {
return Ok(0.0);
}
let in_closure = |idx: usize| closure_set.binary_search(&idx).is_ok();
let mut y: BTreeMap<usize, Precision> = BTreeMap::new();
for &j in &closure_set {
let a_jj = matrix.get(j, j).unwrap_or(0.0);
if a_jj == 0.0 {
return Err(SolverError::InvalidInput {
message: alloc::format!(
"solve_single_entry_neumann: zero diagonal at row {} in closure of target {}",
j,
target,
),
parameter: Some(alloc::string::String::from("matrix")),
});
}
y.insert(j, b[j] / a_jj);
}
let mut x_target = *y.get(&target).unwrap_or(&0.0);
let mut y_next: BTreeMap<usize, Precision> = BTreeMap::new();
for _term in 1..=max_terms {
y_next.clear();
for &j in &closure_set {
let a_jj = matrix.get(j, j).unwrap_or(0.0);
debug_assert!(a_jj != 0.0);
let mut sum: Precision = 0.0;
for (m_idx, a_jm) in matrix.row_iter(j) {
let m = m_idx as usize;
if m == j {
continue;
}
if !in_closure(m) {
continue;
}
if let Some(&y_m) = y.get(&m) {
sum += a_jm * y_m;
}
}
let val = -sum / a_jj;
if val != 0.0 {
y_next.insert(j, val);
}
}
let delta = y_next.get(&target).copied().unwrap_or(0.0);
x_target += delta;
if delta.abs() < tolerance {
break;
}
core::mem::swap(&mut y, &mut y_next);
}
Ok(x_target)
}
pub fn solve_single_entries_neumann(
matrix: &dyn Matrix,
b: &[Precision],
targets: &[usize],
max_terms: usize,
tolerance: Precision,
) -> Result<Vec<(usize, Precision)>> {
let mut out: Vec<(usize, Precision)> = Vec::with_capacity(targets.len());
for &t in targets {
let val = solve_single_entry_neumann(matrix, b, t, max_terms, tolerance)?;
out.push((t, val));
}
out.sort_by_key(|(i, _)| *i);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::SparseMatrix;
use crate::solver::neumann::NeumannSolver;
use crate::solver::{SolverAlgorithm, SolverOptions};
fn build_chain(n: usize) -> SparseMatrix {
let mut triplets = Vec::new();
for i in 0..n {
triplets.push((i, i, 4.0));
if i + 1 < n {
triplets.push((i, i + 1, -1.0));
triplets.push((i + 1, i, -1.0));
}
}
SparseMatrix::from_triplets(triplets, n, n).unwrap()
}
#[test]
fn matches_full_solve_on_chain() {
let n = 8;
let a = build_chain(n);
let b: Vec<Precision> = (1..=n).map(|i| i as Precision).collect();
let full_solver = NeumannSolver::new(64, 1e-12);
let opts = SolverOptions::default();
let full = full_solver.solve(&a, &b, &opts).unwrap();
for target in 0..n {
let est = solve_single_entry_neumann(
&a, &b, target, 32, 1e-10,
)
.unwrap();
let diff = (est - full.solution[target]).abs();
assert!(
diff < 1e-6,
"single-entry estimate diverged at row {}: est={}, full={}, diff={}",
target,
est,
full.solution[target],
diff
);
}
}
#[test]
fn diagonal_matrix_returns_b_over_diag() {
let n = 4;
let triplets: Vec<_> = (0..n).map(|i| (i, i, 2.0)).collect();
let a = SparseMatrix::from_triplets(triplets, n, n).unwrap();
let b = alloc::vec![3.0, 6.0, 9.0, 12.0];
for i in 0..n {
let est = solve_single_entry_neumann(&a, &b, i, 0, 1e-12).unwrap();
assert!((est - b[i] / 2.0).abs() < 1e-12);
}
}
#[test]
fn max_terms_zero_returns_zeroth_neumann_term() {
let n = 4;
let a = build_chain(n);
let b = alloc::vec![1.0, 2.0, 3.0, 4.0];
for i in 0..n {
let est = solve_single_entry_neumann(&a, &b, i, 0, 1e-12).unwrap();
assert!((est - b[i] / 4.0).abs() < 1e-12);
}
}
#[test]
fn target_out_of_bounds_errors() {
let n = 4;
let a = build_chain(n);
let b = alloc::vec![1.0; n];
let err = solve_single_entry_neumann(&a, &b, 99, 8, 1e-10).unwrap_err();
assert!(matches!(err, SolverError::IndexOutOfBounds { .. }));
}
#[test]
fn b_length_mismatch_errors() {
let n = 4;
let a = build_chain(n);
let b = alloc::vec![1.0; n + 3];
let err = solve_single_entry_neumann(&a, &b, 0, 8, 1e-10).unwrap_err();
assert!(matches!(err, SolverError::DimensionMismatch { .. }));
}
#[test]
fn zero_diagonal_at_target_errors() {
let triplets = alloc::vec![(0, 1, 1.0), (1, 1, 1.0)];
let a = SparseMatrix::from_triplets(triplets, 2, 2).unwrap();
let b = alloc::vec![1.0, 1.0];
let err = solve_single_entry_neumann(&a, &b, 0, 4, 1e-10).unwrap_err();
assert!(matches!(err, SolverError::InvalidInput { .. }));
}
#[test]
fn batched_matches_per_target() {
let n = 6;
let a = build_chain(n);
let b = alloc::vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let targets = alloc::vec![4usize, 1, 2];
let batched = solve_single_entries_neumann(&a, &b, &targets, 32, 1e-10).unwrap();
assert_eq!(batched.len(), 3);
assert_eq!(batched[0].0, 1);
assert_eq!(batched[1].0, 2);
assert_eq!(batched[2].0, 4);
for &(idx, val) in &batched {
let scalar = solve_single_entry_neumann(&a, &b, idx, 32, 1e-10).unwrap();
assert!((val - scalar).abs() < 1e-12);
}
}
#[test]
fn op_complexity_class_is_sublinear() {
assert_eq!(
<SolveSingleEntryNeumannOp as Complexity>::CLASS,
ComplexityClass::SubLinear
);
}
#[test]
fn op_compile_time_bound() {
const _: () = assert!(matches!(
<SolveSingleEntryNeumannOp as Complexity>::CLASS,
ComplexityClass::SubLinear
));
}
}