1use std::cell::RefCell;
2
3use crate::{
4 error::DiffsolError, other_error, OdeEquations, OdeSolverMethod, OdeSolverProblem,
5 OdeSolverState, Scalar, Vector,
6};
7use num_traits::{abs, One};
8
9#[derive(Clone)]
10pub struct HermiteInterpolator<V>
11where
12 V: Vector,
13{
14 ys: Vec<V>,
15 ydots: Vec<V>,
16 ts: Vec<V::T>,
17}
18
19impl<V> Default for HermiteInterpolator<V>
20where
21 V: Vector,
22{
23 fn default() -> Self {
24 HermiteInterpolator {
25 ys: Vec::new(),
26 ydots: Vec::new(),
27 ts: Vec::new(),
28 }
29 }
30}
31
32impl<V> HermiteInterpolator<V>
33where
34 V: Vector,
35{
36 pub fn new(ys: Vec<V>, ydots: Vec<V>, ts: Vec<V::T>) -> Self {
37 HermiteInterpolator { ys, ydots, ts }
38 }
39 pub fn last_t(&self) -> Option<V::T> {
40 if self.ts.is_empty() {
41 return None;
42 }
43 Some(self.ts[self.ts.len() - 1])
44 }
45 pub fn last_h(&self) -> Option<V::T> {
46 if self.ts.len() < 2 {
47 return None;
48 }
49 Some(self.ts[self.ts.len() - 1] - self.ts[self.ts.len() - 2])
50 }
51 pub fn reset<'a, Eqn, Method, State>(
52 &mut self,
53 solver: &mut Method,
54 state0: &State,
55 state1: &State,
56 ) -> Result<(), DiffsolError>
57 where
58 Eqn: OdeEquations<V = V, T = V::T> + 'a,
59 Method: OdeSolverMethod<'a, Eqn, State = State>,
60 State: OdeSolverState<V>,
61 {
62 let state0_ref = state0.as_ref();
63 let state1_ref = state1.as_ref();
64 self.ys.clear();
65 self.ydots.clear();
66 self.ts.clear();
67 self.ys.push(state0_ref.y.clone());
68 self.ydots.push(state0_ref.dy.clone());
69 self.ts.push(state0_ref.t);
70
71 solver.set_state(state0.clone());
72 while solver.state().t < state1_ref.t {
73 solver.step()?;
74 self.ys.push(solver.state().y.clone());
75 self.ydots.push(solver.state().dy.clone());
76 self.ts.push(solver.state().t);
77 }
78 Ok(())
79 }
80
81 pub fn interpolate(&self, t: V::T, y: &mut V) -> Option<()> {
82 if t < self.ts[0] || t > self.ts[self.ts.len() - 1] {
83 return None;
84 }
85 if t == self.ts[0] {
86 y.copy_from(&self.ys[0]);
87 return Some(());
88 }
89 let idx = self
90 .ts
91 .iter()
92 .position(|&t0| t0 > t)
93 .unwrap_or(self.ts.len() - 1);
94 let t0 = self.ts[idx - 1];
95 let t1 = self.ts[idx];
96 let h = t1 - t0;
97 let theta = (t - t0) / h;
98 let u0 = &self.ys[idx - 1];
99 let u1 = &self.ys[idx];
100 let f0 = &self.ydots[idx - 1];
101 let f1 = &self.ydots[idx];
102
103 y.copy_from(u0);
104 y.axpy(V::T::one(), u1, -V::T::one());
105 y.axpy(
106 h * (theta - V::T::from(1.0)),
107 f0,
108 V::T::one() - V::T::from(2.0) * theta,
109 );
110 y.axpy(h * theta, f1, V::T::one());
111 y.axpy(
112 V::T::from(1.0) - theta,
113 u0,
114 theta * (theta - V::T::from(1.0)),
115 );
116 y.axpy(theta, u1, V::T::one());
117 Some(())
118 }
119}
120
121pub struct Checkpointing<'a, Eqn, Method>
122where
123 Method: OdeSolverMethod<'a, Eqn>,
124 Eqn: OdeEquations,
125{
126 checkpoints: Vec<Method::State>,
127 segment: RefCell<HermiteInterpolator<Eqn::V>>,
128 previous_segment: RefCell<Option<HermiteInterpolator<Eqn::V>>>,
129 solver: RefCell<Method>,
130}
131
132impl<'a, Eqn, Method> Clone for Checkpointing<'a, Eqn, Method>
133where
134 Method: OdeSolverMethod<'a, Eqn>,
135 Eqn: OdeEquations,
136{
137 fn clone(&self) -> Self {
138 Checkpointing {
139 checkpoints: self.checkpoints.clone(),
140 segment: RefCell::new(self.segment.borrow().clone()),
141 previous_segment: RefCell::new(self.previous_segment.borrow().clone()),
142 solver: RefCell::new(self.solver.borrow().clone()),
143 }
144 }
145}
146
147impl<'a, Eqn, Method> Checkpointing<'a, Eqn, Method>
148where
149 Method: OdeSolverMethod<'a, Eqn>,
150 Eqn: OdeEquations,
151{
152 pub fn new(
153 mut solver: Method,
154 start_idx: usize,
155 checkpoints: Vec<Method::State>,
156 segment: Option<HermiteInterpolator<Eqn::V>>,
157 ) -> Self {
158 if checkpoints.len() < 2 {
159 panic!("Checkpoints must have at least 2 elements");
160 }
161 if start_idx >= checkpoints.len() - 1 {
162 panic!("start_idx must be less than checkpoints.len() - 1");
163 }
164 let segment = segment.unwrap_or_else(|| {
165 let mut segment = HermiteInterpolator::default();
166 segment
167 .reset(
168 &mut solver,
169 &checkpoints[start_idx],
170 &checkpoints[start_idx + 1],
171 )
172 .unwrap();
173 segment
174 });
175 let segment = RefCell::new(segment);
176 let previous_segment = RefCell::new(None);
177 let solver = RefCell::new(solver);
178 Checkpointing {
179 checkpoints,
180 segment,
181 previous_segment,
182 solver,
183 }
184 }
185
186 pub fn last_t(&self) -> Eqn::T {
187 self.segment
188 .borrow()
189 .last_t()
190 .expect("segment should not be empty")
191 }
192
193 pub fn last_h(&self) -> Option<Eqn::T> {
194 self.segment.borrow().last_h()
195 }
196
197 pub fn problem(&self) -> &'a OdeSolverProblem<Eqn> {
198 self.solver.borrow().problem()
199 }
200
201 pub fn interpolate(&self, t: Eqn::T, y: &mut Eqn::V) -> Result<(), DiffsolError> {
202 {
203 let segment = self.segment.borrow();
204 if segment.interpolate(t, y).is_some() {
205 return Ok(());
206 }
207 }
208
209 {
210 let previous_segment = self.previous_segment.borrow();
211 if let Some(previous_segment) = previous_segment.as_ref() {
212 if previous_segment.interpolate(t, y).is_some() {
213 return Ok(());
214 }
215 }
216 }
217
218 let h = self.last_h().unwrap_or(Eqn::T::one());
220 let troundoff = Eqn::T::from(100.0) * Eqn::T::EPSILON * (abs(t) + abs(h));
221 if t < self.checkpoints[0].as_ref().t - troundoff
222 || t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t + troundoff
223 {
224 return Err(other_error!("t is outside of the checkpoints"));
225 }
226
227 let t = if t < self.checkpoints[0].as_ref().t {
229 self.checkpoints[0].as_ref().t
230 } else if t > self.checkpoints[self.checkpoints.len() - 1].as_ref().t {
231 self.checkpoints[self.checkpoints.len() - 1].as_ref().t
232 } else {
233 t
234 };
235
236 let idx = self
238 .checkpoints
239 .iter()
240 .skip(1)
241 .position(|state| state.as_ref().t > t)
242 .expect("t is not in checkpoints");
243 if self.previous_segment.borrow().is_none() {
244 self.previous_segment
245 .replace(Some(HermiteInterpolator::default()));
246 }
247 let mut solver = self.solver.borrow_mut();
248 let mut previous_segment = self.previous_segment.borrow_mut();
249 let mut segment = self.segment.borrow_mut();
250 previous_segment.as_mut().unwrap().reset(
251 &mut *solver,
252 &self.checkpoints[idx],
253 &self.checkpoints[idx + 1],
254 )?;
255 std::mem::swap(&mut *segment, previous_segment.as_mut().unwrap());
256 segment.interpolate(t, y).unwrap();
257 Ok(())
258 }
259}
260
261#[cfg(test)]
262mod tests {
263
264 use crate::{
265 matrix::dense_nalgebra_serial::NalgebraMat,
266 ode_equations::test_models::robertson::robertson, Context, NalgebraLU, OdeEquations,
267 OdeSolverMethod, Op, Vector,
268 };
269
270 use super::{Checkpointing, HermiteInterpolator};
271
272 #[test]
273 fn test_checkpointing() {
274 type M = NalgebraMat<f64>;
275 type LS = NalgebraLU<f64>;
276 let (problem, soln) = robertson::<M>(false);
277 let t_final = soln.solution_points.last().unwrap().t;
278 let n_steps = 30;
279 let mut solver = problem.bdf::<LS>().unwrap();
280 let mut checkpoints = vec![solver.checkpoint()];
281 let mut i = 0;
282 let mut ys = Vec::new();
283 let mut ts = Vec::new();
284 let mut ydots = Vec::new();
285 while solver.state().t < t_final {
286 ts.push(solver.state().t);
287 ys.push(solver.state().y.clone());
288 ydots.push(solver.state().dy.clone());
289 solver.step().unwrap();
290 i += 1;
291 if i % n_steps == 0 && solver.state().t < t_final {
292 checkpoints.push(solver.checkpoint());
293 ts.clear();
294 ys.clear();
295 ydots.clear();
296 }
297 }
298 checkpoints.push(solver.checkpoint());
299 let segment = HermiteInterpolator::new(ys, ydots, ts);
300 let checkpointer =
301 Checkpointing::new(solver, checkpoints.len() - 2, checkpoints, Some(segment));
302 let mut y = problem.context().vector_zeros(problem.eqn.rhs().nstates());
303 for point in soln.solution_points.iter().rev() {
304 checkpointer.interpolate(point.t, &mut y).unwrap();
305 y.assert_eq_norm(&point.state, &problem.atol, problem.rtol, 10.0);
306 }
307 }
308}