Skip to main content

diffsol/ode_solver/
checkpointing.rs

1use num_traits::FromPrimitive;
2use std::cell::RefCell;
3
4use crate::{
5    error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverState, Scalar, Vector,
6};
7use num_traits::{abs, One};
8
9/// Hermite interpolator for ODE solution trajectories.
10///
11/// This interpolator uses cubic Hermite interpolation based on solution values and
12/// derivatives at discrete time points. It provides smooth interpolation between
13/// checkpoints and is used internally by the checkpointing system for adjoint
14/// sensitivity analysis.
15#[derive(Clone)]
16pub struct HermiteInterpolator<V>
17where
18    V: Vector,
19{
20    ys: Vec<V>,
21    ydots: Vec<V>,
22    ts: Vec<V::T>,
23}
24
25impl<V> Default for HermiteInterpolator<V>
26where
27    V: Vector,
28{
29    fn default() -> Self {
30        HermiteInterpolator {
31            ys: Vec::new(),
32            ydots: Vec::new(),
33            ts: Vec::new(),
34        }
35    }
36}
37
38impl<V> HermiteInterpolator<V>
39where
40    V: Vector,
41{
42    /// Create a new Hermite interpolator with the given solution data.
43    ///
44    /// # Arguments
45    /// - `ys`: Vector of solution values at each time point
46    /// - `ydots`: Vector of solution derivatives (dy/dt) at each time point
47    /// - `ts`: Vector of time points (must be sorted)
48    ///
49    /// # Notes
50    /// All three vectors must have the same length. The time points should be
51    /// sorted in increasing or decreasing order for proper interpolation.
52    pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
53        HermiteInterpolator { ys, ydots, ts }
54    }
55
56    /// Get the last time point in the interpolator.
57    ///
58    /// # Returns
59    /// The last time value, or `None` if the interpolator is empty.
60    pub fn last_t(&self) -> Option<V::T> {
61        if self.ts.is_empty() {
62            return None;
63        }
64        Some(self.ts[self.ts.len() - 1])
65    }
66
67    /// Get the last time step size in the interpolator.
68    ///
69    /// # Returns
70    /// The difference between the last two time points, or `None` if there are
71    /// fewer than two time points.
72    pub fn last_h(&self) -> Option<V::T> {
73        if self.ts.len() < 2 {
74            return None;
75        }
76        Some(self.ts[self.ts.len() - 1] - self.ts[self.ts.len() - 2])
77    }
78
79    /// Reset the interpolator by solving the ODE between two checkpointed states.
80    ///
81    /// This method clears the current interpolation data and re-solves the ODE
82    /// from `state0` to `state1`, storing all intermediate solution points.
83    ///
84    /// # Arguments
85    /// - `solver`: The ODE solver to use for integration
86    /// - `state0`: The initial state (starting point)
87    /// - `state1`: The final state (target point)
88    ///
89    /// # Returns
90    /// Ok(()) on success, or an error if the solver fails.
91    pub fn reset<'a, Eqn, Method, State>(
92        &mut self,
93        solver: &mut Method,
94        state0: &State,
95        state1: &State,
96    ) -> Result<(), DiffsolError>
97    where
98        Eqn: OdeEquations<V = V, T = V::T> + 'a,
99        Method: OdeSolverMethod<'a, Eqn, State = State>,
100        State: OdeSolverState<V>,
101    {
102        let state0_ref = state0.as_ref();
103        let state1_ref = state1.as_ref();
104        self.ys.clear();
105        self.ydots.clear();
106        self.ts.clear();
107        self.ys.push(state0_ref.y.clone());
108        self.ydots.push(state0_ref.dy.clone());
109        self.ts.push(state0_ref.t);
110
111        solver.set_state(state0.clone());
112        while solver.state().t < state1_ref.t {
113            solver.step()?;
114            self.ys.push(solver.state().y.clone());
115            self.ydots.push(solver.state().dy.clone());
116            self.ts.push(solver.state().t);
117        }
118        Ok(())
119    }
120
121    /// Interpolate the solution at a given time point.
122    ///
123    /// Uses cubic Hermite interpolation to compute the solution value at time `t`.
124    ///
125    /// # Arguments
126    /// - `t`: The time at which to interpolate
127    /// - `y`: Output vector to store the interpolated solution
128    ///
129    /// # Returns
130    /// `Some(())` if the interpolation succeeded (t is within range), `None` if
131    /// t is outside the range of stored time points.
132    pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> {
133        if t < self.ts[0] || t > self.ts[self.ts.len() - 1] {
134            return None;
135        }
136        if t == self.ts[0] {
137            y.copy_from(&self.ys[0]);
138            return Some(());
139        }
140        let idx = self
141            .ts
142            .iter()
143            .position(|&t0| t0 > t)
144            .unwrap_or(self.ts.len() - 1);
145        let t0 = self.ts[idx - 1];
146        let t1 = self.ts[idx];
147        let h = t1 - t0;
148        let theta = (t - t0) / h;
149        let u0 = &self.ys[idx - 1];
150        let u1 = &self.ys[idx];
151        let f0 = &self.ydots[idx - 1];
152        let f1 = &self.ydots[idx];
153
154        y.copy_from(u0);
155        y.axpy(V::T::one(), u1, -V::T::one());
156        y.axpy(
157            h * (theta - V::T::from_f64(1.0).unwrap()),
158            f0,
159            V::T::one() - V::T::from_f64(2.0).unwrap() * theta,
160        );
161        y.axpy(h * theta, f1, V::T::one());
162        y.axpy(
163            V::T::from_f64(1.0).unwrap() - theta,
164            u0,
165            theta * (theta - V::T::from_f64(1.0).unwrap()),
166        );
167        y.axpy(theta, u1, V::T::one());
168        Some(())
169    }
170}
171
172/// Checkpointing system for adjoint sensitivity analysis.
173///
174/// This struct manages checkpoints of an ODE solution for use in adjoint sensitivity
175/// computation. It stores solution states at discrete checkpoints and uses Hermite
176/// interpolation between them to provide solution values at arbitrary time points
177/// during the backward adjoint solve.
178///
179/// The checkpointing system uses a two-level interpolation strategy:
180/// 1. Checkpoints: Discrete saved states of the ODE solution
181/// 2. Segments: Hermite interpolators that densely sample between adjacent checkpoints
182///
183/// When interpolating at a time point, the system first checks if it's in the current
184/// segment. If not, it re-solves the ODE between the appropriate checkpoints to create
185/// a new segment, then interpolates within that segment.
186///
187pub struct Checkpointing<Eqn, State>
188where
189    Eqn: OdeEquations,
190    State: OdeSolverState<Eqn::V>,
191{
192    checkpoints: Vec<State>,
193    segment: RefCell<HermiteInterpolator<Eqn::V>>,
194    previous_segment: RefCell<Option<HermiteInterpolator<Eqn::V>>>,
195    terminal_reset_root_idx: Option<usize>,
196}
197
198pub type CheckpointingPath<Eqn, State> = Vec<Checkpointing<Eqn, State>>;
199
200impl<Eqn, State> Clone for Checkpointing<Eqn, State>
201where
202    Eqn: OdeEquations,
203    State: OdeSolverState<Eqn::V>,
204{
205    fn clone(&self) -> Self {
206        Checkpointing {
207            checkpoints: self.checkpoints.clone(),
208            segment: RefCell::new(self.segment.borrow().clone()),
209            previous_segment: RefCell::new(self.previous_segment.borrow().clone()),
210            terminal_reset_root_idx: self.terminal_reset_root_idx,
211        }
212    }
213}
214
215impl<Eqn, State> Checkpointing<Eqn, State>
216where
217    Eqn: OdeEquations,
218    State: OdeSolverState<Eqn::V>,
219{
220    /// Create a new checkpointing system.
221    ///
222    /// # Arguments
223    /// - `solver`: The forward solver to use for re-solving segments
224    /// - `start_idx`: Index of the checkpoint to start from (must be < checkpoints.len() - 1)
225    /// - `checkpoints`: Vector of saved solution states (must have at least 2 elements)
226    /// - `segment`: Optional pre-computed Hermite interpolator for the initial segment given by `start_idx`.
227    ///   If `None`, the segment between `checkpoints[start_idx]` and `checkpoints[start_idx+1]`
228    ///   will be computed automatically.
229    ///
230    /// # Panics
231    /// Panics if `checkpoints.len() < 2` or if `start_idx >= checkpoints.len() - 1`.
232    ///
233    pub fn new<'a, Method>(
234        solver: Option<&mut Method>,
235        start_idx: usize,
236        checkpoints: Vec<State>,
237        segment: Option<HermiteInterpolator<Eqn::V>>,
238    ) -> Self
239    where
240        Eqn: 'a,
241        Method: OdeSolverMethod<'a, Eqn, State = State>,
242    {
243        if checkpoints.len() < 2 {
244            panic!("Checkpoints must have at least 2 elements");
245        }
246        if start_idx >= checkpoints.len() - 1 {
247            panic!("start_idx must be less than checkpoints.len() - 1");
248        }
249        let segment = segment.unwrap_or_else(|| {
250            let solver =
251                solver.expect("solver is required when no initial checkpoint segment is provided");
252            let mut segment = HermiteInterpolator::default();
253            segment
254                .reset(solver, &checkpoints[start_idx], &checkpoints[start_idx + 1])
255                .unwrap();
256            segment
257        });
258        let segment = RefCell::new(segment);
259        let previous_segment = RefCell::new(None);
260        Checkpointing {
261            checkpoints,
262            segment,
263            previous_segment,
264            terminal_reset_root_idx: None,
265        }
266    }
267
268    pub(crate) fn set_terminal_reset_root_idx(&mut self, root_idx: usize) {
269        self.terminal_reset_root_idx = Some(root_idx);
270    }
271
272    #[cfg(test)]
273    pub(crate) fn clear_terminal_reset_root_idx(&mut self) {
274        self.terminal_reset_root_idx = None;
275    }
276
277    /// Returns the root index when this checkpointing segment ended at a root.
278    ///
279    /// `None` means the segment ended at a configured stop time.
280    pub fn terminal_reset_root_idx(&self) -> Option<usize> {
281        self.terminal_reset_root_idx
282    }
283
284    pub fn first_checkpoint(&self) -> &State {
285        &self.checkpoints[0]
286    }
287
288    pub fn last_checkpoint(&self) -> &State {
289        &self.checkpoints[self.checkpoints.len() - 1]
290    }
291
292    /// Get the last (most recent) time point in the current segment.
293    ///
294    /// # Returns
295    /// The last time value in the current interpolation segment.
296    ///
297    /// # Panics
298    /// Panics if the segment is empty (should never happen with valid construction).
299    pub fn last_t(&self) -> Eqn::T {
300        self.segment
301            .borrow()
302            .last_t()
303            .expect("segment should not be empty")
304    }
305
306    pub fn first_t(&self) -> Eqn::T {
307        self.checkpoints[0].as_ref().t
308    }
309
310    pub fn end_t(&self) -> Eqn::T {
311        self.checkpoints[self.checkpoints.len() - 1].as_ref().t
312    }
313
314    /// Get the last time step size in the current segment.
315    ///
316    /// # Returns
317    /// The difference between the last two time points in the current segment,
318    /// or `None` if the segment has fewer than two points.
319    pub fn last_h(&self) -> Option<Eqn::T> {
320        self.segment.borrow().last_h()
321    }
322
323    /// Interpolate the solution at a given time point.
324    ///
325    /// This method provides the forward solution value at time `t` for use during
326    /// the backward adjoint solve. It first checks if `t` is within the current
327    /// or previous segment. If not, it identifies the appropriate checkpoint interval,
328    /// re-solves the ODE between those checkpoints to create a new segment, and then
329    /// interpolates within that segment.
330    ///
331    /// # Arguments
332    /// - `t`: The time at which to interpolate the solution
333    /// - `y`: Output vector to store the interpolated solution
334    ///
335    /// # Returns
336    /// - `Ok(())` if interpolation succeeded
337    /// - `Err(...)` if `t` is outside the range of checkpoints (beyond roundoff tolerance)
338    ///   or if re-solving the ODE fails
339    ///
340    /// # Notes
341    /// Small deviations from the checkpoint range (within roundoff error) are automatically
342    /// snapped to the nearest checkpoint boundary.
343    pub fn interpolate<'a, Method>(
344        &self,
345        solver: Option<&mut Method>,
346        t: Eqn::T,
347        y: &mut Eqn::V,
348    ) -> Result<(), DiffsolError>
349    where
350        Eqn: 'a,
351        Method: OdeSolverMethod<'a, Eqn, State = State>,
352    {
353        {
354            let segment = self.segment.borrow();
355            if segment.interpolate(t, y).is_some() {
356                return Ok(());
357            }
358        }
359
360        {
361            let previous_segment = self.previous_segment.borrow();
362            if let Some(previous_segment) = previous_segment.as_ref() {
363                if previous_segment.interpolate(t, y).is_some() {
364                    return Ok(());
365                }
366            }
367        }
368
369        // if t is before first segment or after last segment, return error
370        let h = self.last_h().unwrap_or(Eqn::T::one());
371        let troundoff = Eqn::T::from_f64(100.0).unwrap() * Eqn::T::EPSILON * (abs(t) + abs(h));
372        if t < self.checkpoints[0].as_ref().t - troundoff
373            || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + troundoff
374        {
375            return Err(other_error!("t is outside of the checkpoints"));
376        }
377
378        // snap t to nearest checkpoint if outside of range
379        let t = if t < self.checkpoints[0].as_ref().t {
380            self.checkpoints[0].as_ref().t
381        } else if t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t {
382            self.checkpoints[self.checkpoints.len() - 1].as_ref().t
383        } else {
384            t
385        };
386
387        // else find idx of segment
388        let idx = self
389            .checkpoints
390            .iter()
391            .skip(1)
392            .position(|state| state.as_ref().t > t)
393            .expect("t is not in checkpoints");
394        let solver = solver.ok_or_else(|| {
395            other_error!("solver is required to rebuild checkpoint interpolation segments")
396        })?;
397        if self.previous_segment.borrow().is_none() {
398            self.previous_segment
399                .replace(Some(HermiteInterpolator::default()));
400        }
401        let mut previous_segment = self.previous_segment.borrow_mut();
402        let mut segment = self.segment.borrow_mut();
403        previous_segment.as_mut().unwrap().reset(
404            solver,
405            &self.checkpoints[idx],
406            &self.checkpoints[idx + 1],
407        )?;
408        std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap());
409        segment.interpolate(t, y).unwrap();
410        Ok(())
411    }
412}
413
414#[cfg(test)]
415mod tests {
416
417    use crate::{
418        matrix::dense_nalgebra_serial::NalgebraMat,
419        ode_equations::test_models::robertson::robertson, NalgebraLU, NoCheckpointingSolver,
420        OdeEquations, OdeSolverMethod, Vector,
421    };
422
423    use super::{Checkpointing, HermiteInterpolator};
424
425    #[test]
426    fn test_checkpointing() {
427        type M = NalgebraMat<f64>;
428        type LS = NalgebraLU<f64>;
429        let (problem, soln) = robertson::<M>(false);
430        let t_final = soln.solution_points.last().unwrap().t;
431        let n_steps = 30;
432        let mut solver = problem.bdf::<LS>().unwrap();
433        let mut checkpoints = vec![solver.checkpoint()];
434        let mut i = 0;
435        let mut ys = Vec::new();
436        let mut ts = Vec::new();
437        let mut ydots = Vec::new();
438        while solver.state().t < t_final {
439            ts.push(solver.state().t);
440            ys.push(solver.state().y.clone());
441            ydots.push(solver.state().dy.clone());
442            solver.step().unwrap();
443            i += 1;
444            if i % n_steps == 0 && solver.state().t < t_final {
445                checkpoints.push(solver.checkpoint());
446                ts.clear();
447                ys.clear();
448                ydots.clear();
449            }
450        }
451        checkpoints.push(solver.checkpoint());
452        let segment = HermiteInterpolator::new(ys, ydots, ts);
453        let checkpointer = Checkpointing::new(
454            Some(&mut solver),
455            checkpoints.len() - 2,
456            checkpoints,
457            Some(segment),
458        );
459        let mut y = soln.solution_points.last().unwrap().state.clone();
460        checkpointer
461            .interpolate::<NoCheckpointingSolver<_, _>>(None, checkpointer.last_t(), &mut y)
462            .unwrap();
463        let err = checkpointer
464            .interpolate::<NoCheckpointingSolver<_, _>>(None, soln.solution_points[0].t, &mut y)
465            .unwrap_err();
466        assert!(err
467            .to_string()
468            .contains("solver is required to rebuild checkpoint interpolation segments"));
469        for point in soln.solution_points.iter().rev() {
470            checkpointer
471                .interpolate(Some(&mut solver), point.t, &mut y)
472                .unwrap();
473            y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0);
474        }
475    }
476
477    #[test]
478    #[should_panic(expected = "solver is required when no initial checkpoint segment is provided")]
479    fn test_checkpointing_requires_solver_without_initial_segment() {
480        fn new_without_solver_panics<'a, Eqn, Method>(
481            _solver: &Method,
482            checkpoints: Vec<Method::State>,
483        ) where
484            Eqn: OdeEquations + 'a,
485            Method: OdeSolverMethod<'a, Eqn>,
486        {
487            let _ = Checkpointing::<Eqn, Method::State>::new::<Method>(None, 0, checkpoints, None);
488        }
489
490        type M = NalgebraMat<f64>;
491        type LS = NalgebraLU<f64>;
492        let (problem, _soln) = robertson::<M>(false);
493        let mut solver = problem.bdf::<LS>().unwrap();
494        let checkpoints = vec![solver.checkpoint(), solver.checkpoint()];
495        new_without_solver_panics(&solver, checkpoints);
496    }
497}