1use crate::{
4 Error, Status,
5 alias::Evals,
6 dde::{DDE, DDENumericalMethod, methods::h_init::h_init},
7 interpolate::Interpolation,
8 traits::{CallBackData, Real, State},
9 utils::{constrain_step_size, validate_step_size_parameters},
10};
11use std::collections::VecDeque;
12
13pub struct DOPRI5<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> {
69 pub h0: T,
70 t: T,
71 y: V,
72 h: T,
73 pub rtol: T,
74 pub atol: T,
75 pub h_max: T,
76 pub h_min: T,
77 pub max_steps: usize,
78 pub safe: T,
79 pub fac1: T,
80 pub fac2: T,
81 pub beta: T,
82 pub max_delay: Option<T>,
83 expo1: T,
84 facc1: T,
85 facc2: T,
86 facold: T,
87 fac11: T,
88 fac: T,
89 status: Status<T, V, D>,
90 steps: usize,
91 n_accepted: usize,
92 a: [[T; 7]; 7], b: [T; 7], c: [T; 7], er: [T; 7], d: [T; 7], k: [V; 7], y_old: V,
99 t_old: T,
100 h_old: T,
101 cont: [V; 5], cont_buffer: VecDeque<(T, T, T, [V; 5])>, phi: Option<H>,
104 t0: T,
105 tf: T,
106 posneg: T,
107 lags: [T; L],
108 yd: [V; L],
109}
110
111impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData>
112 DDENumericalMethod<L, T, V, H, D> for DOPRI5<L, T, V, H, D>
113{
114 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: H) -> Result<Evals, Error<T, V>>
115 where
116 F: DDE<L, T, V, D>,
117 {
118 let mut evals = Evals::new();
119
120 self.t = t0;
121 self.y = *y0;
122 self.t0 = t0;
123 self.tf = tf;
124 self.posneg = (tf - t0).signum();
125 self.phi = Some(phi);
126
127 if L > 0 {
128 dde.lags(self.t, &self.y, &mut self.lags);
129 for i in 0..L {
130 if self.lags[i] <= T::zero() {
131 return Err(Error::BadInput {
132 msg: "All lags must be positive.".to_string(),
133 });
134 }
135 let t_delayed = self.t - self.lags[i];
136 if (t_delayed - self.t0) * self.posneg > T::default_epsilon() {
137 return Err(Error::BadInput {
138 msg: format!(
139 "Initial delayed time {} is out of history range (t <= {}).",
140 t_delayed, self.t0
141 ),
142 });
143 }
144 self.yd[i] = (self.phi.as_ref().unwrap())(t_delayed);
145 }
146 }
147 dde.diff(self.t, &self.y, &self.yd, &mut self.k[0]); evals.fcn += 1;
149
150 if self.h0 == T::zero() {
151 let h_est = h_init(
152 dde,
153 self.t,
154 self.tf,
155 &self.y,
156 5,
157 self.rtol,
158 self.atol,
159 self.h_min,
160 self.h_max,
161 self.phi.as_ref().unwrap(),
162 &self.k[0],
163 &mut evals,
164 );
165 self.h0 = h_est;
166 }
167
168 match validate_step_size_parameters::<T, V, D>(
169 self.h0, self.h_min, self.h_max, self.t, self.tf,
170 ) {
171 Ok(h0_validated) => self.h = h0_validated,
172 Err(status) => return Err(status),
173 }
174
175 self.t_old = self.t;
176 self.y_old = self.y;
177 self.h_old = self.h;
178
179 self.steps = 0;
180 self.n_accepted = 0;
181 self.status = Status::Initialized;
182 Ok(evals)
183 }
184
185 fn step<F>(&mut self, dde: &F) -> Result<Evals, Error<T, V>>
186 where
187 F: DDE<L, T, V, D>,
188 {
189 let mut evals = Evals::new();
190
191 if self.steps >= self.max_steps {
192 self.status = Status::Error(Error::MaxSteps {
193 t: self.t,
194 y: self.y,
195 });
196 return Err(Error::MaxSteps {
197 t: self.t,
198 y: self.y,
199 });
200 }
201
202 let t_current_step_start = self.t;
203 let y_current_step_start = self.y;
204 let k0_at_step_start = self.k[0]; let mut min_lag_abs = T::infinity();
207 if L > 0 {
208 let temp_y_for_lags = y_current_step_start + k0_at_step_start * self.h; dde.lags(
211 t_current_step_start + self.h,
212 &temp_y_for_lags,
213 &mut self.lags,
214 );
215 for i in 0..L {
216 min_lag_abs = min_lag_abs.min(self.lags[i].abs());
217 }
218 }
219
220 let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
221 5
222 } else {
223 1
224 };
225
226 let mut y_new_from_iter = y_current_step_start; let mut k_fnew_iter = V::zeros(); let mut k_stages_iter = [V::zeros(); 7]; let mut y_for_errit_prev_iter = y_current_step_start;
231 let mut iteration_failed_to_converge = false;
232
233 for iter_idx in 0..max_iter {
234 if iter_idx > 0 {
235 y_for_errit_prev_iter = y_new_from_iter; }
237
238 for j in 1..=5 {
241 let mut yi_stage_sum = k0_at_step_start * self.a[j][0];
243 for l_idx in 1..j {
244 yi_stage_sum += k_stages_iter[l_idx] * self.a[j][l_idx];
246 }
247 let yi = y_current_step_start + yi_stage_sum * self.h;
248 let ti = t_current_step_start + self.c[j] * self.h;
249
250 if L > 0 {
251 dde.lags(ti, &yi, &mut self.lags);
252 self.lagvals(ti, &yi);
253 }
254 dde.diff(ti, &yi, &self.yd, &mut k_stages_iter[j]);
255 }
256 evals.fcn += 5;
257
258 let mut ysti_sum = k0_at_step_start * self.a[6][0]; for l_idx in 2..=5 {
261 ysti_sum += k_stages_iter[l_idx] * self.a[6][l_idx];
263 }
264 let ysti = y_current_step_start + ysti_sum * self.h;
265 let t_sti = t_current_step_start + self.c[6] * self.h; if L > 0 {
268 dde.lags(t_sti, &ysti, &mut self.lags);
269 self.lagvals(t_sti, &ysti);
270 }
271 dde.diff(t_sti, &ysti, &self.yd, &mut k_stages_iter[6]);
272 evals.fcn += 1;
273
274 let mut sum_for_y_new = k0_at_step_start * self.b[0]; for l_idx in 2..=6 {
277 sum_for_y_new += k_stages_iter[l_idx] * self.b[l_idx];
279 }
280 y_new_from_iter = y_current_step_start + sum_for_y_new * self.h;
281
282 let t_new_val_for_k_fnew = t_current_step_start + self.h; if L > 0 {
285 dde.lags(t_new_val_for_k_fnew, &y_new_from_iter, &mut self.lags);
286 self.lagvals(t_new_val_for_k_fnew, &y_new_from_iter);
287 }
288 dde.diff(
289 t_new_val_for_k_fnew,
290 &y_new_from_iter,
291 &self.yd,
292 &mut k_fnew_iter,
293 );
294 evals.fcn += 1;
295
296 if max_iter > 1 && iter_idx > 0 {
297 let mut errit_val = T::zero();
298 let n_dim = y_current_step_start.len();
299 for i_dim in 0..n_dim {
300 let scale = self.atol
301 + self.rtol
302 * y_for_errit_prev_iter
303 .get(i_dim)
304 .abs()
305 .max(y_new_from_iter.get(i_dim).abs());
306 if scale > T::zero() {
307 let diff_val =
308 y_new_from_iter.get(i_dim) - y_for_errit_prev_iter.get(i_dim);
309 errit_val += (diff_val / scale).powi(2);
310 }
311 }
312 if n_dim > 0 {
313 errit_val = (errit_val / T::from_usize(n_dim).unwrap()).sqrt();
314 }
315
316 if errit_val <= self.rtol * T::from_f64(0.1).unwrap() {
317 break;
318 }
319 if iter_idx == max_iter - 1 {
320 iteration_failed_to_converge =
322 errit_val > self.rtol * T::from_f64(0.1).unwrap();
323 }
324 }
325 }
326
327 if iteration_failed_to_converge {
328 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * self.posneg;
329 if L > 0
330 && min_lag_abs > T::zero()
331 && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs
332 {
333 self.h = min_lag_abs * self.posneg; }
335 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
336 self.status = Status::RejectedStep;
337 self.k[0] = k0_at_step_start;
339 return Ok(evals);
340 }
341
342 let mut err_final = T::zero();
343 let n = y_current_step_start.len();
344 for i in 0..n {
345 let sk = self.atol
346 + self.rtol
347 * y_current_step_start
348 .get(i)
349 .abs()
350 .max(y_new_from_iter.get(i).abs());
351 let mut err_comp_sum = k0_at_step_start.get(i) * self.er[0];
354 for j in 2..=6 {
355 err_comp_sum += k_stages_iter[j].get(i) * self.er[j];
357 }
358 let erri = self.h * err_comp_sum;
359 if sk > T::zero() {
360 err_final += (erri / sk).powi(2);
361 }
362 }
363 if n > 0 {
364 err_final = (err_final / T::from_usize(n).unwrap()).sqrt();
365 }
366
367 self.fac11 = err_final.powf(self.expo1);
368 let fac_beta = if self.beta > T::zero() && self.facold > T::zero() {
369 self.facold.powf(self.beta)
370 } else {
371 T::one()
372 };
373 self.fac = self.fac11 / fac_beta;
374 self.fac = self.facc2.max(self.facc1.min(self.fac / self.safe));
375 let mut h_new_final = self.h / self.fac;
376
377 let t_new_val = t_current_step_start + self.h;
378
379 if err_final <= T::one() {
380 self.facold = err_final.max(T::from_f64(1.0e-4).unwrap());
381 self.n_accepted += 1;
382
383 let ydiff = y_new_from_iter - y_current_step_start;
385 let bspl = k0_at_step_start * self.h - ydiff;
386
387 self.cont[0] = y_current_step_start;
388 self.cont[1] = ydiff;
389 self.cont[2] = bspl;
390 self.cont[3] = ydiff - k_fnew_iter * self.h - bspl;
391
392 let mut d_sum = k0_at_step_start * self.d[0];
394 for j in 2..=6 {
395 d_sum += k_stages_iter[j] * self.d[j];
397 }
398 self.cont[4] = d_sum * self.h;
399
400 self.cont_buffer
401 .push_back((t_current_step_start, t_new_val, self.h, self.cont));
402
403 if let Some(max_delay_val) = self.max_delay {
404 let prune_time = if self.posneg > T::zero() {
405 t_new_val - max_delay_val
406 } else {
407 t_new_val + max_delay_val
408 };
409 while let Some((buf_t_start, buf_t_end, _, _)) = self.cont_buffer.front() {
410 if (self.posneg > T::zero() && *buf_t_end < prune_time)
411 || (self.posneg < T::zero() && *buf_t_start > prune_time)
412 {
413 self.cont_buffer.pop_front();
414 } else {
415 break;
416 }
417 }
418 }
419
420 self.y_old = y_current_step_start;
421 self.t_old = t_current_step_start;
422 self.h_old = self.h;
423
424 self.k[0] = k_fnew_iter; self.y = y_new_from_iter;
426 self.t = t_new_val;
427
428 if let Status::RejectedStep = self.status {
429 h_new_final = self.h_old.min(h_new_final);
430 self.status = Status::Solving;
431 }
432 } else {
433 h_new_final = self.h / self.facc1.min(self.fac11 / self.safe);
434 self.status = Status::RejectedStep;
435 }
436
437 self.steps += 1;
438 self.h = constrain_step_size(h_new_final, self.h_min, self.h_max);
439 Ok(evals)
440 }
441
442 fn t(&self) -> T {
443 self.t
444 }
445 fn y(&self) -> &V {
446 &self.y
447 }
448 fn t_prev(&self) -> T {
449 self.t_old
450 }
451 fn y_prev(&self) -> &V {
452 &self.y_old
453 }
454 fn h(&self) -> T {
455 self.h
456 }
457 fn set_h(&mut self, h: T) {
458 self.h = h;
459 }
460 fn status(&self) -> &Status<T, V, D> {
461 &self.status
462 }
463 fn set_status(&mut self, status: Status<T, V, D>) {
464 self.status = status;
465 }
466}
467
468impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> Interpolation<T, V>
469 for DOPRI5<L, T, V, H, D>
470{
471 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
472 if ((t_interp - self.t_old) * self.posneg < T::zero()
473 || (t_interp - self.t) * self.posneg > T::zero())
474 && (t_interp - self.t_old).abs() > T::default_epsilon()
475 && (t_interp - self.t).abs() > T::default_epsilon()
476 {
477 return Err(Error::OutOfBounds {
478 t_interp,
479 t_prev: self.t_old,
480 t_curr: self.t,
481 });
482 }
483
484 let s = if self.h_old == T::zero() {
485 if (t_interp - self.t_old).abs() < T::default_epsilon() {
486 T::zero()
487 } else {
488 T::one()
489 }
490 } else {
491 (t_interp - self.t_old) / self.h_old
492 };
493 let s1 = T::one() - s;
494
495 let y_interp = self.cont[0]
497 + (self.cont[1] + (self.cont[2] + (self.cont[3] + self.cont[4] * s1) * s) * s1) * s;
498 Ok(y_interp)
499 }
500}
501
502impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> DOPRI5<L, T, V, H, D> {
503 pub fn new() -> Self {
504 Self::default()
505 }
506
507 pub fn rtol(mut self, rtol: T) -> Self {
508 self.rtol = rtol;
509 self
510 }
511 pub fn atol(mut self, atol: T) -> Self {
512 self.atol = atol;
513 self
514 }
515 pub fn h0(mut self, h0: T) -> Self {
516 self.h0 = h0;
517 self
518 }
519 pub fn h_max(mut self, h_max: T) -> Self {
520 self.h_max = h_max;
521 self
522 }
523 pub fn h_min(mut self, h_min: T) -> Self {
524 self.h_min = h_min;
525 self
526 }
527 pub fn max_steps(mut self, max_steps: usize) -> Self {
528 self.max_steps = max_steps;
529 self
530 }
531 pub fn safe(mut self, safe: T) -> Self {
532 self.safe = safe;
533 self
534 }
535 pub fn fac1(mut self, fac1: T) -> Self {
536 self.fac1 = fac1;
537 self.facc1 = T::one() / fac1;
538 self
539 }
540 pub fn fac2(mut self, fac2: T) -> Self {
541 self.fac2 = fac2;
542 self.facc2 = T::one() / fac2;
543 self
544 }
545 pub fn beta(mut self, beta: T) -> Self {
546 self.beta = beta;
547 self
548 }
549 pub fn max_delay(mut self, max_delay: T) -> Self {
550 self.max_delay = Some(max_delay.abs());
551 self
552 }
553
554 fn lagvals(&mut self, t_stage: T, _y_stage: &V) {
555 for i in 0..L {
556 let t_delayed = t_stage - self.lags[i];
557 if (t_delayed - self.t0) * self.posneg <= T::default_epsilon() {
558 self.yd[i] = (self.phi.as_ref().unwrap())(t_delayed);
559 } else {
560 let mut found_in_buffer = false;
561 for (buf_t_start, buf_t_end, buf_h, buf_cont) in self.cont_buffer.iter().rev() {
562 if (t_delayed - *buf_t_start) * self.posneg >= -T::default_epsilon()
563 && (t_delayed - *buf_t_end) * self.posneg <= T::default_epsilon()
564 {
565 let s = if *buf_h == T::zero() {
566 if (t_delayed - *buf_t_start).abs() < T::default_epsilon() {
567 T::zero()
568 } else {
569 T::one()
570 }
571 } else {
572 (t_delayed - *buf_t_start) / *buf_h
573 };
574 self.yd[i] = buf_cont[0]
575 + (buf_cont[1]
576 + (buf_cont[2] + (buf_cont[3] + buf_cont[4] * (T::one() - s)) * s)
577 * (T::one() - s))
578 * s;
579 found_in_buffer = true;
580 break;
581 }
582 }
583 if !found_in_buffer {
584 if let Some((buf_t_start, _buf_t_end, buf_h, buf_cont)) =
585 self.cont_buffer.back()
586 {
587 let s = if *buf_h == T::zero() {
588 T::one()
589 } else {
590 (t_delayed - *buf_t_start) / *buf_h
591 };
592 self.yd[i] = buf_cont[0]
593 + (buf_cont[1]
594 + (buf_cont[2] + (buf_cont[3] + buf_cont[4] * (T::one() - s)) * s)
595 * (T::one() - s))
596 * s;
597 } else {
598 self.yd[i] = (self.phi.as_ref().unwrap())(t_delayed);
599 }
600 }
601 }
602 }
603 }
604}
605
606impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData> Default
607 for DOPRI5<L, T, V, H, D>
608{
609 fn default() -> Self {
610 let a_conv = DOPRI5_A.map(|row| row.map(|x| T::from_f64(x).unwrap()));
612 let b_conv = DOPRI5_B.map(|x| T::from_f64(x).unwrap());
613 let c_conv = DOPRI5_C.map(|x| T::from_f64(x).unwrap());
614 let er_conv = DOPRI5_E.map(|x| T::from_f64(x).unwrap());
615 let d_conv = DOPRI5_D.map(|x| T::from_f64(x).unwrap());
616
617 let expo1_final = T::one() / T::from_f64(5.0).unwrap();
618
619 let fac1_default = T::from_f64(0.2).unwrap();
620 let fac2_default = T::from_f64(10.0).unwrap();
621 let beta_default = T::from_f64(0.04).unwrap();
622
623 DOPRI5 {
624 t: T::zero(),
625 y: V::zeros(),
626 h: T::zero(),
627 h0: T::zero(),
628 rtol: T::from_f64(1e-3).unwrap(),
629 atol: T::from_f64(1e-6).unwrap(),
630 h_max: T::infinity(),
631 h_min: T::zero(),
632 max_steps: 100_000,
633 safe: T::from_f64(0.9).unwrap(),
634 fac1: fac1_default,
635 fac2: fac2_default,
636 beta: beta_default,
637 max_delay: None,
638 expo1: expo1_final,
639 facc1: T::one() / fac1_default,
640 facc2: T::one() / fac2_default,
641 facold: T::from_f64(1.0e-4).unwrap(),
642 fac11: T::zero(),
643 fac: T::zero(),
644 status: Status::Uninitialized,
645 steps: 0,
646 n_accepted: 0,
647 a: a_conv,
648 b: b_conv,
649 c: c_conv,
650 er: er_conv,
651 d: d_conv,
652 k: [V::zeros(); 7],
653 y_old: V::zeros(),
654 t_old: T::zero(),
655 h_old: T::zero(),
656 cont: [V::zeros(); 5], cont_buffer: VecDeque::new(),
658 phi: None,
659 t0: T::zero(),
660 tf: T::zero(),
661 posneg: T::zero(),
662 lags: [T::zero(); L],
663 yd: [V::zeros(); L],
664 }
665 }
666}
667
668const DOPRI5_A: [[f64; 7]; 7] = [
672 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
673 [0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
674 [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0, 0.0, 0.0],
675 [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0, 0.0, 0.0],
676 [
677 19372.0 / 6561.0,
678 -25360.0 / 2187.0,
679 64448.0 / 6561.0,
680 -212.0 / 729.0,
681 0.0,
682 0.0,
683 0.0,
684 ],
685 [
686 9017.0 / 3168.0,
687 -355.0 / 33.0,
688 46732.0 / 5247.0,
689 49.0 / 176.0,
690 -5103.0 / 18656.0,
691 0.0,
692 0.0,
693 ],
694 [
695 35.0 / 384.0,
697 0.0, 500.0 / 1113.0,
699 125.0 / 192.0,
700 -2187.0 / 6784.0,
701 11.0 / 84.0,
702 0.0,
703 ],
704];
705
706const DOPRI5_C: [f64; 7] = [
708 0.0, 0.2, 0.3, 0.8, 8.0 / 9.0, 1.0, 1.0, ];
716
717const DOPRI5_B: [f64; 7] = [
719 35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0, 0.0, ];
727
728const DOPRI5_E: [f64; 7] = [
730 71.0 / 57600.0, 0.0, -71.0 / 16695.0, 71.0 / 1920.0, -17253.0 / 339200.0, 22.0 / 525.0, -1.0 / 40.0, ];
738
739const DOPRI5_D: [f64; 7] = [
741 -12715105075.0 / 11282082432.0, 0.0, 87487479700.0 / 32700410799.0, -10690763975.0 / 1880347072.0, 701980252875.0 / 199316789632.0, -1453857185.0 / 822651844.0, 69997945.0 / 29380423.0, ];