1use std::collections::VecDeque;
4
5use crate::{
6 dde::{DDE, DelayNumericalMethod},
7 error::Error,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 methods::{Adaptive, Delay, ExplicitRungeKutta, h_init::InitialStepSize},
10 stats::Evals,
11 status::Status,
12 traits::{CallBackData, Real, State},
13 utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17 const L: usize,
18 T: Real,
19 Y: State<T>,
20 H: Fn(T) -> Y,
21 D: CallBackData,
22 const O: usize,
23 const S: usize,
24 const I: usize,
25> DelayNumericalMethod<L, T, Y, H, D> for ExplicitRungeKutta<Delay, Adaptive, T, Y, D, O, S, I>
26{
27 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
28 where
29 F: DDE<L, T, Y, D>,
30 {
31 let mut evals = Evals::new();
32
33 if L <= 0 {
35 return Err(Error::NoLags);
36 }
37
38 self.t0 = t0;
40 self.t = t0;
41 self.y = *y0;
42 self.t_prev = self.t;
43 self.y_prev = self.y;
44 self.status = Status::Initialized;
45 self.steps = 0;
46 self.stiffness_counter = 0;
47 self.history = VecDeque::new();
48
49 let mut delays = [T::zero(); L];
51 let mut y_delayed = [Y::zeros(); L];
52
53 dde.lags(self.t, &self.y, &mut delays);
55 for i in 0..L {
56 let t_delayed = self.t - delays[i];
57 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
59 return Err(Error::BadInput {
60 msg: format!(
61 "Initial delayed time {} is out of history range (t <= {}).",
62 t_delayed, t0
63 ),
64 });
65 }
66 y_delayed[i] = phi(t_delayed);
67 }
68
69 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
71 evals.function += 1;
72 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
74
75 if self.h0 == T::zero() {
77 self.h0 = InitialStepSize::<Delay>::compute(
79 dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi,
80 &self.k[0], &mut evals,
81 );
82 evals.function += 2; }
84
85 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
87 Ok(h0) => self.h = h0,
88 Err(status) => return Err(status),
89 }
90 Ok(evals)
91 }
92
93 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
94 where
95 F: DDE<L, T, Y, D>,
96 {
97 let mut evals = Evals::new();
98
99 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
101 self.status = Status::Error(Error::StepSize {
102 t: self.t,
103 y: self.y,
104 });
105 return Err(Error::StepSize {
106 t: self.t,
107 y: self.y,
108 });
109 }
110
111 if self.steps >= self.max_steps {
113 self.status = Status::Error(Error::MaxSteps {
114 t: self.t,
115 y: self.y,
116 });
117 return Err(Error::MaxSteps {
118 t: self.t,
119 y: self.y,
120 });
121 }
122 self.steps += 1;
123
124 let mut delays = [T::zero(); L];
126 let mut y_delayed = [Y::zeros(); L];
127
128 self.k[0] = self.dydt;
130
131 let mut min_delay_abs = T::infinity();
133 let y_pred_for_lags = self.y + self.k[0] * self.h;
135 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
136 for i in 0..L {
137 min_delay_abs = min_delay_abs.min(delays[i].abs());
138 }
139
140 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
142 5
143 } else {
144 1
145 };
146
147 let mut y_next_est = self.y;
148 let mut dydt_next_est = Y::zeros();
149 let mut y_next_est_prev = self.y;
150 let mut dde_iter_failed = false;
151 let mut err_norm: T = T::zero();
152
153 for it in 0..max_iter {
155 if it > 0 {
156 y_next_est_prev = y_next_est;
157 }
158
159 for i in 1..self.stages {
161 let mut y_stage = self.y;
162 for j in 0..i {
163 y_stage += self.k[j] * (self.a[i][j] * self.h);
164 }
165 let t_stage = self.t + self.c[i] * self.h;
167 dde.lags(t_stage, &y_stage, &mut delays);
168 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
169 self.status = Status::Error(e.clone());
170 return Err(e);
171 }
172
173 dde.diff(
174 self.t + self.c[i] * self.h,
175 &y_stage,
176 &y_delayed,
177 &mut self.k[i],
178 );
179 }
180 evals.function += self.stages - 1;
181
182 let mut y_high = self.y;
184 for i in 0..self.stages {
185 y_high += self.k[i] * (self.b[i] * self.h);
186 }
187 let mut y_low = self.y;
188 let bh = &self.bh.unwrap();
189 for i in 0..self.stages {
190 y_low += self.k[i] * (bh[i] * self.h);
191 }
192 let err_vec: Y = y_high - y_low;
193
194 err_norm = T::zero();
196 for n in 0..self.y.len() {
197 let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
198 err_norm = err_norm.max((err_vec.get(n) / tol).abs());
199 }
200
201 if max_iter > 1 && it > 0 {
203 let mut iter_err = T::zero();
204 let n_dim = self.y.len();
205 for d in 0..n_dim {
206 let scale = self.atol
207 + self.rtol * y_next_est_prev.get(d).abs().max(y_high.get(d).abs());
208 if scale > T::zero() {
209 let diff_val = y_high.get(d) - y_next_est_prev.get(d);
210 iter_err += (diff_val / scale).powi(2);
211 }
212 }
213 if n_dim > 0 {
214 iter_err = (iter_err / T::from_usize(n_dim).unwrap()).sqrt();
215 }
216
217 if iter_err <= self.rtol * T::from_f64(0.1).unwrap() {
218 y_next_est = y_high;
219 dde.lags(self.t + self.h, &y_next_est, &mut delays);
220 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
221 self.status = Status::Error(e.clone());
222 return Err(e);
223 }
224 dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
225 evals.function += 1;
226 break;
227 }
228 if it == max_iter - 1 {
229 dde_iter_failed = iter_err > self.rtol * T::from_f64(0.1).unwrap();
230 }
231 }
232
233 y_next_est = y_high;
235
236 dde.lags(self.t + self.h, &y_next_est, &mut delays);
238 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
239 self.status = Status::Error(e.clone());
240 return Err(e);
241 }
242 dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
243 evals.function += 1;
244 }
245
246 if dde_iter_failed {
248 let sign = self.h.signum();
249 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
250 if min_delay_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
251 {
252 self.h = min_delay_abs * sign;
253 }
254
255 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
256 self.status = Status::RejectedStep;
257 return Ok(evals);
258 }
259
260 let order = T::from_usize(self.order).unwrap();
262 let error_exponent = T::one() / order;
263 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
264 scale = scale.max(self.min_scale).min(self.max_scale);
265
266 if err_norm <= T::one() {
268 self.t_prev = self.t;
270 self.y_prev = self.y;
271 self.dydt_prev = self.dydt;
272 self.h_prev = self.h;
273
274 if let Status::RejectedStep = self.status {
275 self.stiffness_counter = 0;
277 scale = scale.min(T::one());
278 }
279 self.status = Status::Solving;
280
281 if self.bi.is_some() {
283 for i in 0..(I - S) {
284 let mut y_stage = self.y;
285 for j in 0..self.stages + i {
286 y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
287 }
288 let t_stage = self.t + self.c[self.stages + i] * self.h;
289 dde.lags(t_stage, &y_stage, &mut delays);
290 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
291 self.status = Status::Error(e.clone());
292 return Err(e);
293 }
294 dde.diff(
295 self.t + self.c[self.stages + i] * self.h,
296 &y_stage,
297 &y_delayed,
298 &mut self.k[self.stages + i],
299 );
300 }
301 evals.function += I - S;
302 }
303
304 self.t += self.h;
306 self.y = y_next_est;
307
308 if self.fsal {
310 self.dydt = self.k[S - 1];
311 } else {
312 dde.lags(self.t, &self.y, &mut delays);
313 if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
314 self.status = Status::Error(e.clone());
315 return Err(e);
316 }
317 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
318 evals.function += 1;
319 }
320
321 self.history.push_back((self.t, self.y, self.dydt));
323 if let Some(max_delay) = self.max_delay {
324 let cutoff_time = self.t - max_delay;
325 while let Some((t_front, _, _)) = self.history.get(1) {
326 if *t_front < cutoff_time {
327 self.history.pop_front();
328 } else {
329 break;
330 }
331 }
332 }
333 } else {
334 self.status = Status::RejectedStep;
336 self.stiffness_counter += 1;
337
338 if self.stiffness_counter >= self.max_rejects {
339 self.status = Status::Error(Error::Stiffness {
340 t: self.t,
341 y: self.y,
342 });
343 return Err(Error::Stiffness {
344 t: self.t,
345 y: self.y,
346 });
347 }
348 }
349
350 self.h *= scale;
352 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
353
354 Ok(evals)
355 }
356
357 fn t(&self) -> T {
358 self.t
359 }
360 fn y(&self) -> &Y {
361 &self.y
362 }
363 fn t_prev(&self) -> T {
364 self.t_prev
365 }
366 fn y_prev(&self) -> &Y {
367 &self.y_prev
368 }
369 fn h(&self) -> T {
370 self.h
371 }
372 fn set_h(&mut self, h: T) {
373 self.h = h;
374 }
375 fn status(&self) -> &Status<T, Y, D> {
376 &self.status
377 }
378 fn set_status(&mut self, status: Status<T, Y, D>) {
379 self.status = status;
380 }
381}
382
383impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
384 ExplicitRungeKutta<Delay, Adaptive, T, Y, D, O, S, I>
385{
386 fn lagvals<const L: usize, H>(
387 &mut self,
388 t_stage: T,
389 delays: &[T; L],
390 y_delayed: &mut [Y; L],
391 phi: &H,
392 ) -> Result<(), Error<T, Y>>
393 where
394 H: Fn(T) -> Y,
395 {
396 for idx in 0..L {
397 let t_delayed = t_stage - delays[idx];
398
399 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
401 y_delayed[idx] = phi(t_delayed);
402 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
404 if self.bi.is_some() {
405 let theta = (t_delayed - self.t_prev) / self.h_prev;
406 let dense_coeffs = self.bi.as_ref().unwrap();
407
408 let mut coeffs = [T::zero(); I];
409 for s_idx in 0..I {
410 if s_idx < self.cont.len() && s_idx < dense_coeffs.len() {
411 coeffs[s_idx] = dense_coeffs[s_idx][self.dense_stages - 1];
412 for j in (0..self.dense_stages - 1).rev() {
413 coeffs[s_idx] = coeffs[s_idx] * theta + dense_coeffs[s_idx][j];
414 }
415 coeffs[s_idx] *= theta;
416 }
417 }
418
419 let mut y_interp = self.y_prev;
420 for s_idx in 0..I {
421 if s_idx < self.k.len() && s_idx < self.cont.len() {
422 y_interp += self.k[s_idx] * (coeffs[s_idx] * self.h_prev);
423 }
424 }
425 y_delayed[idx] = y_interp;
426 } else {
427 y_delayed[idx] = cubic_hermite_interpolate(
428 self.t_prev,
429 self.t,
430 &self.y_prev,
431 &self.y,
432 &self.dydt_prev,
433 &self.dydt,
434 t_delayed,
435 );
436 }
437 } else {
439 let mut found = false;
441 let buffer = &self.history;
442 let mut it = buffer.iter();
443 if let Some(mut left) = it.next() {
444 for right in it {
445 let (t_left, y_left, dydt_left) = left;
446 let (t_right, y_right, dydt_right) = right;
447
448 let in_interval = if self.h.signum() > T::zero() {
449 *t_left <= t_delayed && t_delayed <= *t_right
450 } else {
451 *t_right <= t_delayed && t_delayed <= *t_left
452 };
453
454 if in_interval {
455 y_delayed[idx] = cubic_hermite_interpolate(
456 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
457 t_delayed,
458 );
459 found = true;
460 break;
461 }
462 left = right;
463 }
464 }
465 if !found {
466 return Err(Error::InsufficientHistory {
467 t_delayed,
468 t_prev: self.t_prev,
469 t_curr: self.t,
470 });
471 }
472 }
473 }
474 Ok(())
475 }
476}
477
478impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
479 Interpolation<T, Y> for ExplicitRungeKutta<Delay, Adaptive, T, Y, D, O, S, I>
480{
481 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
483 let dir = (self.t - self.t_prev).signum();
484 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
485 return Err(Error::OutOfBounds {
486 t_interp,
487 t_prev: self.t_prev,
488 t_curr: self.t,
489 });
490 }
491
492 if self.bi.is_some() {
494 let theta = (t_interp - self.t_prev) / self.h_prev;
496
497 let dense_coeffs = self.bi.as_ref().unwrap();
499
500 let mut coeffs = [T::zero(); I];
501 for i in 0..self.dense_stages {
503 coeffs[i] = dense_coeffs[i][self.order - 1];
505
506 for j in (0..self.order - 1).rev() {
508 coeffs[i] = coeffs[i] * theta + dense_coeffs[i][j];
509 }
510
511 coeffs[i] *= theta;
513 }
514
515 let mut y_interp = self.y_prev;
517 for i in 0..I {
518 y_interp += self.k[i] * coeffs[i] * self.h_prev;
519 }
520
521 Ok(y_interp)
522 } else {
523 let y_interp = cubic_hermite_interpolate(
525 self.t_prev,
526 self.t,
527 &self.y_prev,
528 &self.y,
529 &self.dydt_prev,
530 &self.dydt,
531 t_interp,
532 );
533
534 Ok(y_interp)
535 }
536 }
537}