1use num_traits::FromPrimitive;
2use std::cell::RefCell;
3
4use crate::{
5 error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverState, Scalar, Vector,
6};
7use num_traits::{abs, One};
8
9#[derive(Clone)]
16pub struct HermiteInterpolator<V>
17where
18 V: Vector,
19{
20 ys: Vec<V>,
21 ydots: Vec<V>,
22 ts: Vec<V::T>,
23}
24
25impl<V> Default for HermiteInterpolator<V>
26where
27 V: Vector,
28{
29 fn default() -> Self {
30 HermiteInterpolator {
31 ys: Vec::new(),
32 ydots: Vec::new(),
33 ts: Vec::new(),
34 }
35 }
36}
37
38impl<V> HermiteInterpolator<V>
39where
40 V: Vector,
41{
42 pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
53 HermiteInterpolator { ys, ydots, ts }
54 }
55
56 pub fn last_t(&self) -> Option<V::T> {
61 if self.ts.is_empty() {
62 return None;
63 }
64 Some(self.ts[self.ts.len() - 1])
65 }
66
67 pub fn last_h(&self) -> Option<V::T> {
73 if self.ts.len() < 2 {
74 return None;
75 }
76 Some(self.ts[self.ts.len() - 1] - self.ts[self.ts.len() - 2])
77 }
78
79 pub fn reset<'a, Eqn, Method, State>(
92 &mut self,
93 solver: &mut Method,
94 state0: &State,
95 state1: &State,
96 ) -> Result<(), DiffsolError>
97 where
98 Eqn: OdeEquations<V = V, T = V::T> + 'a,
99 Method: OdeSolverMethod<'a, Eqn, State = State>,
100 State: OdeSolverState<V>,
101 {
102 let state0_ref = state0.as_ref();
103 let state1_ref = state1.as_ref();
104 self.ys.clear();
105 self.ydots.clear();
106 self.ts.clear();
107 self.ys.push(state0_ref.y.clone());
108 self.ydots.push(state0_ref.dy.clone());
109 self.ts.push(state0_ref.t);
110
111 solver.set_state(state0.clone());
112 while solver.state().t < state1_ref.t {
113 solver.step()?;
114 self.ys.push(solver.state().y.clone());
115 self.ydots.push(solver.state().dy.clone());
116 self.ts.push(solver.state().t);
117 }
118 Ok(())
119 }
120
121 pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> {
133 if t < self.ts[0] || t > self.ts[self.ts.len() - 1] {
134 return None;
135 }
136 if t == self.ts[0] {
137 y.copy_from(&self.ys[0]);
138 return Some(());
139 }
140 let idx = self
141 .ts
142 .iter()
143 .position(|&t0| t0 > t)
144 .unwrap_or(self.ts.len() - 1);
145 let t0 = self.ts[idx - 1];
146 let t1 = self.ts[idx];
147 let h = t1 - t0;
148 let theta = (t - t0) / h;
149 let u0 = &self.ys[idx - 1];
150 let u1 = &self.ys[idx];
151 let f0 = &self.ydots[idx - 1];
152 let f1 = &self.ydots[idx];
153
154 y.copy_from(u0);
155 y.axpy(V::T::one(), u1, -V::T::one());
156 y.axpy(
157 h * (theta - V::T::from_f64(1.0).unwrap()),
158 f0,
159 V::T::one() - V::T::from_f64(2.0).unwrap() * theta,
160 );
161 y.axpy(h * theta, f1, V::T::one());
162 y.axpy(
163 V::T::from_f64(1.0).unwrap() - theta,
164 u0,
165 theta * (theta - V::T::from_f64(1.0).unwrap()),
166 );
167 y.axpy(theta, u1, V::T::one());
168 Some(())
169 }
170}
171
172pub struct Checkpointing<Eqn, State>
188where
189 Eqn: OdeEquations,
190 State: OdeSolverState<Eqn::V>,
191{
192 checkpoints: Vec<State>,
193 segment: RefCell<HermiteInterpolator<Eqn::V>>,
194 previous_segment: RefCell<Option<HermiteInterpolator<Eqn::V>>>,
195 terminal_reset_root_idx: Option<usize>,
196}
197
198pub type CheckpointingPath<Eqn, State> = Vec<Checkpointing<Eqn, State>>;
199
200impl<Eqn, State> Clone for Checkpointing<Eqn, State>
201where
202 Eqn: OdeEquations,
203 State: OdeSolverState<Eqn::V>,
204{
205 fn clone(&self) -> Self {
206 Checkpointing {
207 checkpoints: self.checkpoints.clone(),
208 segment: RefCell::new(self.segment.borrow().clone()),
209 previous_segment: RefCell::new(self.previous_segment.borrow().clone()),
210 terminal_reset_root_idx: self.terminal_reset_root_idx,
211 }
212 }
213}
214
215impl<Eqn, State> Checkpointing<Eqn, State>
216where
217 Eqn: OdeEquations,
218 State: OdeSolverState<Eqn::V>,
219{
220 pub fn new<'a, Method>(
234 solver: Option<&mut Method>,
235 start_idx: usize,
236 checkpoints: Vec<State>,
237 segment: Option<HermiteInterpolator<Eqn::V>>,
238 ) -> Self
239 where
240 Eqn: 'a,
241 Method: OdeSolverMethod<'a, Eqn, State = State>,
242 {
243 if checkpoints.len() < 2 {
244 panic!("Checkpoints must have at least 2 elements");
245 }
246 if start_idx >= checkpoints.len() - 1 {
247 panic!("start_idx must be less than checkpoints.len() - 1");
248 }
249 let segment = segment.unwrap_or_else(|| {
250 let solver =
251 solver.expect("solver is required when no initial checkpoint segment is provided");
252 let mut segment = HermiteInterpolator::default();
253 segment
254 .reset(solver, &checkpoints[start_idx], &checkpoints[start_idx + 1])
255 .unwrap();
256 segment
257 });
258 let segment = RefCell::new(segment);
259 let previous_segment = RefCell::new(None);
260 Checkpointing {
261 checkpoints,
262 segment,
263 previous_segment,
264 terminal_reset_root_idx: None,
265 }
266 }
267
268 pub(crate) fn set_terminal_reset_root_idx(&mut self, root_idx: usize) {
269 self.terminal_reset_root_idx = Some(root_idx);
270 }
271
272 #[cfg(test)]
273 pub(crate) fn clear_terminal_reset_root_idx(&mut self) {
274 self.terminal_reset_root_idx = None;
275 }
276
277 pub fn terminal_reset_root_idx(&self) -> Option<usize> {
281 self.terminal_reset_root_idx
282 }
283
284 pub fn first_checkpoint(&self) -> &State {
285 &self.checkpoints[0]
286 }
287
288 pub fn last_checkpoint(&self) -> &State {
289 &self.checkpoints[self.checkpoints.len() - 1]
290 }
291
292 pub fn last_t(&self) -> Eqn::T {
300 self.segment
301 .borrow()
302 .last_t()
303 .expect("segment should not be empty")
304 }
305
306 pub fn first_t(&self) -> Eqn::T {
307 self.checkpoints[0].as_ref().t
308 }
309
310 pub fn end_t(&self) -> Eqn::T {
311 self.checkpoints[self.checkpoints.len() - 1].as_ref().t
312 }
313
314 pub fn last_h(&self) -> Option<Eqn::T> {
320 self.segment.borrow().last_h()
321 }
322
323 pub fn interpolate<'a, Method>(
344 &self,
345 solver: Option<&mut Method>,
346 t: Eqn::T,
347 y: &mut Eqn::V,
348 ) -> Result<(), DiffsolError>
349 where
350 Eqn: 'a,
351 Method: OdeSolverMethod<'a, Eqn, State = State>,
352 {
353 {
354 let segment = self.segment.borrow();
355 if segment.interpolate(t, y).is_some() {
356 return Ok(());
357 }
358 }
359
360 {
361 let previous_segment = self.previous_segment.borrow();
362 if let Some(previous_segment) = previous_segment.as_ref() {
363 if previous_segment.interpolate(t, y).is_some() {
364 return Ok(());
365 }
366 }
367 }
368
369 let h = self.last_h().unwrap_or(Eqn::T::one());
371 let troundoff = Eqn::T::from_f64(100.0).unwrap() * Eqn::T::EPSILON * (abs(t) + abs(h));
372 if t < self.checkpoints[0].as_ref().t - troundoff
373 || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + troundoff
374 {
375 return Err(other_error!("t is outside of the checkpoints"));
376 }
377
378 let t = if t < self.checkpoints[0].as_ref().t {
380 self.checkpoints[0].as_ref().t
381 } else if t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t {
382 self.checkpoints[self.checkpoints.len() - 1].as_ref().t
383 } else {
384 t
385 };
386
387 let idx = self
389 .checkpoints
390 .iter()
391 .skip(1)
392 .position(|state| state.as_ref().t > t)
393 .expect("t is not in checkpoints");
394 let solver = solver.ok_or_else(|| {
395 other_error!("solver is required to rebuild checkpoint interpolation segments")
396 })?;
397 if self.previous_segment.borrow().is_none() {
398 self.previous_segment
399 .replace(Some(HermiteInterpolator::default()));
400 }
401 let mut previous_segment = self.previous_segment.borrow_mut();
402 let mut segment = self.segment.borrow_mut();
403 previous_segment.as_mut().unwrap().reset(
404 solver,
405 &self.checkpoints[idx],
406 &self.checkpoints[idx + 1],
407 )?;
408 std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap());
409 segment.interpolate(t, y).unwrap();
410 Ok(())
411 }
412}
413
414#[cfg(test)]
415mod tests {
416
417 use crate::{
418 matrix::dense_nalgebra_serial::NalgebraMat,
419 ode_equations::test_models::robertson::robertson, NalgebraLU, NoCheckpointingSolver,
420 OdeEquations, OdeSolverMethod, Vector,
421 };
422
423 use super::{Checkpointing, HermiteInterpolator};
424
425 #[test]
426 fn test_checkpointing() {
427 type M = NalgebraMat<f64>;
428 type LS = NalgebraLU<f64>;
429 let (problem, soln) = robertson::<M>(false);
430 let t_final = soln.solution_points.last().unwrap().t;
431 let n_steps = 30;
432 let mut solver = problem.bdf::<LS>().unwrap();
433 let mut checkpoints = vec![solver.checkpoint()];
434 let mut i = 0;
435 let mut ys = Vec::new();
436 let mut ts = Vec::new();
437 let mut ydots = Vec::new();
438 while solver.state().t < t_final {
439 ts.push(solver.state().t);
440 ys.push(solver.state().y.clone());
441 ydots.push(solver.state().dy.clone());
442 solver.step().unwrap();
443 i += 1;
444 if i % n_steps == 0 && solver.state().t < t_final {
445 checkpoints.push(solver.checkpoint());
446 ts.clear();
447 ys.clear();
448 ydots.clear();
449 }
450 }
451 checkpoints.push(solver.checkpoint());
452 let segment = HermiteInterpolator::new(ys, ydots, ts);
453 let checkpointer = Checkpointing::new(
454 Some(&mut solver),
455 checkpoints.len() - 2,
456 checkpoints,
457 Some(segment),
458 );
459 let mut y = soln.solution_points.last().unwrap().state.clone();
460 checkpointer
461 .interpolate::<NoCheckpointingSolver<_, _>>(None, checkpointer.last_t(), &mut y)
462 .unwrap();
463 let err = checkpointer
464 .interpolate::<NoCheckpointingSolver<_, _>>(None, soln.solution_points[0].t, &mut y)
465 .unwrap_err();
466 assert!(err
467 .to_string()
468 .contains("solver is required to rebuild checkpoint interpolation segments"));
469 for point in soln.solution_points.iter().rev() {
470 checkpointer
471 .interpolate(Some(&mut solver), point.t, &mut y)
472 .unwrap();
473 y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0);
474 }
475 }
476
477 #[test]
478 #[should_panic(expected = "solver is required when no initial checkpoint segment is provided")]
479 fn test_checkpointing_requires_solver_without_initial_segment() {
480 fn new_without_solver_panics<'a, Eqn, Method>(
481 _solver: &Method,
482 checkpoints: Vec<Method::State>,
483 ) where
484 Eqn: OdeEquations + 'a,
485 Method: OdeSolverMethod<'a, Eqn>,
486 {
487 let _ = Checkpointing::<Eqn, Method::State>::new::<Method>(None, 0, checkpoints, None);
488 }
489
490 type M = NalgebraMat<f64>;
491 type LS = NalgebraLU<f64>;
492 let (problem, _soln) = robertson::<M>(false);
493 let mut solver = problem.bdf::<LS>().unwrap();
494 let checkpoints = vec![solver.checkpoint(), solver.checkpoint()];
495 new_without_solver_panics(&solver, checkpoints);
496 }
497}