differential_equations/methods/erk/adaptive/
delay.rs

1//! Adaptive Runge-Kutta methods for Delay Differential Equations (DDEs)
2
3use std::collections::VecDeque;
4
5use crate::{
6    dde::{DDE, DelayNumericalMethod},
7    error::Error,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    methods::{Adaptive, Delay, ExplicitRungeKutta, h_init::InitialStepSize},
10    stats::Evals,
11    status::Status,
12    traits::{CallBackData, Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17    const L: usize,
18    T: Real,
19    Y: State<T>,
20    H: Fn(T) -> Y,
21    D: CallBackData,
22    const O: usize,
23    const S: usize,
24    const I: usize,
25> DelayNumericalMethod<L, T, Y, H, D> for ExplicitRungeKutta<Delay, Adaptive, T, Y, D, O, S, I>
26{
27    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
28    where
29        F: DDE<L, T, Y, D>,
30    {
31        let mut evals = Evals::new();
32
33        // DDE requires at least one lag
34        if L <= 0 {
35            return Err(Error::NoLags);
36        }
37
38        // Init solver state
39        self.t0 = t0;
40        self.t = t0;
41        self.y = *y0;
42        self.t_prev = self.t;
43        self.y_prev = self.y;
44        self.status = Status::Initialized;
45        self.steps = 0;
46        self.stiffness_counter = 0;
47        self.history = VecDeque::new();
48
49        // Delay buffers
50        let mut delays = [T::zero(); L];
51        let mut y_delayed = [Y::zeros(); L];
52
53        // Initial delays and history
54        dde.lags(self.t, &self.y, &mut delays);
55        for i in 0..L {
56            let t_delayed = self.t - delays[i];
57            // Ensure delayed time is within history range
58            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
59                return Err(Error::BadInput {
60                    msg: format!(
61                        "Initial delayed time {} is out of history range (t <= {}).",
62                        t_delayed, t0
63                    ),
64                });
65            }
66            y_delayed[i] = phi(t_delayed);
67        }
68
69        // Initial derivative and seed history
70        dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
71        evals.function += 1;
72        self.dydt_prev = self.dydt; // Store initial state in history
73        self.history.push_back((self.t, self.y, self.dydt));
74
75        // Initial step size
76        if self.h0 == T::zero() {
77            // Adaptive step size for DDEs
78            self.h0 = InitialStepSize::<Delay>::compute(
79                dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi,
80                &self.k[0], &mut evals,
81            );
82            evals.function += 2; // h_init performs 2 function evaluations
83        }
84
85        // Validate initial step size
86        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
87            Ok(h0) => self.h = h0,
88            Err(status) => return Err(status),
89        }
90        Ok(evals)
91    }
92
93    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
94    where
95        F: DDE<L, T, Y, D>,
96    {
97        let mut evals = Evals::new();
98
99        // Validate step size
100        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
101            self.status = Status::Error(Error::StepSize {
102                t: self.t,
103                y: self.y,
104            });
105            return Err(Error::StepSize {
106                t: self.t,
107                y: self.y,
108            });
109        }
110
111        // Max steps
112        if self.steps >= self.max_steps {
113            self.status = Status::Error(Error::MaxSteps {
114                t: self.t,
115                y: self.y,
116            });
117            return Err(Error::MaxSteps {
118                t: self.t,
119                y: self.y,
120            });
121        }
122        self.steps += 1;
123
124        // Step buffers
125        let mut delays = [T::zero(); L];
126        let mut y_delayed = [Y::zeros(); L];
127
128        // Seed k[0]
129        self.k[0] = self.dydt;
130
131        // Check if delay iteration is needed
132        let mut min_delay_abs = T::infinity();
133        // Predict y(t+h) to estimate delays at t+h
134        let y_pred_for_lags = self.y + self.k[0] * self.h;
135        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
136        for i in 0..L {
137            min_delay_abs = min_delay_abs.min(delays[i].abs());
138        }
139
140        // Delay iteration count
141        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
142            5
143        } else {
144            1
145        };
146
147        let mut y_next_est = self.y;
148        let mut dydt_next_est = Y::zeros();
149        let mut y_next_est_prev = self.y;
150        let mut dde_iter_failed = false;
151        let mut err_norm: T = T::zero();
152
153        // DDE iteration loop
154        for it in 0..max_iter {
155            if it > 0 {
156                y_next_est_prev = y_next_est;
157            }
158
159            // Compute stages
160            for i in 1..self.stages {
161                let mut y_stage = self.y;
162                for j in 0..i {
163                    y_stage += self.k[j] * (self.a[i][j] * self.h);
164                }
165                // Delayed states for this stage
166                let t_stage = self.t + self.c[i] * self.h;
167                dde.lags(t_stage, &y_stage, &mut delays);
168                if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
169                    self.status = Status::Error(e.clone());
170                    return Err(e);
171                }
172
173                dde.diff(
174                    self.t + self.c[i] * self.h,
175                    &y_stage,
176                    &y_delayed,
177                    &mut self.k[i],
178                );
179            }
180            evals.function += self.stages - 1;
181
182            // High/low order solutions for error
183            let mut y_high = self.y;
184            for i in 0..self.stages {
185                y_high += self.k[i] * (self.b[i] * self.h);
186            }
187            let mut y_low = self.y;
188            let bh = &self.bh.unwrap();
189            for i in 0..self.stages {
190                y_low += self.k[i] * (bh[i] * self.h);
191            }
192            let err_vec: Y = y_high - y_low;
193
194            // Infinity-norm-like error scaled by atol/rtol
195            err_norm = T::zero();
196            for n in 0..self.y.len() {
197                let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
198                err_norm = err_norm.max((err_vec.get(n) / tol).abs());
199            }
200
201            // Iteration convergence (if iterating)
202            if max_iter > 1 && it > 0 {
203                let mut iter_err = T::zero();
204                let n_dim = self.y.len();
205                for d in 0..n_dim {
206                    let scale = self.atol
207                        + self.rtol * y_next_est_prev.get(d).abs().max(y_high.get(d).abs());
208                    if scale > T::zero() {
209                        let diff_val = y_high.get(d) - y_next_est_prev.get(d);
210                        iter_err += (diff_val / scale).powi(2);
211                    }
212                }
213                if n_dim > 0 {
214                    iter_err = (iter_err / T::from_usize(n_dim).unwrap()).sqrt();
215                }
216
217                if iter_err <= self.rtol * T::from_f64(0.1).unwrap() {
218                    y_next_est = y_high;
219                    dde.lags(self.t + self.h, &y_next_est, &mut delays);
220                    if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
221                        self.status = Status::Error(e.clone());
222                        return Err(e);
223                    }
224                    dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
225                    evals.function += 1;
226                    break;
227                }
228                if it == max_iter - 1 {
229                    dde_iter_failed = iter_err > self.rtol * T::from_f64(0.1).unwrap();
230                }
231            }
232
233            // Update candidate
234            y_next_est = y_high;
235
236            // Derivative at t+h for candidate
237            dde.lags(self.t + self.h, &y_next_est, &mut delays);
238            if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
239                self.status = Status::Error(e.clone());
240                return Err(e);
241            }
242            dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
243            evals.function += 1;
244        }
245
246        // Iteration failed: reduce h and retry
247        if dde_iter_failed {
248            let sign = self.h.signum();
249            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
250            if min_delay_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
251            {
252                self.h = min_delay_abs * sign;
253            }
254
255            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
256            self.status = Status::RejectedStep;
257            return Ok(evals);
258        }
259
260        // Step size scale factor
261        let order = T::from_usize(self.order).unwrap();
262        let error_exponent = T::one() / order;
263        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
264        scale = scale.max(self.min_scale).min(self.max_scale);
265
266        // Accept/reject
267        if err_norm <= T::one() {
268            // Accept
269            self.t_prev = self.t;
270            self.y_prev = self.y;
271            self.dydt_prev = self.dydt;
272            self.h_prev = self.h;
273
274            if let Status::RejectedStep = self.status {
275                // Dampen growth after rejection
276                self.stiffness_counter = 0;
277                scale = scale.min(T::one());
278            }
279            self.status = Status::Solving;
280
281            // Dense output stages
282            if self.bi.is_some() {
283                for i in 0..(I - S) {
284                    let mut y_stage = self.y;
285                    for j in 0..self.stages + i {
286                        y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
287                    }
288                    let t_stage = self.t + self.c[self.stages + i] * self.h;
289                    dde.lags(t_stage, &y_stage, &mut delays);
290                    if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
291                        self.status = Status::Error(e.clone());
292                        return Err(e);
293                    }
294                    dde.diff(
295                        self.t + self.c[self.stages + i] * self.h,
296                        &y_stage,
297                        &y_delayed,
298                        &mut self.k[self.stages + i],
299                    );
300                }
301                evals.function += I - S;
302            }
303
304            // Advance state
305            self.t += self.h;
306            self.y = y_next_est;
307
308            // Derivative for next step
309            if self.fsal {
310                self.dydt = self.k[S - 1];
311            } else {
312                dde.lags(self.t, &self.y, &mut delays);
313                if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
314                    self.status = Status::Error(e.clone());
315                    return Err(e);
316                }
317                dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
318                evals.function += 1;
319            }
320
321            // Append to history and prune
322            self.history.push_back((self.t, self.y, self.dydt));
323            if let Some(max_delay) = self.max_delay {
324                let cutoff_time = self.t - max_delay;
325                while let Some((t_front, _, _)) = self.history.get(1) {
326                    if *t_front < cutoff_time {
327                        self.history.pop_front();
328                    } else {
329                        break;
330                    }
331                }
332            }
333        } else {
334            // Reject
335            self.status = Status::RejectedStep;
336            self.stiffness_counter += 1;
337
338            if self.stiffness_counter >= self.max_rejects {
339                self.status = Status::Error(Error::Stiffness {
340                    t: self.t,
341                    y: self.y,
342                });
343                return Err(Error::Stiffness {
344                    t: self.t,
345                    y: self.y,
346                });
347            }
348        }
349
350        // Update step size
351        self.h *= scale;
352        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
353
354        Ok(evals)
355    }
356
357    fn t(&self) -> T {
358        self.t
359    }
360    fn y(&self) -> &Y {
361        &self.y
362    }
363    fn t_prev(&self) -> T {
364        self.t_prev
365    }
366    fn y_prev(&self) -> &Y {
367        &self.y_prev
368    }
369    fn h(&self) -> T {
370        self.h
371    }
372    fn set_h(&mut self, h: T) {
373        self.h = h;
374    }
375    fn status(&self) -> &Status<T, Y, D> {
376        &self.status
377    }
378    fn set_status(&mut self, status: Status<T, Y, D>) {
379        self.status = status;
380    }
381}
382
383impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
384    ExplicitRungeKutta<Delay, Adaptive, T, Y, D, O, S, I>
385{
386    fn lagvals<const L: usize, H>(
387        &mut self,
388        t_stage: T,
389        delays: &[T; L],
390        y_delayed: &mut [Y; L],
391        phi: &H,
392    ) -> Result<(), Error<T, Y>>
393    where
394        H: Fn(T) -> Y,
395    {
396        for idx in 0..L {
397            let t_delayed = t_stage - delays[idx];
398
399            // History domain (t_delayed <= t0)
400            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
401                y_delayed[idx] = phi(t_delayed);
402            // Within last accepted step (dense if available, else Hermite)
403            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
404                if self.bi.is_some() {
405                    let theta = (t_delayed - self.t_prev) / self.h_prev;
406                    let dense_coeffs = self.bi.as_ref().unwrap();
407
408                    let mut coeffs = [T::zero(); I];
409                    for s_idx in 0..I {
410                        if s_idx < self.cont.len() && s_idx < dense_coeffs.len() {
411                            coeffs[s_idx] = dense_coeffs[s_idx][self.dense_stages - 1];
412                            for j in (0..self.dense_stages - 1).rev() {
413                                coeffs[s_idx] = coeffs[s_idx] * theta + dense_coeffs[s_idx][j];
414                            }
415                            coeffs[s_idx] *= theta;
416                        }
417                    }
418
419                    let mut y_interp = self.y_prev;
420                    for s_idx in 0..I {
421                        if s_idx < self.k.len() && s_idx < self.cont.len() {
422                            y_interp += self.k[s_idx] * (coeffs[s_idx] * self.h_prev);
423                        }
424                    }
425                    y_delayed[idx] = y_interp;
426                } else {
427                    y_delayed[idx] = cubic_hermite_interpolate(
428                        self.t_prev,
429                        self.t,
430                        &self.y_prev,
431                        &self.y,
432                        &self.dydt_prev,
433                        &self.dydt,
434                        t_delayed,
435                    );
436                }
437            // Between earlier history points (internal buffer)
438            } else {
439                // Search history for bracketing interval
440                let mut found = false;
441                let buffer = &self.history;
442                let mut it = buffer.iter();
443                if let Some(mut left) = it.next() {
444                    for right in it {
445                        let (t_left, y_left, dydt_left) = left;
446                        let (t_right, y_right, dydt_right) = right;
447
448                        let in_interval = if self.h.signum() > T::zero() {
449                            *t_left <= t_delayed && t_delayed <= *t_right
450                        } else {
451                            *t_right <= t_delayed && t_delayed <= *t_left
452                        };
453
454                        if in_interval {
455                            y_delayed[idx] = cubic_hermite_interpolate(
456                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
457                                t_delayed,
458                            );
459                            found = true;
460                            break;
461                        }
462                        left = right;
463                    }
464                }
465                if !found {
466                    return Err(Error::InsufficientHistory {
467                        t_delayed,
468                        t_prev: self.t_prev,
469                        t_curr: self.t,
470                    });
471                }
472            }
473        }
474        Ok(())
475    }
476}
477
478impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
479    Interpolation<T, Y> for ExplicitRungeKutta<Delay, Adaptive, T, Y, D, O, S, I>
480{
481    /// Interpolates the solution at a given time `t_interp`.
482    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
483        let dir = (self.t - self.t_prev).signum();
484        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
485            return Err(Error::OutOfBounds {
486                t_interp,
487                t_prev: self.t_prev,
488                t_curr: self.t,
489            });
490        }
491
492        // If method has dense output coefficients, use them
493        if self.bi.is_some() {
494            // Calculate the normalized distance within the step [0, 1]
495            let theta = (t_interp - self.t_prev) / self.h_prev;
496
497            // Get the interpolation coefficients
498            let dense_coeffs = self.bi.as_ref().unwrap();
499
500            let mut coeffs = [T::zero(); I];
501            // Compute the interpolation coefficients using Horner's method
502            for i in 0..self.dense_stages {
503                // Start with the highest-order term
504                coeffs[i] = dense_coeffs[i][self.order - 1];
505
506                // Apply Horner's method
507                for j in (0..self.order - 1).rev() {
508                    coeffs[i] = coeffs[i] * theta + dense_coeffs[i][j];
509                }
510
511                // Multiply by s
512                coeffs[i] *= theta;
513            }
514
515            // Compute the interpolated value
516            let mut y_interp = self.y_prev;
517            for i in 0..I {
518                y_interp += self.k[i] * coeffs[i] * self.h_prev;
519            }
520
521            Ok(y_interp)
522        } else {
523            // Otherwise use cubic Hermite interpolation
524            let y_interp = cubic_hermite_interpolate(
525                self.t_prev,
526                self.t,
527                &self.y_prev,
528                &self.y,
529                &self.dydt_prev,
530                &self.dydt,
531                t_interp,
532            );
533
534            Ok(y_interp)
535        }
536    }
537}