1use crate::{
4 Error, Status,
5 alias::Evals,
6 interpolate::Interpolation,
7 ode::{ODENumericalMethod, ODE, methods::h_init},
8 traits::{CallBackData, Real, State},
9 utils::{constrain_step_size, validate_step_size_parameters},
10};
11
12pub struct DOPRI5<T: Real, V: State<T>, D: CallBackData> {
67 pub h0: T, t: T,
72 y: V,
73 h: T,
74
75 pub rtol: T,
77 pub atol: T,
78
79 pub h_max: T,
81 pub h_min: T,
82 pub max_steps: usize,
83 pub n_stiff: usize,
84
85 pub safe: T,
87 pub fac1: T,
88 pub fac2: T,
89 pub beta: T,
90
91 expo1: T,
93 facc1: T,
94 facc2: T,
95 facold: T,
96 fac11: T,
97 fac: T,
98
99 status: Status<T, V, D>,
101 steps: usize, n_accepted: usize, h_lamb: T,
106 non_stiff_counter: usize,
107 stiffness_counter: usize,
108
109 a: [[T; 7]; 7],
111 b: [T; 7],
112 c: [T; 7],
113 er: [T; 7],
114
115 d: [T; 7],
117
118 k: [V; 7], y_old: V, t_old: T, h_old: T, cont: [V; 5], }
127
128impl<T: Real, V: State<T>, D: CallBackData> ODENumericalMethod<T, V, D> for DOPRI5<T, V, D> {
129 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
130 where
131 F: ODE<T, V, D>,
132 {
133 let mut evals = Evals::new();
134
135 self.t = t0;
137 self.y = *y0;
138
139 ode.diff(t0, y0, &mut self.k[0]);
141 evals.fcn += 1; self.t_old = self.t;
145 self.y_old = self.y;
146
147 if self.h0 == T::zero() {
149 self.h0 = h_init(
150 ode, t0, tf, y0, 5, self.rtol, self.atol, self.h_min, self.h_max,
151 );
152 evals.fcn += 1; let posneg = (tf - t0).signum();
156 if self.h0.abs() < self.h_min.abs() {
157 self.h0 = self.h_min.abs() * posneg;
158 } else if self.h0.abs() > self.h_max.abs() {
159 self.h0 = self.h_max.abs() * posneg;
160 }
161 }
162
163 match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
165 Ok(h0) => self.h = h0,
166 Err(status) => return Err(status),
167 }
168
169 self.h_lamb = T::zero();
171 self.non_stiff_counter = 0;
172 self.stiffness_counter = 0;
173
174 self.status = Status::Initialized;
176
177 Ok(evals)
178 }
179
180 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
181 where
182 F: ODE<T, V, D>,
183 {
184 let mut evals = Evals::new();
185
186 if self.steps >= self.max_steps {
188 self.status = Status::Error(Error::MaxSteps {
189 t: self.t,
190 y: self.y,
191 });
192 return Err(Error::MaxSteps {
193 t: self.t,
194 y: self.y,
195 });
196 }
197
198 if self.h.abs() < T::default_epsilon() {
200 self.status = Status::Error(Error::StepSize {
201 t: self.t,
202 y: self.y,
203 });
204 return Err(Error::StepSize {
205 t: self.t,
206 y: self.y,
207 });
208 }
209
210 ode.diff(
212 self.t + self.c[1] * self.h,
213 &(self.y + self.k[0] * (self.a[1][0] * self.h)),
214 &mut self.k[1],
215 );
216 ode.diff(
217 self.t + self.c[2] * self.h,
218 &(self.y + self.k[0] * (self.a[2][0] * self.h) + self.k[1] * (self.a[2][1] * self.h)),
219 &mut self.k[2],
220 );
221 ode.diff(
222 self.t + self.c[3] * self.h,
223 &(self.y
224 + self.k[0] * (self.a[3][0] * self.h)
225 + self.k[1] * (self.a[3][1] * self.h)
226 + self.k[2] * (self.a[3][2] * self.h)),
227 &mut self.k[3],
228 );
229 ode.diff(
230 self.t + self.c[4] * self.h,
231 &(self.y
232 + self.k[0] * (self.a[4][0] * self.h)
233 + self.k[1] * (self.a[4][1] * self.h)
234 + self.k[2] * (self.a[4][2] * self.h)
235 + self.k[3] * (self.a[4][3] * self.h)),
236 &mut self.k[4],
237 );
238 ode.diff(
239 self.t + self.c[5] * self.h,
240 &(self.y
241 + self.k[0] * (self.a[5][0] * self.h)
242 + self.k[1] * (self.a[5][1] * self.h)
243 + self.k[2] * (self.a[5][2] * self.h)
244 + self.k[3] * (self.a[5][3] * self.h)
245 + self.k[4] * (self.a[5][4] * self.h)),
246 &mut self.k[5],
247 );
248
249 let ysti = self.y
250 + self.k[0] * (self.a[6][0] * self.h)
251 + self.k[2] * (self.a[6][2] * self.h)
252 + self.k[3] * (self.a[6][3] * self.h)
253 + self.k[4] * (self.a[6][4] * self.h)
254 + self.k[5] * (self.a[6][5] * self.h);
255
256 let t_new = self.t + self.h;
257 ode.diff(t_new, &ysti, &mut self.k[6]);
258
259 let y_new = self.y
260 + self.k[0] * (self.b[0] * self.h)
261 + self.k[2] * (self.b[2] * self.h)
262 + self.k[3] * (self.b[3] * self.h)
263 + self.k[4] * (self.b[4] * self.h)
264 + self.k[5] * (self.b[5] * self.h)
265 + self.k[6] * (self.b[6] * self.h);
266
267 ode.diff(t_new, &y_new, &mut self.k[1]);
268
269 evals.fcn += 7; let mut err = T::zero();
273
274 let n = self.y.len();
275 for i in 0..n {
276 let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
277 let erri = self.h
278 * (self.er[0] * self.k[0].get(i)
279 + self.er[2] * self.k[2].get(i)
280 + self.er[3] * self.k[3].get(i)
281 + self.er[4] * self.k[4].get(i)
282 + self.er[5] * self.k[5].get(i)
283 + self.er[6] * self.k[6].get(i));
284 err += (erri / sk).powi(2);
285 }
286 err = (err / T::from_usize(n).unwrap()).sqrt();
287
288 self.fac11 = err.powf(self.expo1);
290 self.fac = self.fac11 / self.facold.powf(self.beta);
292 self.fac = self.facc2.max(self.facc1.min(self.fac / self.safe));
294 let mut h_new = self.h / self.fac;
295
296 if err <= T::one() {
297 self.facold = err.max(T::from_f64(1.0e-4).unwrap());
299 self.n_accepted += 1;
300
301 if self.n_accepted % self.n_stiff == 0 || self.stiffness_counter > 0 {
303 let mut stnum = T::zero();
304 let mut stden = T::zero();
305
306 for i in 0..n {
307 let stnum_i = self.k[1].get(i) - self.k[6].get(i);
308 stnum += stnum_i * stnum_i;
309
310 let stden_i = y_new.get(i) - ysti.get(i);
311 stden += stden_i * stden_i;
312 }
313
314 if stden > T::zero() {
315 self.h_lamb = self.h * (stnum / stden).sqrt();
316 }
317
318 if self.h_lamb > T::from_f64(3.25).unwrap() {
319 self.non_stiff_counter = 0;
320 self.stiffness_counter += 1;
321 if self.stiffness_counter == 15 {
322 self.status = Status::Error(Error::Stiffness {
324 t: self.t,
325 y: self.y,
326 });
327 return Err(Error::Stiffness {
328 t: self.t,
329 y: self.y,
330 });
331 }
332 } else {
333 self.non_stiff_counter += 1;
334 if self.non_stiff_counter == 6 {
335 self.stiffness_counter = 0;
336 }
337 }
338 }
339
340 let ydiff = y_new - self.y;
343 let bspl = self.k[0] * self.h - ydiff;
344
345 self.cont[0] = self.y;
346 self.cont[1] = ydiff;
347 self.cont[2] = bspl;
348 self.cont[3] = ydiff - self.k[1] * self.h - bspl;
349
350 self.cont[4] = (self.k[0] * self.d[0]
352 + self.k[2] * self.d[2]
353 + self.k[3] * self.d[3]
354 + self.k[4] * self.d[4]
355 + self.k[5] * self.d[5]
356 + self.k[6] * self.d[6])
357 * self.h;
358
359 self.y_old = self.y;
361 self.t_old = self.t;
362 self.h_old = self.h;
363
364 self.k[0] = self.k[1];
366 self.y = y_new;
367 self.t = t_new;
368
369 if let Status::RejectedStep = self.status {
371 h_new = self.h.min(h_new);
372 self.status = Status::Solving;
373 }
374 } else {
375 h_new = self.h / self.facc1.min(self.fac11 / self.safe);
377 self.status = Status::RejectedStep;
378 }
379
380 self.h = constrain_step_size(h_new, self.h_min, self.h_max);
382 Ok(evals)
383 }
384
385 fn t(&self) -> T {
386 self.t
387 }
388
389 fn y(&self) -> &V {
390 &self.y
391 }
392
393 fn t_prev(&self) -> T {
394 self.t_old
395 }
396
397 fn y_prev(&self) -> &V {
398 &self.y_old
399 }
400
401 fn h(&self) -> T {
402 self.h
403 }
404
405 fn set_h(&mut self, h: T) {
406 self.h = h;
407 }
408
409 fn status(&self) -> &Status<T, V, D> {
410 &self.status
411 }
412
413 fn set_status(&mut self, status: Status<T, V, D>) {
414 self.status = status;
415 }
416}
417
418impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for DOPRI5<T, V, D> {
419 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
420 if t_interp < self.t_old || t_interp > self.t {
422 return Err(Error::OutOfBounds {
423 t_interp,
424 t_prev: self.t_old,
425 t_curr: self.t,
426 });
427 }
428
429 let s = (t_interp - self.t_old) / self.h_old;
431 let s1 = T::one() - s;
432
433 let y_interp = self.cont[0]
435 + (self.cont[1] + (self.cont[2] + (self.cont[3] + self.cont[4] * s1) * s) * s1) * s;
436
437 Ok(y_interp)
438 }
439}
440
441impl<T: Real, V: State<T>, D: CallBackData> DOPRI5<T, V, D> {
442 pub fn new() -> Self {
448 DOPRI5 {
449 ..Default::default()
450 }
451 }
452
453 pub fn rtol(mut self, rtol: T) -> Self {
455 self.rtol = rtol;
456 self
457 }
458
459 pub fn atol(mut self, atol: T) -> Self {
460 self.atol = atol;
461 self
462 }
463
464 pub fn h0(mut self, h0: T) -> Self {
465 self.h0 = h0;
466 self
467 }
468
469 pub fn h_max(mut self, h_max: T) -> Self {
470 self.h_max = h_max;
471 self
472 }
473
474 pub fn h_min(mut self, h_min: T) -> Self {
475 self.h_min = h_min;
476 self
477 }
478
479 pub fn max_steps(mut self, max_steps: usize) -> Self {
480 self.max_steps = max_steps;
481 self
482 }
483
484 pub fn n_stiff(mut self, n_stiff: usize) -> Self {
485 self.n_stiff = n_stiff;
486 self
487 }
488
489 pub fn safe(mut self, safe: T) -> Self {
490 self.safe = safe;
491 self
492 }
493
494 pub fn beta(mut self, beta: T) -> Self {
495 self.beta = beta;
496 self
497 }
498
499 pub fn fac1(mut self, fac1: T) -> Self {
500 self.fac1 = fac1;
501 self
502 }
503
504 pub fn fac2(mut self, fac2: T) -> Self {
505 self.fac2 = fac2;
506 self
507 }
508
509 pub fn expo1(mut self, expo1: T) -> Self {
510 self.expo1 = expo1;
511 self
512 }
513
514 pub fn facc1(mut self, facc1: T) -> Self {
515 self.facc1 = facc1;
516 self
517 }
518
519 pub fn facc2(mut self, facc2: T) -> Self {
520 self.facc2 = facc2;
521 self
522 }
523}
524
525impl<T: Real, V: State<T>, D: CallBackData> Default for DOPRI5<T, V, D> {
526 fn default() -> Self {
527 let a = DOPRI5_A.map(|row| row.map(|x| T::from_f64(x).unwrap()));
529 let b = DOPRI5_B.map(|x| T::from_f64(x).unwrap());
530 let c = DOPRI5_C.map(|x| T::from_f64(x).unwrap());
531 let er = DOPRI5_E.map(|x| T::from_f64(x).unwrap());
532 let d = DOPRI5_D.map(|x| T::from_f64(x).unwrap());
533
534 let k_zeros = [V::zeros(); 7];
536 let cont_zeros = [V::zeros(); 5];
537
538 DOPRI5 {
539 t: T::zero(),
541 y: V::zeros(),
542 h: T::zero(),
543
544 h0: T::zero(),
546 rtol: T::from_f64(1e-3).unwrap(),
547 atol: T::from_f64(1e-6).unwrap(),
548 h_max: T::infinity(),
549 h_min: T::zero(),
550 max_steps: 100_000,
551 n_stiff: 1000,
552 safe: T::from_f64(0.9).unwrap(),
553 fac1: T::from_f64(0.2).unwrap(),
554 fac2: T::from_f64(10.0).unwrap(),
555 beta: T::from_f64(0.04).unwrap(),
556 expo1: T::from_f64(1.0 / 5.0).unwrap(),
557 facc1: T::from_f64(1.0 / 0.2).unwrap(),
558 facc2: T::from_f64(1.0 / 10.0).unwrap(),
559 facold: T::from_f64(1.0e-4).unwrap(),
560 fac11: T::zero(),
561 fac: T::zero(),
562
563 a,
565 b,
566 c,
567 er,
568 d,
569
570 status: Status::Uninitialized,
572 h_lamb: T::zero(),
573 non_stiff_counter: 0,
574 stiffness_counter: 0,
575 steps: 0,
576 n_accepted: 0,
577
578 k: k_zeros,
580 y_old: V::zeros(),
581 t_old: T::zero(),
582 h_old: T::zero(),
583 cont: cont_zeros,
584 }
585 }
586}
587
588const DOPRI5_A: [[f64; 7]; 7] = [
592 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
593 [0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
594 [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0, 0.0, 0.0],
595 [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0, 0.0, 0.0],
596 [
597 19372.0 / 6561.0,
598 -25360.0 / 2187.0,
599 64448.0 / 6561.0,
600 -212.0 / 729.0,
601 0.0,
602 0.0,
603 0.0,
604 ],
605 [
606 9017.0 / 3168.0,
607 -355.0 / 33.0,
608 46732.0 / 5247.0,
609 49.0 / 176.0,
610 -5103.0 / 18656.0,
611 0.0,
612 0.0,
613 ],
614 [
615 35.0 / 384.0,
616 0.0,
617 500.0 / 1113.0,
618 125.0 / 192.0,
619 -2187.0 / 6784.0,
620 11.0 / 84.0,
621 0.0,
622 ],
623];
624
625const DOPRI5_C: [f64; 7] = [
627 0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0, ];
635
636const DOPRI5_B: [f64; 7] = [
638 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0, ];
646
647const DOPRI5_E: [f64; 7] = [
649 71.0 / 57600.0, 0.0, -71.0 / 16695.0, 71.0 / 1920.0, -17253.0 / 339200.0, 22.0 / 525.0, -1.0 / 40.0, ];
657
658const DOPRI5_D: [f64; 7] = [
660 -12715105075.0 / 11282082432.0, 0.0, 87487479700.0 / 32700410799.0, -10690763975.0 / 1880347072.0, 701980252875.0 / 199316789632.0, -1453857185.0 / 822651844.0, 69997945.0 / 29380423.0, ];