use crate::complexity::{Complexity, ComplexityClass};
use crate::error::{Result, SolverError};
use crate::matrix::Matrix;
use crate::solver::{SolverAlgorithm, SolverOptions, SolverResult};
use crate::types::Precision;
use alloc::vec::Vec;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SparseDelta {
pub indices: Vec<usize>,
pub values: Vec<Precision>,
}
impl SparseDelta {
pub fn new(indices: Vec<usize>, values: Vec<Precision>) -> Result<Self> {
if indices.len() != values.len() {
return Err(SolverError::InvalidInput {
message: alloc::format!(
"SparseDelta::new: indices.len()={} != values.len()={}",
indices.len(),
values.len()
),
parameter: Some(alloc::string::String::from("indices/values")),
});
}
Ok(Self { indices, values })
}
pub fn empty() -> Self {
Self {
indices: Vec::new(),
values: Vec::new(),
}
}
pub fn nnz(&self) -> usize {
self.indices.len()
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
pub fn apply_to(&self, b: &mut [Precision]) -> Result<()> {
for (&i, &v) in self.indices.iter().zip(self.values.iter()) {
if i >= b.len() {
return Err(SolverError::IndexOutOfBounds {
index: i,
max_index: b.len().saturating_sub(1),
context: alloc::string::String::from("SparseDelta::apply_to"),
});
}
b[i] += v;
}
Ok(())
}
pub fn as_pairs(&self) -> Vec<(usize, Precision)> {
self.indices
.iter()
.copied()
.zip(self.values.iter().copied())
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct IncrementalConfig {
pub full_solve_break_even: usize,
pub warm_start: bool,
}
impl Default for IncrementalConfig {
fn default() -> Self {
Self {
full_solve_break_even: 64,
warm_start: true,
}
}
}
pub trait IncrementalSolver: SolverAlgorithm {
fn solve_on_change(
&self,
matrix: &dyn Matrix,
prev_solution: &[Precision],
delta: &SparseDelta,
options: &SolverOptions,
) -> Result<SolverResult> {
self.solve_on_change_with(
matrix,
prev_solution,
delta,
options,
&IncrementalConfig::default(),
)
}
fn solve_on_change_with(
&self,
matrix: &dyn Matrix,
prev_solution: &[Precision],
delta: &SparseDelta,
options: &SolverOptions,
_inc_config: &IncrementalConfig,
) -> Result<SolverResult> {
if prev_solution.len() != matrix.rows() {
return Err(SolverError::DimensionMismatch {
expected: matrix.rows(),
actual: prev_solution.len(),
operation: alloc::string::String::from("solve_on_change.prev_solution"),
});
}
let n = matrix.rows();
let mut r = alloc::vec![0.0; n];
matrix.multiply_vector(prev_solution, &mut r)?;
for ri in r.iter_mut() {
*ri = -*ri;
}
for ri in r.iter_mut() {
*ri = 0.0; }
delta.apply_to(&mut r)?;
let dx_result = self.solve(matrix, &r, options)?;
let mut x_new = prev_solution.to_vec();
for (xi, dxi) in x_new.iter_mut().zip(dx_result.solution.iter()) {
*xi += dxi;
}
Ok(SolverResult {
solution: x_new,
residual_norm: dx_result.residual_norm,
iterations: dx_result.iterations,
converged: dx_result.converged,
error_bounds: dx_result.error_bounds,
stats: dx_result.stats,
memory_info: dx_result.memory_info,
profile_data: dx_result.profile_data,
})
}
}
impl<T: SolverAlgorithm + ?Sized> IncrementalSolver for T {}
pub struct IncrementalSolveOp;
impl Complexity for IncrementalSolveOp {
const CLASS: ComplexityClass = ComplexityClass::Adaptive {
default: &ComplexityClass::Linear,
worst: &ComplexityClass::Linear,
};
const DETAIL: &'static str =
"O(k_warm · nnz(A)) per call where k_warm ≪ k_cold for small deltas on \
well-conditioned DD systems; falls back to full solve when \
nnz(delta) > full_solve_break_even (default 64).";
}
#[cfg(test)]
mod tests {
use super::*;
use crate::matrix::SparseMatrix;
use crate::solver::neumann::NeumannSolver;
fn build_test_system() -> (SparseMatrix, Vec<Precision>) {
let triplets = alloc::vec![
(0usize, 0, 5.0), (0, 1, 1.0),
(1, 0, 1.0), (1, 1, 5.0), (1, 2, 1.0),
(2, 1, 1.0), (2, 2, 5.0), (2, 3, 1.0),
(3, 2, 1.0), (3, 3, 5.0), (3, 4, 1.0),
(4, 3, 1.0), (4, 4, 5.0),
];
let m = SparseMatrix::from_triplets(triplets, 5, 5).unwrap();
let b = alloc::vec![1.0, 2.0, 3.0, 4.0, 5.0];
(m, b)
}
#[test]
fn sparse_delta_apply_correct() {
let mut b = alloc::vec![0.0; 5];
let d = SparseDelta::new(alloc::vec![1, 3], alloc::vec![10.0, -5.0]).unwrap();
d.apply_to(&mut b).unwrap();
assert_eq!(b, alloc::vec![0.0, 10.0, 0.0, -5.0, 0.0]);
}
#[test]
fn sparse_delta_validation_rejects_length_mismatch() {
let r = SparseDelta::new(alloc::vec![1, 3], alloc::vec![10.0]);
assert!(r.is_err(), "should reject mismatched lengths");
}
#[test]
fn sparse_delta_apply_rejects_out_of_bounds() {
let mut b = alloc::vec![0.0; 3];
let d = SparseDelta::new(alloc::vec![10], alloc::vec![1.0]).unwrap();
let r = d.apply_to(&mut b);
assert!(matches!(r, Err(SolverError::IndexOutOfBounds { .. })));
}
#[test]
fn incremental_solve_matches_full_solve_on_same_b() {
let (m, b) = build_test_system();
let solver = NeumannSolver::new(64, 1e-12);
let opts = SolverOptions::default();
let full = solver.solve(&m, &b, &opts).unwrap();
let empty = SparseDelta::empty();
let inc = solver
.solve_on_change(&m, &full.solution, &empty, &opts)
.unwrap();
for (a, c) in full.solution.iter().zip(inc.solution.iter()) {
assert!(
(a - c).abs() < 1e-6,
"full {a} vs incremental {c} diverge beyond tolerance"
);
}
}
#[test]
fn incremental_solve_tracks_new_solution_when_b_changes() {
let (m, b) = build_test_system();
let solver = NeumannSolver::new(64, 1e-12);
let opts = SolverOptions::default();
let prev = solver.solve(&m, &b, &opts).unwrap();
let delta = SparseDelta::new(alloc::vec![2], alloc::vec![0.5]).unwrap();
let inc = solver
.solve_on_change(&m, &prev.solution, &delta, &opts)
.unwrap();
let mut b_new = b.clone();
delta.apply_to(&mut b_new).unwrap();
let cold = solver.solve(&m, &b_new, &opts).unwrap();
for (a, c) in cold.solution.iter().zip(inc.solution.iter()) {
assert!(
(a - c).abs() < 1e-4,
"cold {a} vs incremental {c} differ beyond tolerance"
);
}
}
#[test]
fn warm_start_uses_fewer_iters_than_cold_for_small_delta() {
let (m, b) = build_test_system();
let solver = NeumannSolver::new(64, 1e-10);
let opts = SolverOptions {
tolerance: 1e-8,
max_iterations: 200,
..SolverOptions::default()
};
let prev = solver.solve(&m, &b, &opts).unwrap();
let delta = SparseDelta::new(alloc::vec![2], alloc::vec![0.05]).unwrap();
let warm = solver
.solve_on_change(&m, &prev.solution, &delta, &opts)
.unwrap();
let mut b_new = b.clone();
delta.apply_to(&mut b_new).unwrap();
let cold = solver.solve(&m, &b_new, &opts).unwrap();
assert!(warm.converged, "warm-start must converge");
assert!(cold.converged, "cold-start must converge");
assert!(
warm.iterations <= cold.iterations,
"warm-start iterations ({}) should be <= cold-start ({}) on a small delta",
warm.iterations, cold.iterations,
);
}
}