differential_equations/methods/erk/adaptive/
delay.rs

1//! Adaptive Runge-Kutta methods for Delay Differential Equations (DDEs)
2
3use std::collections::VecDeque;
4
5use crate::{
6    dde::{DDE, DelayNumericalMethod},
7    error::Error,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    methods::{Adaptive, Delay, ExplicitRungeKutta, h_init::InitialStepSize},
10    stats::Evals,
11    status::Status,
12    traits::{Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17    const L: usize,
18    T: Real,
19    Y: State<T>,
20    H: Fn(T) -> Y,
21    const O: usize,
22    const S: usize,
23    const I: usize,
24> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
25{
26    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
27    where
28        F: DDE<L, T, Y>,
29    {
30        let mut evals = Evals::new();
31
32        // DDE requires at least one lag
33        if L <= 0 {
34            return Err(Error::NoLags);
35        }
36
37        // Init solver state
38        self.t0 = t0;
39        self.t = t0;
40        self.y = *y0;
41        self.t_prev = self.t;
42        self.y_prev = self.y;
43        self.status = Status::Initialized;
44        self.steps = 0;
45        self.stiffness_counter = 0;
46        self.history = VecDeque::new();
47
48        // Delay buffers
49        let mut delays = [T::zero(); L];
50        let mut y_delayed = [Y::zeros(); L];
51
52        // Initial delays and history
53        dde.lags(self.t, &self.y, &mut delays);
54        for i in 0..L {
55            let t_delayed = self.t - delays[i];
56            // Ensure delayed time is within history range
57            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
58                return Err(Error::BadInput {
59                    msg: format!(
60                        "Initial delayed time {} is out of history range (t <= {}).",
61                        t_delayed, t0
62                    ),
63                });
64            }
65            y_delayed[i] = phi(t_delayed);
66        }
67
68        // Initial derivative and seed history
69        dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
70        evals.function += 1;
71        self.dydt_prev = self.dydt; // Store initial state in history
72        self.history.push_back((self.t, self.y, self.dydt));
73
74        // Initial step size
75        if self.h0 == T::zero() {
76            // Adaptive step size for DDEs
77            self.h0 = InitialStepSize::<Delay>::compute(
78                dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
79                &self.k[0], &mut evals,
80            );
81            evals.function += 2; // h_init performs 2 function evaluations
82        }
83
84        // Validate initial step size
85        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
86            Ok(h0) => self.h = h0,
87            Err(status) => return Err(status),
88        }
89        Ok(evals)
90    }
91
92    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
93    where
94        F: DDE<L, T, Y>,
95    {
96        let mut evals = Evals::new();
97
98        // Validate step size
99        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
100            self.status = Status::Error(Error::StepSize {
101                t: self.t,
102                y: self.y,
103            });
104            return Err(Error::StepSize {
105                t: self.t,
106                y: self.y,
107            });
108        }
109
110        // Max steps
111        if self.steps >= self.max_steps {
112            self.status = Status::Error(Error::MaxSteps {
113                t: self.t,
114                y: self.y,
115            });
116            return Err(Error::MaxSteps {
117                t: self.t,
118                y: self.y,
119            });
120        }
121        self.steps += 1;
122
123        // Step buffers
124        let mut delays = [T::zero(); L];
125        let mut y_delayed = [Y::zeros(); L];
126
127        // Seed k[0]
128        self.k[0] = self.dydt;
129
130        // Check if delay iteration is needed
131        let mut min_delay_abs = T::infinity();
132        // Predict y(t+h) to estimate delays at t+h
133        let y_pred_for_lags = self.y + self.k[0] * self.h;
134        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
135        for i in 0..L {
136            min_delay_abs = min_delay_abs.min(delays[i].abs());
137        }
138
139        // Delay iteration count
140        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
141            5
142        } else {
143            1
144        };
145
146        let mut y_next_est = self.y;
147        let mut dydt_next_est = Y::zeros();
148        let mut y_next_est_prev = self.y;
149        let mut dde_iter_failed = false;
150        let mut err_norm: T = T::zero();
151
152        // DDE iteration loop
153        for it in 0..max_iter {
154            if it > 0 {
155                y_next_est_prev = y_next_est;
156            }
157
158            // Compute stages
159            for i in 1..self.stages {
160                let mut y_stage = self.y;
161                for j in 0..i {
162                    y_stage += self.k[j] * (self.a[i][j] * self.h);
163                }
164                // Delayed states for this stage
165                let t_stage = self.t + self.c[i] * self.h;
166                dde.lags(t_stage, &y_stage, &mut delays);
167                if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
168                    self.status = Status::Error(e.clone());
169                    return Err(e);
170                }
171
172                dde.diff(
173                    self.t + self.c[i] * self.h,
174                    &y_stage,
175                    &y_delayed,
176                    &mut self.k[i],
177                );
178            }
179            evals.function += self.stages - 1;
180
181            // High/low order solutions for error
182            let mut y_high = self.y;
183            for i in 0..self.stages {
184                y_high += self.k[i] * (self.b[i] * self.h);
185            }
186            let mut y_low = self.y;
187            let bh = &self.bh.unwrap();
188            for i in 0..self.stages {
189                y_low += self.k[i] * (bh[i] * self.h);
190            }
191            let err_vec: Y = y_high - y_low;
192
193            // Infinity-norm-like error scaled by atol/rtol
194            err_norm = T::zero();
195            for n in 0..self.y.len() {
196                let tol =
197                    self.atol[n] + self.rtol[n] * self.y.get(n).abs().max(y_high.get(n).abs());
198                err_norm = err_norm.max((err_vec.get(n) / tol).abs());
199            }
200
201            // Iteration convergence (if iterating)
202            if max_iter > 1 && it > 0 {
203                let mut iter_err = T::zero();
204                let n_dim = self.y.len();
205                for d in 0..n_dim {
206                    let scale = self.atol[d]
207                        + self.rtol[d] * y_next_est_prev.get(d).abs().max(y_high.get(d).abs());
208                    if scale > T::zero() {
209                        let diff_val = y_high.get(d) - y_next_est_prev.get(d);
210                        iter_err += (diff_val / scale).powi(2);
211                    }
212                }
213                if n_dim > 0 {
214                    iter_err = (iter_err / T::from_usize(n_dim).unwrap()).sqrt();
215                }
216
217                if iter_err <= self.rtol.average() * T::from_f64(0.1).unwrap() {
218                    y_next_est = y_high;
219                    dde.lags(self.t + self.h, &y_next_est, &mut delays);
220                    if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
221                        self.status = Status::Error(e.clone());
222                        return Err(e);
223                    }
224                    dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
225                    evals.function += 1;
226                    break;
227                }
228                if it == max_iter - 1 {
229                    dde_iter_failed = iter_err > self.rtol.average() * T::from_f64(0.1).unwrap();
230                }
231            }
232
233            // Update candidate
234            y_next_est = y_high;
235
236            // Derivative at t+h for candidate
237            dde.lags(self.t + self.h, &y_next_est, &mut delays);
238            if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
239                self.status = Status::Error(e.clone());
240                return Err(e);
241            }
242            dde.diff(self.t + self.h, &y_next_est, &y_delayed, &mut dydt_next_est);
243            evals.function += 1;
244        }
245
246        // Iteration failed: reduce h and retry
247        if dde_iter_failed {
248            let sign = self.h.signum();
249            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
250            if min_delay_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
251            {
252                self.h = min_delay_abs * sign;
253            }
254
255            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
256            self.status = Status::RejectedStep;
257            return Ok(evals);
258        }
259
260        // Step size scale factor
261        let order = T::from_usize(self.order).unwrap();
262        let error_exponent = T::one() / order;
263        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
264        scale = scale.max(self.min_scale).min(self.max_scale);
265
266        // Accept/reject
267        if err_norm <= T::one() {
268            // Accept
269            self.t_prev = self.t;
270            self.y_prev = self.y;
271            self.dydt_prev = self.dydt;
272            self.h_prev = self.h;
273
274            if let Status::RejectedStep = self.status {
275                // Dampen growth after rejection
276                self.stiffness_counter = 0;
277                scale = scale.min(T::one());
278            }
279            self.status = Status::Solving;
280
281            // Dense output stages
282            if self.bi.is_some() {
283                for i in 0..(I - S) {
284                    let mut y_stage = self.y;
285                    for j in 0..self.stages + i {
286                        y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
287                    }
288                    let t_stage = self.t + self.c[self.stages + i] * self.h;
289                    dde.lags(t_stage, &y_stage, &mut delays);
290                    if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
291                        self.status = Status::Error(e.clone());
292                        return Err(e);
293                    }
294                    dde.diff(
295                        self.t + self.c[self.stages + i] * self.h,
296                        &y_stage,
297                        &y_delayed,
298                        &mut self.k[self.stages + i],
299                    );
300                }
301                evals.function += I - S;
302            }
303
304            // Advance state
305            self.t += self.h;
306            self.y = y_next_est;
307
308            // Derivative for next step
309            if self.fsal {
310                self.dydt = self.k[S - 1];
311            } else {
312                dde.lags(self.t, &self.y, &mut delays);
313                if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
314                    self.status = Status::Error(e.clone());
315                    return Err(e);
316                }
317                dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
318                evals.function += 1;
319            }
320
321            // Append to history and prune
322            self.history.push_back((self.t, self.y, self.dydt));
323            if let Some(max_delay) = self.max_delay {
324                let cutoff_time = self.t - max_delay;
325                while let Some((t_front, _, _)) = self.history.get(1) {
326                    if *t_front < cutoff_time {
327                        self.history.pop_front();
328                    } else {
329                        break;
330                    }
331                }
332            }
333        } else {
334            // Reject
335            self.status = Status::RejectedStep;
336            self.stiffness_counter += 1;
337
338            if self.stiffness_counter >= self.max_rejects {
339                self.status = Status::Error(Error::Stiffness {
340                    t: self.t,
341                    y: self.y,
342                });
343                return Err(Error::Stiffness {
344                    t: self.t,
345                    y: self.y,
346                });
347            }
348        }
349
350        // Update step size
351        self.h *= scale;
352        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
353
354        Ok(evals)
355    }
356
357    fn t(&self) -> T {
358        self.t
359    }
360    fn y(&self) -> &Y {
361        &self.y
362    }
363    fn t_prev(&self) -> T {
364        self.t_prev
365    }
366    fn y_prev(&self) -> &Y {
367        &self.y_prev
368    }
369    fn h(&self) -> T {
370        self.h
371    }
372    fn set_h(&mut self, h: T) {
373        self.h = h;
374    }
375    fn status(&self) -> &Status<T, Y> {
376        &self.status
377    }
378    fn set_status(&mut self, status: Status<T, Y>) {
379        self.status = status;
380    }
381}
382
383impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
384    ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
385{
386    fn lagvals<const L: usize, H>(
387        &mut self,
388        t_stage: T,
389        delays: &[T; L],
390        y_delayed: &mut [Y; L],
391        phi: &H,
392    ) -> Result<(), Error<T, Y>>
393    where
394        H: Fn(T) -> Y,
395    {
396        for idx in 0..L {
397            let t_delayed = t_stage - delays[idx];
398
399            // History domain (t_delayed <= t0)
400            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
401                y_delayed[idx] = phi(t_delayed);
402            // Within last accepted step (dense if available, else Hermite)
403            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
404                if self.bi.is_some() {
405                    let theta = (t_delayed - self.t_prev) / self.h_prev;
406                    let dense_coeffs = self.bi.as_ref().unwrap();
407
408                    let mut coeffs = [T::zero(); I];
409                    for s_idx in 0..I {
410                        if s_idx < self.cont.len() && s_idx < dense_coeffs.len() {
411                            coeffs[s_idx] = dense_coeffs[s_idx][self.dense_stages - 1];
412                            for j in (0..self.dense_stages - 1).rev() {
413                                coeffs[s_idx] = coeffs[s_idx] * theta + dense_coeffs[s_idx][j];
414                            }
415                            coeffs[s_idx] *= theta;
416                        }
417                    }
418
419                    let mut y_interp = self.y_prev;
420                    for s_idx in 0..I {
421                        if s_idx < self.k.len() && s_idx < self.cont.len() {
422                            y_interp += self.k[s_idx] * (coeffs[s_idx] * self.h_prev);
423                        }
424                    }
425                    y_delayed[idx] = y_interp;
426                } else {
427                    y_delayed[idx] = cubic_hermite_interpolate(
428                        self.t_prev,
429                        self.t,
430                        &self.y_prev,
431                        &self.y,
432                        &self.dydt_prev,
433                        &self.dydt,
434                        t_delayed,
435                    );
436                }
437            // Between earlier history points (internal buffer)
438            } else {
439                // Search history for bracketing interval
440                let mut found = false;
441                let buffer = &self.history;
442                let mut it = buffer.iter();
443                if let Some(mut left) = it.next() {
444                    for right in it {
445                        let (t_left, y_left, dydt_left) = left;
446                        let (t_right, y_right, dydt_right) = right;
447
448                        let in_interval = if self.h.signum() > T::zero() {
449                            *t_left <= t_delayed && t_delayed <= *t_right
450                        } else {
451                            *t_right <= t_delayed && t_delayed <= *t_left
452                        };
453
454                        if in_interval {
455                            y_delayed[idx] = cubic_hermite_interpolate(
456                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
457                                t_delayed,
458                            );
459                            found = true;
460                            break;
461                        }
462                        left = right;
463                    }
464                }
465                if !found {
466                    return Err(Error::InsufficientHistory {
467                        t_delayed,
468                        t_prev: self.t_prev,
469                        t_curr: self.t,
470                    });
471                }
472            }
473        }
474        Ok(())
475    }
476}
477
478impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
479    for ExplicitRungeKutta<Delay, Adaptive, T, Y, O, S, I>
480{
481    /// Interpolates the solution at a given time `t_interp`.
482    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
483        let dir = (self.t - self.t_prev).signum();
484        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
485            return Err(Error::OutOfBounds {
486                t_interp,
487                t_prev: self.t_prev,
488                t_curr: self.t,
489            });
490        }
491
492        // If method has dense output coefficients, use them
493        if self.bi.is_some() {
494            // Calculate the normalized distance within the step [0, 1]
495            let theta = (t_interp - self.t_prev) / self.h_prev;
496
497            // Get the interpolation coefficients
498            let dense_coeffs = self.bi.as_ref().unwrap();
499
500            let mut coeffs = [T::zero(); I];
501            // Compute the interpolation coefficients using Horner's method
502            for i in 0..self.dense_stages {
503                // Start with the highest-order term
504                coeffs[i] = dense_coeffs[i][self.order - 1];
505
506                // Apply Horner's method
507                for j in (0..self.order - 1).rev() {
508                    coeffs[i] = coeffs[i] * theta + dense_coeffs[i][j];
509                }
510
511                // Multiply by s
512                coeffs[i] *= theta;
513            }
514
515            // Compute the interpolated value
516            let mut y_interp = self.y_prev;
517            for i in 0..I {
518                y_interp += self.k[i] * coeffs[i] * self.h_prev;
519            }
520
521            Ok(y_interp)
522        } else {
523            // Otherwise use cubic Hermite interpolation
524            let y_interp = cubic_hermite_interpolate(
525                self.t_prev,
526                self.t,
527                &self.y_prev,
528                &self.y,
529                &self.dydt_prev,
530                &self.dydt,
531                t_interp,
532            );
533
534            Ok(y_interp)
535        }
536    }
537}