Skip to main content

diffsol/nonlinear_solver/
convergence.rs

1use log::trace;
2use num_traits::{FromPrimitive, One, Pow, ToPrimitive};
3
4use crate::{scalar::IndexType, Scalar, Vector};
5
6#[derive(Clone)]
7pub struct Convergence<'a, V: Vector> {
8    pub rtol: V::T,
9    pub atol: &'a V,
10    tol: V::T,
11    max_iter: IndexType,
12    niter: IndexType,
13    old_norm: Option<V::T>,
14    eta: V::T,
15}
16
17pub enum ConvergenceStatus {
18    Converged,
19    Diverged,
20    Continue,
21}
22
23impl<'a, V: Vector> Convergence<'a, V> {
24    pub fn max_iter(&self) -> IndexType {
25        self.max_iter
26    }
27    pub fn set_max_iter(&mut self, value: IndexType) {
28        self.max_iter = value;
29    }
30    pub fn niter(&self) -> IndexType {
31        self.niter
32    }
33    pub fn eta(&self) -> V::T {
34        self.eta
35    }
36    pub fn reset_eta(&mut self) {
37        self.eta = V::T::from_f64(20.0.pow(1.25)).unwrap();
38    }
39
40    pub fn reset_eta_timestep_change(&mut self) {
41        self.eta = V::T::from_f64(100.0.pow(1.25)).unwrap();
42    }
43
44    pub fn new(rtol: V::T, atol: &'a V) -> Self {
45        Self::with_tolerance(rtol, atol, V::T::from_f64(0.2).unwrap())
46    }
47
48    pub fn with_tolerance(rtol: V::T, atol: &'a V, tol: V::T) -> Self {
49        Self {
50            rtol,
51            atol,
52            tol,
53            max_iter: 10,
54            old_norm: None,
55            eta: V::T::from_f64(20.0.pow(1.25)).unwrap(),
56            niter: 0,
57        }
58    }
59    pub fn reset(&mut self) {
60        self.niter = 0;
61        self.old_norm = None;
62    }
63
64    pub fn norm(&self, dy: &V, y: &V) -> V::T {
65        dy.squared_norm(y, self.atol, self.rtol).sqrt()
66    }
67
68    pub fn check_norm(&mut self, norm: V::T) -> ConvergenceStatus {
69        trace!(
70            "  Iteration {}, check non-linear solver norm = {:.3e}",
71            self.niter + 1,
72            norm.to_f64().unwrap()
73        );
74        self.niter += 1;
75        if let Some(old_norm) = self.old_norm {
76            let rate =
77                (norm / old_norm).pow(V::T::one() / (V::T::from_usize(self.niter - 1).unwrap()));
78
79            // check if iteration is diverging
80            if rate > V::T::from_f64(0.9).unwrap() {
81                trace!("  Diverged with rate {}", rate);
82                return ConvergenceStatus::Diverged;
83            }
84
85            // if iteration is not going to converge in max_iter
86            // (assuming the current rate), then abort
87            if rate.pow(i32::try_from(self.max_iter - self.niter).unwrap()) / (V::T::one() - rate)
88                * norm
89                > self.tol
90            {
91                trace!(
92                    "  Diverged as will not converge in max iterations with rate {}",
93                    rate
94                );
95                return ConvergenceStatus::Diverged;
96            }
97
98            let eta = rate / (V::T::one() - rate);
99            trace!(
100                "  Updated mean convergence rate = {:.3e}, eta = {:.3e}",
101                rate.to_f64().unwrap(),
102                eta.to_f64().unwrap()
103            );
104            self.eta = eta;
105        } else {
106            let min_eta = V::T::from_f64(1e4).unwrap() * V::T::EPSILON;
107            if self.eta < min_eta {
108                self.eta = min_eta;
109            }
110            self.eta = self.eta.pow(V::T::from_f64(0.8).unwrap());
111            trace!(
112                "  First iteration, set eta = {:.3e}",
113                self.eta.to_f64().unwrap()
114            );
115        };
116        // check if iteration is converged
117        if self.eta * norm < self.tol {
118            trace!(
119                "  Converged with eta * norm = {:.3e} < tol = {:.3e}",
120                (self.eta * norm).to_f64().unwrap(),
121                self.tol.to_f64().unwrap()
122            );
123            return ConvergenceStatus::Converged;
124        }
125        trace!(
126            "  Not yet converged: eta * norm = {:.3e} >= tol = {:.3e}",
127            (self.eta * norm).to_f64().unwrap(),
128            self.tol.to_f64().unwrap()
129        );
130        ConvergenceStatus::Continue
131    }
132
133    pub fn check_new_iteration(&mut self, norm: V::T) -> ConvergenceStatus {
134        let status = self.check_norm(norm);
135        if self.niter == 1 {
136            self.old_norm = Some(norm);
137        }
138        status
139    }
140}