1use std::collections::VecDeque;
3
4use crate::{
5 dde::{DDE, DelayNumericalMethod},
6 error::Error,
7 interpolate::{Interpolation, cubic_hermite_interpolate},
8 methods::{Delay, DormandPrince, ExplicitRungeKutta, h_init::InitialStepSize},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15impl<
16 const L: usize,
17 T: Real,
18 Y: State<T>,
19 H: Fn(T) -> Y,
20 const O: usize,
21 const S: usize,
22 const I: usize,
23> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
24{
25 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
26 where
27 F: DDE<L, T, Y> + ?Sized,
28 {
29 let mut evals = Evals::new();
30
31 if L == 0 {
33 return Err(Error::NoLags);
34 }
35
36 self.t0 = t0;
38 self.t = t0;
39 self.y = y0.clone();
40 self.dydt = y0.zeros_like();
41 self.y_prev = y0.clone();
42 self.dydt_prev = y0.zeros_like();
43 self.k = core::array::from_fn(|_| y0.zeros_like());
44 self.cont = core::array::from_fn(|_| y0.zeros_like());
45 self.t_prev = self.t;
46 self.y_prev = self.y.clone();
47 self.status = Status::Initialized;
48 self.steps = 0;
49 self.stiffness_counter = 0;
50 self.non_stiffness_counter = 0;
51 self.history = VecDeque::new();
52
53 let mut delays = [T::zero(); L];
55 let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
56
57 dde.lags(self.t, &self.y, &mut delays);
59 for i in 0..L {
60 let t_delayed = self.t - delays[i];
61 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
63 return Err(Error::BadInput {
64 msg: format!("Delayed time {} is beyond initial time {}", t_delayed, t0),
65 });
66 }
67 y_delayed[i] = phi(t_delayed);
68 }
69
70 dde.diff(self.t, &self.y, &y_delayed, &mut self.k[0]);
72 self.dydt = self.k[0].clone();
73 evals.function += 1;
74 self.dydt_prev = self.dydt.clone();
75
76 self.history
78 .push_back((self.t, self.y.clone(), self.dydt.clone()));
79
80 if self.h0 == T::zero() {
82 self.h0 = InitialStepSize::<Delay>::compute(
83 dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
84 &self.k[0], &mut evals,
85 );
86 }
87
88 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
90 Ok(h0) => self.h = (self.filter)(h0),
91 Err(status) => return Err(status),
92 }
93 Ok(evals)
94 }
95
96 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
97 where
98 F: DDE<L, T, Y> + ?Sized,
99 {
100 let mut evals = Evals::new();
101
102 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
104 self.status = Status::Error(Error::StepSize {
105 t: self.t,
106 y: self.y.clone(),
107 });
108 return Err(Error::StepSize {
109 t: self.t,
110 y: self.y.clone(),
111 });
112 }
113
114 if self.steps >= self.max_steps {
116 self.status = Status::Error(Error::MaxSteps {
117 t: self.t,
118 y: self.y.clone(),
119 });
120 return Err(Error::MaxSteps {
121 t: self.t,
122 y: self.y.clone(),
123 });
124 }
125 self.steps += 1;
126
127 let mut delays = [T::zero(); L];
129 let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
130
131 let mut min_delay_abs = T::infinity();
133 let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
135 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
136 for i in 0..L {
137 min_delay_abs = min_delay_abs.min(delays[i].abs());
138 }
139
140 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
142 5
143 } else {
144 1
145 };
146 let mut y_next_est = self.y.clone();
147 let mut y_next_est_prev = self.y.clone();
148 let mut dde_iter_failed = false;
149 let mut err_norm: T = T::zero();
150 let mut y_last_stage = self.y.zeros_like();
151
152 for it in 0..max_iter {
154 if it > 0 {
155 y_next_est_prev = y_next_est.clone();
156 }
157
158 let mut y_stage = self.y.zeros_like();
160 for i in 1..self.stages {
161 y_stage = self.y.clone();
162 for j in 0..i {
163 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
164 }
165
166 dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
168 if let Err(e) =
169 self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
170 {
171 self.status = Status::Error(e.clone());
172 return Err(e);
173 }
174 dde.diff(
175 self.t + self.c[i] * self.h,
176 &y_stage,
177 &y_delayed,
178 &mut self.k[i],
179 );
180 }
181 evals.function += self.stages - 1;
182
183 y_last_stage = y_stage.clone();
185
186 let mut yseg = self.y.zeros_like();
188 for i in 0..self.stages {
189 yseg.add_scaled(self.b[i], &self.k[i]);
190 }
191
192 let y_new = self.y.plus_scaled(self.h, &yseg);
193
194 let er = self.er.unwrap();
196 let n = self.y.len();
197 let mut err2 = T::zero();
198 let mut err_state = self.y.zeros_like();
199 for (j, coefficient) in er.iter().enumerate().take(self.stages) {
200 err_state.add_scaled(*coefficient, &self.k[j]);
201 }
202 let err_val = self
203 .y
204 .error_norm(&y_new, &err_state, &self.atol, &self.rtol);
205
206 if let Some(bh) = &self.bh {
207 let mut err2_state = yseg.clone();
208 for (j, coefficient) in bh.iter().enumerate().take(self.stages) {
209 err2_state.add_scaled(-*coefficient, &self.k[j]);
210 }
211 err2 = self
212 .y
213 .error_norm(&y_new, &err2_state, &self.atol, &self.rtol);
214 }
215 let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
216 if deno <= T::zero() {
217 deno = T::one();
218 }
219 err_norm =
220 self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
221
222 if max_iter > 1 && it > 0 {
224 let n_dim = self.y.len();
225 let iter_diff = y_new.minus(&y_next_est_prev);
226 let mut dde_iteration_error =
227 y_next_est_prev.error_norm(&y_new, &iter_diff, &self.atol, &self.rtol);
228 if n_dim > 0 {
229 dde_iteration_error =
230 (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
231 }
232
233 if dde_iteration_error <= self.rtol.average() * T::from_f64(0.1).unwrap() {
234 break;
235 }
236 if it == max_iter - 1 {
237 dde_iter_failed =
238 dde_iteration_error > self.rtol.average() * T::from_f64(0.1).unwrap();
239 }
240 }
241 y_next_est = y_new.clone();
242 }
243
244 if dde_iter_failed {
246 let sign = self.h.signum();
247 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
248 if L > 0
249 && min_delay_abs > T::zero()
250 && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
251 {
252 self.h = min_delay_abs * sign;
253 }
254 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
255 self.h = (self.filter)(self.h);
256 self.status = Status::RejectedStep;
257 return Ok(evals);
258 }
259
260 let order = T::from_usize(self.order).unwrap();
262 let error_exponent = T::one() / order;
263 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
264
265 scale = scale.max(self.min_scale).min(self.max_scale);
267
268 if err_norm <= T::one() {
270 let y_new = y_next_est.clone();
271 let t_new = self.t + self.h;
272
273 dde.lags(t_new, &y_new, &mut delays);
275 if let Err(e) = self.lagvals(t_new, &delays, &mut y_delayed, phi) {
276 self.status = Status::Error(e.clone());
277 return Err(e);
278 }
279 dde.diff(t_new, &y_new, &y_delayed, &mut self.dydt);
280 evals.function += 1;
281 let n_stiff_threshold = 100;
283 if self.steps.is_multiple_of(n_stiff_threshold) {
284 let mut yseg = self.y.zeros_like();
285 for i in 0..self.stages {
286 yseg.add_scaled(self.b[i], &self.k[i]);
287 }
288 let stdnum = yseg.diff_norm_squared(&self.k[S - 1]);
289 let stden = self.dydt.diff_norm_squared(&y_last_stage);
290
291 if stden > T::zero() {
292 let h_lamb = self.h * (stdnum / stden).sqrt();
293 if h_lamb > T::from_f64(6.1).unwrap() {
294 self.non_stiffness_counter = 0;
295 self.stiffness_counter += 1;
296 if self.stiffness_counter == 15 {
297 self.status = Status::Error(Error::Stiffness {
298 t: self.t,
299 y: self.y.clone(),
300 });
301 return Err(Error::Stiffness {
302 t: self.t,
303 y: self.y.clone(),
304 });
305 }
306 }
307 } else {
308 self.non_stiffness_counter += 1;
309 if self.non_stiffness_counter == 6 {
310 self.stiffness_counter = 0;
311 }
312 }
313 }
314
315 self.cont[0] = self.y.clone();
317 let ydiff = y_new.minus(&self.y);
318 self.cont[1] = ydiff.clone();
319 let mut bspl = ydiff.zeros_like();
320 bspl.add_scaled(self.h, &self.k[0]);
321 bspl.add_scaled(-T::one(), &ydiff);
322 self.cont[2] = bspl.clone();
323 let mut cont3 = ydiff;
324 cont3.add_scaled(-self.h, &self.dydt);
325 cont3.add_scaled(-T::one(), &bspl);
326 self.cont[3] = cont3;
327
328 if self.bi.is_some() {
330 if I > S {
331 self.k[self.stages] = self.dydt.clone();
332 for i in S + 1..I {
333 let mut y_stage = self.y.clone();
334 for j in 0..i {
335 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
336 }
337
338 let t_stage = self.t + self.c[i] * self.h;
339 dde.lags(t_stage, &y_stage, &mut delays);
340 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
341 self.status = Status::Error(e.clone());
342 return Err(e);
343 }
344 dde.diff(t_stage, &y_stage, &y_delayed, &mut self.k[i]);
345 evals.function += 1;
346 }
347 }
348
349 for i in 4..self.order {
351 self.cont[i].fill(T::zero());
352 for j in 0..self.dense_stages {
353 let bi = self.bi.as_ref().expect("dense output coefficients checked");
354 self.cont[i].add_scaled(bi[i][j], &self.k[j]);
355 }
356 self.cont[i].scale_by(self.h);
357 }
358 }
359
360 self.t_prev = self.t;
362 self.y_prev = self.y.clone();
363 self.dydt_prev = self.k[0].clone();
364 self.h_prev = self.h;
365
366 self.t = t_new;
368 self.y = y_new;
369 self.k[0] = self.dydt.clone();
370
371 self.history
373 .push_back((self.t, self.y.clone(), self.dydt.clone()));
374 if let Some(max_delay) = self.max_delay {
375 let cutoff_time = self.t - max_delay;
376 while let Some((t_front, _, _)) = self.history.get(1) {
377 if *t_front < cutoff_time {
378 self.history.pop_front();
379 } else {
380 break;
381 }
382 }
383 }
384 if let Status::RejectedStep = self.status {
385 self.status = Status::Solving;
386 scale = scale.min(T::one());
387 }
388 } else {
389 self.status = Status::RejectedStep;
391 }
392
393 self.h *= scale;
395 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
397 self.h = (self.filter)(self.h);
399
400 Ok(evals)
401 }
402
403 fn t(&self) -> T {
404 self.t
405 }
406 fn y(&self) -> &Y {
407 &self.y
408 }
409 fn t_prev(&self) -> T {
410 self.t_prev
411 }
412 fn y_prev(&self) -> &Y {
413 &self.y_prev
414 }
415 fn h(&self) -> T {
416 self.h
417 }
418 fn set_h(&mut self, h: T) {
419 self.h = (self.filter)(h);
420 }
421 fn status(&self) -> &Status<T, Y> {
422 &self.status
423 }
424 fn set_status(&mut self, status: Status<T, Y>) {
425 self.status = status;
426 }
427}
428
429impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
430 ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
431{
432 fn lagvals<const L: usize, H>(
433 &mut self,
434 t_stage: T,
435 lags: &[T; L],
436 yd: &mut [Y; L],
437 phi: &H,
438 ) -> Result<(), Error<T, Y>>
439 where
440 H: Fn(T) -> Y,
441 {
442 for i in 0..L {
443 let t_delayed = t_stage - lags[i];
444
445 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
447 yd[i] = phi(t_delayed);
448 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
450 if self.bi.is_some() {
451 let theta = (t_delayed - self.t_prev) / self.h_prev;
452 let one_minus_theta = T::one() - theta;
453
454 let ilast = self.cont.len() - 1;
456 let poly = (1..ilast)
457 .rev()
458 .fold(self.cont[ilast].clone(), |mut acc, i| {
459 let factor = if i >= 4 {
460 if (ilast - i) % 2 == 1 {
461 one_minus_theta
462 } else {
463 theta
464 }
465 } else if i % 2 == 1 {
466 one_minus_theta
467 } else {
468 theta
469 };
470 acc.scale_by(factor);
471 acc.add_scaled(T::one(), &self.cont[i]);
472 acc
473 });
474
475 let y_interp = self.cont[0].plus_scaled(theta, &poly);
477 yd[i] = y_interp;
478 } else {
479 yd[i] = cubic_hermite_interpolate(
480 self.t_prev,
481 self.t,
482 &self.y_prev,
483 &self.y,
484 &self.dydt_prev,
485 &self.dydt,
486 t_delayed,
487 );
488 }
489 } else {
491 let mut found_interpolation = false;
493 let buffer = &self.history;
494 let mut buffer_iter = buffer.iter();
496 if let Some(mut prev_entry) = buffer_iter.next() {
497 for curr_entry in buffer_iter {
498 let (t_left, y_left, dydt_left) = prev_entry;
499 let (t_right, y_right, dydt_right) = curr_entry;
500
501 let is_between = if self.h.signum() > T::zero() {
503 *t_left <= t_delayed && t_delayed <= *t_right
504 } else {
505 *t_right <= t_delayed && t_delayed <= *t_left
506 };
507
508 if is_between {
509 yd[i] = cubic_hermite_interpolate(
510 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
511 t_delayed,
512 );
513 found_interpolation = true;
514 break;
515 }
516 prev_entry = curr_entry;
517 }
518 }
519 if !found_interpolation {
521 return Err(Error::InsufficientHistory {
522 t_delayed,
523 t_prev: self.t_prev,
524 t_curr: self.t,
525 });
526 }
527 }
528 }
529 Ok(())
530 }
531}
532
533impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
534 for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
535{
536 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
537 let dir = (self.t - self.t_prev).signum();
539 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
540 return Err(Error::OutOfBounds {
541 t_interp,
542 t_prev: self.t_prev,
543 t_curr: self.t,
544 });
545 }
546
547 let theta = (t_interp - self.t_prev) / self.h_prev;
549 let one_minus_theta = T::one() - theta;
550
551 let ilast = self.cont.len() - 1;
553 let poly = (1..ilast)
554 .rev()
555 .fold(self.cont[ilast].clone(), |mut acc, i| {
556 let factor = if i >= 4 {
557 if (ilast - i) % 2 == 1 {
558 one_minus_theta
559 } else {
560 theta
561 }
562 } else if i % 2 == 1 {
563 one_minus_theta
564 } else {
565 theta
566 };
567 acc.scale_by(factor);
568 acc.add_scaled(T::one(), &self.cont[i]);
569 acc
570 });
571
572 let y_interp = self.cont[0].plus_scaled(theta, &poly);
574
575 Ok(y_interp)
576 }
577}