differential_equations/methods/erk/dormandprince/
ordinary.rs1use super::{ExplicitRungeKutta, Ordinary, DormandPrince};
4use crate::{
5 Error, Status,
6 alias::Evals,
7 methods::h_init::InitialStepSize,
8 interpolate::Interpolation,
9 ode::{OrdinaryNumericalMethod, ODE},
10 traits::{CallBackData, Real, State},
11 utils::{constrain_step_size, validate_step_size_parameters},
12};
13
14impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> OrdinaryNumericalMethod<T, V, D> for ExplicitRungeKutta<Ordinary, DormandPrince, T, V, D, O, S, I> {
15 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
16 where
17 F: ODE<T, V, D>,
18 {
19 let mut evals = Evals::new();
20
21 if self.h0 == T::zero() {
23 self.h0 = InitialStepSize::<Ordinary>::compute(ode, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, &mut evals);
25 }
26
27 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
29 Ok(h0) => self.h = h0,
30 Err(status) => return Err(status),
31 }
32
33 self.stiffness_counter = 0;
35
36 self.t = t0;
38 self.y = *y0;
39 ode.diff(self.t, &self.y, &mut self.k[0]);
40 self.dydt = self.k[0];
41 evals.fcn += 1;
42
43 self.t_prev = self.t;
45 self.y_prev = self.y;
46 self.dydt_prev = self.dydt;
47
48 self.status = Status::Initialized;
50
51 Ok(evals)
52 }
53
54 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
55 where
56 F: ODE<T, V, D>,
57 {
58 let mut evals = Evals::new();
59
60 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
62 self.status = Status::Error(Error::StepSize {
63 t: self.t, y: self.y
64 });
65 return Err(Error::StepSize {
66 t: self.t, y: self.y
67 });
68 }
69
70 if self.steps >= self.max_steps {
72 self.status = Status::Error(Error::MaxSteps {
73 t: self.t, y: self.y
74 });
75 return Err(Error::MaxSteps {
76 t: self.t, y: self.y
77 });
78 }
79 self.steps += 1;
80
81 let mut y_stage = V::zeros();
83 for i in 1..self.stages {
84 y_stage = V::zeros();
85
86 for j in 0..i {
87 y_stage += self.k[j] * self.a[i][j];
88 }
89 y_stage = self.y + y_stage * self.h;
90
91 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
92 }
93
94 let ysti = y_stage;
96
97 let mut yseg = V::zeros();
99 for i in 0..self.stages {
100 yseg += self.k[i] * self.b[i];
101 }
102
103 let y_new = self.y + yseg * self.h;
105
106 let t_new = self.t + self.h;
108
109 evals.fcn += self.stages - 1; let er = self.er.unwrap();
114 let n = self.y.len();
115 let mut err = T::zero();
116 let mut err2 = T::zero();
117 let mut erri;
118 for i in 0..n {
119 let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
121
122 erri = T::zero();
124 for j in 0..self.stages {
125 erri += er[j] * self.k[j].get(i);
126 }
127 err += (erri / sk).powi(2);
128
129 if let Some(bh) = &self.bh {
131 erri = yseg.get(i);
132 for j in 0..self.stages {
133 erri -= bh[j] * self.k[j].get(i);
134 }
135 err2 += (erri / sk).powi(2);
136 }
137 }
138 let mut deno = err + T::from_f64(0.01).unwrap() * err2;
139 if deno <= T::zero() {
140 deno = T::one();
141 }
142 err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
143
144 let order = T::from_usize(self.order).unwrap();
146 let error_exponent = T::one() / order;
147 let mut scale = self.safety_factor * err.powf(-error_exponent);
148
149 scale = scale.max(self.min_scale).min(self.max_scale);
151
152 if err <= T::one() {
154 ode.diff(t_new, &y_new, &mut self.dydt);
156 evals.fcn += 1;
157
158 let n_stiff_threshold = 100;
160 if self.steps % n_stiff_threshold == 0 {
161 let mut stdnum = T::zero();
162 let mut stden = T::zero();
163 let sqr = yseg - self.k[S-1];
164 for i in 0..sqr.len() {
165 stdnum += sqr.get(i).powi(2);
166 }
167 let sqr = self.dydt - ysti;
168 for i in 0..sqr.len() {
169 stden += sqr.get(i).powi(2);
170 }
171
172 if stden > T::zero() {
173 let h_lamb = self.h * (stdnum / stden).sqrt();
174 if h_lamb > T::from_f64(6.1).unwrap() {
175 self.non_stiffness_counter = 0;
176 self.stiffness_counter += 1;
177 if self.stiffness_counter == 15 {
178 self.status = Status::Error(Error::Stiffness {
180 t: self.t,
181 y: self.y,
182 });
183 return Err(Error::Stiffness {
184 t: self.t,
185 y: self.y,
186 });
187 }
188 }
189 } else {
190 self.non_stiffness_counter += 1;
191 if self.non_stiffness_counter == 6 {
192 self.stiffness_counter = 0;
193 }
194 }
195 }
196
197 self.cont[0] = self.y;
199 let ydiff = y_new - self.y;
200 self.cont[1] = ydiff;
201 let bspl = self.k[0] * self.h - ydiff;
202 self.cont[2] = bspl;
203 self.cont[3] = ydiff - self.dydt * self.h - bspl;
204
205 if let Some(bi) = &self.bi {
207 if I > S {
209 self.k[self.stages] = self.dydt;
211
212 for i in S+1..I {
213 let mut y_stage = V::zeros();
214 for j in 0..i {
215 y_stage += self.k[j] * self.a[i][j];
216 }
217 y_stage = self.y + y_stage * self.h;
218
219 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
220 evals.fcn += 1;
221 }
222 }
223
224 for i in 4..self.order {
226 self.cont[i] = V::zeros();
227 for j in 0..self.dense_stages {
228 self.cont[i] += self.k[j] * bi[i][j];
229 }
230 self.cont[i] = self.cont[i] * self.h;
231 }
232 }
233
234 self.t_prev = self.t;
236 self.y_prev = self.y;
237 self.dydt_prev = self.k[0];
238 self.h_prev = self.h;
239
240 self.t = t_new;
242 self.y = y_new;
243 self.k[0] = self.dydt;
244
245 if let Status::RejectedStep = self.status {
247 self.status = Status::Solving;
248
249 scale = scale.min(T::one());
251 }
252 } else {
253 self.status = Status::RejectedStep;
255 }
256
257 self.h *= scale;
259
260 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
262
263 Ok(evals)
264 }
265
266 fn t(&self) -> T { self.t }
267 fn y(&self) -> &V { &self.y }
268 fn t_prev(&self) -> T { self.t_prev }
269 fn y_prev(&self) -> &V { &self.y_prev }
270 fn h(&self) -> T { self.h }
271 fn set_h(&mut self, h: T) { self.h = h; }
272 fn status(&self) -> &Status<T, V, D> { &self.status }
273 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
274}
275
276impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Ordinary, DormandPrince, T, V, D, O, S, I> {
277 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
278 if t_interp < self.t_prev || t_interp > self.t {
280 return Err(Error::OutOfBounds {
281 t_interp,
282 t_prev: self.t_prev,
283 t_curr: self.t,
284 });
285 }
286
287 let s = (t_interp - self.t_prev) / self.h_prev;
289 let s1 = T::one() - s;
290
291 let ilast = self.cont.len() - 1;
293 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
294 let factor = if i >= 4 {
295 if (ilast - i) % 2 == 1 { s1 } else { s }
297 } else {
298 if i % 2 == 1 { s1 } else { s }
300 };
301 acc * factor + self.cont[i]
302 });
303
304 let y_interp = self.cont[0] + poly * s;
306
307 Ok(y_interp)
308 }
309}