1use num_traits::FromPrimitive;
2use std::cell::RefCell;
3
4use crate::{
5 error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverProblem,
6 OdeSolverState, Scalar, Vector,
7};
8use num_traits::{abs, One};
9
10#[derive(Clone)]
17pub struct HermiteInterpolator<V>
18where
19 V: Vector,
20{
21 ys: Vec<V>,
22 ydots: Vec<V>,
23 ts: Vec<V::T>,
24}
25
26impl<V> Default for HermiteInterpolator<V>
27where
28 V: Vector,
29{
30 fn default() -> Self {
31 HermiteInterpolator {
32 ys: Vec::new(),
33 ydots: Vec::new(),
34 ts: Vec::new(),
35 }
36 }
37}
38
39impl<V> HermiteInterpolator<V>
40where
41 V: Vector,
42{
43 pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
54 HermiteInterpolator { ys, ydots, ts }
55 }
56
57 pub fn last_t(&self) -> Option<V::T> {
62 if self.ts.is_empty() {
63 return None;
64 }
65 Some(self.ts[self.ts.len() - 1])
66 }
67
68 pub fn last_h(&self) -> Option<V::T> {
74 if self.ts.len() < 2 {
75 return None;
76 }
77 Some(self.ts[self.ts.len() - 1] - self.ts[self.ts.len() - 2])
78 }
79
80 pub fn reset<'a, Eqn, Method, State>(
93 &mut self,
94 solver: &mut Method,
95 state0: &State,
96 state1: &State,
97 ) -> Result<(), DiffsolError>
98 where
99 Eqn: OdeEquations<V = V, T = V::T> + 'a,
100 Method: OdeSolverMethod<'a, Eqn, State = State>,
101 State: OdeSolverState<V>,
102 {
103 let state0_ref = state0.as_ref();
104 let state1_ref = state1.as_ref();
105 self.ys.clear();
106 self.ydots.clear();
107 self.ts.clear();
108 self.ys.push(state0_ref.y.clone());
109 self.ydots.push(state0_ref.dy.clone());
110 self.ts.push(state0_ref.t);
111
112 solver.set_state(state0.clone());
113 while solver.state().t < state1_ref.t {
114 solver.step()?;
115 self.ys.push(solver.state().y.clone());
116 self.ydots.push(solver.state().dy.clone());
117 self.ts.push(solver.state().t);
118 }
119 Ok(())
120 }
121
122 pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> {
134 if t < self.ts[0] || t > self.ts[self.ts.len() - 1] {
135 return None;
136 }
137 if t == self.ts[0] {
138 y.copy_from(&self.ys[0]);
139 return Some(());
140 }
141 let idx = self
142 .ts
143 .iter()
144 .position(|&t0| t0 > t)
145 .unwrap_or(self.ts.len() - 1);
146 let t0 = self.ts[idx - 1];
147 let t1 = self.ts[idx];
148 let h = t1 - t0;
149 let theta = (t - t0) / h;
150 let u0 = &self.ys[idx - 1];
151 let u1 = &self.ys[idx];
152 let f0 = &self.ydots[idx - 1];
153 let f1 = &self.ydots[idx];
154
155 y.copy_from(u0);
156 y.axpy(V::T::one(), u1, -V::T::one());
157 y.axpy(
158 h * (theta - V::T::from_f64(1.0).unwrap()),
159 f0,
160 V::T::one() - V::T::from_f64(2.0).unwrap() * theta,
161 );
162 y.axpy(h * theta, f1, V::T::one());
163 y.axpy(
164 V::T::from_f64(1.0).unwrap() - theta,
165 u0,
166 theta * (theta - V::T::from_f64(1.0).unwrap()),
167 );
168 y.axpy(theta, u1, V::T::one());
169 Some(())
170 }
171}
172
173pub struct Checkpointing<'a, Eqn, Method>
192where
193 Method: OdeSolverMethod<'a, Eqn>,
194 Eqn: OdeEquations,
195{
196 checkpoints: Vec<Method::State>,
197 segment: RefCell<HermiteInterpolator<Eqn::V>>,
198 previous_segment: RefCell<Option<HermiteInterpolator<Eqn::V>>>,
199 solver: RefCell<Method>,
200}
201
202impl<'a, Eqn, Method> Clone for Checkpointing<'a, Eqn, Method>
203where
204 Method: OdeSolverMethod<'a, Eqn>,
205 Eqn: OdeEquations,
206{
207 fn clone(&self) -> Self {
208 Checkpointing {
209 checkpoints: self.checkpoints.clone(),
210 segment: RefCell::new(self.segment.borrow().clone()),
211 previous_segment: RefCell::new(self.previous_segment.borrow().clone()),
212 solver: RefCell::new(self.solver.borrow().clone()),
213 }
214 }
215}
216
217impl<'a, Eqn, Method> Checkpointing<'a, Eqn, Method>
218where
219 Method: OdeSolverMethod<'a, Eqn>,
220 Eqn: OdeEquations,
221{
222 pub fn new(
236 mut solver: Method,
237 start_idx: usize,
238 checkpoints: Vec<Method::State>,
239 segment: Option<HermiteInterpolator<Eqn::V>>,
240 ) -> Self {
241 if checkpoints.len() < 2 {
242 panic!("Checkpoints must have at least 2 elements");
243 }
244 if start_idx >= checkpoints.len() - 1 {
245 panic!("start_idx must be less than checkpoints.len() - 1");
246 }
247 let segment = segment.unwrap_or_else(|| {
248 let mut segment = HermiteInterpolator::default();
249 segment
250 .reset(
251 &mut solver,
252 &checkpoints[start_idx],
253 &checkpoints[start_idx + 1],
254 )
255 .unwrap();
256 segment
257 });
258 let segment = RefCell::new(segment);
259 let previous_segment = RefCell::new(None);
260 let solver = RefCell::new(solver);
261 Checkpointing {
262 checkpoints,
263 segment,
264 previous_segment,
265 solver,
266 }
267 }
268
269 pub fn last_t(&self) -> Eqn::T {
277 self.segment
278 .borrow()
279 .last_t()
280 .expect("segment should not be empty")
281 }
282
283 pub fn last_h(&self) -> Option<Eqn::T> {
289 self.segment.borrow().last_h()
290 }
291
292 pub fn problem(&self) -> &'a OdeSolverProblem<Eqn> {
298 self.solver.borrow().problem()
299 }
300
301 pub fn interpolate(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> {
322 {
323 let segment = self.segment.borrow();
324 if segment.interpolate(t, y).is_some() {
325 return Ok(());
326 }
327 }
328
329 {
330 let previous_segment = self.previous_segment.borrow();
331 if let Some(previous_segment) = previous_segment.as_ref() {
332 if previous_segment.interpolate(t, y).is_some() {
333 return Ok(());
334 }
335 }
336 }
337
338 let h = self.last_h().unwrap_or(Eqn::T::one());
340 let troundoff = Eqn::T::from_f64(100.0).unwrap() * Eqn::T::EPSILON * (abs(t) + abs(h));
341 if t < self.checkpoints[0].as_ref().t - troundoff
342 || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + troundoff
343 {
344 return Err(other_error!("t is outside of the checkpoints"));
345 }
346
347 let t = if t < self.checkpoints[0].as_ref().t {
349 self.checkpoints[0].as_ref().t
350 } else if t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t {
351 self.checkpoints[self.checkpoints.len() - 1].as_ref().t
352 } else {
353 t
354 };
355
356 let idx = self
358 .checkpoints
359 .iter()
360 .skip(1)
361 .position(|state| state.as_ref().t > t)
362 .expect("t is not in checkpoints");
363 if self.previous_segment.borrow().is_none() {
364 self.previous_segment
365 .replace(Some(HermiteInterpolator::default()));
366 }
367 let mut solver = self.solver.borrow_mut();
368 let mut previous_segment = self.previous_segment.borrow_mut();
369 let mut segment = self.segment.borrow_mut();
370 previous_segment.as_mut().unwrap().reset(
371 &mut *solver,
372 &self.checkpoints[idx],
373 &self.checkpoints[idx + 1],
374 )?;
375 std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap());
376 segment.interpolate(t, y).unwrap();
377 Ok(())
378 }
379}
380
381#[cfg(test)]
382mod tests {
383
384 use crate::{
385 matrix::dense_nalgebra_serial::NalgebraMat,
386 ode_equations::test_models::robertson::robertson, Context, NalgebraLU, OdeEquations,
387 OdeSolverMethod, Op, Vector,
388 };
389
390 use super::{Checkpointing, HermiteInterpolator};
391
392 #[test]
393 fn test_checkpointing() {
394 type M = NalgebraMat<f64>;
395 type LS = NalgebraLU<f64>;
396 let (problem, soln) = robertson::<M>(false);
397 let t_final = soln.solution_points.last().unwrap().t;
398 let n_steps = 30;
399 let mut solver = problem.bdf::<LS>().unwrap();
400 let mut checkpoints = vec![solver.checkpoint()];
401 let mut i = 0;
402 let mut ys = Vec::new();
403 let mut ts = Vec::new();
404 let mut ydots = Vec::new();
405 while solver.state().t < t_final {
406 ts.push(solver.state().t);
407 ys.push(solver.state().y.clone());
408 ydots.push(solver.state().dy.clone());
409 solver.step().unwrap();
410 i += 1;
411 if i % n_steps == 0 && solver.state().t < t_final {
412 checkpoints.push(solver.checkpoint());
413 ts.clear();
414 ys.clear();
415 ydots.clear();
416 }
417 }
418 checkpoints.push(solver.checkpoint());
419 let segment = HermiteInterpolator::new(ys, ydots, ts);
420 let checkpointer =
421 Checkpointing::new(solver, checkpoints.len() - 2, checkpoints, Some(segment));
422 let mut y = problem.context().vector_zeros(problem.eqn.rhs().nstates());
423 for point in soln.solution_points.iter().rev() {
424 checkpointer.interpolate(point.t, &mut y).unwrap();
425 y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0);
426 }
427 }
428}