diffsol/ode_solver/
checkpointing.rs

1use std::cell::RefCell;
2
3use crate::{
4    error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverProblem,
5    OdeSolverState, Scalar, Vector,
6};
7use num_traits::{abs, One};
8
9#[derive(Clone)]
10pub struct HermiteInterpolator<V>
11where
12    V: Vector,
13{
14    ys: Vec<V>,
15    ydots: Vec<V>,
16    ts: Vec<V::T>,
17}
18
19impl<V> Default for HermiteInterpolator<V>
20where
21    V: Vector,
22{
23    fn default() -> Self {
24        HermiteInterpolator {
25            ys: Vec::new(),
26            ydots: Vec::new(),
27            ts: Vec::new(),
28        }
29    }
30}
31
32impl<V> HermiteInterpolator<V>
33where
34    V: Vector,
35{
36    pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
37        HermiteInterpolator { ys, ydots, ts }
38    }
39    pub fn last_t(&self) -> Option<V::T> {
40        if self.ts.is_empty() {
41            return None;
42        }
43        Some(self.ts[self.ts.len() - 1])
44    }
45    pub fn last_h(&self) -> Option<V::T> {
46        if self.ts.len() < 2 {
47            return None;
48        }
49        Some(self.ts[self.ts.len() - 1] - self.ts[self.ts.len() - 2])
50    }
51    pub fn reset<'a, Eqn, Method, State>(
52        &mut self,
53        solver: &mut Method,
54        state0: &State,
55        state1: &State,
56    ) -> Result<(), DiffsolError>
57    where
58        Eqn: OdeEquations<V = V, T = V::T> + 'a,
59        Method: OdeSolverMethod<'a, Eqn, State = State>,
60        State: OdeSolverState<V>,
61    {
62        let state0_ref = state0.as_ref();
63        let state1_ref = state1.as_ref();
64        self.ys.clear();
65        self.ydots.clear();
66        self.ts.clear();
67        self.ys.push(state0_ref.y.clone());
68        self.ydots.push(state0_ref.dy.clone());
69        self.ts.push(state0_ref.t);
70
71        solver.set_state(state0.clone());
72        while solver.state().t < state1_ref.t {
73            solver.step()?;
74            self.ys.push(solver.state().y.clone());
75            self.ydots.push(solver.state().dy.clone());
76            self.ts.push(solver.state().t);
77        }
78        Ok(())
79    }
80
81    pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> {
82        if t < self.ts[0] || t > self.ts[self.ts.len() - 1] {
83            return None;
84        }
85        if t == self.ts[0] {
86            y.copy_from(&self.ys[0]);
87            return Some(());
88        }
89        let idx = self
90            .ts
91            .iter()
92            .position(|&t0| t0 > t)
93            .unwrap_or(self.ts.len() - 1);
94        let t0 = self.ts[idx - 1];
95        let t1 = self.ts[idx];
96        let h = t1 - t0;
97        let theta = (t - t0) / h;
98        let u0 = &self.ys[idx - 1];
99        let u1 = &self.ys[idx];
100        let f0 = &self.ydots[idx - 1];
101        let f1 = &self.ydots[idx];
102
103        y.copy_from(u0);
104        y.axpy(V::T::one(), u1, -V::T::one());
105        y.axpy(
106            h * (theta - V::T::from(1.0)),
107            f0,
108            V::T::one() - V::T::from(2.0) * theta,
109        );
110        y.axpy(h * theta, f1, V::T::one());
111        y.axpy(
112            V::T::from(1.0) - theta,
113            u0,
114            theta * (theta - V::T::from(1.0)),
115        );
116        y.axpy(theta, u1, V::T::one());
117        Some(())
118    }
119}
120
121pub struct Checkpointing<'a, Eqn, Method>
122where
123    Method: OdeSolverMethod<'a, Eqn>,
124    Eqn: OdeEquations,
125{
126    checkpoints: Vec<Method::State>,
127    segment: RefCell<HermiteInterpolator<Eqn::V>>,
128    previous_segment: RefCell<Option<HermiteInterpolator<Eqn::V>>>,
129    solver: RefCell<Method>,
130}
131
132impl<'a, Eqn, Method> Clone for Checkpointing<'a, Eqn, Method>
133where
134    Method: OdeSolverMethod<'a, Eqn>,
135    Eqn: OdeEquations,
136{
137    fn clone(&self) -> Self {
138        Checkpointing {
139            checkpoints: self.checkpoints.clone(),
140            segment: RefCell::new(self.segment.borrow().clone()),
141            previous_segment: RefCell::new(self.previous_segment.borrow().clone()),
142            solver: RefCell::new(self.solver.borrow().clone()),
143        }
144    }
145}
146
147impl<'a, Eqn, Method> Checkpointing<'a, Eqn, Method>
148where
149    Method: OdeSolverMethod<'a, Eqn>,
150    Eqn: OdeEquations,
151{
152    pub fn new(
153        mut solver: Method,
154        start_idx: usize,
155        checkpoints: Vec<Method::State>,
156        segment: Option<HermiteInterpolator<Eqn::V>>,
157    ) -> Self {
158        if checkpoints.len() < 2 {
159            panic!("Checkpoints must have at least 2 elements");
160        }
161        if start_idx >= checkpoints.len() - 1 {
162            panic!("start_idx must be less than checkpoints.len() - 1");
163        }
164        let segment = segment.unwrap_or_else(|| {
165            let mut segment = HermiteInterpolator::default();
166            segment
167                .reset(
168                    &mut solver,
169                    &checkpoints[start_idx],
170                    &checkpoints[start_idx + 1],
171                )
172                .unwrap();
173            segment
174        });
175        let segment = RefCell::new(segment);
176        let previous_segment = RefCell::new(None);
177        let solver = RefCell::new(solver);
178        Checkpointing {
179            checkpoints,
180            segment,
181            previous_segment,
182            solver,
183        }
184    }
185
186    pub fn last_t(&self) -> Eqn::T {
187        self.segment
188            .borrow()
189            .last_t()
190            .expect("segment should not be empty")
191    }
192
193    pub fn last_h(&self) -> Option<Eqn::T> {
194        self.segment.borrow().last_h()
195    }
196
197    pub fn problem(&self) -> &'a OdeSolverProblem<Eqn> {
198        self.solver.borrow().problem()
199    }
200
201    pub fn interpolate(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> {
202        {
203            let segment = self.segment.borrow();
204            if segment.interpolate(t, y).is_some() {
205                return Ok(());
206            }
207        }
208
209        {
210            let previous_segment = self.previous_segment.borrow();
211            if let Some(previous_segment) = previous_segment.as_ref() {
212                if previous_segment.interpolate(t, y).is_some() {
213                    return Ok(());
214                }
215            }
216        }
217
218        // if t is before first segment or after last segment, return error
219        let h = self.last_h().unwrap_or(Eqn::T::one());
220        let troundoff = Eqn::T::from(100.0) * Eqn::T::EPSILON * (abs(t) + abs(h));
221        if t < self.checkpoints[0].as_ref().t - troundoff
222            || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + troundoff
223        {
224            return Err(other_error!("t is outside of the checkpoints"));
225        }
226
227        // snap t to nearest checkpoint if outside of range
228        let t = if t < self.checkpoints[0].as_ref().t {
229            self.checkpoints[0].as_ref().t
230        } else if t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t {
231            self.checkpoints[self.checkpoints.len() - 1].as_ref().t
232        } else {
233            t
234        };
235
236        // else find idx of segment
237        let idx = self
238            .checkpoints
239            .iter()
240            .skip(1)
241            .position(|state| state.as_ref().t > t)
242            .expect("t is not in checkpoints");
243        if self.previous_segment.borrow().is_none() {
244            self.previous_segment
245                .replace(Some(HermiteInterpolator::default()));
246        }
247        let mut solver = self.solver.borrow_mut();
248        let mut previous_segment = self.previous_segment.borrow_mut();
249        let mut segment = self.segment.borrow_mut();
250        previous_segment.as_mut().unwrap().reset(
251            &mut *solver,
252            &self.checkpoints[idx],
253            &self.checkpoints[idx + 1],
254        )?;
255        std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap());
256        segment.interpolate(t, y).unwrap();
257        Ok(())
258    }
259}
260
261#[cfg(test)]
262mod tests {
263
264    use crate::{
265        matrix::dense_nalgebra_serial::NalgebraMat,
266        ode_equations::test_models::robertson::robertson, Context, NalgebraLU, OdeEquations,
267        OdeSolverMethod, Op, Vector,
268    };
269
270    use super::{Checkpointing, HermiteInterpolator};
271
272    #[test]
273    fn test_checkpointing() {
274        type M = NalgebraMat<f64>;
275        type LS = NalgebraLU<f64>;
276        let (problem, soln) = robertson::<M>(false);
277        let t_final = soln.solution_points.last().unwrap().t;
278        let n_steps = 30;
279        let mut solver = problem.bdf::<LS>().unwrap();
280        let mut checkpoints = vec![solver.checkpoint()];
281        let mut i = 0;
282        let mut ys = Vec::new();
283        let mut ts = Vec::new();
284        let mut ydots = Vec::new();
285        while solver.state().t < t_final {
286            ts.push(solver.state().t);
287            ys.push(solver.state().y.clone());
288            ydots.push(solver.state().dy.clone());
289            solver.step().unwrap();
290            i += 1;
291            if i % n_steps == 0 && solver.state().t < t_final {
292                checkpoints.push(solver.checkpoint());
293                ts.clear();
294                ys.clear();
295                ydots.clear();
296            }
297        }
298        checkpoints.push(solver.checkpoint());
299        let segment = HermiteInterpolator::new(ys, ydots, ts);
300        let checkpointer =
301            Checkpointing::new(solver, checkpoints.len() - 2, checkpoints, Some(segment));
302        let mut y = problem.context().vector_zeros(problem.eqn.rhs().nstates());
303        for point in soln.solution_points.iter().rev() {
304            checkpointer.interpolate(point.t, &mut y).unwrap();
305            y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0);
306        }
307    }
308}