use crate::error::{Result, SolverError};
use crate::matrix::Matrix;
use crate::types::{
ConvergenceMode, ErrorBounds, MemoryInfo, NormType, Precision, ProfileData, SolverStats,
};
use alloc::{string::String, vec::Vec};
pub mod neumann;
pub use neumann::NeumannSolver;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SolverOptions {
pub tolerance: Precision,
pub max_iterations: usize,
pub convergence_mode: ConvergenceMode,
pub norm_type: NormType,
pub collect_stats: bool,
pub streaming_interval: usize,
pub initial_guess: Option<Vec<Precision>>,
pub compute_error_bounds: bool,
pub error_bounds_tolerance: Precision,
pub enable_profiling: bool,
pub random_seed: Option<u64>,
pub coherence_threshold: Precision,
}
impl Default for SolverOptions {
fn default() -> Self {
Self {
tolerance: 1e-6,
max_iterations: 1000,
convergence_mode: ConvergenceMode::ResidualNorm,
norm_type: NormType::L2,
collect_stats: false,
streaming_interval: 0,
initial_guess: None,
compute_error_bounds: false,
error_bounds_tolerance: 1e-8,
enable_profiling: false,
random_seed: None,
coherence_threshold: 0.0,
}
}
}
impl SolverOptions {
pub fn high_precision() -> Self {
Self {
tolerance: 1e-12,
max_iterations: 5000,
convergence_mode: ConvergenceMode::Combined,
norm_type: NormType::L2,
collect_stats: true,
streaming_interval: 0,
initial_guess: None,
compute_error_bounds: true,
error_bounds_tolerance: 1e-14,
enable_profiling: false,
random_seed: None,
coherence_threshold: 0.0,
}
}
pub fn fast() -> Self {
Self {
tolerance: 1e-3,
max_iterations: 100,
convergence_mode: ConvergenceMode::ResidualNorm,
norm_type: NormType::L2,
collect_stats: false,
streaming_interval: 0,
initial_guess: None,
compute_error_bounds: false,
error_bounds_tolerance: 1e-4,
enable_profiling: false,
random_seed: None,
coherence_threshold: 0.0,
}
}
pub fn streaming(interval: usize) -> Self {
Self {
tolerance: 1e-4,
max_iterations: 1000,
convergence_mode: ConvergenceMode::ResidualNorm,
norm_type: NormType::L2,
collect_stats: true,
streaming_interval: interval,
initial_guess: None,
compute_error_bounds: false,
error_bounds_tolerance: 1e-6,
enable_profiling: true,
random_seed: None,
coherence_threshold: 0.0,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SolverResult {
pub solution: Vec<Precision>,
pub residual_norm: Precision,
pub iterations: usize,
pub converged: bool,
pub error_bounds: Option<ErrorBounds>,
pub stats: Option<SolverStats>,
pub memory_info: Option<MemoryInfo>,
pub profile_data: Option<Vec<ProfileData>>,
}
impl SolverResult {
pub fn success(solution: Vec<Precision>, residual_norm: Precision, iterations: usize) -> Self {
Self {
solution,
residual_norm,
iterations,
converged: true,
error_bounds: None,
stats: None,
memory_info: None,
profile_data: None,
}
}
pub fn failure(solution: Vec<Precision>, residual_norm: Precision, iterations: usize) -> Self {
Self {
solution,
residual_norm,
iterations,
converged: false,
error_bounds: None,
stats: None,
memory_info: None,
profile_data: None,
}
}
pub fn error(error: SolverError) -> Self {
Self {
solution: Vec::new(),
residual_norm: Precision::INFINITY,
iterations: 0,
converged: false,
error_bounds: None,
stats: None,
memory_info: None,
profile_data: None,
}
}
pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
self.converged && self.residual_norm <= tolerance
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PartialSolution {
pub iteration: usize,
pub solution: Vec<Precision>,
pub residual_norm: Precision,
pub converged: bool,
pub estimated_remaining: Option<usize>,
#[cfg(feature = "std")]
#[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
pub timestamp: std::time::Instant,
#[cfg(not(feature = "std"))]
pub timestamp: u64,
}
pub trait SolverAlgorithm: Send + Sync {
type State: SolverState;
fn initialize(
&self,
matrix: &dyn Matrix,
b: &[Precision],
options: &SolverOptions,
) -> Result<Self::State>;
fn step(&self, state: &mut Self::State) -> Result<StepResult>;
fn is_converged(&self, state: &Self::State) -> bool;
fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
fn algorithm_name(&self) -> &'static str;
fn solve(
&self,
matrix: &dyn Matrix,
b: &[Precision],
options: &SolverOptions,
) -> Result<SolverResult> {
let mut state = self.initialize(matrix, b, options)?;
let mut iterations = 0;
#[cfg(feature = "std")]
let start_time = std::time::Instant::now();
while !self.is_converged(&state) && iterations < options.max_iterations {
match self.step(&mut state)? {
StepResult::Continue => {
iterations += 1;
let residual = state.residual_norm();
if !residual.is_finite() {
return Err(SolverError::NumericalInstability {
reason: "Non-finite residual norm".to_string(),
iteration: iterations,
residual_norm: residual,
});
}
}
StepResult::Converged => break,
StepResult::Failed(reason) => {
return Err(SolverError::AlgorithmError {
algorithm: self.algorithm_name().to_string(),
message: reason,
context: vec![
("iteration".to_string(), iterations.to_string()),
(
"residual_norm".to_string(),
state.residual_norm().to_string(),
),
],
});
}
}
}
let converged = self.is_converged(&state);
let solution = self.extract_solution(&state);
let residual_norm = state.residual_norm();
if !converged && iterations >= options.max_iterations {
return Err(SolverError::ConvergenceFailure {
iterations,
residual_norm,
tolerance: options.tolerance,
algorithm: self.algorithm_name().to_string(),
});
}
let mut result = if converged {
SolverResult::success(solution, residual_norm, iterations)
} else {
SolverResult::failure(solution, residual_norm, iterations)
};
if options.collect_stats {
#[cfg(feature = "std")]
{
let total_time = start_time.elapsed().as_millis() as f64;
let mut stats = SolverStats::new();
stats.total_time_ms = total_time;
stats.matvec_count = state.matvec_count();
result.stats = Some(stats);
}
}
if options.compute_error_bounds {
result.error_bounds = state.error_bounds();
}
Ok(result)
}
}
pub trait SolverState: Send + Sync {
fn residual_norm(&self) -> Precision;
fn matvec_count(&self) -> usize;
fn error_bounds(&self) -> Option<ErrorBounds>;
fn memory_usage(&self) -> MemoryInfo;
fn reset(&mut self);
}
#[derive(Debug, Clone, PartialEq)]
pub enum StepResult {
Continue,
Converged,
Failed(String),
}
pub mod utils {
use super::*;
pub fn l2_norm(v: &[Precision]) -> Precision {
v.iter().map(|x| x * x).sum::<Precision>().sqrt()
}
pub fn l1_norm(v: &[Precision]) -> Precision {
v.iter().map(|x| x.abs()).sum()
}
pub fn linf_norm(v: &[Precision]) -> Precision {
v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
}
pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
match norm_type {
NormType::L1 => l1_norm(v),
NormType::L2 => l2_norm(v),
NormType::LInfinity => linf_norm(v),
NormType::Weighted => l2_norm(v), }
}
pub fn compute_residual(
matrix: &dyn Matrix,
x: &[Precision],
b: &[Precision],
residual: &mut [Precision],
) -> Result<()> {
matrix.multiply_vector(x, residual)?;
for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
*r -= b_val;
}
Ok(())
}
pub fn check_convergence(
residual_norm: Precision,
tolerance: Precision,
mode: ConvergenceMode,
b_norm: Precision,
prev_solution: Option<&[Precision]>,
current_solution: &[Precision],
) -> bool {
match mode {
ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
ConvergenceMode::RelativeResidual => {
if b_norm > 0.0 {
(residual_norm / b_norm) <= tolerance
} else {
residual_norm <= tolerance
}
}
ConvergenceMode::SolutionChange => {
if let Some(prev) = prev_solution {
let mut change_norm = 0.0;
for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
let diff = curr - prev_val;
change_norm += diff * diff;
}
change_norm.sqrt() <= tolerance
} else {
false
}
}
ConvergenceMode::RelativeSolutionChange => {
if let Some(prev) = prev_solution {
let mut change_norm = 0.0;
let mut solution_norm = 0.0;
for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
let diff = curr - prev_val;
change_norm += diff * diff;
solution_norm += prev_val * prev_val;
}
if solution_norm > 0.0 {
(change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
} else {
change_norm.sqrt() <= tolerance
}
} else {
false
}
}
ConvergenceMode::Combined => {
residual_norm <= tolerance
&& (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
}
}
}
}
pub struct ForwardPushSolver;
pub struct BackwardPushSolver;
pub struct HybridSolver;
impl SolverAlgorithm for ForwardPushSolver {
type State = ();
fn initialize(
&self,
_matrix: &dyn Matrix,
_b: &[Precision],
_options: &SolverOptions,
) -> Result<Self::State> {
Err(SolverError::AlgorithmError {
algorithm: "forward_push".to_string(),
message: "Not implemented yet".to_string(),
context: vec![],
})
}
fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
Err(SolverError::AlgorithmError {
algorithm: "forward_push".to_string(),
message: "Not implemented yet".to_string(),
context: vec![],
})
}
fn is_converged(&self, _state: &Self::State) -> bool {
false
}
fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
Vec::new()
}
fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
Err(SolverError::AlgorithmError {
algorithm: "forward_push".to_string(),
message: "Not implemented yet".to_string(),
context: vec![],
})
}
fn algorithm_name(&self) -> &'static str {
"forward_push"
}
}
impl SolverState for () {
fn residual_norm(&self) -> Precision {
0.0
}
fn matvec_count(&self) -> usize {
0
}
fn error_bounds(&self) -> Option<ErrorBounds> {
None
}
fn memory_usage(&self) -> MemoryInfo {
MemoryInfo {
current_usage_bytes: 0,
peak_usage_bytes: 0,
matrix_memory_bytes: 0,
vector_memory_bytes: 0,
workspace_memory_bytes: 0,
allocation_count: 0,
deallocation_count: 0,
}
}
fn reset(&mut self) {}
}
impl SolverAlgorithm for BackwardPushSolver {
type State = ();
fn initialize(
&self,
_matrix: &dyn Matrix,
_b: &[Precision],
_options: &SolverOptions,
) -> Result<Self::State> {
Ok(())
}
fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
Ok(StepResult::Converged)
}
fn is_converged(&self, _state: &Self::State) -> bool {
true
}
fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
Vec::new()
}
fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
Ok(())
}
fn algorithm_name(&self) -> &'static str {
"backward_push"
}
}
impl SolverAlgorithm for HybridSolver {
type State = ();
fn initialize(
&self,
_matrix: &dyn Matrix,
_b: &[Precision],
_options: &SolverOptions,
) -> Result<Self::State> {
Ok(())
}
fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
Ok(StepResult::Converged)
}
fn is_converged(&self, _state: &Self::State) -> bool {
true
}
fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
Vec::new()
}
fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
Ok(())
}
fn algorithm_name(&self) -> &'static str {
"hybrid"
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::matrix::SparseMatrix;
#[test]
fn test_solver_options() {
let default_opts = SolverOptions::default();
assert_eq!(default_opts.tolerance, 1e-6);
assert_eq!(default_opts.max_iterations, 1000);
let fast_opts = SolverOptions::fast();
assert_eq!(fast_opts.tolerance, 1e-3);
assert_eq!(fast_opts.max_iterations, 100);
let precision_opts = SolverOptions::high_precision();
assert_eq!(precision_opts.tolerance, 1e-12);
assert!(precision_opts.compute_error_bounds);
}
#[test]
fn test_solver_result() {
let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
assert!(result.converged);
assert!(result.meets_quality_criteria(1e-6));
assert!(!result.meets_quality_criteria(1e-10));
}
#[test]
fn test_norm_calculations() {
use utils::*;
let v = vec![3.0, 4.0];
assert_eq!(l1_norm(&v), 7.0);
assert_eq!(l2_norm(&v), 5.0);
assert_eq!(linf_norm(&v), 4.0);
}
}