diffsol/nonlinear_solver/
convergence.rs1use log::trace;
2use num_traits::{FromPrimitive, One, Pow, ToPrimitive};
3
4use crate::{scalar::IndexType, Scalar, Vector};
5
6#[derive(Clone)]
7pub struct Convergence<'a, V: Vector> {
8 pub rtol: V::T,
9 pub atol: &'a V,
10 tol: V::T,
11 max_iter: IndexType,
12 niter: IndexType,
13 old_norm: Option<V::T>,
14 eta: V::T,
15}
16
17pub enum ConvergenceStatus {
18 Converged,
19 Diverged,
20 Continue,
21}
22
23impl<'a, V: Vector> Convergence<'a, V> {
24 pub fn max_iter(&self) -> IndexType {
25 self.max_iter
26 }
27 pub fn set_max_iter(&mut self, value: IndexType) {
28 self.max_iter = value;
29 }
30 pub fn niter(&self) -> IndexType {
31 self.niter
32 }
33 pub fn eta(&self) -> V::T {
34 self.eta
35 }
36 pub fn reset_eta(&mut self) {
37 self.eta = V::T::from_f64(20.0.pow(1.25)).unwrap();
38 }
39
40 pub fn reset_eta_timestep_change(&mut self) {
41 self.eta = V::T::from_f64(100.0.pow(1.25)).unwrap();
42 }
43
44 pub fn new(rtol: V::T, atol: &'a V) -> Self {
45 Self::with_tolerance(rtol, atol, V::T::from_f64(0.2).unwrap())
46 }
47
48 pub fn with_tolerance(rtol: V::T, atol: &'a V, tol: V::T) -> Self {
49 Self {
50 rtol,
51 atol,
52 tol,
53 max_iter: 10,
54 old_norm: None,
55 eta: V::T::from_f64(20.0.pow(1.25)).unwrap(),
56 niter: 0,
57 }
58 }
59 pub fn reset(&mut self) {
60 self.niter = 0;
61 self.old_norm = None;
62 }
63
64 pub fn norm(&self, dy: &V, y: &V) -> V::T {
65 dy.squared_norm(y, self.atol, self.rtol).sqrt()
66 }
67
68 pub fn check_norm(&mut self, norm: V::T) -> ConvergenceStatus {
69 trace!(
70 " Iteration {}, check non-linear solver norm = {:.3e}",
71 self.niter + 1,
72 norm.to_f64().unwrap()
73 );
74 self.niter += 1;
75 if let Some(old_norm) = self.old_norm {
76 let rate =
77 (norm / old_norm).pow(V::T::one() / (V::T::from_usize(self.niter - 1).unwrap()));
78
79 if rate > V::T::from_f64(0.9).unwrap() {
81 trace!(" Diverged with rate {}", rate);
82 return ConvergenceStatus::Diverged;
83 }
84
85 if rate.pow(i32::try_from(self.max_iter - self.niter).unwrap()) / (V::T::one() - rate)
88 * norm
89 > self.tol
90 {
91 trace!(
92 " Diverged as will not converge in max iterations with rate {}",
93 rate
94 );
95 return ConvergenceStatus::Diverged;
96 }
97
98 let eta = rate / (V::T::one() - rate);
99 trace!(
100 " Updated mean convergence rate = {:.3e}, eta = {:.3e}",
101 rate.to_f64().unwrap(),
102 eta.to_f64().unwrap()
103 );
104 self.eta = eta;
105 } else {
106 let min_eta = V::T::from_f64(1e4).unwrap() * V::T::EPSILON;
107 if self.eta < min_eta {
108 self.eta = min_eta;
109 }
110 self.eta = self.eta.pow(V::T::from_f64(0.8).unwrap());
111 trace!(
112 " First iteration, set eta = {:.3e}",
113 self.eta.to_f64().unwrap()
114 );
115 };
116 if self.eta * norm < self.tol {
118 trace!(
119 " Converged with eta * norm = {:.3e} < tol = {:.3e}",
120 (self.eta * norm).to_f64().unwrap(),
121 self.tol.to_f64().unwrap()
122 );
123 return ConvergenceStatus::Converged;
124 }
125 trace!(
126 " Not yet converged: eta * norm = {:.3e} >= tol = {:.3e}",
127 (self.eta * norm).to_f64().unwrap(),
128 self.tol.to_f64().unwrap()
129 );
130 ConvergenceStatus::Continue
131 }
132
133 pub fn check_new_iteration(&mut self, norm: V::T) -> ConvergenceStatus {
134 let status = self.check_norm(norm);
135 if self.niter == 1 {
136 self.old_norm = Some(norm);
137 }
138 status
139 }
140}