dess_core/solver.rs
1use crate::imports::*;
2
3#[common_derives]
4pub enum SolverTypes {
5 /// Euler with fixed time step.
6 /// parameter `dt` provides time step size for whenever solver is between
7 /// `t_report` times. ≥
8 EulerFixed { dt: f64 },
9 /// Heun's Method. (basic Runge-Kutta 2nd order with fixed time step)
10 HeunsMethod { dt: f64 },
11 /// Midpoint Method. ( alternate Runge-Kutta 2nd order with fixed time step)
12 MidpointMethod { dt: f64 },
13 /// Ralston's Method. ( alternate Runge-Kutta 2nd order with fixed time step)
14 RalstonsMethod { dt: f64 },
15 /// Bogacki-Shampine Method. Runge-Kutte 2/3 order adaptive solver
16 RK23BogackiShampine(Box<AdaptiveSolverConfig>),
17 /// Runge-Kutta 4th order with fixed time step
18 /// parameter `dt` provides time step size for whenever solver is between
19 /// `t_report` times.
20 RK4Fixed { dt: f64 },
21 // TODO: add this stuff back into fixed options
22 // /// time step to use if `t_report` is larger than `dt`
23 // dt: f64,
24 /// Runge-Kutta 4/5 order adaptive, Cash-Karp method
25 /// https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method
26 RK45CashKarp(Box<AdaptiveSolverConfig>),
27 // TODO: add more variants here
28}
29
30impl Default for SolverTypes {
31 fn default() -> Self {
32 SolverTypes::RK4Fixed { dt: 0.1 }
33 }
34}
35#[pyo3_api(
36 #[new]
37 fn new_py(
38 dt_init: f64,
39 dt_max: f64,
40 max_iter: u8,
41 rtol: f64,
42 atol: f64,
43 save: Option<bool>,
44 save_states: Option<bool>,
45 ) -> Self {
46 Self{
47 dt_max,
48 max_iter,
49 atol,
50 rtol,
51 save: save.unwrap_or(false),
52 save_states: save_states.unwrap_or(false),
53 state: SolverState {
54 dt: dt_init,
55 ..Default::default()
56 },
57 history: Default::default(),
58 }
59 }
60
61 #[pyo3(name = "dt_mean")]
62 fn dt_mean_py(&self) -> Option<f64> {
63 self.dt_mean()
64 }
65)]
66#[common_derives]
67pub struct AdaptiveSolverConfig {
68 /// max allowable dt
69 pub dt_max: f64,
70 /// max number of iterations per time step
71 pub max_iter: u8,
72 /// absolute euclidean error tolerance
73 pub atol: f64,
74 /// relative euclidean error tolerance
75 pub rtol: f64,
76 /// save iteration history
77 pub save: bool,
78 /// save states in iteration history
79 /// this is computationally expensive and should be generally `false`
80 pub save_states: bool,
81 /// solver state
82 pub state: SolverState,
83 /// history of solver state
84 pub history: SolverStateHistoryVec,
85}
86
87impl Default for AdaptiveSolverConfig {
88 fn default() -> Self {
89 Self {
90 dt_max: 10.,
91 max_iter: 5,
92 rtol: 1e-5,
93 atol: 1e-9,
94 save: false,
95 save_states: false,
96 state: SolverState {
97 dt: 0.1,
98 ..Default::default()
99 },
100 history: Default::default(),
101 }
102 }
103}
104
105impl AdaptiveSolverConfig {
106 pub fn dt_mean(&self) -> Option<f64> {
107 if !self.history.is_empty() {
108 Some(self.history.dt.iter().fold(0., |acc, &x| acc + x) / self.history.len() as f64)
109 } else {
110 None
111 }
112 }
113}
114
115impl AsMut<AdaptiveSolverConfig> for AdaptiveSolverConfig {
116 fn as_mut(&mut self) -> &mut AdaptiveSolverConfig {
117 self
118 }
119}
120
121#[common_derives]
122#[pyo3_api]
123#[derive(HistoryVec)]
124/// Solver is considered considered converged when any one of the following conditions are met:
125/// - `norm_err` is less than `atol`
126/// - `norm_err_rel` is less than `rtol`
127/// - `n_iter` >= `n_max_iter`
128pub struct SolverState {
129 /// time step size used by solver
130 pub dt: f64,
131 /// number of iterations to achieve tolerance
132 pub n_iter: u8,
133 /// Absolute error based on difference in L2 (euclidean) norm
134 pub norm_err: Option<f64>,
135 /// Relative error based on difference in L2 (euclidean) norm
136 pub norm_err_rel: Option<f64>,
137 /// current system time used in solver
138 pub t_curr: f64,
139 /// current values of states
140 pub states: Vec<f64>,
141}
142
143impl Default for SolverState {
144 fn default() -> Self {
145 Self {
146 dt: 0.1,
147 n_iter: 0,
148 norm_err: None,
149 norm_err_rel: None,
150 t_curr: 0.,
151 states: Default::default(),
152 }
153 }
154}
155
156pub trait SolverBase: HasStates + Sized {
157 /// reset all time derivatives to zero for start of `solve_step`
158 fn reset_derivs(&mut self);
159 /// Updates time derivatives of states.
160 /// This method must be user defined.
161 fn update_derivs(&mut self);
162 /// steps dt without affecting states
163 fn step_time(&mut self, dt: &f64);
164 /// Returns `solver_conf`, if applicable
165 fn sc(&self) -> Option<&AdaptiveSolverConfig>;
166 /// Returns mut `solver_conf`, if applicable
167 fn sc_mut(&mut self) -> Option<&mut AdaptiveSolverConfig>;
168 /// Returns [Self::state]
169 fn state(&self) -> &crate::SystemState;
170}
171
172pub trait SolverVariantMethods: SolverBase {
173 /// Steps forward by `dt`
174 fn euler(&mut self, dt: &f64) {
175 self.update_derivs();
176 self.step_states_by_dt(dt);
177 self.update_derivs();
178 }
179 /// Heun's Method (starts out with Euler's method but adds an extra step)
180 /// See Heun's Method (the first listed Heun's method, not the one also known as Ralston's Method):
181 /// https://en.wikipedia.org/wiki/Heun%27s_method
182 fn heun(&mut self, dt: &f64) {
183 self.update_derivs();
184 //making copy without history, to avoid stepping dt twice
185 let mut updated_self = self.bare_clone();
186 //recording initial derivative value for later use
187 let deriv_0: Vec<f64> = self.derivs();
188 //this will give euler's formula result
189 self.step_states_by_dt(dt);
190 self.update_derivs();
191 //recording derivative at endpoint of euler's method line
192 let deriv_1: Vec<f64> = self.derivs();
193 //creating new vector that is average of deriv_1 and deriv_2
194 let deriv_mean: Vec<f64> = deriv_0
195 .iter()
196 .zip(&deriv_1)
197 .map(|(d_1, d_2)| d_1 * 0.5 + d_2 * 0.5)
198 .collect::<Vec<f64>>();
199 //updates derivative in updated_self to be the average of deriv_0 and deriv_1
200 updated_self.set_derivs(&deriv_mean);
201 //steps states using the average derivative
202 updated_self.step_states_by_dt(dt);
203 //saving updated state
204 let new_state = updated_self.states();
205 //setting state to be the updated state
206 self.set_states(new_state);
207 self.update_derivs();
208 }
209 /// Midpoint Method
210 /// See: https://en.wikipedia.org/wiki/Midpoint_method
211 fn midpoint(&mut self, dt: &f64) {
212 self.update_derivs();
213 //making copy without history, to avoid stepping dt twice
214 let mut updated_self = self.bare_clone();
215 //updating time and state to midpoint of line
216 updated_self.step_states_by_dt(&(0.5 * dt));
217 updated_self.update_derivs();
218 //recording derivative at midpoint
219 let deriv_1: Vec<f64> = updated_self.derivs();
220 //updates derivative in self to be deriv_1
221 self.set_derivs(&deriv_1);
222 //steps states using the midpoint derivative
223 self.step_states_by_dt(dt);
224 self.update_derivs();
225 }
226 /// Ralston's Method
227 /// See Ralston's Method: https://en.wikipedia.org/wiki/List_of_Runge%E2%80%93Kutta_methods#Ralston.27s_method
228 fn ralston(&mut self, dt: &f64) {
229 self.update_derivs();
230 //making copy without history, to avoid stepping dt twice
231 let mut updated_self = self.bare_clone();
232 //recording initial derivative for later
233 let deriv_0: Vec<f64> = updated_self.derivs();
234 //updating time and state to 2/3 way through line
235 updated_self.step_states_by_dt(&(2.0 * dt / 3.0));
236 updated_self.update_derivs();
237 //recording derivative at 2/3 way through line
238 let deriv_1: Vec<f64> = updated_self.derivs();
239 //creating new vector that is weighted average of deriv_0 and deriv_1
240 let deriv_mean: Vec<f64> = deriv_0
241 .iter()
242 .zip(&deriv_1)
243 .map(|(d_1, d_2)| d_1 / 4.0 + 3.0 * d_2 / 4.0)
244 .collect::<Vec<f64>>();
245 //updates derivative in self to be deriv_mean
246 self.set_derivs(&deriv_mean);
247 //steps states using deriv_mean
248 self.step_states_by_dt(dt);
249 self.update_derivs();
250 }
251 ///solves time step with adaptive Bogacki Shampine Method (variant of RK23) and returns 'dt' used
252 ///see: https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method
253 fn rk23_bogacki_shampine(&mut self, dt_max: &f64) -> f64 {
254 let sc_mut = self.sc_mut().unwrap();
255 // reset iteration counter
256 sc_mut.state.n_iter = 0;
257 sc_mut.state.dt = sc_mut.state.dt.min(*dt_max).min(sc_mut.dt_max);
258
259 // loop to find `dt` that results in meeting tolerance
260 // and does not exceed `dt_max`
261 let (delta3, dt_used) = loop {
262 let sc = self.sc().unwrap();
263 let dt = sc.state.dt;
264
265 // run a single step at `dt`
266 let (delta2, delta3) = self.rk23_bogacki_shampine_step(dt);
267
268 // reborrow because of the borrow above in `self.rk23_bogacki_shampine_step(dt);`
269 let sc = self.sc().unwrap();
270 // grab states for later use if solver steps are to be saved
271 let states = if sc.save {
272 self.states()
273 .clone()
274 .iter()
275 .zip(delta3.clone())
276 .map(|(s, d)| s + d)
277 .collect::<Vec<f64>>()
278 } else {
279 vec![]
280 };
281
282 let t_curr = self.state().time;
283
284 // mutably borrow sc to update it
285 let sc_mut = self.sc_mut().unwrap();
286
287 // update `n_iter`, `norm_err`, `norm_err_rel`, `t_curr`, and `states`
288 // still need to update dt at some point
289 sc_mut.state.n_iter += 1;
290 // different way of calculating norm -- could add in via an enum later
291 // let mut length = 0.;
292 // for _item in &delta2 {
293 // length += 1.;
294 // }
295 // sc_mut.state.norm_err = Some(
296 // delta2
297 // .iter()
298 // .zip(&delta3)
299 // .map(|(d2, d3)| (((d2 - d3).powi(2)).sqrt()))
300 // .collect::<Vec<f64>>()
301 // .iter()
302 // .sum::<f64>()
303 // / length,
304 // );
305 // let norm_d3 = delta3
306 // .iter()
307 // .map(|d3| (d3.powi(2)).sqrt())
308 // .collect::<Vec<f64>>()
309 // .iter()
310 // .sum::<f64>()
311 // / length;
312 sc_mut.state.norm_err = Some(
313 delta2
314 .iter()
315 .zip(&delta3)
316 .map(|(d2, d3)| (d2 - d3).powi(2))
317 .collect::<Vec<f64>>()
318 .iter()
319 .sum::<f64>()
320 .sqrt(),
321 );
322 let norm_d3 = delta3
323 .iter()
324 .map(|d3| d3.powi(2))
325 .collect::<Vec<f64>>()
326 .iter()
327 .sum::<f64>()
328 .sqrt();
329 //making sure that rtol is always considered as long as you don't divide by 0
330 sc_mut.state.norm_err_rel = if norm_d3 != 0. {
331 // `unwrap` is ok here because `norm_err` will always be some by this point
332 Some(sc_mut.state.norm_err.unwrap() / norm_d3)
333 } else {
334 // avoid dividing by 0
335 None
336 };
337
338 sc_mut.state.t_curr = t_curr;
339
340 if sc_mut.save_states {
341 sc_mut.state.states = states;
342 }
343
344 // conditions for breaking loop
345 // if there is a relative error, use that
346 // otherwise, use the absolute error
347 let tol_met = match sc_mut.state.norm_err_rel {
348 Some(norm_err_rel) => norm_err_rel <= sc_mut.rtol,
349 None => match sc_mut.state.norm_err {
350 Some(norm_err) => norm_err <= sc_mut.atol,
351 None => unreachable!(),
352 },
353 };
354
355 // Because we need to be able to possibly expand the next time step,
356 // regardless of whether break condition is met,
357 // adapt dt based on `rtol` if it is Some; use `atol` otherwise
358 // this adaptation strategy came directly from Chapra and Canale's section on adapting the time step
359 // The approach is to adapt more aggressively to meet rtol when decreasing the time step size
360 // than when increasing time step size.
361 let dt_coeff = match sc_mut.state.norm_err_rel {
362 Some(norm_err_rel) => match sc_mut.state.norm_err {
363 //ensures that if either rtol or atol are met, then the step succeeds
364 //prioritizes rtol -- if both are met, then rtol is used
365 //if no atol exists, just considers rtol
366 Some(norm_err) => {
367 if norm_err_rel <= sc_mut.rtol {
368 (sc_mut.rtol / norm_err_rel).powf(0.2)
369 } else if norm_err <= sc_mut.atol {
370 (sc_mut.atol / norm_err).powf(0.2)
371 } else {
372 0.25
373 }
374 }
375 // (sc_mut.rtol / norm_err_rel).powf(
376 // if norm_err_rel <= sc_mut.rtol || norm_err <= sc_mut.atol {
377 // 0.2
378 // } else {
379 // 0.25
380 // },
381 // ),
382 None => (sc_mut.rtol / norm_err_rel).powf(if norm_err_rel <= sc_mut.rtol {
383 0.2
384 } else {
385 0.25
386 }),
387 },
388 //if no rtol exists, just consideres atol
389 None => {
390 match sc_mut.state.norm_err {
391 Some(norm_err) => (sc_mut.atol / norm_err)
392 .powf(if norm_err <= sc_mut.atol { 0.2 } else { 0.25 }),
393 None => 1., // don't adapt if there is not enough information to do so (if neither atol or rtol exist)
394 }
395 }
396 };
397 // if tolerance is achieved here, then we proceed to the next time step, and
398 // `dt` will be limited to `dt_max` at the start of the next time step. If tolerance
399 // is not achieved, then time step will be decreased.
400 let break_cond = sc_mut.state.n_iter >= sc_mut.max_iter
401 || sc_mut.state.norm_err.unwrap() < sc_mut.atol
402 || tol_met;
403
404 if break_cond {
405 // save before modifying dt
406 if sc_mut.save {
407 sc_mut.history.push(sc_mut.state.clone());
408 }
409 // store used dt before adapting
410 let dt_used = sc_mut.state.dt;
411 // adapt for next solver time step
412 sc_mut.state.dt *= dt_coeff;
413 break (delta3, dt_used);
414 };
415 // adapt for next iteration in current time step
416 sc_mut.state.dt *= dt_coeff;
417 };
418
419 // increment forward with 3rd order solution
420 self.step_states(delta3);
421 self.step_time(&dt_used);
422 self.update_derivs();
423 // dbg!(self.state.time);
424 // dbg!(self.t_report[self.state.i]);
425 dt_used
426 }
427 fn rk23_bogacki_shampine_step(&mut self, dt: f64) -> (Vec<f64>, Vec<f64>) {
428 self.update_derivs();
429
430 // k1 = f(t_i, x_i)
431 let k1s = self.derivs();
432
433 // k2 = f(t_i + 1 / 2 * h, x_i + 1 / 2 * k1 * h)
434 let mut sys1 = self.bare_clone();
435 sys1.step_states_by_dt(&(dt / 2.));
436 sys1.update_derivs();
437 let k2s = sys1.derivs();
438 // k3 = f(t_i + 3 / 4 * h, x_i + 3 / 4 * k2 * h)
439 let mut sys2 = self.bare_clone();
440 sys2.set_derivs(&k2s);
441 sys2.step_states_by_dt(&(dt * 3. / 4.));
442 sys2.update_derivs();
443 let k3s = sys2.derivs();
444 // k4 = f(x_i + h, y_i + 2 / 9 * k1 * h + 1 / 3 * k2 * h + 4 / 9 * k3 * h) = 3rd order solution
445 let mut sys3 = self.bare_clone();
446 sys3.step_time(&(dt));
447 // 3nd order delta
448 let delta3: Vec<f64> = {
449 let (k1s, k2s, k3s) = (k1s.clone(), k2s.clone(), k3s.clone());
450 let zipped = zip!(k1s, k2s, k3s);
451 let mut steps = vec![];
452 for (k1, (k2, k3)) in zipped {
453 steps.push((2. / 9. * k1 + 1. / 3. * k2 + 4. / 9. * k3) * dt);
454 }
455 steps
456 };
457 let delta3_new = delta3.clone();
458 sys3.step_states(delta3_new);
459 sys3.update_derivs();
460 let k4s = sys3.derivs();
461 // 2nd order delta
462 let mut delta2: Vec<f64> = vec![];
463 let zipped = zip!(k1s, k2s, k3s, k4s);
464 for (k1, (k2, (k3, k4))) in zipped {
465 delta2.push((7. / 24. * k1 + 1. / 4. * k2 + 1. / 3. * k3 + 1. / 8. * k4) * dt);
466 }
467 (delta2, delta3)
468 }
469 /// solves time step with 4th order Runge-Kutta method.
470 /// See RK4 method: https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Examples
471 fn rk4fixed(&mut self, dt: &f64) {
472 self.update_derivs();
473
474 // k1 = f(x_i, y_i)
475 let k1s = self.derivs();
476
477 // k2 = f(x_i + 1 / 2 * h, y_i + 1 / 2 * k1 * h)
478 let mut sys1 = self.bare_clone();
479 sys1.step_states_by_dt(&(dt / 2.));
480 sys1.update_derivs();
481 let k2s = sys1.derivs();
482
483 // k3 = f(x_i + 1 / 2 * h, y_i + 1 / 2 * k2 * h)
484 let mut sys2 = self.bare_clone();
485 sys2.set_derivs(&k2s);
486 sys2.step_states_by_dt(&(dt / 2.));
487 sys2.update_derivs();
488 let k3s = sys2.derivs();
489
490 // k4 = f(x_i + h, y_i + k3 * h)
491 let mut sys3 = self.bare_clone();
492 sys3.set_derivs(&k3s);
493 sys3.step_states_by_dt(dt);
494 sys3.update_derivs();
495 let k4s = sys3.derivs();
496
497 let mut delta: Vec<f64> = vec![];
498 let zipped = zip!(k1s, k2s, k3s, k4s);
499 for (k1, (k2, (k3, k4))) in zipped {
500 delta.push(1. / 6. * (k1 + 2. * k2 + 2. * k3 + k4) * dt);
501 }
502
503 self.step_states(delta);
504 self.step_time(dt);
505 self.update_derivs();
506 }
507 /// solves time step with adaptive Cash-Karp Method (variant of RK45) and returns `dt` used
508 /// https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method
509 fn rk45_cash_karp(&mut self, dt_max: &f64) -> f64 {
510 let sc_mut = self.sc_mut().unwrap();
511 // reset iteration counter
512 sc_mut.state.n_iter = 0;
513 sc_mut.state.dt = sc_mut.state.dt.min(*dt_max).min(sc_mut.dt_max);
514
515 // loop to find `dt` that results in meeting tolerance
516 // and does not exceed `dt_max`
517 let (delta5, dt_used) = loop {
518 let sc = self.sc().unwrap();
519 let dt = sc.state.dt;
520
521 // run a single step at `dt`
522 let (delta4, delta5) = self.rk45_cash_karp_step(dt);
523
524 // reborrow because of the borrow above in `self.rk45_cash_karp_step(dt);`
525 let sc = self.sc().unwrap();
526 // grab states for later use if solver steps are to be saved
527 let states = if sc.save {
528 self.states()
529 .clone()
530 .iter()
531 .zip(delta5.clone())
532 .map(|(s, d)| s + d)
533 .collect::<Vec<f64>>()
534 } else {
535 vec![]
536 };
537
538 let t_curr = self.state().time;
539
540 // mutably borrow sc to update it
541 let sc_mut = self.sc_mut().unwrap();
542
543 // update `n_iter`, `norm_err`, `norm_err_rel`, `t_curr`, and `states`
544 // still need to update dt at some point
545 sc_mut.state.n_iter += 1;
546 //another way to calculate norm -- can be added in later via an enum
547 // let mut length = 0.;
548 // for _item in &delta4 {
549 // length += 1.;
550 // }
551 // sc_mut.state.norm_err = Some(
552 // delta4
553 // .iter()
554 // .zip(&delta5)
555 // .map(|(d4, d5)| (((d4 - d5).powi(2)).sqrt()))
556 // .collect::<Vec<f64>>()
557 // .iter()
558 // .sum::<f64>()
559 // / length,
560 // );
561 // let norm_d5 = delta5
562 // .iter()
563 // .map(|d5| (d5.powi(2)).sqrt())
564 // .collect::<Vec<f64>>()
565 // .iter()
566 // .sum::<f64>()
567 // / length;
568 sc_mut.state.norm_err = Some(
569 delta4
570 .iter()
571 .zip(&delta5)
572 .map(|(d4, d5)| (d4 - d5).powi(2))
573 .collect::<Vec<f64>>()
574 .iter()
575 .sum::<f64>()
576 .sqrt(),
577 );
578 let norm_d5 = delta5
579 .iter()
580 .map(|d5| d5.powi(2))
581 .collect::<Vec<f64>>()
582 .iter()
583 .sum::<f64>()
584 .sqrt();
585 //ensures that rtol is calculated and considered as long as you are not dividing by 0
586 sc_mut.state.norm_err_rel = if norm_d5 != 0. {
587 // `unwrap` is ok here because `norm_err` will always be some by this point
588 Some(sc_mut.state.norm_err.unwrap() / norm_d5)
589 } else {
590 // avoid dividing by 0
591 None
592 };
593
594 sc_mut.state.t_curr = t_curr;
595
596 if sc_mut.save_states {
597 sc_mut.state.states = states;
598 }
599
600 // conditions for breaking loop
601 // if there is a relative error, use that
602 // otherwise, use the absolute error
603 let tol_met = match sc_mut.state.norm_err_rel {
604 Some(norm_err_rel) => norm_err_rel <= sc_mut.rtol,
605 None => match sc_mut.state.norm_err {
606 Some(norm_err) => norm_err <= sc_mut.atol,
607 None => unreachable!(),
608 },
609 };
610
611 // Because we need to be able to possibly expand the next time step,
612 // regardless of whether break condition is met,
613 // adapt dt based on `rtol` if it is Some; use `atol` otherwise
614 // this adaptation strategy came directly from Chapra and Canale's section on adapting the time step
615 // The approach is to adapt more aggressively to meet rtol when decreasing the time step size
616 // than when increasing time step size.
617 let dt_coeff = match sc_mut.state.norm_err_rel {
618 Some(norm_err_rel) => {
619 //ensures that if either rtol or atol are met, then the step succeeds
620 //prioritizes rtol -- if both atol and rtol are met, rtol is used
621 if norm_err_rel <= sc_mut.rtol {
622 (sc_mut.rtol / norm_err_rel).powf(0.2)
623 } else if sc_mut.state.norm_err.unwrap() <= sc_mut.atol {
624 (sc_mut.atol / sc_mut.state.norm_err.unwrap()).powf(0.2)
625 } else {
626 0.25
627 }
628 // (sc_mut.rtol / norm_err_rel).powf(
629 // if norm_err_rel <= sc_mut.rtol || norm_err <= sc_mut.atol {
630 // 0.2
631 // } else {
632 // 0.25
633 // },
634 // ),
635 }
636 //if rtol doesn't exist just use atol
637 None => {
638 match sc_mut.state.norm_err {
639 Some(norm_err) => (sc_mut.atol / norm_err)
640 .powf(if norm_err <= sc_mut.atol { 0.2 } else { 0.25 }),
641 None => 1., // don't adapt if there is not enough information to do so
642 }
643 }
644 };
645
646 // if tolerance is achieved here, then we proceed to the next time step, and
647 // `dt` will be limited to `dt_max` at the start of the next time step. If tolerance
648 // is not achieved, then time step will be decreased.
649 let break_cond = sc_mut.state.n_iter >= sc_mut.max_iter
650 || sc_mut.state.norm_err.unwrap() < sc_mut.atol
651 || tol_met;
652
653 if break_cond {
654 // save before modifying dt
655 if sc_mut.save {
656 sc_mut.history.push(sc_mut.state.clone());
657 }
658 // store used dt before adapting
659 let dt_used = sc_mut.state.dt;
660 // adapt for next solver time step
661 sc_mut.state.dt *= dt_coeff;
662 break (delta5, dt_used);
663 };
664 // adapt for next iteration in current time step
665 sc_mut.state.dt *= dt_coeff;
666 };
667
668 // increment forward with 5th order solution
669 self.step_states(delta5);
670 self.step_time(&dt_used);
671 self.update_derivs();
672 // dbg!(self.state.time);
673 // dbg!(self.t_report[self.state.i]);
674 dt_used
675 }
676
677 fn rk45_cash_karp_step(&mut self, dt: f64) -> (Vec<f64>, Vec<f64>) {
678 self.update_derivs();
679
680 // k1 = f(x_i, y_i)
681 let k1s = self.derivs();
682
683 // k2 = f(x_i + 1 / 5 * h, y_i + 1 / 5 * k1 * h)
684 let mut sys1 = self.bare_clone();
685 sys1.step_states_by_dt(&(dt / 5.));
686 sys1.update_derivs();
687 let k2s = sys1.derivs();
688
689 // k3 = f(x_i + 3 / 10 * h, y_i + 3 / 40 * k1 * h + 9 / 40 * k2 * h)
690 let mut sys2 = self.bare_clone();
691 sys2.step_time(&(dt * 3. / 10.));
692 sys2.step_states(
693 k1s.iter()
694 .zip(k2s.clone())
695 .map(|(k1, k2)| (3. / 40. * k1 + 9. / 40. * k2) * dt)
696 .collect(),
697 );
698 sys2.update_derivs();
699 let k3s = sys2.derivs();
700
701 // k4 = f(x_i + 3 / 5 * h, y_i + 3 / 10 * k1 * h - 9 / 10 * k2 * h + 6 / 5 * k3 * h)
702 let mut sys3 = self.bare_clone();
703 sys3.step_time(&(dt * 3. / 5.));
704 sys3.step_states({
705 let (k1s, k2s, k3s) = (k1s.clone(), k2s.clone(), k3s.clone());
706 let zipped = zip!(k1s, k2s, k3s);
707 let mut steps = vec![];
708 for (k1, (k2, k3)) in zipped {
709 steps.push((3. / 10. * k1 - 9. / 10. * k2 + 6. / 5. * k3) * dt);
710 }
711 steps
712 });
713 sys3.update_derivs();
714 let k4s = sys3.derivs();
715
716 // k5 = f(x_i + h, y_i - 11 / 54 * k1 * h + 5 / 2 * k2 * h - 70 / 27 * k3 * h + 35 / 27 * k4 * h)
717 let mut sys4 = self.bare_clone();
718 sys4.step_time(&dt);
719 sys4.step_states({
720 let (k1s, k2s, k3s, k4s) = (k1s.clone(), k2s.clone(), k3s.clone(), k4s.clone());
721 let zipped = zip!(k1s, k2s, k3s, k4s);
722 let mut steps = vec![];
723 for (k1, (k2, (k3, k4))) in zipped {
724 steps.push((-11. / 54. * k1 + 5. / 2. * k2 - 70. / 27. * k3 + 35. / 27. * k4) * dt);
725 }
726 steps
727 });
728 sys4.update_derivs();
729 let k5s = sys4.derivs();
730
731 // k6 = f(x_i + 7 / 8 * h, y_i + 1631 / 55296 * k1 * h + 175 / 512 * k2 * h + 575 / 13824 * k3 * h + 44275 / 110592 * k4 * h + 253 / 4096 * k5 * h)
732 let mut sys5 = self.bare_clone();
733 sys5.step_time(&(dt * 7. / 8.));
734 sys5.step_states({
735 let (k1s, k2s, k3s, k4s, k5s) = (
736 k1s.clone(),
737 k2s.clone(),
738 k3s.clone(),
739 k4s.clone(),
740 k5s.clone(),
741 );
742 let zipped = zip!(k1s, k2s, k3s, k4s, k5s);
743 let mut steps = vec![];
744 for (k1, (k2, (k3, (k4, k5)))) in zipped {
745 steps.push(
746 (1_631. / 55_296. * k1
747 + 175. / 512. * k2
748 + 575. / 13_824. * k3
749 + 44_275. / 110_592. * k4
750 + 253. / 4096. * k5)
751 * dt,
752 );
753 }
754 steps
755 });
756 sys5.update_derivs();
757 let k6s = sys5.derivs();
758
759 // 4th order delta
760 let mut delta4: Vec<f64> = vec![];
761 // 5th order delta
762 let mut delta5: Vec<f64> = vec![];
763 let zipped = zip!(k1s, k2s, k3s, k4s, k5s, k6s);
764 for (k1, (_k2, (k3, (k4, (k5, k6))))) in zipped {
765 delta5.push(
766 (37. / 378. * k1 + 250. / 621. * k3 + 125. / 594. * k4 + 512. / 1_771. * k6) * dt,
767 );
768 delta4.push(
769 (2825. / 27_648. * k1
770 + 18_575. / 48_384. * k3
771 + 13_525. / 55_296. * k4
772 + 277. / 14_336. * k5
773 + 1. / 4. * k6)
774 * dt,
775 );
776 }
777 (delta4, delta5)
778 }
779}