Skip to main content

differential_equations/methods/erk/adaptive/
delay.rs

1//! Adaptive Runge-Kutta methods for Delay Differential Equations (DDEs)
2use std::collections::VecDeque;
3
4use crate::{
5    dde::{DDE, DelayNumericalMethod},
6    error::Error,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    methods::{Adaptive, Delay, ExplicitRungeKutta, h_init::InitialStepSize},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15impl<
16    const L: usize,
17    T: Real,
18    Y: State<T>,
19    H: Fn(T) -> Y,
20    const O: usize,
21    const S: usize,
22    const I: usize,
23> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
24{
25    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
26    where
27        F: DDE<L, T, Y> + ?Sized,
28    {
29        let mut evals = Evals::new();
30
31        // DDE requires at least one lag
32        if L == 0 {
33            return Err(Error::NoLags);
34        }
35
36        // Init solver state
37        self.t0 = t0;
38        self.t = t0;
39        self.y = y0.clone();
40        self.dydt = y0.zeros_like();
41        self.y_prev = y0.clone();
42        self.dydt_prev = y0.zeros_like();
43        self.k = core::array::from_fn(|_| y0.zeros_like());
44        self.cont = core::array::from_fn(|_| y0.zeros_like());
45        self.t_prev = self.t;
46        self.y_prev = self.y.clone();
47        self.status = Status::Initialized;
48        self.steps = 0;
49        self.stiffness_counter = 0;
50        self.history = VecDeque::new();
51
52        // Delay buffers
53        let mut delays = [T::zero(); L];
54        let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
55
56        // Initial delays and history
57        dde.lags(self.t, &self.y, &mut delays);
58        for i in 0..L {
59            let t_delayed = self.t - delays[i];
60            // Ensure delayed time is within history range
61            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
62                return Err(Error::BadInput {
63                    msg: format!(
64                        "Initial delayed time {} is out of history range (t <= {}).",
65                        t_delayed, t0
66                    ),
67                });
68            }
69            y_delayed[i] = phi(t_delayed);
70        }
71
72        // Initial derivative and seed history
73        dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
74        evals.function += 1;
75        self.dydt_prev = self.dydt.clone(); // Store initial state in history
76        self.history
77            .push_back((self.t, self.y.clone(), self.dydt.clone()));
78
79        // Initial step size
80        if self.h0 == T::zero() {
81            // Adaptive step size for DDEs
82            self.h0 = InitialStepSize::<Delay>::compute(
83                dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
84                &self.k[0], &mut evals,
85            );
86            evals.function += 2; // h_init performs 2 function evaluations
87        }
88
89        // Validate initial step size
90        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
91            Ok(h0) => self.h = (self.filter)(h0),
92            Err(status) => return Err(status),
93        }
94        Ok(evals)
95    }
96
97    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
98    where
99        F: DDE<L, T, Y> + ?Sized,
100    {
101        let mut evals = Evals::new();
102
103        // Validate step size
104        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
105            self.status = Status::Error(Error::StepSize {
106                t: self.t,
107                y: self.y.clone(),
108            });
109            return Err(Error::StepSize {
110                t: self.t,
111                y: self.y.clone(),
112            });
113        }
114
115        // Max steps
116        if self.steps >= self.max_steps {
117            self.status = Status::Error(Error::MaxSteps {
118                t: self.t,
119                y: self.y.clone(),
120            });
121            return Err(Error::MaxSteps {
122                t: self.t,
123                y: self.y.clone(),
124            });
125        }
126        self.steps += 1;
127
128        // Step buffers
129        let mut delays = [T::zero(); L];
130        let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
131
132        // Seed k[0]
133        self.k[0] = self.dydt.clone();
134
135        // Check if delay iteration is needed
136        let mut min_delay_abs = T::infinity();
137        // Predict y(t+h) to estimate delays at t+h
138        let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
139        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
140        for i in 0..L {
141            min_delay_abs = min_delay_abs.min(delays[i].abs());
142        }
143
144        // Delay iteration count
145        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
146            5
147        } else {
148            1
149        };
150
151        let mut y_next_est = self.y.clone();
152        let mut dydt_next_est = self.y.zeros_like();
153        let mut y_next_est_prev = self.y.clone();
154        let mut dde_iter_failed = false;
155        let mut err_norm: T = T::zero();
156
157        // DDE iteration loop
158        for it in 0..max_iter {
159            if it > 0 {
160                y_next_est_prev = y_next_est.clone();
161            }
162
163            // Compute stages
164            for i in 1..self.stages {
165                let mut y_stage = self.y.clone();
166                for j in 0..i {
167                    y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
168                }
169                // Delayed states for this stage
170                let t_stage = self.t + self.c[i] * self.h;
171                dde.lags(t_stage, &y_stage, &mut delays);
172                if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
173                    self.status = Status::Error(e.clone());
174                    return Err(e);
175                }
176
177                dde.diff(
178                    self.t + self.c[i] * self.h,
179                    &y_stage,
180                    &y_delayed,
181                    &mut self.k[i],
182                );
183            }
184            evals.function += self.stages - 1;
185
186            // High/low order solutions for error
187            let mut y_high = self.y.clone();
188            for i in 0..self.stages {
189                y_high.add_scaled(self.b[i] * self.h, &self.k[i]);
190            }
191            let mut y_low = self.y.clone();
192            let bh = &self.bh.unwrap();
193            for i in 0..self.stages {
194                y_low.add_scaled(bh[i] * self.h, &self.k[i]);
195            }
196
197            let err = y_high.minus(&y_low);
198            err_norm = self.y.error_norm_inf(&y_high, &err, &self.atol, &self.rtol);
199
200            // Iteration convergence (if iterating)
201            if max_iter > 1 && it > 0 {
202                let n_dim = self.y.len();
203                let iter_diff = y_high.minus(&y_next_est_prev);
204                let mut iter_err =
205                    y_next_est_prev.error_norm(&y_high, &iter_diff, &self.atol, &self.rtol);
206                if n_dim > 0 {
207                    iter_err = (iter_err / T::from_usize(n_dim).unwrap()).sqrt();
208                }
209
210                if iter_err <= self.rtol.average() * T::from_f64(0.1).unwrap() {
211                    y_next_est = y_high.clone();
212                    dde.lags(self.t + self.h, &y_next_est, &mut delays);
213                    if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
214                        self.status = Status::Error(e.clone());
215                        return Err(e);
216                    }
217                    dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
218                    evals.function += 1;
219                    break;
220                }
221                if it == max_iter - 1 {
222                    dde_iter_failed = iter_err > self.rtol.average() * T::from_f64(0.1).unwrap();
223                }
224            }
225
226            // Update candidate
227            y_next_est = y_high.clone();
228
229            // Derivative at t+h for candidate
230            dde.lags(self.t + self.h, &y_next_est, &mut delays);
231            if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
232                self.status = Status::Error(e.clone());
233                return Err(e);
234            }
235            dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
236            evals.function += 1;
237        }
238
239        // Iteration failed: reduce h and retry
240        if dde_iter_failed {
241            let sign = self.h.signum();
242            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
243            if min_delay_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
244            {
245                self.h = min_delay_abs * sign;
246            }
247
248            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
249            self.h = (self.filter)(self.h);
250            self.status = Status::RejectedStep;
251            return Ok(evals);
252        }
253
254        // Step size scale factor
255        let order = T::from_usize(self.order).unwrap();
256        let error_exponent = T::one() / order;
257        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
258        scale = scale.max(self.min_scale).min(self.max_scale);
259
260        // Accept/reject
261        if err_norm <= T::one() {
262            // Accept
263            self.t_prev = self.t;
264            self.y_prev = self.y.clone();
265            self.dydt_prev = self.dydt.clone();
266            self.h_prev = self.h;
267
268            if let Status::RejectedStep = self.status {
269                // Dampen growth after rejection
270                self.stiffness_counter = 0;
271                scale = scale.min(T::one());
272            }
273            self.status = Status::Solving;
274
275            // Dense output stages
276            if self.bi.is_some() {
277                for i in 0..(I - S) {
278                    let mut y_stage = self.y.clone();
279                    for j in 0..self.stages + i {
280                        y_stage.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
281                    }
282                    let t_stage = self.t + self.c[self.stages + i] * self.h;
283                    dde.lags(t_stage, &y_stage, &mut delays);
284                    if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
285                        self.status = Status::Error(e.clone());
286                        return Err(e);
287                    }
288                    dde.diff(
289                        self.t + self.c[self.stages + i] * self.h,
290                        &y_stage,
291                        &y_delayed,
292                        &mut self.k[self.stages + i],
293                    );
294                }
295                evals.function += I - S;
296            }
297
298            // Advance state
299            self.t += self.h;
300            self.y = y_next_est;
301
302            // Derivative for next step
303            if self.fsal {
304                self.dydt = self.k[S - 1].clone();
305            } else {
306                dde.lags(self.t, &self.y, &mut delays);
307                if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
308                    self.status = Status::Error(e.clone());
309                    return Err(e);
310                }
311                dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
312                evals.function += 1;
313            }
314
315            // Append to history and prune
316            self.history
317                .push_back((self.t, self.y.clone(), self.dydt.clone()));
318            if let Some(max_delay) = self.max_delay {
319                let cutoff_time = self.t - max_delay;
320                while let Some((t_front, _, _)) = self.history.get(1) {
321                    if *t_front < cutoff_time {
322                        self.history.pop_front();
323                    } else {
324                        break;
325                    }
326                }
327            }
328        } else {
329            // Reject
330            self.status = Status::RejectedStep;
331            self.stiffness_counter += 1;
332
333            if self.stiffness_counter >= self.max_rejects {
334                self.status = Status::Error(Error::Stiffness {
335                    t: self.t,
336                    y: self.y.clone(),
337                });
338                return Err(Error::Stiffness {
339                    t: self.t,
340                    y: self.y.clone(),
341                });
342            }
343        }
344
345        // Update step size
346        self.h *= scale;
347        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
348        self.h = (self.filter)(self.h);
349
350        Ok(evals)
351    }
352
353    fn t(&self) -> T {
354        self.t
355    }
356    fn y(&self) -> &Y {
357        &self.y
358    }
359    fn t_prev(&self) -> T {
360        self.t_prev
361    }
362    fn y_prev(&self) -> &Y {
363        &self.y_prev
364    }
365    fn h(&self) -> T {
366        self.h
367    }
368    fn set_h(&mut self, h: T) {
369        self.h = (self.filter)(h);
370    }
371    fn status(&self) -> &Status<T, Y> {
372        &self.status
373    }
374    fn set_status(&mut self, status: Status<T, Y>) {
375        self.status = status;
376    }
377}
378
379impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
380    ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
381{
382    fn lagvals<const L: usize, H>(
383        &mut self,
384        t_stage: T,
385        delays: &[T; L],
386        y_delayed: &mut [Y; L],
387        phi: &H,
388    ) -> Result<(), Error<T, Y>>
389    where
390        H: Fn(T) -> Y,
391    {
392        for idx in 0..L {
393            let t_delayed = t_stage - delays[idx];
394
395            // History domain (t_delayed <= t0)
396            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
397                y_delayed[idx] = phi(t_delayed);
398            // Within last accepted step (dense if available, else Hermite)
399            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
400                if let Some(dense_coeffs) = self.bi.as_ref() {
401                    let theta = (t_delayed - self.t_prev) / self.h_prev;
402
403                    let mut coeffs = [T::zero(); I];
404                    for s_idx in 0..I {
405                        if s_idx < self.cont.len() && s_idx < dense_coeffs.len() {
406                            coeffs[s_idx] = dense_coeffs[s_idx][self.dense_stages - 1];
407                            for j in (0..self.dense_stages - 1).rev() {
408                                coeffs[s_idx] = coeffs[s_idx] * theta + dense_coeffs[s_idx][j];
409                            }
410                            coeffs[s_idx] *= theta;
411                        }
412                    }
413
414                    let mut y_interp = self.y_prev.clone();
415                    for s_idx in 0..I {
416                        if s_idx < self.k.len() && s_idx < self.cont.len() {
417                            y_interp.add_scaled(coeffs[s_idx] * self.h_prev, &self.k[s_idx]);
418                        }
419                    }
420                    y_delayed[idx] = y_interp;
421                } else {
422                    y_delayed[idx] = cubic_hermite_interpolate(
423                        self.t_prev,
424                        self.t,
425                        &self.y_prev,
426                        &self.y,
427                        &self.dydt_prev,
428                        &self.dydt,
429                        t_delayed,
430                    );
431                }
432            // Between earlier history points (internal buffer)
433            } else {
434                // Search history for bracketing interval
435                let mut found = false;
436                let buffer = &self.history;
437                let mut it = buffer.iter();
438                if let Some(mut left) = it.next() {
439                    for right in it {
440                        let (t_left, y_left, dydt_left) = left;
441                        let (t_right, y_right, dydt_right) = right;
442
443                        let in_interval = if self.h.signum() > T::zero() {
444                            *t_left <= t_delayed && t_delayed <= *t_right
445                        } else {
446                            *t_right <= t_delayed && t_delayed <= *t_left
447                        };
448
449                        if in_interval {
450                            y_delayed[idx] = cubic_hermite_interpolate(
451                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
452                                t_delayed,
453                            );
454                            found = true;
455                            break;
456                        }
457                        left = right;
458                    }
459                }
460                if !found {
461                    return Err(Error::InsufficientHistory {
462                        t_delayed,
463                        t_prev: self.t_prev,
464                        t_curr: self.t,
465                    });
466                }
467            }
468        }
469        Ok(())
470    }
471}
472
473impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
474    for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
475{
476    /// Interpolates the solution at a given time `t_interp`.
477    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
478        let dir = (self.t - self.t_prev).signum();
479        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
480            return Err(Error::OutOfBounds {
481                t_interp,
482                t_prev: self.t_prev,
483                t_curr: self.t,
484            });
485        }
486
487        // If method has dense output coefficients, use them
488        if let Some(dense_coeffs) = self.bi.as_ref() {
489            // Calculate the normalized distance within the step [0, 1]
490            let theta = (t_interp - self.t_prev) / self.h_prev;
491
492            let mut coeffs = [T::zero(); I];
493            // Compute the interpolation coefficients using Horner's method
494            for i in 0..self.dense_stages {
495                // Start with the highest-order term
496                coeffs[i] = dense_coeffs[i][self.order - 1];
497
498                // Apply Horner's method
499                for j in (0..self.order - 1).rev() {
500                    coeffs[i] = coeffs[i] * theta + dense_coeffs[i][j];
501                }
502
503                // Multiply by s
504                coeffs[i] *= theta;
505            }
506
507            // Compute the interpolated value
508            let mut y_interp = self.y_prev.clone();
509            for i in 0..I {
510                y_interp.add_scaled(coeffs[i] * self.h_prev, &self.k[i]);
511            }
512
513            Ok(y_interp)
514        } else {
515            // Otherwise use cubic Hermite interpolation
516            let y_interp = cubic_hermite_interpolate(
517                self.t_prev,
518                self.t,
519                &self.y_prev,
520                &self.y,
521                &self.dydt_prev,
522                &self.dydt,
523                t_interp,
524            );
525
526            Ok(y_interp)
527        }
528    }
529}