dess_core/
solver.rs

1use crate::imports::*;
2
3#[common_derives]
4pub enum SolverTypes {
5    /// Euler with fixed time step.
6    /// parameter `dt` provides time step size for whenever solver is between
7    /// `t_report` times.  ≥
8    EulerFixed { dt: f64 },
9    /// Heun's Method. (basic Runge-Kutta 2nd order with fixed time step)
10    HeunsMethod { dt: f64 },
11    /// Midpoint Method. ( alternate Runge-Kutta 2nd order with fixed time step)
12    MidpointMethod { dt: f64 },
13    /// Ralston's Method. ( alternate Runge-Kutta 2nd order with fixed time step)
14    RalstonsMethod { dt: f64 },
15    /// Bogacki-Shampine Method. Runge-Kutte 2/3 order adaptive solver
16    RK23BogackiShampine(Box<AdaptiveSolverConfig>),
17    /// Runge-Kutta 4th order with fixed time step
18    /// parameter `dt` provides time step size for whenever solver is between
19    /// `t_report` times.  
20    RK4Fixed { dt: f64 },
21    // TODO: add this stuff back into fixed options
22    // /// time step to use if `t_report` is larger than `dt`
23    // dt: f64,
24    /// Runge-Kutta 4/5 order adaptive, Cash-Karp method
25    /// https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method
26    RK45CashKarp(Box<AdaptiveSolverConfig>),
27    // TODO: add more variants here
28}
29
30impl Default for SolverTypes {
31    fn default() -> Self {
32        SolverTypes::RK4Fixed { dt: 0.1 }
33    }
34}
35#[pyo3_api(
36    #[new]
37    fn new_py(
38        dt_init: f64,
39        dt_max: f64,
40        max_iter: u8,
41        rtol: f64,
42        atol: f64,
43        save: Option<bool>,
44        save_states: Option<bool>,
45    ) -> Self {
46        Self{
47            dt_max,
48            max_iter,
49            atol,
50            rtol,
51            save: save.unwrap_or(false),
52            save_states: save_states.unwrap_or(false),
53            state: SolverState {
54                dt: dt_init,
55                ..Default::default()
56            },
57            history: Default::default(),
58        }
59    }
60
61    #[pyo3(name = "dt_mean")]
62    fn dt_mean_py(&self) -> Option<f64> {
63        self.dt_mean()
64    }
65)]
66#[common_derives]
67pub struct AdaptiveSolverConfig {
68    /// max allowable dt
69    pub dt_max: f64,
70    /// max number of iterations per time step
71    pub max_iter: u8,
72    /// absolute euclidean error tolerance
73    pub atol: f64,
74    /// relative euclidean error tolerance
75    pub rtol: f64,
76    /// save iteration history
77    pub save: bool,
78    /// save states in iteration history
79    /// this is computationally expensive and should be generally `false`
80    pub save_states: bool,
81    /// solver state
82    pub state: SolverState,
83    /// history of solver state
84    pub history: SolverStateHistoryVec,
85}
86
87impl Default for AdaptiveSolverConfig {
88    fn default() -> Self {
89        Self {
90            dt_max: 10.,
91            max_iter: 5,
92            rtol: 1e-5,
93            atol: 1e-9,
94            save: false,
95            save_states: false,
96            state: SolverState {
97                dt: 0.1,
98                ..Default::default()
99            },
100            history: Default::default(),
101        }
102    }
103}
104
105impl AdaptiveSolverConfig {
106    pub fn dt_mean(&self) -> Option<f64> {
107        if !self.history.is_empty() {
108            Some(self.history.dt.iter().fold(0., |acc, &x| acc + x) / self.history.len() as f64)
109        } else {
110            None
111        }
112    }
113}
114
115impl AsMut<AdaptiveSolverConfig> for AdaptiveSolverConfig {
116    fn as_mut(&mut self) -> &mut AdaptiveSolverConfig {
117        self
118    }
119}
120
121#[common_derives]
122#[pyo3_api]
123#[derive(HistoryVec)]
124/// Solver is considered considered converged when any one of the following conditions are met:
125/// - `norm_err` is less than `atol`
126/// - `norm_err_rel` is less than `rtol`
127/// - `n_iter` >= `n_max_iter`
128pub struct SolverState {
129    /// time step size used by solver
130    pub dt: f64,
131    /// number of iterations to achieve tolerance
132    pub n_iter: u8,
133    /// Absolute error based on difference in L2 (euclidean) norm
134    pub norm_err: Option<f64>,
135    /// Relative error based on difference in L2 (euclidean) norm
136    pub norm_err_rel: Option<f64>,
137    /// current system time used in solver
138    pub t_curr: f64,
139    /// current values of states
140    pub states: Vec<f64>,
141}
142
143impl Default for SolverState {
144    fn default() -> Self {
145        Self {
146            dt: 0.1,
147            n_iter: 0,
148            norm_err: None,
149            norm_err_rel: None,
150            t_curr: 0.,
151            states: Default::default(),
152        }
153    }
154}
155
156pub trait SolverBase: HasStates + Sized {
157    /// reset all time derivatives to zero for start of `solve_step`
158    fn reset_derivs(&mut self);
159    /// Updates time derivatives of states.
160    /// This method must be user defined.
161    fn update_derivs(&mut self);
162    /// steps dt without affecting states
163    fn step_time(&mut self, dt: &f64);
164    /// Returns `solver_conf`, if applicable
165    fn sc(&self) -> Option<&AdaptiveSolverConfig>;
166    /// Returns mut `solver_conf`, if applicable
167    fn sc_mut(&mut self) -> Option<&mut AdaptiveSolverConfig>;
168    /// Returns [Self::state]
169    fn state(&self) -> &crate::SystemState;
170}
171
172pub trait SolverVariantMethods: SolverBase {
173    /// Steps forward by `dt`
174    fn euler(&mut self, dt: &f64) {
175        self.update_derivs();
176        self.step_states_by_dt(dt);
177        self.update_derivs();
178    }
179    /// Heun's Method (starts out with Euler's method but adds an extra step)
180    /// See Heun's Method (the first listed Heun's method, not the one also known as Ralston's Method):
181    /// https://en.wikipedia.org/wiki/Heun%27s_method
182    fn heun(&mut self, dt: &f64) {
183        self.update_derivs();
184        //making copy without history, to avoid stepping dt twice
185        let mut updated_self = self.bare_clone();
186        //recording initial derivative value for later use
187        let deriv_0: Vec<f64> = self.derivs();
188        //this will give euler's formula result
189        self.step_states_by_dt(dt);
190        self.update_derivs();
191        //recording derivative at endpoint of euler's method line
192        let deriv_1: Vec<f64> = self.derivs();
193        //creating new vector that is average of deriv_1 and deriv_2
194        let deriv_mean: Vec<f64> = deriv_0
195            .iter()
196            .zip(&deriv_1)
197            .map(|(d_1, d_2)| d_1 * 0.5 + d_2 * 0.5)
198            .collect::<Vec<f64>>();
199        //updates derivative in updated_self to be the average of deriv_0 and deriv_1
200        updated_self.set_derivs(&deriv_mean);
201        //steps states using the average derivative
202        updated_self.step_states_by_dt(dt);
203        //saving updated state
204        let new_state = updated_self.states();
205        //setting state to be the updated state
206        self.set_states(new_state);
207        self.update_derivs();
208    }
209    /// Midpoint Method
210    /// See: https://en.wikipedia.org/wiki/Midpoint_method
211    fn midpoint(&mut self, dt: &f64) {
212        self.update_derivs();
213        //making copy without history, to avoid stepping dt twice
214        let mut updated_self = self.bare_clone();
215        //updating time and state to midpoint of line
216        updated_self.step_states_by_dt(&(0.5 * dt));
217        updated_self.update_derivs();
218        //recording derivative at midpoint
219        let deriv_1: Vec<f64> = updated_self.derivs();
220        //updates derivative in self to be deriv_1
221        self.set_derivs(&deriv_1);
222        //steps states using the midpoint derivative
223        self.step_states_by_dt(dt);
224        self.update_derivs();
225    }
226    /// Ralston's Method
227    /// See Ralston's Method: https://en.wikipedia.org/wiki/List_of_Runge%E2%80%93Kutta_methods#Ralston.27s_method
228    fn ralston(&mut self, dt: &f64) {
229        self.update_derivs();
230        //making copy without history, to avoid stepping dt twice
231        let mut updated_self = self.bare_clone();
232        //recording initial derivative for later
233        let deriv_0: Vec<f64> = updated_self.derivs();
234        //updating time and state to 2/3 way through line
235        updated_self.step_states_by_dt(&(2.0 * dt / 3.0));
236        updated_self.update_derivs();
237        //recording derivative at 2/3 way through line
238        let deriv_1: Vec<f64> = updated_self.derivs();
239        //creating new vector that is weighted average of deriv_0 and deriv_1
240        let deriv_mean: Vec<f64> = deriv_0
241            .iter()
242            .zip(&deriv_1)
243            .map(|(d_1, d_2)| d_1 / 4.0 + 3.0 * d_2 / 4.0)
244            .collect::<Vec<f64>>();
245        //updates derivative in self to be deriv_mean
246        self.set_derivs(&deriv_mean);
247        //steps states using deriv_mean
248        self.step_states_by_dt(dt);
249        self.update_derivs();
250    }
251    ///solves time step with adaptive Bogacki Shampine Method (variant of RK23) and returns 'dt' used
252    ///see: https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method
253    fn rk23_bogacki_shampine(&mut self, dt_max: &f64) -> f64 {
254        let sc_mut = self.sc_mut().unwrap();
255        // reset iteration counter
256        sc_mut.state.n_iter = 0;
257        sc_mut.state.dt = sc_mut.state.dt.min(*dt_max).min(sc_mut.dt_max);
258
259        // loop to find `dt` that results in meeting tolerance
260        // and does not exceed `dt_max`
261        let (delta3, dt_used) = loop {
262            let sc = self.sc().unwrap();
263            let dt = sc.state.dt;
264
265            // run a single step at `dt`
266            let (delta2, delta3) = self.rk23_bogacki_shampine_step(dt);
267
268            // reborrow because of the borrow above in `self.rk23_bogacki_shampine_step(dt);`
269            let sc = self.sc().unwrap();
270            // grab states for later use if solver steps are to be saved
271            let states = if sc.save {
272                self.states()
273                    .clone()
274                    .iter()
275                    .zip(delta3.clone())
276                    .map(|(s, d)| s + d)
277                    .collect::<Vec<f64>>()
278            } else {
279                vec![]
280            };
281
282            let t_curr = self.state().time;
283
284            // mutably borrow sc to update it
285            let sc_mut = self.sc_mut().unwrap();
286
287            // update `n_iter`, `norm_err`, `norm_err_rel`, `t_curr`, and `states`
288            // still need to update dt at some point
289            sc_mut.state.n_iter += 1;
290            // different way of calculating norm -- could add in via an enum later
291            // let mut length = 0.;
292            // for _item in &delta2 {
293            //     length += 1.;
294            // }
295            // sc_mut.state.norm_err = Some(
296            //     delta2
297            //         .iter()
298            //         .zip(&delta3)
299            //         .map(|(d2, d3)| (((d2 - d3).powi(2)).sqrt()))
300            //         .collect::<Vec<f64>>()
301            //         .iter()
302            //         .sum::<f64>()
303            //         / length,
304            // );
305            // let norm_d3 = delta3
306            //     .iter()
307            //     .map(|d3| (d3.powi(2)).sqrt())
308            //     .collect::<Vec<f64>>()
309            //     .iter()
310            //     .sum::<f64>()
311            //     / length;
312            sc_mut.state.norm_err = Some(
313                delta2
314                    .iter()
315                    .zip(&delta3)
316                    .map(|(d2, d3)| (d2 - d3).powi(2))
317                    .collect::<Vec<f64>>()
318                    .iter()
319                    .sum::<f64>()
320                    .sqrt(),
321            );
322            let norm_d3 = delta3
323                .iter()
324                .map(|d3| d3.powi(2))
325                .collect::<Vec<f64>>()
326                .iter()
327                .sum::<f64>()
328                .sqrt();
329            //making sure that rtol is always considered as long as you don't divide by 0
330            sc_mut.state.norm_err_rel = if norm_d3 != 0. {
331                // `unwrap` is ok here because `norm_err` will always be some by this point
332                Some(sc_mut.state.norm_err.unwrap() / norm_d3)
333            } else {
334                // avoid dividing by 0
335                None
336            };
337
338            sc_mut.state.t_curr = t_curr;
339
340            if sc_mut.save_states {
341                sc_mut.state.states = states;
342            }
343
344            // conditions for breaking loop
345            // if there is a relative error, use that
346            // otherwise, use the absolute error
347            let tol_met = match sc_mut.state.norm_err_rel {
348                Some(norm_err_rel) => norm_err_rel <= sc_mut.rtol,
349                None => match sc_mut.state.norm_err {
350                    Some(norm_err) => norm_err <= sc_mut.atol,
351                    None => unreachable!(),
352                },
353            };
354
355            // Because we need to be able to possibly expand the next time step,
356            // regardless of whether break condition is met,
357            // adapt dt based on `rtol` if it is Some; use `atol` otherwise
358            // this adaptation strategy came directly from Chapra and Canale's section on adapting the time step
359            // The approach is to adapt more aggressively to meet rtol when decreasing the time step size
360            // than when increasing time step size.
361            let dt_coeff = match sc_mut.state.norm_err_rel {
362                Some(norm_err_rel) => match sc_mut.state.norm_err {
363                    //ensures that if either rtol or atol are met, then the step succeeds
364                    //prioritizes rtol -- if both are met, then rtol is used
365                    //if no atol exists, just considers rtol
366                    Some(norm_err) => {
367                        if norm_err_rel <= sc_mut.rtol {
368                            (sc_mut.rtol / norm_err_rel).powf(0.2)
369                        } else if norm_err <= sc_mut.atol {
370                            (sc_mut.atol / norm_err).powf(0.2)
371                        } else {
372                            0.25
373                        }
374                    }
375                    // (sc_mut.rtol / norm_err_rel).powf(
376                    //     if norm_err_rel <= sc_mut.rtol || norm_err <= sc_mut.atol {
377                    //         0.2
378                    //     } else {
379                    //         0.25
380                    //     },
381                    // ),
382                    None => (sc_mut.rtol / norm_err_rel).powf(if norm_err_rel <= sc_mut.rtol {
383                        0.2
384                    } else {
385                        0.25
386                    }),
387                },
388                //if no rtol exists, just consideres atol
389                None => {
390                    match sc_mut.state.norm_err {
391                        Some(norm_err) => (sc_mut.atol / norm_err)
392                            .powf(if norm_err <= sc_mut.atol { 0.2 } else { 0.25 }),
393                        None => 1., // don't adapt if there is not enough information to do so (if neither atol or rtol exist)
394                    }
395                }
396            };
397            // if tolerance is achieved here, then we proceed to the next time step, and
398            // `dt` will be limited to `dt_max` at the start of the next time step.  If tolerance
399            // is not achieved, then time step will be decreased.
400            let break_cond = sc_mut.state.n_iter >= sc_mut.max_iter
401                || sc_mut.state.norm_err.unwrap() < sc_mut.atol
402                || tol_met;
403
404            if break_cond {
405                // save before modifying dt
406                if sc_mut.save {
407                    sc_mut.history.push(sc_mut.state.clone());
408                }
409                // store used dt before adapting
410                let dt_used = sc_mut.state.dt;
411                // adapt for next solver time step
412                sc_mut.state.dt *= dt_coeff;
413                break (delta3, dt_used);
414            };
415            // adapt for next iteration in current time step
416            sc_mut.state.dt *= dt_coeff;
417        };
418
419        // increment forward with 3rd order solution
420        self.step_states(delta3);
421        self.step_time(&dt_used);
422        self.update_derivs();
423        // dbg!(self.state.time);
424        // dbg!(self.t_report[self.state.i]);
425        dt_used
426    }
427    fn rk23_bogacki_shampine_step(&mut self, dt: f64) -> (Vec<f64>, Vec<f64>) {
428        self.update_derivs();
429
430        // k1 = f(t_i, x_i)
431        let k1s = self.derivs();
432
433        // k2 = f(t_i + 1 / 2 * h, x_i + 1 / 2 * k1 * h)
434        let mut sys1 = self.bare_clone();
435        sys1.step_states_by_dt(&(dt / 2.));
436        sys1.update_derivs();
437        let k2s = sys1.derivs();
438        // k3 = f(t_i + 3 / 4 * h, x_i + 3 / 4 * k2 * h)
439        let mut sys2 = self.bare_clone();
440        sys2.set_derivs(&k2s);
441        sys2.step_states_by_dt(&(dt * 3. / 4.));
442        sys2.update_derivs();
443        let k3s = sys2.derivs();
444        // k4 = f(x_i + h, y_i + 2 / 9 * k1 * h + 1 / 3 * k2 * h + 4 / 9 * k3 * h) = 3rd order solution
445        let mut sys3 = self.bare_clone();
446        sys3.step_time(&(dt));
447        // 3nd order delta
448        let delta3: Vec<f64> = {
449            let (k1s, k2s, k3s) = (k1s.clone(), k2s.clone(), k3s.clone());
450            let zipped = zip!(k1s, k2s, k3s);
451            let mut steps = vec![];
452            for (k1, (k2, k3)) in zipped {
453                steps.push((2. / 9. * k1 + 1. / 3. * k2 + 4. / 9. * k3) * dt);
454            }
455            steps
456        };
457        let delta3_new = delta3.clone();
458        sys3.step_states(delta3_new);
459        sys3.update_derivs();
460        let k4s = sys3.derivs();
461        // 2nd order delta
462        let mut delta2: Vec<f64> = vec![];
463        let zipped = zip!(k1s, k2s, k3s, k4s);
464        for (k1, (k2, (k3, k4))) in zipped {
465            delta2.push((7. / 24. * k1 + 1. / 4. * k2 + 1. / 3. * k3 + 1. / 8. * k4) * dt);
466        }
467        (delta2, delta3)
468    }
469    /// solves time step with 4th order Runge-Kutta method.
470    /// See RK4 method: https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Examples
471    fn rk4fixed(&mut self, dt: &f64) {
472        self.update_derivs();
473
474        // k1 = f(x_i, y_i)
475        let k1s = self.derivs();
476
477        // k2 = f(x_i + 1 / 2 * h, y_i + 1 / 2 * k1 * h)
478        let mut sys1 = self.bare_clone();
479        sys1.step_states_by_dt(&(dt / 2.));
480        sys1.update_derivs();
481        let k2s = sys1.derivs();
482
483        // k3 = f(x_i + 1 / 2 * h, y_i + 1 / 2 * k2 * h)
484        let mut sys2 = self.bare_clone();
485        sys2.set_derivs(&k2s);
486        sys2.step_states_by_dt(&(dt / 2.));
487        sys2.update_derivs();
488        let k3s = sys2.derivs();
489
490        // k4 = f(x_i + h, y_i + k3 * h)
491        let mut sys3 = self.bare_clone();
492        sys3.set_derivs(&k3s);
493        sys3.step_states_by_dt(dt);
494        sys3.update_derivs();
495        let k4s = sys3.derivs();
496
497        let mut delta: Vec<f64> = vec![];
498        let zipped = zip!(k1s, k2s, k3s, k4s);
499        for (k1, (k2, (k3, k4))) in zipped {
500            delta.push(1. / 6. * (k1 + 2. * k2 + 2. * k3 + k4) * dt);
501        }
502
503        self.step_states(delta);
504        self.step_time(dt);
505        self.update_derivs();
506    }
507    /// solves time step with adaptive Cash-Karp Method (variant of RK45) and returns `dt` used
508    /// https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method
509    fn rk45_cash_karp(&mut self, dt_max: &f64) -> f64 {
510        let sc_mut = self.sc_mut().unwrap();
511        // reset iteration counter
512        sc_mut.state.n_iter = 0;
513        sc_mut.state.dt = sc_mut.state.dt.min(*dt_max).min(sc_mut.dt_max);
514
515        // loop to find `dt` that results in meeting tolerance
516        // and does not exceed `dt_max`
517        let (delta5, dt_used) = loop {
518            let sc = self.sc().unwrap();
519            let dt = sc.state.dt;
520
521            // run a single step at `dt`
522            let (delta4, delta5) = self.rk45_cash_karp_step(dt);
523
524            // reborrow because of the borrow above in `self.rk45_cash_karp_step(dt);`
525            let sc = self.sc().unwrap();
526            // grab states for later use if solver steps are to be saved
527            let states = if sc.save {
528                self.states()
529                    .clone()
530                    .iter()
531                    .zip(delta5.clone())
532                    .map(|(s, d)| s + d)
533                    .collect::<Vec<f64>>()
534            } else {
535                vec![]
536            };
537
538            let t_curr = self.state().time;
539
540            // mutably borrow sc to update it
541            let sc_mut = self.sc_mut().unwrap();
542
543            // update `n_iter`, `norm_err`, `norm_err_rel`, `t_curr`, and `states`
544            // still need to update dt at some point
545            sc_mut.state.n_iter += 1;
546            //another way to calculate norm -- can be added in later via an enum
547            // let mut length = 0.;
548            // for _item in &delta4 {
549            //     length += 1.;
550            // }
551            // sc_mut.state.norm_err = Some(
552            //     delta4
553            //         .iter()
554            //         .zip(&delta5)
555            //         .map(|(d4, d5)| (((d4 - d5).powi(2)).sqrt()))
556            //         .collect::<Vec<f64>>()
557            //         .iter()
558            //         .sum::<f64>()
559            //         / length,
560            // );
561            // let norm_d5 = delta5
562            //     .iter()
563            //     .map(|d5| (d5.powi(2)).sqrt())
564            //     .collect::<Vec<f64>>()
565            //     .iter()
566            //     .sum::<f64>()
567            //     / length;
568            sc_mut.state.norm_err = Some(
569                delta4
570                    .iter()
571                    .zip(&delta5)
572                    .map(|(d4, d5)| (d4 - d5).powi(2))
573                    .collect::<Vec<f64>>()
574                    .iter()
575                    .sum::<f64>()
576                    .sqrt(),
577            );
578            let norm_d5 = delta5
579                .iter()
580                .map(|d5| d5.powi(2))
581                .collect::<Vec<f64>>()
582                .iter()
583                .sum::<f64>()
584                .sqrt();
585            //ensures that rtol is calculated and considered as long as you are not dividing by 0
586            sc_mut.state.norm_err_rel = if norm_d5 != 0. {
587                // `unwrap` is ok here because `norm_err` will always be some by this point
588                Some(sc_mut.state.norm_err.unwrap() / norm_d5)
589            } else {
590                // avoid dividing by 0
591                None
592            };
593
594            sc_mut.state.t_curr = t_curr;
595
596            if sc_mut.save_states {
597                sc_mut.state.states = states;
598            }
599
600            // conditions for breaking loop
601            // if there is a relative error, use that
602            // otherwise, use the absolute error
603            let tol_met = match sc_mut.state.norm_err_rel {
604                Some(norm_err_rel) => norm_err_rel <= sc_mut.rtol,
605                None => match sc_mut.state.norm_err {
606                    Some(norm_err) => norm_err <= sc_mut.atol,
607                    None => unreachable!(),
608                },
609            };
610
611            // Because we need to be able to possibly expand the next time step,
612            // regardless of whether break condition is met,
613            // adapt dt based on `rtol` if it is Some; use `atol` otherwise
614            // this adaptation strategy came directly from Chapra and Canale's section on adapting the time step
615            // The approach is to adapt more aggressively to meet rtol when decreasing the time step size
616            // than when increasing time step size.
617            let dt_coeff = match sc_mut.state.norm_err_rel {
618                Some(norm_err_rel) => {
619                    //ensures that if either rtol or atol are met, then the step succeeds
620                    //prioritizes rtol -- if both atol and rtol are met, rtol is used
621                    if norm_err_rel <= sc_mut.rtol {
622                        (sc_mut.rtol / norm_err_rel).powf(0.2)
623                    } else if sc_mut.state.norm_err.unwrap() <= sc_mut.atol {
624                        (sc_mut.atol / sc_mut.state.norm_err.unwrap()).powf(0.2)
625                    } else {
626                        0.25
627                    }
628                    // (sc_mut.rtol / norm_err_rel).powf(
629                    //     if norm_err_rel <= sc_mut.rtol || norm_err <= sc_mut.atol {
630                    //         0.2
631                    //     } else {
632                    //         0.25
633                    //     },
634                    // ),
635                }
636                //if rtol doesn't exist just use atol
637                None => {
638                    match sc_mut.state.norm_err {
639                        Some(norm_err) => (sc_mut.atol / norm_err)
640                            .powf(if norm_err <= sc_mut.atol { 0.2 } else { 0.25 }),
641                        None => 1., // don't adapt if there is not enough information to do so
642                    }
643                }
644            };
645
646            // if tolerance is achieved here, then we proceed to the next time step, and
647            // `dt` will be limited to `dt_max` at the start of the next time step.  If tolerance
648            // is not achieved, then time step will be decreased.
649            let break_cond = sc_mut.state.n_iter >= sc_mut.max_iter
650                || sc_mut.state.norm_err.unwrap() < sc_mut.atol
651                || tol_met;
652
653            if break_cond {
654                // save before modifying dt
655                if sc_mut.save {
656                    sc_mut.history.push(sc_mut.state.clone());
657                }
658                // store used dt before adapting
659                let dt_used = sc_mut.state.dt;
660                // adapt for next solver time step
661                sc_mut.state.dt *= dt_coeff;
662                break (delta5, dt_used);
663            };
664            // adapt for next iteration in current time step
665            sc_mut.state.dt *= dt_coeff;
666        };
667
668        // increment forward with 5th order solution
669        self.step_states(delta5);
670        self.step_time(&dt_used);
671        self.update_derivs();
672        // dbg!(self.state.time);
673        // dbg!(self.t_report[self.state.i]);
674        dt_used
675    }
676
677    fn rk45_cash_karp_step(&mut self, dt: f64) -> (Vec<f64>, Vec<f64>) {
678        self.update_derivs();
679
680        // k1 = f(x_i, y_i)
681        let k1s = self.derivs();
682
683        // k2 = f(x_i + 1 / 5 * h, y_i + 1 / 5 * k1 * h)
684        let mut sys1 = self.bare_clone();
685        sys1.step_states_by_dt(&(dt / 5.));
686        sys1.update_derivs();
687        let k2s = sys1.derivs();
688
689        // k3 = f(x_i + 3 / 10 * h, y_i + 3 / 40 * k1 * h + 9 / 40 * k2 * h)
690        let mut sys2 = self.bare_clone();
691        sys2.step_time(&(dt * 3. / 10.));
692        sys2.step_states(
693            k1s.iter()
694                .zip(k2s.clone())
695                .map(|(k1, k2)| (3. / 40. * k1 + 9. / 40. * k2) * dt)
696                .collect(),
697        );
698        sys2.update_derivs();
699        let k3s = sys2.derivs();
700
701        // k4 = f(x_i + 3 / 5 * h, y_i + 3 / 10 * k1 * h - 9 / 10 * k2 * h + 6 / 5 * k3 * h)
702        let mut sys3 = self.bare_clone();
703        sys3.step_time(&(dt * 3. / 5.));
704        sys3.step_states({
705            let (k1s, k2s, k3s) = (k1s.clone(), k2s.clone(), k3s.clone());
706            let zipped = zip!(k1s, k2s, k3s);
707            let mut steps = vec![];
708            for (k1, (k2, k3)) in zipped {
709                steps.push((3. / 10. * k1 - 9. / 10. * k2 + 6. / 5. * k3) * dt);
710            }
711            steps
712        });
713        sys3.update_derivs();
714        let k4s = sys3.derivs();
715
716        // k5 = f(x_i + h, y_i - 11 / 54 * k1 * h + 5 / 2 * k2 * h - 70 / 27 * k3 * h + 35 / 27 * k4 * h)
717        let mut sys4 = self.bare_clone();
718        sys4.step_time(&dt);
719        sys4.step_states({
720            let (k1s, k2s, k3s, k4s) = (k1s.clone(), k2s.clone(), k3s.clone(), k4s.clone());
721            let zipped = zip!(k1s, k2s, k3s, k4s);
722            let mut steps = vec![];
723            for (k1, (k2, (k3, k4))) in zipped {
724                steps.push((-11. / 54. * k1 + 5. / 2. * k2 - 70. / 27. * k3 + 35. / 27. * k4) * dt);
725            }
726            steps
727        });
728        sys4.update_derivs();
729        let k5s = sys4.derivs();
730
731        // k6 = f(x_i + 7 / 8 * h, y_i + 1631 / 55296 * k1 * h + 175 / 512 * k2 * h + 575 / 13824 * k3 * h + 44275 / 110592 * k4 * h + 253 / 4096 * k5 * h)
732        let mut sys5 = self.bare_clone();
733        sys5.step_time(&(dt * 7. / 8.));
734        sys5.step_states({
735            let (k1s, k2s, k3s, k4s, k5s) = (
736                k1s.clone(),
737                k2s.clone(),
738                k3s.clone(),
739                k4s.clone(),
740                k5s.clone(),
741            );
742            let zipped = zip!(k1s, k2s, k3s, k4s, k5s);
743            let mut steps = vec![];
744            for (k1, (k2, (k3, (k4, k5)))) in zipped {
745                steps.push(
746                    (1_631. / 55_296. * k1
747                        + 175. / 512. * k2
748                        + 575. / 13_824. * k3
749                        + 44_275. / 110_592. * k4
750                        + 253. / 4096. * k5)
751                        * dt,
752                );
753            }
754            steps
755        });
756        sys5.update_derivs();
757        let k6s = sys5.derivs();
758
759        // 4th order delta
760        let mut delta4: Vec<f64> = vec![];
761        // 5th order delta
762        let mut delta5: Vec<f64> = vec![];
763        let zipped = zip!(k1s, k2s, k3s, k4s, k5s, k6s);
764        for (k1, (_k2, (k3, (k4, (k5, k6))))) in zipped {
765            delta5.push(
766                (37. / 378. * k1 + 250. / 621. * k3 + 125. / 594. * k4 + 512. / 1_771. * k6) * dt,
767            );
768            delta4.push(
769                (2825. / 27_648. * k1
770                    + 18_575. / 48_384. * k3
771                    + 13_525. / 55_296. * k4
772                    + 277. / 14_336. * k5
773                    + 1. / 4. * k6)
774                    * dt,
775            );
776        }
777        (delta4, delta5)
778    }
779}