Skip to main content

diffsol/ode_solver/
checkpointing.rs

1use num_traits::FromPrimitive;
2use std::cell::RefCell;
3
4use crate::{
5    error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverProblem,
6    OdeSolverState, Scalar, Vector,
7};
8use num_traits::{abs, One};
9
10/// Hermite interpolator for ODE solution trajectories.
11///
12/// This interpolator uses cubic Hermite interpolation based on solution values and
13/// derivatives at discrete time points. It provides smooth interpolation between
14/// checkpoints and is used internally by the checkpointing system for adjoint
15/// sensitivity analysis.
16#[derive(Clone)]
17pub struct HermiteInterpolator<V>
18where
19    V: Vector,
20{
21    ys: Vec<V>,
22    ydots: Vec<V>,
23    ts: Vec<V::T>,
24}
25
26impl<V> Default for HermiteInterpolator<V>
27where
28    V: Vector,
29{
30    fn default() -> Self {
31        HermiteInterpolator {
32            ys: Vec::new(),
33            ydots: Vec::new(),
34            ts: Vec::new(),
35        }
36    }
37}
38
39impl<V> HermiteInterpolator<V>
40where
41    V: Vector,
42{
43    /// Create a new Hermite interpolator with the given solution data.
44    ///
45    /// # Arguments
46    /// - `ys`: Vector of solution values at each time point
47    /// - `ydots`: Vector of solution derivatives (dy/dt) at each time point
48    /// - `ts`: Vector of time points (must be sorted)
49    ///
50    /// # Notes
51    /// All three vectors must have the same length. The time points should be
52    /// sorted in increasing or decreasing order for proper interpolation.
53    pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
54        HermiteInterpolator { ys, ydots, ts }
55    }
56
57    /// Get the last time point in the interpolator.
58    ///
59    /// # Returns
60    /// The last time value, or `None` if the interpolator is empty.
61    pub fn last_t(&self) -> Option<V::T> {
62        if self.ts.is_empty() {
63            return None;
64        }
65        Some(self.ts[self.ts.len() - 1])
66    }
67
68    /// Get the last time step size in the interpolator.
69    ///
70    /// # Returns
71    /// The difference between the last two time points, or `None` if there are
72    /// fewer than two time points.
73    pub fn last_h(&self) -> Option<V::T> {
74        if self.ts.len() < 2 {
75            return None;
76        }
77        Some(self.ts[self.ts.len() - 1] - self.ts[self.ts.len() - 2])
78    }
79
80    /// Reset the interpolator by solving the ODE between two checkpointed states.
81    ///
82    /// This method clears the current interpolation data and re-solves the ODE
83    /// from `state0` to `state1`, storing all intermediate solution points.
84    ///
85    /// # Arguments
86    /// - `solver`: The ODE solver to use for integration
87    /// - `state0`: The initial state (starting point)
88    /// - `state1`: The final state (target point)
89    ///
90    /// # Returns
91    /// Ok(()) on success, or an error if the solver fails.
92    pub fn reset<'a, Eqn, Method, State>(
93        &mut self,
94        solver: &mut Method,
95        state0: &State,
96        state1: &State,
97    ) -> Result<(), DiffsolError>
98    where
99        Eqn: OdeEquations<V = V, T = V::T> + 'a,
100        Method: OdeSolverMethod<'a, Eqn, State = State>,
101        State: OdeSolverState<V>,
102    {
103        let state0_ref = state0.as_ref();
104        let state1_ref = state1.as_ref();
105        self.ys.clear();
106        self.ydots.clear();
107        self.ts.clear();
108        self.ys.push(state0_ref.y.clone());
109        self.ydots.push(state0_ref.dy.clone());
110        self.ts.push(state0_ref.t);
111
112        solver.set_state(state0.clone());
113        while solver.state().t < state1_ref.t {
114            solver.step()?;
115            self.ys.push(solver.state().y.clone());
116            self.ydots.push(solver.state().dy.clone());
117            self.ts.push(solver.state().t);
118        }
119        Ok(())
120    }
121
122    /// Interpolate the solution at a given time point.
123    ///
124    /// Uses cubic Hermite interpolation to compute the solution value at time `t`.
125    ///
126    /// # Arguments
127    /// - `t`: The time at which to interpolate
128    /// - `y`: Output vector to store the interpolated solution
129    ///
130    /// # Returns
131    /// `Some(())` if the interpolation succeeded (t is within range), `None` if
132    /// t is outside the range of stored time points.
133    pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> {
134        if t < self.ts[0] || t > self.ts[self.ts.len() - 1] {
135            return None;
136        }
137        if t == self.ts[0] {
138            y.copy_from(&self.ys[0]);
139            return Some(());
140        }
141        let idx = self
142            .ts
143            .iter()
144            .position(|&t0| t0 > t)
145            .unwrap_or(self.ts.len() - 1);
146        let t0 = self.ts[idx - 1];
147        let t1 = self.ts[idx];
148        let h = t1 - t0;
149        let theta = (t - t0) / h;
150        let u0 = &self.ys[idx - 1];
151        let u1 = &self.ys[idx];
152        let f0 = &self.ydots[idx - 1];
153        let f1 = &self.ydots[idx];
154
155        y.copy_from(u0);
156        y.axpy(V::T::one(), u1, -V::T::one());
157        y.axpy(
158            h * (theta - V::T::from_f64(1.0).unwrap()),
159            f0,
160            V::T::one() - V::T::from_f64(2.0).unwrap() * theta,
161        );
162        y.axpy(h * theta, f1, V::T::one());
163        y.axpy(
164            V::T::from_f64(1.0).unwrap() - theta,
165            u0,
166            theta * (theta - V::T::from_f64(1.0).unwrap()),
167        );
168        y.axpy(theta, u1, V::T::one());
169        Some(())
170    }
171}
172
173/// Checkpointing system for adjoint sensitivity analysis.
174///
175/// This struct manages checkpoints of an ODE solution for use in adjoint sensitivity
176/// computation. It stores solution states at discrete checkpoints and uses Hermite
177/// interpolation between them to provide solution values at arbitrary time points
178/// during the backward adjoint solve.
179///
180/// The checkpointing system uses a two-level interpolation strategy:
181/// 1. Checkpoints: Discrete saved states of the ODE solution
182/// 2. Segments: Hermite interpolators that densely sample between adjacent checkpoints
183///
184/// When interpolating at a time point, the system first checks if it's in the current
185/// segment. If not, it re-solves the ODE between the appropriate checkpoints to create
186/// a new segment, then interpolates within that segment.
187///
188/// # Type Parameters
189/// - `Eqn`: The ODE equations type (inferred from the solver)
190/// - `Method`: The forward solver method type (inferred from the solver)
191pub struct Checkpointing<'a, Eqn, Method>
192where
193    Method: OdeSolverMethod<'a, Eqn>,
194    Eqn: OdeEquations,
195{
196    checkpoints: Vec<Method::State>,
197    segment: RefCell<HermiteInterpolator<Eqn::V>>,
198    previous_segment: RefCell<Option<HermiteInterpolator<Eqn::V>>>,
199    solver: RefCell<Method>,
200}
201
202impl<'a, Eqn, Method> Clone for Checkpointing<'a, Eqn, Method>
203where
204    Method: OdeSolverMethod<'a, Eqn>,
205    Eqn: OdeEquations,
206{
207    fn clone(&self) -> Self {
208        Checkpointing {
209            checkpoints: self.checkpoints.clone(),
210            segment: RefCell::new(self.segment.borrow().clone()),
211            previous_segment: RefCell::new(self.previous_segment.borrow().clone()),
212            solver: RefCell::new(self.solver.borrow().clone()),
213        }
214    }
215}
216
217impl<'a, Eqn, Method> Checkpointing<'a, Eqn, Method>
218where
219    Method: OdeSolverMethod<'a, Eqn>,
220    Eqn: OdeEquations,
221{
222    /// Create a new checkpointing system.
223    ///
224    /// # Arguments
225    /// - `solver`: The forward solver to use for re-solving segments
226    /// - `start_idx`: Index of the checkpoint to start from (must be < checkpoints.len() - 1)
227    /// - `checkpoints`: Vector of saved solution states (must have at least 2 elements)
228    /// - `segment`: Optional pre-computed Hermite interpolator for the initial segment given by `start_idx`.
229    ///   If `None`, the segment between `checkpoints[start_idx]` and `checkpoints[start_idx+1]`
230    ///   will be computed automatically.
231    ///
232    /// # Panics
233    /// Panics if `checkpoints.len() < 2` or if `start_idx >= checkpoints.len() - 1`.
234    ///
235    pub fn new(
236        mut solver: Method,
237        start_idx: usize,
238        checkpoints: Vec<Method::State>,
239        segment: Option<HermiteInterpolator<Eqn::V>>,
240    ) -> Self {
241        if checkpoints.len() < 2 {
242            panic!("Checkpoints must have at least 2 elements");
243        }
244        if start_idx >= checkpoints.len() - 1 {
245            panic!("start_idx must be less than checkpoints.len() - 1");
246        }
247        let segment = segment.unwrap_or_else(|| {
248            let mut segment = HermiteInterpolator::default();
249            segment
250                .reset(
251                    &mut solver,
252                    &checkpoints[start_idx],
253                    &checkpoints[start_idx + 1],
254                )
255                .unwrap();
256            segment
257        });
258        let segment = RefCell::new(segment);
259        let previous_segment = RefCell::new(None);
260        let solver = RefCell::new(solver);
261        Checkpointing {
262            checkpoints,
263            segment,
264            previous_segment,
265            solver,
266        }
267    }
268
269    /// Get the last (most recent) time point in the current segment.
270    ///
271    /// # Returns
272    /// The last time value in the current interpolation segment.
273    ///
274    /// # Panics
275    /// Panics if the segment is empty (should never happen with valid construction).
276    pub fn last_t(&self) -> Eqn::T {
277        self.segment
278            .borrow()
279            .last_t()
280            .expect("segment should not be empty")
281    }
282
283    /// Get the last time step size in the current segment.
284    ///
285    /// # Returns
286    /// The difference between the last two time points in the current segment,
287    /// or `None` if the segment has fewer than two points.
288    pub fn last_h(&self) -> Option<Eqn::T> {
289        self.segment.borrow().last_h()
290    }
291
292    /// Get a reference to the ODE problem associated with this checkpointing system.
293    ///
294    /// # Returns
295    /// A reference to the `OdeSolverProblem` that defines the ODE equations,
296    /// tolerances, and other solver parameters.
297    pub fn problem(&self) -> &'a OdeSolverProblem<Eqn> {
298        self.solver.borrow().problem()
299    }
300
301    /// Interpolate the solution at a given time point.
302    ///
303    /// This method provides the forward solution value at time `t` for use during
304    /// the backward adjoint solve. It first checks if `t` is within the current
305    /// or previous segment. If not, it identifies the appropriate checkpoint interval,
306    /// re-solves the ODE between those checkpoints to create a new segment, and then
307    /// interpolates within that segment.
308    ///
309    /// # Arguments
310    /// - `t`: The time at which to interpolate the solution
311    /// - `y`: Output vector to store the interpolated solution
312    ///
313    /// # Returns
314    /// - `Ok(())` if interpolation succeeded
315    /// - `Err(...)` if `t` is outside the range of checkpoints (beyond roundoff tolerance)
316    ///   or if re-solving the ODE fails
317    ///
318    /// # Notes
319    /// Small deviations from the checkpoint range (within roundoff error) are automatically
320    /// snapped to the nearest checkpoint boundary.
321    pub fn interpolate(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> {
322        {
323            let segment = self.segment.borrow();
324            if segment.interpolate(t, y).is_some() {
325                return Ok(());
326            }
327        }
328
329        {
330            let previous_segment = self.previous_segment.borrow();
331            if let Some(previous_segment) = previous_segment.as_ref() {
332                if previous_segment.interpolate(t, y).is_some() {
333                    return Ok(());
334                }
335            }
336        }
337
338        // if t is before first segment or after last segment, return error
339        let h = self.last_h().unwrap_or(Eqn::T::one());
340        let troundoff = Eqn::T::from_f64(100.0).unwrap() * Eqn::T::EPSILON * (abs(t) + abs(h));
341        if t < self.checkpoints[0].as_ref().t - troundoff
342            || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + troundoff
343        {
344            return Err(other_error!("t is outside of the checkpoints"));
345        }
346
347        // snap t to nearest checkpoint if outside of range
348        let t = if t < self.checkpoints[0].as_ref().t {
349            self.checkpoints[0].as_ref().t
350        } else if t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t {
351            self.checkpoints[self.checkpoints.len() - 1].as_ref().t
352        } else {
353            t
354        };
355
356        // else find idx of segment
357        let idx = self
358            .checkpoints
359            .iter()
360            .skip(1)
361            .position(|state| state.as_ref().t > t)
362            .expect("t is not in checkpoints");
363        if self.previous_segment.borrow().is_none() {
364            self.previous_segment
365                .replace(Some(HermiteInterpolator::default()));
366        }
367        let mut solver = self.solver.borrow_mut();
368        let mut previous_segment = self.previous_segment.borrow_mut();
369        let mut segment = self.segment.borrow_mut();
370        previous_segment.as_mut().unwrap().reset(
371            &mut *solver,
372            &self.checkpoints[idx],
373            &self.checkpoints[idx + 1],
374        )?;
375        std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap());
376        segment.interpolate(t, y).unwrap();
377        Ok(())
378    }
379}
380
381#[cfg(test)]
382mod tests {
383
384    use crate::{
385        matrix::dense_nalgebra_serial::NalgebraMat,
386        ode_equations::test_models::robertson::robertson, Context, NalgebraLU, OdeEquations,
387        OdeSolverMethod, Op, Vector,
388    };
389
390    use super::{Checkpointing, HermiteInterpolator};
391
392    #[test]
393    fn test_checkpointing() {
394        type M = NalgebraMat<f64>;
395        type LS = NalgebraLU<f64>;
396        let (problem, soln) = robertson::<M>(false);
397        let t_final = soln.solution_points.last().unwrap().t;
398        let n_steps = 30;
399        let mut solver = problem.bdf::<LS>().unwrap();
400        let mut checkpoints = vec![solver.checkpoint()];
401        let mut i = 0;
402        let mut ys = Vec::new();
403        let mut ts = Vec::new();
404        let mut ydots = Vec::new();
405        while solver.state().t < t_final {
406            ts.push(solver.state().t);
407            ys.push(solver.state().y.clone());
408            ydots.push(solver.state().dy.clone());
409            solver.step().unwrap();
410            i += 1;
411            if i % n_steps == 0 && solver.state().t < t_final {
412                checkpoints.push(solver.checkpoint());
413                ts.clear();
414                ys.clear();
415                ydots.clear();
416            }
417        }
418        checkpoints.push(solver.checkpoint());
419        let segment = HermiteInterpolator::new(ys, ydots, ts);
420        let checkpointer =
421            Checkpointing::new(solver, checkpoints.len() - 2, checkpoints, Some(segment));
422        let mut y = problem.context().vector_zeros(problem.eqn.rhs().nstates());
423        for point in soln.solution_points.iter().rev() {
424            checkpointer.interpolate(point.t, &mut y).unwrap();
425            y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0);
426        }
427    }
428}