use crate::complexity::{Complexity, ComplexityClass};
use crate::error::{Result, SolverError};
use crate::matrix::Matrix;
use crate::types::Precision;
use alloc::collections::BTreeMap;
#[derive(Debug, Clone, Copy, Default)]
pub struct VerifySparseSolutionOp;
impl Complexity for VerifySparseSolutionOp {
const CLASS: ComplexityClass = ComplexityClass::SubLinear;
const DETAIL: &'static str = "Residual audit restricted to caller-supplied closure entries. \
O(|entries| · avg_row_nnz) — same class as the SubLinear orchestrator \
whose output it verifies. Independent of n for sparse DD matrices.";
}
#[derive(Debug, Clone, PartialEq)]
pub struct WitnessReport {
pub max_residual: Precision,
pub threshold: Precision,
pub ok: bool,
pub worst_row: Option<usize>,
}
pub fn verify_sparse_solution(
matrix: &dyn Matrix,
prev_solution: &[Precision],
b: &[Precision],
entries: &[(usize, Precision)],
tolerance: Precision,
) -> Result<WitnessReport> {
let n = matrix.rows();
if prev_solution.len() != n {
return Err(SolverError::DimensionMismatch {
expected: n,
actual: prev_solution.len(),
operation: alloc::string::String::from("verify_sparse_solution::prev_solution"),
});
}
if b.len() != n {
return Err(SolverError::DimensionMismatch {
expected: n,
actual: b.len(),
operation: alloc::string::String::from("verify_sparse_solution::b"),
});
}
let mut overlay: BTreeMap<usize, Precision> = BTreeMap::new();
for &(i, val) in entries {
if i < n {
overlay.insert(i, val);
}
}
let x_at = |j: usize| -> Precision {
match overlay.get(&j) {
Some(&v) => v,
None => {
if j < prev_solution.len() {
prev_solution[j]
} else {
0.0
}
}
}
};
let b_inf = b
.iter()
.map(|v| v.abs())
.fold(0.0_f64, |a, x| if a > x { a } else { x });
let threshold = tolerance * b_inf.max(1.0);
let mut max_residual: Precision = 0.0;
let mut worst_row: Option<usize> = None;
for &(i, _) in entries {
if i >= n {
continue;
}
let mut ax_i: Precision = 0.0;
for (col_idx, a_ij) in matrix.row_iter(i) {
let j = col_idx as usize;
ax_i += a_ij * x_at(j);
}
let r = (b[i] - ax_i).abs();
if r > max_residual {
max_residual = r;
worst_row = Some(i);
}
}
Ok(WitnessReport {
max_residual,
threshold,
ok: max_residual <= threshold,
worst_row,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::SparseMatrix;
use crate::solver::{neumann::NeumannSolver, SolverAlgorithm, SolverOptions};
use crate::{solve_on_change_sublinear, SparseDelta};
fn build_ring(n: usize) -> SparseMatrix {
let mut t = Vec::new();
for i in 0..n {
t.push((i, i, 5.0_f64));
t.push((i, (i + 1) % n, 1.0));
t.push((i, (i + 2) % n, 1.0));
t.push((i, (i + n - 1) % n, -1.0));
t.push((i, (i + n - 2) % n, -1.0));
}
SparseMatrix::from_triplets(t, n, n).unwrap()
}
fn build_strong_ring(n: usize) -> SparseMatrix {
let mut t = Vec::new();
for i in 0..n {
t.push((i, i, 10.0_f64));
t.push((i, (i + 1) % n, 0.5));
t.push((i, (i + n - 1) % n, -0.5));
}
SparseMatrix::from_triplets(t, n, n).unwrap()
}
#[test]
fn op_complexity_class_is_sublinear() {
assert_eq!(
<VerifySparseSolutionOp as Complexity>::CLASS,
ComplexityClass::SubLinear
);
}
#[test]
fn op_compile_time_bound() {
const _: () = assert!(matches!(
<VerifySparseSolutionOp as Complexity>::CLASS,
ComplexityClass::SubLinear
));
}
#[test]
fn witness_passes_on_genuine_sublinear_output() {
let n = 32;
let m = build_strong_ring(n);
let b_prev: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();
let solver = NeumannSolver::new(64, 1e-12);
let opts = SolverOptions::default();
let prev_solution = solver.solve(&m, &b_prev, &opts).unwrap().solution;
let delta = SparseDelta::new(vec![10], vec![0.5]).unwrap();
let mut b_new = b_prev.clone();
delta.apply_to(&mut b_new).unwrap();
let entries = solve_on_change_sublinear(
&m,
&prev_solution,
&b_new,
&delta,
8,
32,
1e-12,
)
.unwrap();
let report = verify_sparse_solution(&m, &prev_solution, &b_new, &entries, 1e-3).unwrap();
assert!(
report.ok,
"witness should pass on a genuine SubLinear output; max_residual={}, threshold={}, worst_row={:?}",
report.max_residual, report.threshold, report.worst_row
);
}
#[test]
fn witness_fails_on_perturbed_output() {
let n = 16;
let m = build_ring(n);
let b_prev: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();
let solver = NeumannSolver::new(64, 1e-12);
let opts = SolverOptions::default();
let prev = solver.solve(&m, &b_prev, &opts).unwrap().solution;
let delta = SparseDelta::new(vec![3], vec![0.2]).unwrap();
let mut b_new = b_prev.clone();
delta.apply_to(&mut b_new).unwrap();
let mut entries =
solve_on_change_sublinear(&m, &prev, &b_new, &delta, 6, 32, 1e-10).unwrap();
if let Some(first) = entries.first_mut() {
first.1 += 100.0;
}
let report = verify_sparse_solution(&m, &prev, &b_new, &entries, 1e-6).unwrap();
assert!(!report.ok, "witness should fail on a corrupted entry");
assert!(report.worst_row.is_some());
}
#[test]
fn witness_empty_entries_passes_trivially() {
let m = build_ring(4);
let prev = vec![0.0; 4];
let b = vec![1.0; 4];
let report = verify_sparse_solution(&m, &prev, &b, &[], 1e-6).unwrap();
assert!(report.ok);
assert_eq!(report.max_residual, 0.0);
assert_eq!(report.worst_row, None);
}
#[test]
fn witness_dimension_mismatch_errors() {
let m = build_ring(4);
let bad_prev = vec![0.0; 3];
let b = vec![1.0; 4];
let err = verify_sparse_solution(&m, &bad_prev, &b, &[], 1e-6).unwrap_err();
assert!(matches!(err, SolverError::DimensionMismatch { .. }));
let prev = vec![0.0; 4];
let bad_b = vec![1.0; 5];
let err = verify_sparse_solution(&m, &prev, &bad_b, &[], 1e-6).unwrap_err();
assert!(matches!(err, SolverError::DimensionMismatch { .. }));
}
#[test]
fn witness_out_of_bound_entries_silently_ignored() {
let m = build_ring(4);
let prev = vec![0.0; 4];
let b = vec![1.0; 4];
let entries = vec![(99usize, 1.0), (100usize, 2.0)];
let report = verify_sparse_solution(&m, &prev, &b, &entries, 1e-6).unwrap();
assert!(report.ok);
assert_eq!(report.max_residual, 0.0);
}
}