Skip to main content

sublinear_solver/solver/
mod.rs

1//! Sublinear-time solver algorithms for asymmetric diagonally dominant systems.
2//!
3//! This module implements the core solver algorithms including Neumann series,
4//! forward/backward push methods, and hybrid random-walk approaches.
5
6use crate::matrix::Matrix;
7use crate::types::{
8    Precision, ConvergenceMode, NormType, ErrorBounds, SolverStats,
9    DimensionType, MemoryInfo, ProfileData
10};
11use crate::error::{SolverError, Result};
12use alloc::{vec::Vec, string::String, boxed::Box};
13
14pub mod neumann;
15
16// Re-export solver implementations
17pub use neumann::NeumannSolver;
18
19/// Configuration options for solver algorithms.
20#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SolverOptions {
23    /// Convergence tolerance
24    pub tolerance: Precision,
25    /// Maximum number of iterations
26    pub max_iterations: usize,
27    /// Convergence detection mode
28    pub convergence_mode: ConvergenceMode,
29    /// Norm type for error measurement
30    pub norm_type: NormType,
31    /// Enable detailed statistics collection
32    pub collect_stats: bool,
33    /// Streaming solution interval (0 = no streaming)
34    pub streaming_interval: usize,
35    /// Initial guess for the solution (if None, use zero)
36    pub initial_guess: Option<Vec<Precision>>,
37    /// Enable error bounds computation
38    pub compute_error_bounds: bool,
39    /// Relative tolerance for error bounds
40    pub error_bounds_tolerance: Precision,
41    /// Enable performance profiling
42    pub enable_profiling: bool,
43    /// Random seed for stochastic algorithms
44    pub random_seed: Option<u64>,
45    /// Coherence gate threshold (ADR-001 roadmap item #3). If the matrix's
46    /// diagonal-dominance margin (`coherence::coherence_score`) falls below
47    /// this value, the solver returns `Err(SolverError::Incoherent { .. })`
48    /// *before* doing any iterative work — refusing to spend polynomial
49    /// cost on a near-singular system that can only produce an ε-quality
50    /// answer.
51    ///
52    /// `0.0` (the default) disables the gate, preserving wire compatibility
53    /// with every existing caller. Recommended starting value when enabling:
54    /// `0.05`.
55    pub coherence_threshold: Precision,
56}
57
58impl Default for SolverOptions {
59    fn default() -> Self {
60        Self {
61            tolerance: 1e-6,
62            max_iterations: 1000,
63            convergence_mode: ConvergenceMode::ResidualNorm,
64            norm_type: NormType::L2,
65            collect_stats: false,
66            streaming_interval: 0,
67            initial_guess: None,
68            compute_error_bounds: false,
69            error_bounds_tolerance: 1e-8,
70            enable_profiling: false,
71            random_seed: None,
72            // Gate disabled by default. Callers opt in by setting > 0.
73            coherence_threshold: 0.0,
74        }
75    }
76}
77
78impl SolverOptions {
79    /// Create options optimized for high precision.
80    pub fn high_precision() -> Self {
81        Self {
82            tolerance: 1e-12,
83            max_iterations: 5000,
84            convergence_mode: ConvergenceMode::Combined,
85            norm_type: NormType::L2,
86            collect_stats: true,
87            streaming_interval: 0,
88            initial_guess: None,
89            compute_error_bounds: true,
90            error_bounds_tolerance: 1e-14,
91            enable_profiling: false,
92            random_seed: None,
93            coherence_threshold: 0.0,
94        }
95    }
96
97    /// Create options optimized for fast solving.
98    pub fn fast() -> Self {
99        Self {
100            tolerance: 1e-3,
101            max_iterations: 100,
102            convergence_mode: ConvergenceMode::ResidualNorm,
103            norm_type: NormType::L2,
104            collect_stats: false,
105            streaming_interval: 0,
106            initial_guess: None,
107            compute_error_bounds: false,
108            error_bounds_tolerance: 1e-4,
109            enable_profiling: false,
110            random_seed: None,
111            coherence_threshold: 0.0,
112        }
113    }
114
115    /// Create options optimized for streaming applications.
116    pub fn streaming(interval: usize) -> Self {
117        Self {
118            tolerance: 1e-4,
119            max_iterations: 1000,
120            convergence_mode: ConvergenceMode::ResidualNorm,
121            norm_type: NormType::L2,
122            collect_stats: true,
123            streaming_interval: interval,
124            initial_guess: None,
125            compute_error_bounds: false,
126            error_bounds_tolerance: 1e-6,
127            enable_profiling: true,
128            random_seed: None,
129            coherence_threshold: 0.0,
130        }
131    }
132}
133
134/// Result of a solver computation.
135#[derive(Debug, Clone, PartialEq)]
136#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
137pub struct SolverResult {
138    /// Final solution vector
139    pub solution: Vec<Precision>,
140    /// Final residual norm
141    pub residual_norm: Precision,
142    /// Number of iterations performed
143    pub iterations: usize,
144    /// Whether the algorithm converged
145    pub converged: bool,
146    /// Error bounds (if computed)
147    pub error_bounds: Option<ErrorBounds>,
148    /// Detailed statistics (if collected)
149    pub stats: Option<SolverStats>,
150    /// Memory usage information
151    pub memory_info: Option<MemoryInfo>,
152    /// Performance profiling data
153    pub profile_data: Option<Vec<ProfileData>>,
154}
155
156impl SolverResult {
157    /// Create a successful result.
158    pub fn success(
159        solution: Vec<Precision>,
160        residual_norm: Precision,
161        iterations: usize,
162    ) -> Self {
163        Self {
164            solution,
165            residual_norm,
166            iterations,
167            converged: true,
168            error_bounds: None,
169            stats: None,
170            memory_info: None,
171            profile_data: None,
172        }
173    }
174
175    /// Create a failure result.
176    pub fn failure(
177        solution: Vec<Precision>,
178        residual_norm: Precision,
179        iterations: usize,
180    ) -> Self {
181        Self {
182            solution,
183            residual_norm,
184            iterations,
185            converged: false,
186            error_bounds: None,
187            stats: None,
188            memory_info: None,
189            profile_data: None,
190        }
191    }
192
193    /// Create an error result.
194    pub fn error(error: SolverError) -> Self {
195        Self {
196            solution: Vec::new(),
197            residual_norm: Precision::INFINITY,
198            iterations: 0,
199            converged: false,
200            error_bounds: None,
201            stats: None,
202            memory_info: None,
203            profile_data: None,
204        }
205    }
206
207    /// Check if the solution meets the specified quality criteria.
208    pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
209        self.converged && self.residual_norm <= tolerance
210    }
211}
212
213/// Partial solution for streaming applications.
214#[derive(Debug, Clone, PartialEq)]
215#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
216pub struct PartialSolution {
217    /// Current iteration number
218    pub iteration: usize,
219    /// Current solution estimate
220    pub solution: Vec<Precision>,
221    /// Current residual norm
222    pub residual_norm: Precision,
223    /// Whether convergence has been achieved
224    pub converged: bool,
225    /// Estimated remaining iterations
226    pub estimated_remaining: Option<usize>,
227    /// Timestamp when this solution was computed (not serialized)
228    #[cfg(feature = "std")]
229    #[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
230    pub timestamp: std::time::Instant,
231    #[cfg(not(feature = "std"))]
232    pub timestamp: u64,
233}
234
235/// Core trait for all solver algorithms.
236///
237/// This trait defines the interface that all sublinear-time solvers must implement,
238/// providing both batch and streaming solution capabilities.
239pub trait SolverAlgorithm: Send + Sync {
240    /// Solver-specific state type
241    type State: SolverState;
242
243    /// Initialize the solver state for a given problem.
244    fn initialize(
245        &self,
246        matrix: &dyn Matrix,
247        b: &[Precision],
248        options: &SolverOptions,
249    ) -> Result<Self::State>;
250
251    /// Perform a single iteration step.
252    fn step(&self, state: &mut Self::State) -> Result<StepResult>;
253
254    /// Check if the current state meets convergence criteria.
255    fn is_converged(&self, state: &Self::State) -> bool;
256
257    /// Extract the current solution from the state.
258    fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
259
260    /// Update the right-hand side for incremental solving.
261    fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
262
263    /// Get the algorithm name for identification.
264    fn algorithm_name(&self) -> &'static str;
265
266    /// Solve the linear system Ax = b.
267    ///
268    /// This is the main interface for solving linear systems. It handles
269    /// the iteration loop and convergence checking automatically.
270    fn solve(
271        &self,
272        matrix: &dyn Matrix,
273        b: &[Precision],
274        options: &SolverOptions,
275    ) -> Result<SolverResult> {
276        let mut state = self.initialize(matrix, b, options)?;
277        let mut iterations = 0;
278
279        #[cfg(feature = "std")]
280        let start_time = std::time::Instant::now();
281
282        while !self.is_converged(&state) && iterations < options.max_iterations {
283            match self.step(&mut state)? {
284                StepResult::Continue => {
285                    iterations += 1;
286
287                    // Check for numerical issues
288                    let residual = state.residual_norm();
289                    if !residual.is_finite() {
290                        return Err(SolverError::NumericalInstability {
291                            reason: "Non-finite residual norm".to_string(),
292                            iteration: iterations,
293                            residual_norm: residual,
294                        });
295                    }
296                },
297                StepResult::Converged => break,
298                StepResult::Failed(reason) => {
299                    return Err(SolverError::AlgorithmError {
300                        algorithm: self.algorithm_name().to_string(),
301                        message: reason,
302                        context: vec![
303                            ("iteration".to_string(), iterations.to_string()),
304                            ("residual_norm".to_string(), state.residual_norm().to_string()),
305                        ],
306                    });
307                }
308            }
309        }
310
311        let converged = self.is_converged(&state);
312        let solution = self.extract_solution(&state);
313        let residual_norm = state.residual_norm();
314
315        // Check for convergence failure
316        if !converged && iterations >= options.max_iterations {
317            return Err(SolverError::ConvergenceFailure {
318                iterations,
319                residual_norm,
320                tolerance: options.tolerance,
321                algorithm: self.algorithm_name().to_string(),
322            });
323        }
324
325        let mut result = if converged {
326            SolverResult::success(solution, residual_norm, iterations)
327        } else {
328            SolverResult::failure(solution, residual_norm, iterations)
329        };
330
331        // Add optional data if requested
332        if options.collect_stats {
333            #[cfg(feature = "std")]
334            {
335                let total_time = start_time.elapsed().as_millis() as f64;
336                let mut stats = SolverStats::new();
337                stats.total_time_ms = total_time;
338                stats.matvec_count = state.matvec_count();
339                result.stats = Some(stats);
340            }
341        }
342
343        if options.compute_error_bounds {
344            result.error_bounds = state.error_bounds();
345        }
346
347        Ok(result)
348    }
349}
350
351/// Trait for solver state management.
352pub trait SolverState: Send + Sync {
353    /// Get the current residual norm.
354    fn residual_norm(&self) -> Precision;
355
356    /// Get the number of matrix-vector multiplications performed.
357    fn matvec_count(&self) -> usize;
358
359    /// Get error bounds if available.
360    fn error_bounds(&self) -> Option<ErrorBounds>;
361
362    /// Get current memory usage.
363    fn memory_usage(&self) -> MemoryInfo;
364
365    /// Reset the state for a new solve.
366    fn reset(&mut self);
367}
368
369/// Result of a single iteration step.
370#[derive(Debug, Clone, PartialEq)]
371pub enum StepResult {
372    /// Continue iterating
373    Continue,
374    /// Convergence achieved
375    Converged,
376    /// Algorithm failed with reason
377    Failed(String),
378}
379
380/// Utility functions for solver implementations.
381pub mod utils {
382    use super::*;
383
384    /// Compute the L2 norm of a vector.
385    pub fn l2_norm(v: &[Precision]) -> Precision {
386        v.iter().map(|x| x * x).sum::<Precision>().sqrt()
387    }
388
389    /// Compute the L1 norm of a vector.
390    pub fn l1_norm(v: &[Precision]) -> Precision {
391        v.iter().map(|x| x.abs()).sum()
392    }
393
394    /// Compute the L∞ norm of a vector.
395    pub fn linf_norm(v: &[Precision]) -> Precision {
396        v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
397    }
398
399    /// Compute vector norm according to specified type.
400    pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
401        match norm_type {
402            NormType::L1 => l1_norm(v),
403            NormType::L2 => l2_norm(v),
404            NormType::LInfinity => linf_norm(v),
405            NormType::Weighted => l2_norm(v), // Default to L2 for weighted
406        }
407    }
408
409    /// Compute residual vector: r = A*x - b
410    pub fn compute_residual(
411        matrix: &dyn Matrix,
412        x: &[Precision],
413        b: &[Precision],
414        residual: &mut [Precision],
415    ) -> Result<()> {
416        matrix.multiply_vector(x, residual)?;
417        for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
418            *r -= b_val;
419        }
420        Ok(())
421    }
422
423    /// Check convergence based on specified criteria.
424    pub fn check_convergence(
425        residual_norm: Precision,
426        tolerance: Precision,
427        mode: ConvergenceMode,
428        b_norm: Precision,
429        prev_solution: Option<&[Precision]>,
430        current_solution: &[Precision],
431    ) -> bool {
432        match mode {
433            ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
434            ConvergenceMode::RelativeResidual => {
435                if b_norm > 0.0 {
436                    (residual_norm / b_norm) <= tolerance
437                } else {
438                    residual_norm <= tolerance
439                }
440            },
441            ConvergenceMode::SolutionChange => {
442                if let Some(prev) = prev_solution {
443                    let mut change_norm = 0.0;
444                    for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
445                        let diff = curr - prev_val;
446                        change_norm += diff * diff;
447                    }
448                    change_norm.sqrt() <= tolerance
449                } else {
450                    false
451                }
452            },
453            ConvergenceMode::RelativeSolutionChange => {
454                if let Some(prev) = prev_solution {
455                    let mut change_norm = 0.0;
456                    let mut solution_norm = 0.0;
457                    for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
458                        let diff = curr - prev_val;
459                        change_norm += diff * diff;
460                        solution_norm += prev_val * prev_val;
461                    }
462                    if solution_norm > 0.0 {
463                        (change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
464                    } else {
465                        change_norm.sqrt() <= tolerance
466                    }
467                } else {
468                    false
469                }
470            },
471            ConvergenceMode::Combined => {
472                // Use the most conservative criterion
473                residual_norm <= tolerance &&
474                (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
475            },
476        }
477    }
478}
479
480// Forward declarations for solver implementations that will be added
481pub struct ForwardPushSolver;
482pub struct BackwardPushSolver;
483pub struct HybridSolver;
484
485// Placeholder implementations - will be implemented in separate modules
486impl SolverAlgorithm for ForwardPushSolver {
487    type State = ();
488
489    fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> {
490        Err(SolverError::AlgorithmError {
491            algorithm: "forward_push".to_string(),
492            message: "Not implemented yet".to_string(),
493            context: vec![],
494        })
495    }
496
497    fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
498        Err(SolverError::AlgorithmError {
499            algorithm: "forward_push".to_string(),
500            message: "Not implemented yet".to_string(),
501            context: vec![],
502        })
503    }
504
505    fn is_converged(&self, _state: &Self::State) -> bool {
506        false
507    }
508
509    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
510        Vec::new()
511    }
512
513    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
514        Err(SolverError::AlgorithmError {
515            algorithm: "forward_push".to_string(),
516            message: "Not implemented yet".to_string(),
517            context: vec![],
518        })
519    }
520
521    fn algorithm_name(&self) -> &'static str {
522        "forward_push"
523    }
524}
525
526impl SolverState for () {
527    fn residual_norm(&self) -> Precision {
528        0.0
529    }
530
531    fn matvec_count(&self) -> usize {
532        0
533    }
534
535    fn error_bounds(&self) -> Option<ErrorBounds> {
536        None
537    }
538
539    fn memory_usage(&self) -> MemoryInfo {
540        MemoryInfo {
541            current_usage_bytes: 0,
542            peak_usage_bytes: 0,
543            matrix_memory_bytes: 0,
544            vector_memory_bytes: 0,
545            workspace_memory_bytes: 0,
546            allocation_count: 0,
547            deallocation_count: 0,
548        }
549    }
550
551    fn reset(&mut self) {}
552}
553
554// Similar placeholder implementations for BackwardPushSolver and HybridSolver
555impl SolverAlgorithm for BackwardPushSolver {
556    type State = ();
557    fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
558    fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
559    fn is_converged(&self, _state: &Self::State) -> bool { true }
560    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
561    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
562    fn algorithm_name(&self) -> &'static str { "backward_push" }
563}
564
565impl SolverAlgorithm for HybridSolver {
566    type State = ();
567    fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
568    fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
569    fn is_converged(&self, _state: &Self::State) -> bool { true }
570    fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
571    fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
572    fn algorithm_name(&self) -> &'static str { "hybrid" }
573}
574
575#[cfg(all(test, feature = "std"))]
576mod tests {
577    use super::*;
578    use crate::matrix::SparseMatrix;
579
580    #[test]
581    fn test_solver_options() {
582        let default_opts = SolverOptions::default();
583        assert_eq!(default_opts.tolerance, 1e-6);
584        assert_eq!(default_opts.max_iterations, 1000);
585
586        let fast_opts = SolverOptions::fast();
587        assert_eq!(fast_opts.tolerance, 1e-3);
588        assert_eq!(fast_opts.max_iterations, 100);
589
590        let precision_opts = SolverOptions::high_precision();
591        assert_eq!(precision_opts.tolerance, 1e-12);
592        assert!(precision_opts.compute_error_bounds);
593    }
594
595    #[test]
596    fn test_solver_result() {
597        let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
598        assert!(result.converged);
599        assert!(result.meets_quality_criteria(1e-6));
600        assert!(!result.meets_quality_criteria(1e-10));
601    }
602
603    #[test]
604    fn test_norm_calculations() {
605        use utils::*;
606
607        let v = vec![3.0, 4.0];
608        assert_eq!(l1_norm(&v), 7.0);
609        assert_eq!(l2_norm(&v), 5.0);
610        assert_eq!(linf_norm(&v), 4.0);
611    }
612}