numeric_algs/integration/
dormand_prince.rs1use super::{Integrator, StepSize};
2use crate::traits::{State, StateDerivative};
3use std::mem;
4
5pub struct DPIntegrator<S: State> {
6 default_step: f64,
7 original_default_step: f64,
8 max_err: f64,
9 min_step: f64,
10 max_step: f64,
11 last_derivative: Option<S::Derivative>,
12}
13
14impl<S: State> DPIntegrator<S> {
15 pub fn new(default_step: f64, min_step: f64, max_step: f64, max_err: f64) -> Self {
16 DPIntegrator {
17 default_step: default_step,
18 original_default_step: default_step,
19 min_step: min_step,
20 max_step: max_step,
21 max_err: max_err,
22 last_derivative: None,
23 }
24 }
25
26 pub fn reset_default_step(&mut self) {
27 self.default_step = self.original_default_step;
28 }
29
30 pub fn reset(&mut self) {
31 self.last_derivative = None;
32 self.reset_default_step();
33 }
34}
35
36impl<S: State> Integrator<S> for DPIntegrator<S> {
37 fn propagate_in_place<D>(&mut self, start: &mut S, diff_eq: D, step_size: StepSize)
38 where
39 D: Fn(&S) -> S::Derivative,
40 {
41 let h = match step_size {
42 StepSize::UseDefault => self.default_step,
43 StepSize::Step(x) => x,
44 };
45
46 let k1 = if let Some(last_derivative) = mem::replace(&mut self.last_derivative, None) {
47 last_derivative
48 } else {
49 diff_eq(start)
50 };
51 let k2 = diff_eq(&start.shift(&(k1.clone() / 5.0), h));
52 let k3 = diff_eq(&start.shift(&(k1.clone() * 3.0 / 40.0 + k2.clone() * 9.0 / 40.0), h));
53 let k4 = diff_eq(&start.shift(
54 &(k1.clone() * 44.0 / 45.0 - k2.clone() * 56.0 / 15.0 + k3.clone() * 32.0 / 9.0),
55 h,
56 ));
57 let k5 = diff_eq(&start.shift(
58 &(k1.clone() * 19372.0 / 6561.0 - k2.clone() * 25360.0 / 2187.0
59 + k3.clone() * 64448.0 / 6561.0
60 - k4.clone() * 212.0 / 729.0),
61 h,
62 ));
63 let k6 = diff_eq(&start.shift(
64 &(k1.clone() * 9017.0 / 3168.0 - k2.clone() * 355.0 / 33.0
65 + k3.clone() * 46732.0 / 5247.0
66 + k4.clone() * 49.0 / 176.0
67 - k5.clone() * 5103.0 / 18656.0),
68 h,
69 ));
70
71 let new_state = start.shift(
72 &(k1.clone() * 35.0 / 384.0 + k3.clone() * 500.0 / 1113.0 + k4.clone() * 125.0 / 192.0
73 - k5.clone() * 2187.0 / 6784.0
74 + k6.clone() * 11.0 / 84.0),
75 h,
76 );
77
78 let k7 = diff_eq(&new_state);
79
80 let error = ((k1 * 71.0 / 57600.0 - k3 * 71.0 / 16695.0 + k4 * 71.0 / 1920.0
81 - k5 * 17253.0 / 339200.0
82 + k6 * 22.0 / 525.0
83 - k7.clone() / 40.0)
84 * h)
85 .abs();
86
87 if error != 0.0 {
88 self.default_step = h * (self.max_err / error).powf(0.25);
89 } else {
90 self.default_step = self.max_step;
91 }
92
93 if self.default_step < self.min_step {
94 self.default_step = self.min_step;
95 }
96 if self.default_step > self.max_step {
97 self.default_step = self.max_step;
98 }
99
100 if self.default_step < 0.8 * h && step_size == StepSize::UseDefault {
103 self.propagate_in_place(start, diff_eq, step_size);
104 return;
105 }
106
107 *start = new_state;
108
109 self.last_derivative = Some(k7);
111 }
112}