1use std::collections::VecDeque;
4
5use crate::{
6 dde::{DDE, DelayNumericalMethod},
7 error::Error,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 methods::{Delay, DormandPrince, ExplicitRungeKutta, h_init::InitialStepSize},
10 stats::Evals,
11 status::Status,
12 traits::{Real, State},
13 utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17 const L: usize,
18 T: Real,
19 Y: State<T>,
20 H: Fn(T) -> Y,
21 const O: usize,
22 const S: usize,
23 const I: usize,
24> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
25{
26 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
27 where
28 F: DDE<L, T, Y>,
29 {
30 let mut evals = Evals::new();
31
32 if L <= 0 {
34 return Err(Error::NoLags);
35 }
36
37 self.t0 = t0;
39 self.t = t0;
40 self.y = *y0;
41 self.t_prev = self.t;
42 self.y_prev = self.y;
43 self.status = Status::Initialized;
44 self.steps = 0;
45 self.stiffness_counter = 0;
46 self.non_stiffness_counter = 0;
47 self.history = VecDeque::new();
48
49 let mut delays = [T::zero(); L];
51 let mut y_delayed = [Y::zeros(); L];
52
53 dde.lags(self.t, &self.y, &mut delays);
55 for i in 0..L {
56 let t_delayed = self.t - delays[i];
57 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
59 return Err(Error::BadInput {
60 msg: format!("Delayed time {} is beyond initial time {}", t_delayed, t0),
61 });
62 }
63 y_delayed[i] = phi(t_delayed);
64 }
65
66 dde.diff(self.t, &self.y, &y_delayed, &mut self.k[0]);
68 self.dydt = self.k[0];
69 evals.function += 1;
70 self.dydt_prev = self.dydt;
71
72 self.history.push_back((self.t, self.y, self.dydt));
74
75 if self.h0 == T::zero() {
77 self.h0 = InitialStepSize::<Delay>::compute(
78 dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
79 &self.k[0], &mut evals,
80 );
81 }
82
83 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
85 Ok(h0) => self.h = h0,
86 Err(status) => return Err(status),
87 }
88 Ok(evals)
89 }
90
91 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
92 where
93 F: DDE<L, T, Y>,
94 {
95 let mut evals = Evals::new();
96
97 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
99 self.status = Status::Error(Error::StepSize {
100 t: self.t,
101 y: self.y,
102 });
103 return Err(Error::StepSize {
104 t: self.t,
105 y: self.y,
106 });
107 }
108
109 if self.steps >= self.max_steps {
111 self.status = Status::Error(Error::MaxSteps {
112 t: self.t,
113 y: self.y,
114 });
115 return Err(Error::MaxSteps {
116 t: self.t,
117 y: self.y,
118 });
119 }
120 self.steps += 1;
121
122 let mut delays = [T::zero(); L];
124 let mut y_delayed = [Y::zeros(); L];
125
126 let mut min_delay_abs = T::infinity();
128 let y_pred_for_lags = self.y + self.k[0] * self.h;
130 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
131 for i in 0..L {
132 min_delay_abs = min_delay_abs.min(delays[i].abs());
133 }
134
135 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
137 5
138 } else {
139 1
140 };
141 let mut y_next_est = self.y;
142 let mut y_next_est_prev = self.y;
143 let mut dde_iter_failed = false;
144 let mut err_norm: T = T::zero();
145 let mut y_last_stage = Y::zeros();
146
147 for it in 0..max_iter {
149 if it > 0 {
150 y_next_est_prev = y_next_est;
151 }
152
153 let mut y_stage = Y::zeros();
155 for i in 1..self.stages {
156 y_stage = Y::zeros();
157 for j in 0..i {
158 y_stage += self.k[j] * self.a[i][j];
159 }
160 y_stage = self.y + y_stage * self.h;
161
162 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
164 if let Err(e) =
165 self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
166 {
167 self.status = Status::Error(e.clone());
168 return Err(e);
169 }
170 dde.diff(
171 self.t + self.c[i] * self.h,
172 &y_stage,
173 &y_delayed,
174 &mut self.k[i],
175 );
176 }
177 evals.function += self.stages - 1;
178
179 y_last_stage = y_stage;
181
182 let mut yseg = Y::zeros();
184 for i in 0..self.stages {
185 yseg += self.k[i] * self.b[i];
186 }
187
188 let y_new = self.y + yseg * self.h;
189
190 let er = self.er.unwrap();
192 let n = self.y.len();
193 let mut err_val = T::zero();
194 let mut err2 = T::zero();
195 let mut erri;
196 for i in 0..n {
197 let sk = self.atol[i] + self.rtol[i] * self.y.get(i).abs().max(y_new.get(i).abs());
199
200 erri = T::zero();
202 for j in 0..self.stages {
203 erri += er[j] * self.k[j].get(i);
204 }
205 err_val += (erri / sk).powi(2);
206
207 if let Some(bh) = &self.bh {
209 erri = yseg.get(i);
210 for j in 0..self.stages {
211 erri -= bh[j] * self.k[j].get(i);
212 }
213 err2 += (erri / sk).powi(2);
214 }
215 }
216 let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
217 if deno <= T::zero() {
218 deno = T::one();
219 }
220 err_norm =
221 self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
222
223 if max_iter > 1 && it > 0 {
225 let mut dde_iteration_error = T::zero();
226 let n_dim = self.y.len();
227 for i_dim in 0..n_dim {
228 let scale = self.atol[i_dim]
229 + self.rtol[i_dim]
230 * y_next_est_prev.get(i_dim).abs().max(y_new.get(i_dim).abs());
231 if scale > T::zero() {
232 let diff_val = y_new.get(i_dim) - y_next_est_prev.get(i_dim);
233 dde_iteration_error += (diff_val / scale).powi(2);
234 }
235 }
236 if n_dim > 0 {
237 dde_iteration_error =
238 (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
239 }
240
241 if dde_iteration_error <= self.rtol.average() * T::from_f64(0.1).unwrap() {
242 break;
243 }
244 if it == max_iter - 1 {
245 dde_iter_failed =
246 dde_iteration_error > self.rtol.average() * T::from_f64(0.1).unwrap();
247 }
248 }
249 y_next_est = y_new;
250 }
251
252 if dde_iter_failed {
254 let sign = self.h.signum();
255 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
256 if L > 0
257 && min_delay_abs > T::zero()
258 && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
259 {
260 self.h = min_delay_abs * sign;
261 }
262 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
263 self.status = Status::RejectedStep;
264 return Ok(evals);
265 }
266
267 let order = T::from_usize(self.order).unwrap();
269 let error_exponent = T::one() / order;
270 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
271
272 scale = scale.max(self.min_scale).min(self.max_scale);
274
275 if err_norm <= T::one() {
277 let y_new = y_next_est;
278 let t_new = self.t + self.h;
279
280 dde.lags(t_new, &y_new, &mut delays);
282 if let Err(e) = self.lagvals(t_new, &delays, &mut y_delayed, phi) {
283 self.status = Status::Error(e.clone());
284 return Err(e);
285 }
286 dde.diff(t_new, &y_new, &y_delayed, &mut self.dydt);
287 evals.function += 1;
288 let n_stiff_threshold = 100;
290 if self.steps % n_stiff_threshold == 0 {
291 let mut stdnum = T::zero();
292 let mut stden = T::zero();
293 let sqr = {
294 let mut yseg = Y::zeros();
295 for i in 0..self.stages {
296 yseg += self.k[i] * self.b[i];
297 }
298 yseg - self.k[S - 1]
299 };
300 for i in 0..sqr.len() {
301 stdnum += sqr.get(i).powi(2);
302 }
303 let sqr = self.dydt - y_last_stage;
304 for i in 0..sqr.len() {
305 stden += sqr.get(i).powi(2);
306 }
307
308 if stden > T::zero() {
309 let h_lamb = self.h * (stdnum / stden).sqrt();
310 if h_lamb > T::from_f64(6.1).unwrap() {
311 self.non_stiffness_counter = 0;
312 self.stiffness_counter += 1;
313 if self.stiffness_counter == 15 {
314 self.status = Status::Error(Error::Stiffness {
315 t: self.t,
316 y: self.y,
317 });
318 return Err(Error::Stiffness {
319 t: self.t,
320 y: self.y,
321 });
322 }
323 }
324 } else {
325 self.non_stiffness_counter += 1;
326 if self.non_stiffness_counter == 6 {
327 self.stiffness_counter = 0;
328 }
329 }
330 }
331
332 self.cont[0] = self.y;
334 let ydiff = y_new - self.y;
335 self.cont[1] = ydiff;
336 let bspl = self.k[0] * self.h - ydiff;
337 self.cont[2] = bspl;
338 self.cont[3] = ydiff - self.dydt * self.h - bspl;
339
340 if let Some(bi) = &self.bi {
342 if I > S {
343 self.k[self.stages] = self.dydt;
344 for i in S + 1..I {
345 let mut y_stage = Y::zeros();
346 for j in 0..i {
347 y_stage += self.k[j] * self.a[i][j];
348 }
349 y_stage = self.y + y_stage * self.h;
350
351 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
352 for lag_idx in 0..L {
353 let t_delayed = (self.t + self.c[i] * self.h) - delays[lag_idx];
354
355 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
356 y_delayed[lag_idx] = phi(t_delayed);
357 } else if (t_delayed - self.t_prev) * self.h.signum()
358 > T::default_epsilon()
359 {
360 if self.bi.is_some() {
361 let theta = (t_delayed - self.t_prev) / self.h_prev;
362 let one_minus_theta = T::one() - theta;
363 let ilast = self.cont.len() - 1;
364 let poly =
365 (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {
366 let factor = if cont_i >= 4 {
367 if (ilast - cont_i) % 2 == 1 {
368 one_minus_theta
369 } else {
370 theta
371 }
372 } else if cont_i % 2 == 1 {
373 one_minus_theta
374 } else {
375 theta
376 };
377 acc * factor + self.cont[cont_i]
378 });
379 y_delayed[lag_idx] = self.cont[0] + poly * theta;
380 } else {
381 y_delayed[lag_idx] = cubic_hermite_interpolate(
382 self.t_prev,
383 self.t,
384 &self.y_prev,
385 &self.y,
386 &self.dydt_prev,
387 &self.dydt,
388 t_delayed,
389 );
390 }
391 } else {
392 let mut found_interpolation = false;
393 let buffer = &self.history;
394 let mut buffer_iter = buffer.iter();
395 if let Some(mut prev_entry) = buffer_iter.next() {
396 for curr_entry in buffer_iter {
397 let (t_left, y_left, dydt_left) = prev_entry;
398 let (t_right, y_right, dydt_right) = curr_entry;
399
400 let is_between = if self.h.signum() > T::zero() {
401 *t_left <= t_delayed && t_delayed <= *t_right
402 } else {
403 *t_right <= t_delayed && t_delayed <= *t_left
404 };
405
406 if is_between {
407 y_delayed[lag_idx] = cubic_hermite_interpolate(
408 *t_left, *t_right, y_left, y_right, dydt_left,
409 dydt_right, t_delayed,
410 );
411 found_interpolation = true;
412 break;
413 }
414 prev_entry = curr_entry;
415 }
416 }
417 if !found_interpolation {
418 return Err(Error::InsufficientHistory {
419 t_delayed,
420 t_prev: self.t_prev,
421 t_curr: self.t,
422 });
423 }
424 }
425 }
426 dde.diff(
427 self.t + self.c[i] * self.h,
428 &y_stage,
429 &y_delayed,
430 &mut self.k[i],
431 );
432 evals.function += 1;
433 }
434 }
435
436 for i in 4..self.order {
438 self.cont[i] = Y::zeros();
439 for j in 0..self.dense_stages {
440 self.cont[i] += self.k[j] * bi[i][j];
441 }
442 self.cont[i] = self.cont[i] * self.h;
443 }
444 }
445
446 self.t_prev = self.t;
448 self.y_prev = self.y;
449 self.dydt_prev = self.k[0];
450 self.h_prev = self.h;
451
452 self.t = t_new;
454 self.y = y_new;
455 self.k[0] = self.dydt;
456
457 self.history.push_back((self.t, self.y, self.dydt));
459 if let Some(max_delay) = self.max_delay {
460 let cutoff_time = self.t - max_delay;
461 while let Some((t_front, _, _)) = self.history.get(1) {
462 if *t_front < cutoff_time {
463 self.history.pop_front();
464 } else {
465 break;
466 }
467 }
468 }
469 if let Status::RejectedStep = self.status {
470 self.status = Status::Solving;
471 scale = scale.min(T::one());
472 }
473 } else {
474 self.status = Status::RejectedStep;
476 }
477
478 self.h *= scale;
480 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
482
483 Ok(evals)
484 }
485
486 fn t(&self) -> T {
487 self.t
488 }
489 fn y(&self) -> &Y {
490 &self.y
491 }
492 fn t_prev(&self) -> T {
493 self.t_prev
494 }
495 fn y_prev(&self) -> &Y {
496 &self.y_prev
497 }
498 fn h(&self) -> T {
499 self.h
500 }
501 fn set_h(&mut self, h: T) {
502 self.h = h;
503 }
504 fn status(&self) -> &Status<T, Y> {
505 &self.status
506 }
507 fn set_status(&mut self, status: Status<T, Y>) {
508 self.status = status;
509 }
510}
511
512impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
513 ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
514{
515 fn lagvals<const L: usize, H>(
516 &mut self,
517 t_stage: T,
518 lags: &[T; L],
519 yd: &mut [Y; L],
520 phi: &H,
521 ) -> Result<(), Error<T, Y>>
522 where
523 H: Fn(T) -> Y,
524 {
525 for i in 0..L {
526 let t_delayed = t_stage - lags[i];
527
528 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
530 yd[i] = phi(t_delayed);
531 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
533 if self.bi.is_some() {
534 let theta = (t_delayed - self.t_prev) / self.h_prev;
535 let one_minus_theta = T::one() - theta;
536
537 let ilast = self.cont.len() - 1;
539 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
540 let factor = if i >= 4 {
541 if (ilast - i) % 2 == 1 {
542 one_minus_theta
543 } else {
544 theta
545 }
546 } else if i % 2 == 1 {
547 one_minus_theta
548 } else {
549 theta
550 };
551 acc * factor + self.cont[i]
552 });
553
554 let y_interp = self.cont[0] + poly * theta;
556 yd[i] = y_interp;
557 } else {
558 yd[i] = cubic_hermite_interpolate(
559 self.t_prev,
560 self.t,
561 &self.y_prev,
562 &self.y,
563 &self.dydt_prev,
564 &self.dydt,
565 t_delayed,
566 );
567 }
568 } else {
570 let mut found_interpolation = false;
572 let buffer = &self.history;
573 let mut buffer_iter = buffer.iter();
575 if let Some(mut prev_entry) = buffer_iter.next() {
576 for curr_entry in buffer_iter {
577 let (t_left, y_left, dydt_left) = prev_entry;
578 let (t_right, y_right, dydt_right) = curr_entry;
579
580 let is_between = if self.h.signum() > T::zero() {
582 *t_left <= t_delayed && t_delayed <= *t_right
583 } else {
584 *t_right <= t_delayed && t_delayed <= *t_left
585 };
586
587 if is_between {
588 yd[i] = cubic_hermite_interpolate(
589 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
590 t_delayed,
591 );
592 found_interpolation = true;
593 break;
594 }
595 prev_entry = curr_entry;
596 }
597 }
598 if !found_interpolation {
600 return Err(Error::InsufficientHistory {
601 t_delayed,
602 t_prev: self.t_prev,
603 t_curr: self.t,
604 });
605 }
606 }
607 }
608 Ok(())
609 }
610}
611
612impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
613 for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
614{
615 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
616 let dir = (self.t - self.t_prev).signum();
618 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
619 return Err(Error::OutOfBounds {
620 t_interp,
621 t_prev: self.t_prev,
622 t_curr: self.t,
623 });
624 }
625
626 let theta = (t_interp - self.t_prev) / self.h_prev;
628 let one_minus_theta = T::one() - theta;
629
630 let ilast = self.cont.len() - 1;
632 let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
633 let factor = if i >= 4 {
634 if (ilast - i) % 2 == 1 {
635 one_minus_theta
636 } else {
637 theta
638 }
639 } else if i % 2 == 1 {
640 one_minus_theta
641 } else {
642 theta
643 };
644 acc * factor + self.cont[i]
645 });
646
647 let y_interp = self.cont[0] + poly * theta;
649
650 Ok(y_interp)
651 }
652}