1use crate::matrix::Matrix;
7use crate::types::{
8 Precision, ConvergenceMode, NormType, ErrorBounds, SolverStats,
9 DimensionType, MemoryInfo, ProfileData
10};
11use crate::error::{SolverError, Result};
12use alloc::{vec::Vec, string::String, boxed::Box};
13
14pub mod neumann;
15
16pub use neumann::NeumannSolver;
18
19#[derive(Debug, Clone, PartialEq)]
21#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22pub struct SolverOptions {
23 pub tolerance: Precision,
25 pub max_iterations: usize,
27 pub convergence_mode: ConvergenceMode,
29 pub norm_type: NormType,
31 pub collect_stats: bool,
33 pub streaming_interval: usize,
35 pub initial_guess: Option<Vec<Precision>>,
37 pub compute_error_bounds: bool,
39 pub error_bounds_tolerance: Precision,
41 pub enable_profiling: bool,
43 pub random_seed: Option<u64>,
45 pub coherence_threshold: Precision,
56}
57
58impl Default for SolverOptions {
59 fn default() -> Self {
60 Self {
61 tolerance: 1e-6,
62 max_iterations: 1000,
63 convergence_mode: ConvergenceMode::ResidualNorm,
64 norm_type: NormType::L2,
65 collect_stats: false,
66 streaming_interval: 0,
67 initial_guess: None,
68 compute_error_bounds: false,
69 error_bounds_tolerance: 1e-8,
70 enable_profiling: false,
71 random_seed: None,
72 coherence_threshold: 0.0,
74 }
75 }
76}
77
78impl SolverOptions {
79 pub fn high_precision() -> Self {
81 Self {
82 tolerance: 1e-12,
83 max_iterations: 5000,
84 convergence_mode: ConvergenceMode::Combined,
85 norm_type: NormType::L2,
86 collect_stats: true,
87 streaming_interval: 0,
88 initial_guess: None,
89 compute_error_bounds: true,
90 error_bounds_tolerance: 1e-14,
91 enable_profiling: false,
92 random_seed: None,
93 coherence_threshold: 0.0,
94 }
95 }
96
97 pub fn fast() -> Self {
99 Self {
100 tolerance: 1e-3,
101 max_iterations: 100,
102 convergence_mode: ConvergenceMode::ResidualNorm,
103 norm_type: NormType::L2,
104 collect_stats: false,
105 streaming_interval: 0,
106 initial_guess: None,
107 compute_error_bounds: false,
108 error_bounds_tolerance: 1e-4,
109 enable_profiling: false,
110 random_seed: None,
111 coherence_threshold: 0.0,
112 }
113 }
114
115 pub fn streaming(interval: usize) -> Self {
117 Self {
118 tolerance: 1e-4,
119 max_iterations: 1000,
120 convergence_mode: ConvergenceMode::ResidualNorm,
121 norm_type: NormType::L2,
122 collect_stats: true,
123 streaming_interval: interval,
124 initial_guess: None,
125 compute_error_bounds: false,
126 error_bounds_tolerance: 1e-6,
127 enable_profiling: true,
128 random_seed: None,
129 coherence_threshold: 0.0,
130 }
131 }
132}
133
134#[derive(Debug, Clone, PartialEq)]
136#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
137pub struct SolverResult {
138 pub solution: Vec<Precision>,
140 pub residual_norm: Precision,
142 pub iterations: usize,
144 pub converged: bool,
146 pub error_bounds: Option<ErrorBounds>,
148 pub stats: Option<SolverStats>,
150 pub memory_info: Option<MemoryInfo>,
152 pub profile_data: Option<Vec<ProfileData>>,
154}
155
156impl SolverResult {
157 pub fn success(
159 solution: Vec<Precision>,
160 residual_norm: Precision,
161 iterations: usize,
162 ) -> Self {
163 Self {
164 solution,
165 residual_norm,
166 iterations,
167 converged: true,
168 error_bounds: None,
169 stats: None,
170 memory_info: None,
171 profile_data: None,
172 }
173 }
174
175 pub fn failure(
177 solution: Vec<Precision>,
178 residual_norm: Precision,
179 iterations: usize,
180 ) -> Self {
181 Self {
182 solution,
183 residual_norm,
184 iterations,
185 converged: false,
186 error_bounds: None,
187 stats: None,
188 memory_info: None,
189 profile_data: None,
190 }
191 }
192
193 pub fn error(error: SolverError) -> Self {
195 Self {
196 solution: Vec::new(),
197 residual_norm: Precision::INFINITY,
198 iterations: 0,
199 converged: false,
200 error_bounds: None,
201 stats: None,
202 memory_info: None,
203 profile_data: None,
204 }
205 }
206
207 pub fn meets_quality_criteria(&self, tolerance: Precision) -> bool {
209 self.converged && self.residual_norm <= tolerance
210 }
211}
212
213#[derive(Debug, Clone, PartialEq)]
215#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
216pub struct PartialSolution {
217 pub iteration: usize,
219 pub solution: Vec<Precision>,
221 pub residual_norm: Precision,
223 pub converged: bool,
225 pub estimated_remaining: Option<usize>,
227 #[cfg(feature = "std")]
229 #[cfg_attr(feature = "serde", serde(skip, default = "std::time::Instant::now"))]
230 pub timestamp: std::time::Instant,
231 #[cfg(not(feature = "std"))]
232 pub timestamp: u64,
233}
234
235pub trait SolverAlgorithm: Send + Sync {
240 type State: SolverState;
242
243 fn initialize(
245 &self,
246 matrix: &dyn Matrix,
247 b: &[Precision],
248 options: &SolverOptions,
249 ) -> Result<Self::State>;
250
251 fn step(&self, state: &mut Self::State) -> Result<StepResult>;
253
254 fn is_converged(&self, state: &Self::State) -> bool;
256
257 fn extract_solution(&self, state: &Self::State) -> Vec<Precision>;
259
260 fn update_rhs(&self, state: &mut Self::State, delta_b: &[(usize, Precision)]) -> Result<()>;
262
263 fn algorithm_name(&self) -> &'static str;
265
266 fn solve(
271 &self,
272 matrix: &dyn Matrix,
273 b: &[Precision],
274 options: &SolverOptions,
275 ) -> Result<SolverResult> {
276 let mut state = self.initialize(matrix, b, options)?;
277 let mut iterations = 0;
278
279 #[cfg(feature = "std")]
280 let start_time = std::time::Instant::now();
281
282 while !self.is_converged(&state) && iterations < options.max_iterations {
283 match self.step(&mut state)? {
284 StepResult::Continue => {
285 iterations += 1;
286
287 let residual = state.residual_norm();
289 if !residual.is_finite() {
290 return Err(SolverError::NumericalInstability {
291 reason: "Non-finite residual norm".to_string(),
292 iteration: iterations,
293 residual_norm: residual,
294 });
295 }
296 },
297 StepResult::Converged => break,
298 StepResult::Failed(reason) => {
299 return Err(SolverError::AlgorithmError {
300 algorithm: self.algorithm_name().to_string(),
301 message: reason,
302 context: vec![
303 ("iteration".to_string(), iterations.to_string()),
304 ("residual_norm".to_string(), state.residual_norm().to_string()),
305 ],
306 });
307 }
308 }
309 }
310
311 let converged = self.is_converged(&state);
312 let solution = self.extract_solution(&state);
313 let residual_norm = state.residual_norm();
314
315 if !converged && iterations >= options.max_iterations {
317 return Err(SolverError::ConvergenceFailure {
318 iterations,
319 residual_norm,
320 tolerance: options.tolerance,
321 algorithm: self.algorithm_name().to_string(),
322 });
323 }
324
325 let mut result = if converged {
326 SolverResult::success(solution, residual_norm, iterations)
327 } else {
328 SolverResult::failure(solution, residual_norm, iterations)
329 };
330
331 if options.collect_stats {
333 #[cfg(feature = "std")]
334 {
335 let total_time = start_time.elapsed().as_millis() as f64;
336 let mut stats = SolverStats::new();
337 stats.total_time_ms = total_time;
338 stats.matvec_count = state.matvec_count();
339 result.stats = Some(stats);
340 }
341 }
342
343 if options.compute_error_bounds {
344 result.error_bounds = state.error_bounds();
345 }
346
347 Ok(result)
348 }
349}
350
351pub trait SolverState: Send + Sync {
353 fn residual_norm(&self) -> Precision;
355
356 fn matvec_count(&self) -> usize;
358
359 fn error_bounds(&self) -> Option<ErrorBounds>;
361
362 fn memory_usage(&self) -> MemoryInfo;
364
365 fn reset(&mut self);
367}
368
369#[derive(Debug, Clone, PartialEq)]
371pub enum StepResult {
372 Continue,
374 Converged,
376 Failed(String),
378}
379
380pub mod utils {
382 use super::*;
383
384 pub fn l2_norm(v: &[Precision]) -> Precision {
386 v.iter().map(|x| x * x).sum::<Precision>().sqrt()
387 }
388
389 pub fn l1_norm(v: &[Precision]) -> Precision {
391 v.iter().map(|x| x.abs()).sum()
392 }
393
394 pub fn linf_norm(v: &[Precision]) -> Precision {
396 v.iter().map(|x| x.abs()).fold(0.0, Precision::max)
397 }
398
399 pub fn compute_norm(v: &[Precision], norm_type: NormType) -> Precision {
401 match norm_type {
402 NormType::L1 => l1_norm(v),
403 NormType::L2 => l2_norm(v),
404 NormType::LInfinity => linf_norm(v),
405 NormType::Weighted => l2_norm(v), }
407 }
408
409 pub fn compute_residual(
411 matrix: &dyn Matrix,
412 x: &[Precision],
413 b: &[Precision],
414 residual: &mut [Precision],
415 ) -> Result<()> {
416 matrix.multiply_vector(x, residual)?;
417 for (r, &b_val) in residual.iter_mut().zip(b.iter()) {
418 *r -= b_val;
419 }
420 Ok(())
421 }
422
423 pub fn check_convergence(
425 residual_norm: Precision,
426 tolerance: Precision,
427 mode: ConvergenceMode,
428 b_norm: Precision,
429 prev_solution: Option<&[Precision]>,
430 current_solution: &[Precision],
431 ) -> bool {
432 match mode {
433 ConvergenceMode::ResidualNorm => residual_norm <= tolerance,
434 ConvergenceMode::RelativeResidual => {
435 if b_norm > 0.0 {
436 (residual_norm / b_norm) <= tolerance
437 } else {
438 residual_norm <= tolerance
439 }
440 },
441 ConvergenceMode::SolutionChange => {
442 if let Some(prev) = prev_solution {
443 let mut change_norm = 0.0;
444 for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
445 let diff = curr - prev_val;
446 change_norm += diff * diff;
447 }
448 change_norm.sqrt() <= tolerance
449 } else {
450 false
451 }
452 },
453 ConvergenceMode::RelativeSolutionChange => {
454 if let Some(prev) = prev_solution {
455 let mut change_norm = 0.0;
456 let mut solution_norm = 0.0;
457 for (&curr, &prev_val) in current_solution.iter().zip(prev.iter()) {
458 let diff = curr - prev_val;
459 change_norm += diff * diff;
460 solution_norm += prev_val * prev_val;
461 }
462 if solution_norm > 0.0 {
463 (change_norm.sqrt() / solution_norm.sqrt()) <= tolerance
464 } else {
465 change_norm.sqrt() <= tolerance
466 }
467 } else {
468 false
469 }
470 },
471 ConvergenceMode::Combined => {
472 residual_norm <= tolerance &&
474 (b_norm == 0.0 || (residual_norm / b_norm) <= tolerance)
475 },
476 }
477 }
478}
479
480pub struct ForwardPushSolver;
482pub struct BackwardPushSolver;
483pub struct HybridSolver;
484
485impl SolverAlgorithm for ForwardPushSolver {
487 type State = ();
488
489 fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> {
490 Err(SolverError::AlgorithmError {
491 algorithm: "forward_push".to_string(),
492 message: "Not implemented yet".to_string(),
493 context: vec![],
494 })
495 }
496
497 fn step(&self, _state: &mut Self::State) -> Result<StepResult> {
498 Err(SolverError::AlgorithmError {
499 algorithm: "forward_push".to_string(),
500 message: "Not implemented yet".to_string(),
501 context: vec![],
502 })
503 }
504
505 fn is_converged(&self, _state: &Self::State) -> bool {
506 false
507 }
508
509 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> {
510 Vec::new()
511 }
512
513 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> {
514 Err(SolverError::AlgorithmError {
515 algorithm: "forward_push".to_string(),
516 message: "Not implemented yet".to_string(),
517 context: vec![],
518 })
519 }
520
521 fn algorithm_name(&self) -> &'static str {
522 "forward_push"
523 }
524}
525
526impl SolverState for () {
527 fn residual_norm(&self) -> Precision {
528 0.0
529 }
530
531 fn matvec_count(&self) -> usize {
532 0
533 }
534
535 fn error_bounds(&self) -> Option<ErrorBounds> {
536 None
537 }
538
539 fn memory_usage(&self) -> MemoryInfo {
540 MemoryInfo {
541 current_usage_bytes: 0,
542 peak_usage_bytes: 0,
543 matrix_memory_bytes: 0,
544 vector_memory_bytes: 0,
545 workspace_memory_bytes: 0,
546 allocation_count: 0,
547 deallocation_count: 0,
548 }
549 }
550
551 fn reset(&mut self) {}
552}
553
554impl SolverAlgorithm for BackwardPushSolver {
556 type State = ();
557 fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
558 fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
559 fn is_converged(&self, _state: &Self::State) -> bool { true }
560 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
561 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
562 fn algorithm_name(&self) -> &'static str { "backward_push" }
563}
564
565impl SolverAlgorithm for HybridSolver {
566 type State = ();
567 fn initialize(&self, _matrix: &dyn Matrix, _b: &[Precision], _options: &SolverOptions) -> Result<Self::State> { Ok(()) }
568 fn step(&self, _state: &mut Self::State) -> Result<StepResult> { Ok(StepResult::Converged) }
569 fn is_converged(&self, _state: &Self::State) -> bool { true }
570 fn extract_solution(&self, _state: &Self::State) -> Vec<Precision> { Vec::new() }
571 fn update_rhs(&self, _state: &mut Self::State, _delta_b: &[(usize, Precision)]) -> Result<()> { Ok(()) }
572 fn algorithm_name(&self) -> &'static str { "hybrid" }
573}
574
575#[cfg(all(test, feature = "std"))]
576mod tests {
577 use super::*;
578 use crate::matrix::SparseMatrix;
579
580 #[test]
581 fn test_solver_options() {
582 let default_opts = SolverOptions::default();
583 assert_eq!(default_opts.tolerance, 1e-6);
584 assert_eq!(default_opts.max_iterations, 1000);
585
586 let fast_opts = SolverOptions::fast();
587 assert_eq!(fast_opts.tolerance, 1e-3);
588 assert_eq!(fast_opts.max_iterations, 100);
589
590 let precision_opts = SolverOptions::high_precision();
591 assert_eq!(precision_opts.tolerance, 1e-12);
592 assert!(precision_opts.compute_error_bounds);
593 }
594
595 #[test]
596 fn test_solver_result() {
597 let result = SolverResult::success(vec![1.0, 2.0], 1e-8, 10);
598 assert!(result.converged);
599 assert!(result.meets_quality_criteria(1e-6));
600 assert!(!result.meets_quality_criteria(1e-10));
601 }
602
603 #[test]
604 fn test_norm_calculations() {
605 use utils::*;
606
607 let v = vec![3.0, 4.0];
608 assert_eq!(l1_norm(&v), 7.0);
609 assert_eq!(l2_norm(&v), 5.0);
610 assert_eq!(linf_norm(&v), 4.0);
611 }
612}