#[allow(unused_imports)]
use crate::algebra::prelude::*;
pub struct Convergence {
pub rtol: R,
pub atol: R,
pub dtol: R,
pub max_iters: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConvergedReason {
ConvergedRtol,
ConvergedAtol,
ConvergedTrustRegion,
ConvergedHappyBreakdown,
DivergedDtol,
DivergedMaxIts,
StoppedByMonitor,
Continued,
}
#[derive(Clone, Debug, Default)]
pub struct SolverCounters {
pub num_global_reductions: usize,
pub residual_replacements: usize,
}
#[cfg(feature = "metrics")]
#[derive(Clone, Debug, Default)]
pub struct SolveMetrics {
pub reductions: usize,
pub reduction_wait_nanos: u64,
pub matvec_nanos: u64,
pub pc_apply_nanos: u64,
pub bytes_reduced: usize,
}
#[cfg(not(feature = "metrics"))]
#[derive(Clone, Debug, Default)]
pub struct SolveMetrics;
#[must_use]
#[derive(Clone, Debug)]
pub struct SolveStats<R> {
pub iterations: usize,
pub final_residual: R,
pub reason: ConvergedReason,
pub counters: SolverCounters,
pub complex_drift_events: usize,
pub complex_drift_counts: [usize; 6],
pub complex_drift_max_rel: R,
pub metrics: SolveMetrics,
}
impl<R: Default> SolveStats<R> {
pub fn new(iterations: usize, final_residual: R, reason: ConvergedReason) -> Self {
Self {
iterations,
final_residual,
reason,
counters: SolverCounters::default(),
complex_drift_events: 0,
complex_drift_counts: [0; 6],
complex_drift_max_rel: R::default(),
metrics: SolveMetrics::default(),
}
}
pub fn with_counters(mut self, counters: SolverCounters) -> Self {
self.counters = counters;
self
}
}
impl Convergence {
pub fn new(rtol: R, atol: R, dtol: R, max_iters: usize) -> Self {
Self {
rtol,
atol,
dtol,
max_iters,
}
}
pub fn check(&self, rnorm: R, bnorm: R, iters: usize) -> (ConvergedReason, SolveStats<R>) {
if rnorm <= self.atol {
let stats = SolveStats::new(iters, rnorm, ConvergedReason::ConvergedAtol);
return (ConvergedReason::ConvergedAtol, stats);
}
if rnorm <= self.rtol * bnorm {
let stats = SolveStats::new(iters, rnorm, ConvergedReason::ConvergedRtol);
return (ConvergedReason::ConvergedRtol, stats);
}
if rnorm >= self.dtol * bnorm {
let stats = SolveStats::new(iters, rnorm, ConvergedReason::DivergedDtol);
return (ConvergedReason::DivergedDtol, stats);
}
if iters >= self.max_iters {
let stats = SolveStats::new(iters, rnorm, ConvergedReason::DivergedMaxIts);
return (ConvergedReason::DivergedMaxIts, stats);
}
let stats = SolveStats::new(iters, rnorm, ConvergedReason::Continued);
(ConvergedReason::Continued, stats)
}
}
impl Convergence {
#[deprecated(since = "0.1.0", note = "use check() method instead")]
pub fn check_legacy(&self, res_norm: R, res0_norm: R, i: usize) -> (bool, SolveStats<R>) {
let (reason, stats) = self.check(res_norm, res0_norm, i);
let converged = matches!(
reason,
ConvergedReason::ConvergedRtol | ConvergedReason::ConvergedAtol
);
let mut legacy_stats =
SolveStats::new(stats.iterations, stats.final_residual, stats.reason);
legacy_stats.counters = stats.counters;
(
converged || reason != ConvergedReason::Continued,
legacy_stats,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convergence_new() {
let conv = Convergence::new(1e-6, 1e-12, 1e3, 1000);
assert_eq!(conv.rtol, 1e-6);
assert_eq!(conv.atol, 1e-12);
assert_eq!(conv.dtol, 1e3);
assert_eq!(conv.max_iters, 1000);
}
#[test]
fn test_converged_absolute_tolerance() {
let conv = Convergence::new(1e-6, 1e-8, 1e3, 100);
let rnorm = 1e-9; let bnorm = 1.0;
let iters = 5;
let (reason, stats) = conv.check(rnorm, bnorm, iters);
assert_eq!(reason, ConvergedReason::ConvergedAtol);
assert_eq!(stats.reason, ConvergedReason::ConvergedAtol);
assert_eq!(stats.iterations, 5);
assert_eq!(stats.final_residual, 1e-9);
}
#[test]
fn test_converged_relative_tolerance() {
let conv = Convergence::new(1e-6, 1e-12, 1e3, 100);
let rnorm = 1e-7; let bnorm = 1.0;
let iters = 10;
let (reason, stats) = conv.check(rnorm, bnorm, iters);
assert_eq!(reason, ConvergedReason::ConvergedRtol);
assert_eq!(stats.reason, ConvergedReason::ConvergedRtol);
assert_eq!(stats.iterations, 10);
assert_eq!(stats.final_residual, 1e-7);
}
#[test]
fn test_diverged_tolerance() {
let conv = Convergence::new(1e-6, 1e-12, 2.0, 100);
let rnorm = 3.0; let bnorm = 1.0;
let iters = 5;
let (reason, stats) = conv.check(rnorm, bnorm, iters);
assert_eq!(reason, ConvergedReason::DivergedDtol);
assert_eq!(stats.reason, ConvergedReason::DivergedDtol);
assert_eq!(stats.iterations, 5);
assert_eq!(stats.final_residual, 3.0);
}
#[test]
fn test_diverged_max_iterations() {
let conv = Convergence::new(1e-6, 1e-12, 1e3, 10);
let rnorm = 1e-3; let bnorm = 1.0;
let iters = 10;
let (reason, stats) = conv.check(rnorm, bnorm, iters);
assert_eq!(reason, ConvergedReason::DivergedMaxIts);
assert_eq!(stats.reason, ConvergedReason::DivergedMaxIts);
assert_eq!(stats.iterations, 10);
assert_eq!(stats.final_residual, 1e-3);
}
#[test]
fn test_continued() {
let conv = Convergence::new(1e-6, 1e-12, 1e3, 100);
let rnorm = 1e-3; let bnorm = 1.0;
let iters = 5;
let (reason, stats) = conv.check(rnorm, bnorm, iters);
assert_eq!(reason, ConvergedReason::Continued);
assert_eq!(stats.reason, ConvergedReason::Continued);
assert_eq!(stats.iterations, 5);
assert_eq!(stats.final_residual, 1e-3);
}
#[test]
fn test_convergence_precedence() {
let conv = Convergence::new(1e-6, 1e-8, 1e3, 100);
let rnorm = 1e-9; let bnorm = 1.0;
let iters = 5;
let (reason, _) = conv.check(rnorm, bnorm, iters);
assert_eq!(reason, ConvergedReason::ConvergedAtol);
}
#[test]
fn test_converged_reason_equality() {
assert_eq!(
ConvergedReason::ConvergedRtol,
ConvergedReason::ConvergedRtol
);
assert_eq!(
ConvergedReason::ConvergedAtol,
ConvergedReason::ConvergedAtol
);
assert_eq!(ConvergedReason::DivergedDtol, ConvergedReason::DivergedDtol);
assert_eq!(
ConvergedReason::DivergedMaxIts,
ConvergedReason::DivergedMaxIts
);
assert_eq!(ConvergedReason::Continued, ConvergedReason::Continued);
assert_ne!(
ConvergedReason::ConvergedRtol,
ConvergedReason::ConvergedAtol
);
assert_ne!(
ConvergedReason::DivergedDtol,
ConvergedReason::DivergedMaxIts
);
}
#[test]
fn test_converged_reason_debug() {
let reason = ConvergedReason::ConvergedRtol;
let debug_str = format!("{:?}", reason);
assert!(debug_str.contains("ConvergedRtol"));
}
#[test]
fn test_solve_stats_clone() {
let stats = SolveStats::new(42, 1e-8, ConvergedReason::ConvergedRtol);
let cloned = stats.clone();
assert_eq!(cloned.iterations, 42);
assert_eq!(cloned.final_residual, 1e-8);
assert_eq!(cloned.reason, ConvergedReason::ConvergedRtol);
}
#[test]
fn test_solve_stats_debug() {
let stats = SolveStats::new(10, 1e-6, ConvergedReason::ConvergedAtol);
let debug_str = format!("{:?}", stats);
assert!(debug_str.contains("10"));
assert!(debug_str.contains("ConvergedAtol"));
}
#[test]
#[allow(deprecated)]
fn test_legacy_check_convergence() {
let conv = Convergence::new(1e-6, 1e-12, 1e3, 100);
let res_norm = 1e-8;
let res0_norm = 1.0;
let iters = 5;
let (should_stop, stats) = conv.check_legacy(res_norm, res0_norm, iters);
assert!(should_stop);
assert_eq!(stats.iterations, 5);
assert_eq!(stats.final_residual, 1e-8);
}
#[test]
#[allow(deprecated)]
fn test_legacy_check_continue() {
let conv = Convergence::new(1e-6, 1e-12, 1e3, 100);
let res_norm = 1e-3;
let res0_norm = 1.0;
let iters = 5;
let (should_stop, stats) = conv.check_legacy(res_norm, res0_norm, iters);
assert!(!should_stop);
assert_eq!(stats.iterations, 5);
assert_eq!(stats.final_residual, 1e-3);
assert_eq!(stats.reason, ConvergedReason::Continued);
}
#[test]
fn test_different_numeric_types() {
let conv_f64 = Convergence::new(1e-6f64, 1e-12f64, 1e3f64, 100);
let (reason, _) = conv_f64.check(1e-8f64, 1.0f64, 5);
assert_eq!(reason, ConvergedReason::ConvergedRtol);
let conv2 = Convergence::new(1e-8, 1e-16, 1e6, 50);
let (reason2, _) = conv2.check(1e-10, 1.0, 10);
assert_eq!(reason2, ConvergedReason::ConvergedRtol);
}
}