1use std::collections::VecDeque;
3
4use crate::{
5 dde::{DDE, DelayNumericalMethod},
6 error::Error,
7 interpolate::{Interpolation, cubic_hermite_interpolate},
8 methods::{Adaptive, Delay, ExplicitRungeKutta, h_init::InitialStepSize},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15impl<
16 const L: usize,
17 T: Real,
18 Y: State<T>,
19 H: Fn(T) -> Y,
20 const O: usize,
21 const S: usize,
22 const I: usize,
23> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
24{
25 fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
26 where
27 F: DDE<L, T, Y> + ?Sized,
28 {
29 let mut evals = Evals::new();
30
31 if L == 0 {
33 return Err(Error::NoLags);
34 }
35
36 self.t0 = t0;
38 self.t = t0;
39 self.y = y0.clone();
40 self.dydt = y0.zeros_like();
41 self.y_prev = y0.clone();
42 self.dydt_prev = y0.zeros_like();
43 self.k = core::array::from_fn(|_| y0.zeros_like());
44 self.cont = core::array::from_fn(|_| y0.zeros_like());
45 self.t_prev = self.t;
46 self.y_prev = self.y.clone();
47 self.status = Status::Initialized;
48 self.steps = 0;
49 self.stiffness_counter = 0;
50 self.history = VecDeque::new();
51
52 let mut delays = [T::zero(); L];
54 let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
55
56 dde.lags(self.t, &self.y, &mut delays);
58 for i in 0..L {
59 let t_delayed = self.t - delays[i];
60 if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
62 return Err(Error::BadInput {
63 msg: format!(
64 "Initial delayed time {} is out of history range (t <= {}).",
65 t_delayed, t0
66 ),
67 });
68 }
69 y_delayed[i] = phi(t_delayed);
70 }
71
72 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
74 evals.function += 1;
75 self.dydt_prev = self.dydt.clone(); self.history
77 .push_back((self.t, self.y.clone(), self.dydt.clone()));
78
79 if self.h0 == T::zero() {
81 self.h0 = InitialStepSize::<Delay>::compute(
83 dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
84 &self.k[0], &mut evals,
85 );
86 evals.function += 2; }
88
89 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
91 Ok(h0) => self.h = (self.filter)(h0),
92 Err(status) => return Err(status),
93 }
94 Ok(evals)
95 }
96
97 fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
98 where
99 F: DDE<L, T, Y> + ?Sized,
100 {
101 let mut evals = Evals::new();
102
103 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
105 self.status = Status::Error(Error::StepSize {
106 t: self.t,
107 y: self.y.clone(),
108 });
109 return Err(Error::StepSize {
110 t: self.t,
111 y: self.y.clone(),
112 });
113 }
114
115 if self.steps >= self.max_steps {
117 self.status = Status::Error(Error::MaxSteps {
118 t: self.t,
119 y: self.y.clone(),
120 });
121 return Err(Error::MaxSteps {
122 t: self.t,
123 y: self.y.clone(),
124 });
125 }
126 self.steps += 1;
127
128 let mut delays = [T::zero(); L];
130 let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
131
132 self.k[0] = self.dydt.clone();
134
135 let mut min_delay_abs = T::infinity();
137 let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
139 dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
140 for i in 0..L {
141 min_delay_abs = min_delay_abs.min(delays[i].abs());
142 }
143
144 let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
146 5
147 } else {
148 1
149 };
150
151 let mut y_next_est = self.y.clone();
152 let mut dydt_next_est = self.y.zeros_like();
153 let mut y_next_est_prev = self.y.clone();
154 let mut dde_iter_failed = false;
155 let mut err_norm: T = T::zero();
156
157 for it in 0..max_iter {
159 if it > 0 {
160 y_next_est_prev = y_next_est.clone();
161 }
162
163 for i in 1..self.stages {
165 let mut y_stage = self.y.clone();
166 for j in 0..i {
167 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
168 }
169 let t_stage = self.t + self.c[i] * self.h;
171 dde.lags(t_stage, &y_stage, &mut delays);
172 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
173 self.status = Status::Error(e.clone());
174 return Err(e);
175 }
176
177 dde.diff(
178 self.t + self.c[i] * self.h,
179 &y_stage,
180 &y_delayed,
181 &mut self.k[i],
182 );
183 }
184 evals.function += self.stages - 1;
185
186 let mut y_high = self.y.clone();
188 for i in 0..self.stages {
189 y_high.add_scaled(self.b[i] * self.h, &self.k[i]);
190 }
191 let mut y_low = self.y.clone();
192 let bh = &self.bh.unwrap();
193 for i in 0..self.stages {
194 y_low.add_scaled(bh[i] * self.h, &self.k[i]);
195 }
196
197 let err = y_high.minus(&y_low);
198 err_norm = self.y.error_norm_inf(&y_high, &err, &self.atol, &self.rtol);
199
200 if max_iter > 1 && it > 0 {
202 let n_dim = self.y.len();
203 let iter_diff = y_high.minus(&y_next_est_prev);
204 let mut iter_err =
205 y_next_est_prev.error_norm(&y_high, &iter_diff, &self.atol, &self.rtol);
206 if n_dim > 0 {
207 iter_err = (iter_err / T::from_usize(n_dim).unwrap()).sqrt();
208 }
209
210 if iter_err <= self.rtol.average() * T::from_f64(0.1).unwrap() {
211 y_next_est = y_high.clone();
212 dde.lags(self.t + self.h, &y_next_est, &mut delays);
213 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
214 self.status = Status::Error(e.clone());
215 return Err(e);
216 }
217 dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
218 evals.function += 1;
219 break;
220 }
221 if it == max_iter - 1 {
222 dde_iter_failed = iter_err > self.rtol.average() * T::from_f64(0.1).unwrap();
223 }
224 }
225
226 y_next_est = y_high.clone();
228
229 dde.lags(self.t + self.h, &y_next_est, &mut delays);
231 if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
232 self.status = Status::Error(e.clone());
233 return Err(e);
234 }
235 dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
236 evals.function += 1;
237 }
238
239 if dde_iter_failed {
241 let sign = self.h.signum();
242 self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
243 if min_delay_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
244 {
245 self.h = min_delay_abs * sign;
246 }
247
248 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
249 self.h = (self.filter)(self.h);
250 self.status = Status::RejectedStep;
251 return Ok(evals);
252 }
253
254 let order = T::from_usize(self.order).unwrap();
256 let error_exponent = T::one() / order;
257 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
258 scale = scale.max(self.min_scale).min(self.max_scale);
259
260 if err_norm <= T::one() {
262 self.t_prev = self.t;
264 self.y_prev = self.y.clone();
265 self.dydt_prev = self.dydt.clone();
266 self.h_prev = self.h;
267
268 if let Status::RejectedStep = self.status {
269 self.stiffness_counter = 0;
271 scale = scale.min(T::one());
272 }
273 self.status = Status::Solving;
274
275 if self.bi.is_some() {
277 for i in 0..(I - S) {
278 let mut y_stage = self.y.clone();
279 for j in 0..self.stages + i {
280 y_stage.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
281 }
282 let t_stage = self.t + self.c[self.stages + i] * self.h;
283 dde.lags(t_stage, &y_stage, &mut delays);
284 if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
285 self.status = Status::Error(e.clone());
286 return Err(e);
287 }
288 dde.diff(
289 self.t + self.c[self.stages + i] * self.h,
290 &y_stage,
291 &y_delayed,
292 &mut self.k[self.stages + i],
293 );
294 }
295 evals.function += I - S;
296 }
297
298 self.t += self.h;
300 self.y = y_next_est;
301
302 if self.fsal {
304 self.dydt = self.k[S - 1].clone();
305 } else {
306 dde.lags(self.t, &self.y, &mut delays);
307 if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
308 self.status = Status::Error(e.clone());
309 return Err(e);
310 }
311 dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
312 evals.function += 1;
313 }
314
315 self.history
317 .push_back((self.t, self.y.clone(), self.dydt.clone()));
318 if let Some(max_delay) = self.max_delay {
319 let cutoff_time = self.t - max_delay;
320 while let Some((t_front, _, _)) = self.history.get(1) {
321 if *t_front < cutoff_time {
322 self.history.pop_front();
323 } else {
324 break;
325 }
326 }
327 }
328 } else {
329 self.status = Status::RejectedStep;
331 self.stiffness_counter += 1;
332
333 if self.stiffness_counter >= self.max_rejects {
334 self.status = Status::Error(Error::Stiffness {
335 t: self.t,
336 y: self.y.clone(),
337 });
338 return Err(Error::Stiffness {
339 t: self.t,
340 y: self.y.clone(),
341 });
342 }
343 }
344
345 self.h *= scale;
347 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
348 self.h = (self.filter)(self.h);
349
350 Ok(evals)
351 }
352
353 fn t(&self) -> T {
354 self.t
355 }
356 fn y(&self) -> &Y {
357 &self.y
358 }
359 fn t_prev(&self) -> T {
360 self.t_prev
361 }
362 fn y_prev(&self) -> &Y {
363 &self.y_prev
364 }
365 fn h(&self) -> T {
366 self.h
367 }
368 fn set_h(&mut self, h: T) {
369 self.h = (self.filter)(h);
370 }
371 fn status(&self) -> &Status<T, Y> {
372 &self.status
373 }
374 fn set_status(&mut self, status: Status<T, Y>) {
375 self.status = status;
376 }
377}
378
379impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
380 ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
381{
382 fn lagvals<const L: usize, H>(
383 &mut self,
384 t_stage: T,
385 delays: &[T; L],
386 y_delayed: &mut [Y; L],
387 phi: &H,
388 ) -> Result<(), Error<T, Y>>
389 where
390 H: Fn(T) -> Y,
391 {
392 for idx in 0..L {
393 let t_delayed = t_stage - delays[idx];
394
395 if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
397 y_delayed[idx] = phi(t_delayed);
398 } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
400 if let Some(dense_coeffs) = self.bi.as_ref() {
401 let theta = (t_delayed - self.t_prev) / self.h_prev;
402
403 let mut coeffs = [T::zero(); I];
404 for s_idx in 0..I {
405 if s_idx < self.cont.len() && s_idx < dense_coeffs.len() {
406 coeffs[s_idx] = dense_coeffs[s_idx][self.dense_stages - 1];
407 for j in (0..self.dense_stages - 1).rev() {
408 coeffs[s_idx] = coeffs[s_idx] * theta + dense_coeffs[s_idx][j];
409 }
410 coeffs[s_idx] *= theta;
411 }
412 }
413
414 let mut y_interp = self.y_prev.clone();
415 for s_idx in 0..I {
416 if s_idx < self.k.len() && s_idx < self.cont.len() {
417 y_interp.add_scaled(coeffs[s_idx] * self.h_prev, &self.k[s_idx]);
418 }
419 }
420 y_delayed[idx] = y_interp;
421 } else {
422 y_delayed[idx] = cubic_hermite_interpolate(
423 self.t_prev,
424 self.t,
425 &self.y_prev,
426 &self.y,
427 &self.dydt_prev,
428 &self.dydt,
429 t_delayed,
430 );
431 }
432 } else {
434 let mut found = false;
436 let buffer = &self.history;
437 let mut it = buffer.iter();
438 if let Some(mut left) = it.next() {
439 for right in it {
440 let (t_left, y_left, dydt_left) = left;
441 let (t_right, y_right, dydt_right) = right;
442
443 let in_interval = if self.h.signum() > T::zero() {
444 *t_left <= t_delayed && t_delayed <= *t_right
445 } else {
446 *t_right <= t_delayed && t_delayed <= *t_left
447 };
448
449 if in_interval {
450 y_delayed[idx] = cubic_hermite_interpolate(
451 *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
452 t_delayed,
453 );
454 found = true;
455 break;
456 }
457 left = right;
458 }
459 }
460 if !found {
461 return Err(Error::InsufficientHistory {
462 t_delayed,
463 t_prev: self.t_prev,
464 t_curr: self.t,
465 });
466 }
467 }
468 }
469 Ok(())
470 }
471}
472
473impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
474 for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
475{
476 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
478 let dir = (self.t - self.t_prev).signum();
479 if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
480 return Err(Error::OutOfBounds {
481 t_interp,
482 t_prev: self.t_prev,
483 t_curr: self.t,
484 });
485 }
486
487 if let Some(dense_coeffs) = self.bi.as_ref() {
489 let theta = (t_interp - self.t_prev) / self.h_prev;
491
492 let mut coeffs = [T::zero(); I];
493 for i in 0..self.dense_stages {
495 coeffs[i] = dense_coeffs[i][self.order - 1];
497
498 for j in (0..self.order - 1).rev() {
500 coeffs[i] = coeffs[i] * theta + dense_coeffs[i][j];
501 }
502
503 coeffs[i] *= theta;
505 }
506
507 let mut y_interp = self.y_prev.clone();
509 for i in 0..I {
510 y_interp.add_scaled(coeffs[i] * self.h_prev, &self.k[i]);
511 }
512
513 Ok(y_interp)
514 } else {
515 let y_interp = cubic_hermite_interpolate(
517 self.t_prev,
518 self.t,
519 &self.y_prev,
520 &self.y,
521 &self.dydt_prev,
522 &self.dydt,
523 t_interp,
524 );
525
526 Ok(y_interp)
527 }
528 }
529}