1use std::collections::VecDeque;
4
5use crate::{
6 dde::{DDE, DelayNumericalMethod},
7 error::Error,
8 interpolate::{Interpolation, cubic_hermite_interpolate},
9 methods::{Adaptive, Delay, ExplicitRungeKutta, h_init::InitialStepSize},
10 stats::Evals,
11 status::Status,
12 traits::{Real, State},
13 utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17 const L: usize,
18 T: Real,
19 Y: State<T>,
20 H: Fn(T) -> Y,
21 const O: usize,
22 const S: usize,
23 const I: usize,
24> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
25{
26 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
27 where
28 F: DDE<L, T, Y>,
29 {
30 let mut evals = Evals::new();
31
32 if L <= 0 {
34 return Err(Error::NoLags);
35 }
36
37 self.t0 = t0;
39 self.t = t0;
40 self.y = *y0;
41 self.t_prev = self.t;
42 self.y_prev = self.y;
43 self.status = Status::Initialized;
44 self.steps = 0;
45 self.stiffness_counter = 0;
46 self.history = VecDeque::new();
47
48 let mut delays = [T::zero(); L];
50 let mut y_delayed = [Y::zeros(); L];
51
52 dde.lags(self.t, &self.y, &mut delays);
54 for i in 0..L {
55 let t_delayed = self.t - delays[i];
56 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
58 return Err(Error::BadInput {
59 msg: format!(
60 "Initial delayed time {} is out of history range (t <= {}).",
61 t_delayed, t0
62 ),
63 });
64 }
65 y_delayed[i] = phi(t_delayed);
66 }
67
68 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
70 evals.function += 1;
71 self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
73
74 if self.h0 == T::zero() {
76 self.h0 = InitialStepSize::<Delay>::compute(
78 dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
79 &self.k[0], &mut evals,
80 );
81 evals.function += 2; }
83
84 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
86 Ok(h0) => self.h = h0,
87 Err(status) => return Err(status),
88 }
89 Ok(evals)
90 }
91
92 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
93 where
94 F: DDE<L, T, Y>,
95 {
96 let mut evals = Evals::new();
97
98 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
100 self.status = Status::Error(Error::StepSize {
101 t: self.t,
102 y: self.y,
103 });
104 return Err(Error::StepSize {
105 t: self.t,
106 y: self.y,
107 });
108 }
109
110 if self.steps >= self.max_steps {
112 self.status = Status::Error(Error::MaxSteps {
113 t: self.t,
114 y: self.y,
115 });
116 return Err(Error::MaxSteps {
117 t: self.t,
118 y: self.y,
119 });
120 }
121 self.steps += 1;
122
123 let mut delays = [T::zero(); L];
125 let mut y_delayed = [Y::zeros(); L];
126
127 self.k[0] = self.dydt;
129
130 let mut min_delay_abs = T::infinity();
132 let y_pred_for_lags = self.y + self.k[0] * self.h;
134 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
135 for i in 0..L {
136 min_delay_abs = min_delay_abs.min(delays[i].abs());
137 }
138
139 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
141 5
142 } else {
143 1
144 };
145
146 let mut y_next_est = self.y;
147 let mut dydt_next_est = Y::zeros();
148 let mut y_next_est_prev = self.y;
149 let mut dde_iter_failed = false;
150 let mut err_norm: T = T::zero();
151
152 for it in 0..max_iter {
154 if it > 0 {
155 y_next_est_prev = y_next_est;
156 }
157
158 for i in 1..self.stages {
160 let mut y_stage = self.y;
161 for j in 0..i {
162 y_stage += self.k[j] * (self.a[i][j] * self.h);
163 }
164 let t_stage = self.t + self.c[i] * self.h;
166 dde.lags(t_stage, &y_stage, &mut delays);
167 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
168 self.status = Status::Error(e.clone());
169 return Err(e);
170 }
171
172 dde.diff(
173 self.t + self.c[i] * self.h,
174 &y_stage,
175 &y_delayed,
176 &mut self.k[i],
177 );
178 }
179 evals.function += self.stages - 1;
180
181 let mut y_high = self.y;
183 for i in 0..self.stages {
184 y_high += self.k[i] * (self.b[i] * self.h);
185 }
186 let mut y_low = self.y;
187 let bh = &self.bh.unwrap();
188 for i in 0..self.stages {
189 y_low += self.k[i] * (bh[i] * self.h);
190 }
191 let err_vec: Y = y_high - y_low;
192
193 err_norm = T::zero();
195 for n in 0..self.y.len() {
196 let tol =
197 self.atol[n] + self.rtol[n] * self.y.get(n).abs().max(y_high.get(n).abs());
198 err_norm = err_norm.max((err_vec.get(n) / tol).abs());
199 }
200
201 if max_iter > 1 && it > 0 {
203 let mut iter_err = T::zero();
204 let n_dim = self.y.len();
205 for d in 0..n_dim {
206 let scale = self.atol[d]
207 + self.rtol[d] * y_next_est_prev.get(d).abs().max(y_high.get(d).abs());
208 if scale > T::zero() {
209 let diff_val = y_high.get(d) - y_next_est_prev.get(d);
210 iter_err += (diff_val / scale).powi(2);
211 }
212 }
213 if n_dim > 0 {
214 iter_err = (iter_err / T::from_usize(n_dim).unwrap()).sqrt();
215 }
216
217 if iter_err <= self.rtol.average() * T::from_f64(0.1).unwrap() {
218 y_next_est = y_high;
219 dde.lags(self.t + self.h, &y_next_est, &mut delays);
220 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
221 self.status = Status::Error(e.clone());
222 return Err(e);
223 }
224 dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
225 evals.function += 1;
226 break;
227 }
228 if it == max_iter - 1 {
229 dde_iter_failed = iter_err > self.rtol.average() * T::from_f64(0.1).unwrap();
230 }
231 }
232
233 y_next_est = y_high;
235
236 dde.lags(self.t + self.h, &y_next_est, &mut delays);
238 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
239 self.status = Status::Error(e.clone());
240 return Err(e);
241 }
242 dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
243 evals.function += 1;
244 }
245
246 if dde_iter_failed {
248 let sign = self.h.signum();
249 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
250 if min_delay_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
251 {
252 self.h = min_delay_abs * sign;
253 }
254
255 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
256 self.status = Status::RejectedStep;
257 return Ok(evals);
258 }
259
260 let order = T::from_usize(self.order).unwrap();
262 let error_exponent = T::one() / order;
263 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
264 scale = scale.max(self.min_scale).min(self.max_scale);
265
266 if err_norm <= T::one() {
268 self.t_prev = self.t;
270 self.y_prev = self.y;
271 self.dydt_prev = self.dydt;
272 self.h_prev = self.h;
273
274 if let Status::RejectedStep = self.status {
275 self.stiffness_counter = 0;
277 scale = scale.min(T::one());
278 }
279 self.status = Status::Solving;
280
281 if self.bi.is_some() {
283 for i in 0..(I - S) {
284 let mut y_stage = self.y;
285 for j in 0..self.stages + i {
286 y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
287 }
288 let t_stage = self.t + self.c[self.stages + i] * self.h;
289 dde.lags(t_stage, &y_stage, &mut delays);
290 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
291 self.status = Status::Error(e.clone());
292 return Err(e);
293 }
294 dde.diff(
295 self.t + self.c[self.stages + i] * self.h,
296 &y_stage,
297 &y_delayed,
298 &mut self.k[self.stages + i],
299 );
300 }
301 evals.function += I - S;
302 }
303
304 self.t += self.h;
306 self.y = y_next_est;
307
308 if self.fsal {
310 self.dydt = self.k[S - 1];
311 } else {
312 dde.lags(self.t, &self.y, &mut delays);
313 if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
314 self.status = Status::Error(e.clone());
315 return Err(e);
316 }
317 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
318 evals.function += 1;
319 }
320
321 self.history.push_back((self.t, self.y, self.dydt));
323 if let Some(max_delay) = self.max_delay {
324 let cutoff_time = self.t - max_delay;
325 while let Some((t_front, _, _)) = self.history.get(1) {
326 if *t_front < cutoff_time {
327 self.history.pop_front();
328 } else {
329 break;
330 }
331 }
332 }
333 } else {
334 self.status = Status::RejectedStep;
336 self.stiffness_counter += 1;
337
338 if self.stiffness_counter >= self.max_rejects {
339 self.status = Status::Error(Error::Stiffness {
340 t: self.t,
341 y: self.y,
342 });
343 return Err(Error::Stiffness {
344 t: self.t,
345 y: self.y,
346 });
347 }
348 }
349
350 self.h *= scale;
352 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
353
354 Ok(evals)
355 }
356
357 fn t(&self) -> T {
358 self.t
359 }
360 fn y(&self) -> &Y {
361 &self.y
362 }
363 fn t_prev(&self) -> T {
364 self.t_prev
365 }
366 fn y_prev(&self) -> &Y {
367 &self.y_prev
368 }
369 fn h(&self) -> T {
370 self.h
371 }
372 fn set_h(&mut self, h: T) {
373 self.h = h;
374 }
375 fn status(&self) -> &Status<T, Y> {
376 &self.status
377 }
378 fn set_status(&mut self, status: Status<T, Y>) {
379 self.status = status;
380 }
381}
382
383impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
384 ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
385{
386 fn lagvals<const L: usize, H>(
387 &mut self,
388 t_stage: T,
389 delays: &[T; L],
390 y_delayed: &mut [Y; L],
391 phi: &H,
392 ) -> Result<(), Error<T, Y>>
393 where
394 H: Fn(T) -> Y,
395 {
396 for idx in 0..L {
397 let t_delayed = t_stage - delays[idx];
398
399 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
401 y_delayed[idx] = phi(t_delayed);
402 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
404 if self.bi.is_some() {
405 let theta = (t_delayed - self.t_prev) / self.h_prev;
406 let dense_coeffs = self.bi.as_ref().unwrap();
407
408 let mut coeffs = [T::zero(); I];
409 for s_idx in 0..I {
410 if s_idx < self.cont.len() && s_idx < dense_coeffs.len() {
411 coeffs[s_idx] = dense_coeffs[s_idx][self.dense_stages - 1];
412 for j in (0..self.dense_stages - 1).rev() {
413 coeffs[s_idx] = coeffs[s_idx] * theta + dense_coeffs[s_idx][j];
414 }
415 coeffs[s_idx] *= theta;
416 }
417 }
418
419 let mut y_interp = self.y_prev;
420 for s_idx in 0..I {
421 if s_idx < self.k.len() && s_idx < self.cont.len() {
422 y_interp += self.k[s_idx] * (coeffs[s_idx] * self.h_prev);
423 }
424 }
425 y_delayed[idx] = y_interp;
426 } else {
427 y_delayed[idx] = cubic_hermite_interpolate(
428 self.t_prev,
429 self.t,
430 &self.y_prev,
431 &self.y,
432 &self.dydt_prev,
433 &self.dydt,
434 t_delayed,
435 );
436 }
437 } else {
439 let mut found = false;
441 let buffer = &self.history;
442 let mut it = buffer.iter();
443 if let Some(mut left) = it.next() {
444 for right in it {
445 let (t_left, y_left, dydt_left) = left;
446 let (t_right, y_right, dydt_right) = right;
447
448 let in_interval = if self.h.signum() > T::zero() {
449 *t_left <= t_delayed && t_delayed <= *t_right
450 } else {
451 *t_right <= t_delayed && t_delayed <= *t_left
452 };
453
454 if in_interval {
455 y_delayed[idx] = cubic_hermite_interpolate(
456 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
457 t_delayed,
458 );
459 found = true;
460 break;
461 }
462 left = right;
463 }
464 }
465 if !found {
466 return Err(Error::InsufficientHistory {
467 t_delayed,
468 t_prev: self.t_prev,
469 t_curr: self.t,
470 });
471 }
472 }
473 }
474 Ok(())
475 }
476}
477
478impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
479 for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
480{
481 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
483 let dir = (self.t - self.t_prev).signum();
484 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
485 return Err(Error::OutOfBounds {
486 t_interp,
487 t_prev: self.t_prev,
488 t_curr: self.t,
489 });
490 }
491
492 if self.bi.is_some() {
494 let theta = (t_interp - self.t_prev) / self.h_prev;
496
497 let dense_coeffs = self.bi.as_ref().unwrap();
499
500 let mut coeffs = [T::zero(); I];
501 for i in 0..self.dense_stages {
503 coeffs[i] = dense_coeffs[i][self.order - 1];
505
506 for j in (0..self.order - 1).rev() {
508 coeffs[i] = coeffs[i] * theta + dense_coeffs[i][j];
509 }
510
511 coeffs[i] *= theta;
513 }
514
515 let mut y_interp = self.y_prev;
517 for i in 0..I {
518 y_interp += self.k[i] * coeffs[i] * self.h_prev;
519 }
520
521 Ok(y_interp)
522 } else {
523 let y_interp = cubic_hermite_interpolate(
525 self.t_prev,
526 self.t,
527 &self.y_prev,
528 &self.y,
529 &self.dydt_prev,
530 &self.dydt,
531 t_interp,
532 );
533
534 Ok(y_interp)
535 }
536 }
537}