Skip to main content

differential_equations/methods/erk/dormandprince/
delay.rs

1//! Dormand–Prince explicit Runge–Kutta methods for Delay Differential Equations (DDEs)
2use std::collections::VecDeque;
3
4use crate::{
5    dde::{DDE, DelayNumericalMethod},
6    error::Error,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    methods::{Delay, DormandPrince, ExplicitRungeKutta, h_init::InitialStepSize},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15impl<
16    const L: usize,
17    T: Real,
18    Y: State<T>,
19    H: Fn(T) -> Y,
20    const O: usize,
21    const S: usize,
22    const I: usize,
23> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
24{
25    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
26    where
27        F: DDE<L, T, Y> + ?Sized,
28    {
29        let mut evals = Evals::new();
30
31        // DDE requires at least one lag
32        if L == 0 {
33            return Err(Error::NoLags);
34        }
35
36        // Initialize solver state
37        self.t0 = t0;
38        self.t = t0;
39        self.y = y0.clone();
40        self.dydt = y0.zeros_like();
41        self.y_prev = y0.clone();
42        self.dydt_prev = y0.zeros_like();
43        self.k = core::array::from_fn(|_| y0.zeros_like());
44        self.cont = core::array::from_fn(|_| y0.zeros_like());
45        self.t_prev = self.t;
46        self.y_prev = self.y.clone();
47        self.status = Status::Initialized;
48        self.steps = 0;
49        self.stiffness_counter = 0;
50        self.non_stiffness_counter = 0;
51        self.history = VecDeque::new();
52
53        // Delay buffers
54        let mut delays = [T::zero(); L];
55        let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
56
57        // Evaluate initial delays and history
58        dde.lags(self.t, &self.y, &mut delays);
59        for i in 0..L {
60            let t_delayed = self.t - delays[i];
61            // Ensure delayed time is within history range
62            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
63                return Err(Error::BadInput {
64                    msg: format!("Delayed time {} is beyond initial time {}", t_delayed, t0),
65                });
66            }
67            y_delayed[i] = phi(t_delayed);
68        }
69
70        // Initial derivative
71        dde.diff(self.t, &self.y, &y_delayed, &mut self.k[0]);
72        self.dydt = self.k[0].clone();
73        evals.function += 1;
74        self.dydt_prev = self.dydt.clone();
75
76        // Seed history
77        self.history
78            .push_back((self.t, self.y.clone(), self.dydt.clone()));
79
80        // Initial step size
81        if self.h0 == T::zero() {
82            self.h0 = InitialStepSize::<Delay>::compute(
83                dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
84                &self.k[0], &mut evals,
85            );
86        }
87
88        // Validate and set initial step size h
89        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
90            Ok(h0) => self.h = (self.filter)(h0),
91            Err(status) => return Err(status),
92        }
93        Ok(evals)
94    }
95
96    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
97    where
98        F: DDE<L, T, Y> + ?Sized,
99    {
100        let mut evals = Evals::new();
101
102        // Validate step size
103        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
104            self.status = Status::Error(Error::StepSize {
105                t: self.t,
106                y: self.y.clone(),
107            });
108            return Err(Error::StepSize {
109                t: self.t,
110                y: self.y.clone(),
111            });
112        }
113
114        // Check maximum number of steps
115        if self.steps >= self.max_steps {
116            self.status = Status::Error(Error::MaxSteps {
117                t: self.t,
118                y: self.y.clone(),
119            });
120            return Err(Error::MaxSteps {
121                t: self.t,
122                y: self.y.clone(),
123            });
124        }
125        self.steps += 1;
126
127        // Step buffers
128        let mut delays = [T::zero(); L];
129        let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
130
131        // Decide if delay iteration is needed
132        let mut min_delay_abs = T::infinity();
133        // Predict y(t+h) to estimate delays at t+h
134        let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
135        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
136        for i in 0..L {
137            min_delay_abs = min_delay_abs.min(delays[i].abs());
138        }
139
140        // Delay iteration count
141        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
142            5
143        } else {
144            1
145        };
146        let mut y_next_est = self.y.clone();
147        let mut y_next_est_prev = self.y.clone();
148        let mut dde_iter_failed = false;
149        let mut err_norm: T = T::zero();
150        let mut y_last_stage = self.y.zeros_like();
151
152        // DDE iteration loop
153        for it in 0..max_iter {
154            if it > 0 {
155                y_next_est_prev = y_next_est.clone();
156            }
157
158            // Compute stages
159            let mut y_stage = self.y.zeros_like();
160            for i in 1..self.stages {
161                y_stage = self.y.clone();
162                for j in 0..i {
163                    y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
164                }
165
166                // Delayed states for this stage
167                dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
168                if let Err(e) =
169                    self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
170                {
171                    self.status = Status::Error(e.clone());
172                    return Err(e);
173                }
174                dde.diff(
175                    self.t + self.c[i] * self.h,
176                    &y_stage,
177                    &y_delayed,
178                    &mut self.k[i],
179                );
180            }
181            evals.function += self.stages - 1;
182
183            // Keep last stage for stiffness detection
184            y_last_stage = y_stage.clone();
185
186            // RK combination
187            let mut yseg = self.y.zeros_like();
188            for i in 0..self.stages {
189                yseg.add_scaled(self.b[i], &self.k[i]);
190            }
191
192            let y_new = self.y.plus_scaled(self.h, &yseg);
193
194            // Dormand–Prince error estimation
195            let er = self.er.unwrap();
196            let n = self.y.len();
197            let mut err2 = T::zero();
198            let mut err_state = self.y.zeros_like();
199            for (j, coefficient) in er.iter().enumerate().take(self.stages) {
200                err_state.add_scaled(*coefficient, &self.k[j]);
201            }
202            let err_val = self
203                .y
204                .error_norm(&y_new, &err_state, &self.atol, &self.rtol);
205
206            if let Some(bh) = &self.bh {
207                let mut err2_state = yseg.clone();
208                for (j, coefficient) in bh.iter().enumerate().take(self.stages) {
209                    err2_state.add_scaled(-*coefficient, &self.k[j]);
210                }
211                err2 = self
212                    .y
213                    .error_norm(&y_new, &err2_state, &self.atol, &self.rtol);
214            }
215            let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
216            if deno <= T::zero() {
217                deno = T::one();
218            }
219            err_norm =
220                self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
221
222            // Convergence check (if iterating)
223            if max_iter > 1 && it > 0 {
224                let n_dim = self.y.len();
225                let iter_diff = y_new.minus(&y_next_est_prev);
226                let mut dde_iteration_error =
227                    y_next_est_prev.error_norm(&y_new, &iter_diff, &self.atol, &self.rtol);
228                if n_dim > 0 {
229                    dde_iteration_error =
230                        (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
231                }
232
233                if dde_iteration_error <= self.rtol.average() * T::from_f64(0.1).unwrap() {
234                    break;
235                }
236                if it == max_iter - 1 {
237                    dde_iter_failed =
238                        dde_iteration_error > self.rtol.average() * T::from_f64(0.1).unwrap();
239                }
240            }
241            y_next_est = y_new.clone();
242        }
243
244        // Iteration failed: reduce h and retry
245        if dde_iter_failed {
246            let sign = self.h.signum();
247            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
248            if L > 0
249                && min_delay_abs > T::zero()
250                && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
251            {
252                self.h = min_delay_abs * sign;
253            }
254            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
255            self.h = (self.filter)(self.h);
256            self.status = Status::RejectedStep;
257            return Ok(evals);
258        }
259
260        // Step size control
261        let order = T::from_usize(self.order).unwrap();
262        let error_exponent = T::one() / order;
263        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
264
265        // Clamp scale factor
266        scale = scale.max(self.min_scale).min(self.max_scale);
267
268        // Accept/reject
269        if err_norm <= T::one() {
270            let y_new = y_next_est.clone();
271            let t_new = self.t + self.h;
272
273            // Derivative at new point
274            dde.lags(t_new, &y_new, &mut delays);
275            if let Err(e) = self.lagvals(t_new, &delays, &mut y_delayed, phi) {
276                self.status = Status::Error(e.clone());
277                return Err(e);
278            }
279            dde.diff(t_new, &y_new, &y_delayed, &mut self.dydt);
280            evals.function += 1;
281            // Stiffness detection (every 100 steps)
282            let n_stiff_threshold = 100;
283            if self.steps.is_multiple_of(n_stiff_threshold) {
284                let mut yseg = self.y.zeros_like();
285                for i in 0..self.stages {
286                    yseg.add_scaled(self.b[i], &self.k[i]);
287                }
288                let stdnum = yseg.diff_norm_squared(&self.k[S - 1]);
289                let stden = self.dydt.diff_norm_squared(&y_last_stage);
290
291                if stden > T::zero() {
292                    let h_lamb = self.h * (stdnum / stden).sqrt();
293                    if h_lamb > T::from_f64(6.1).unwrap() {
294                        self.non_stiffness_counter = 0;
295                        self.stiffness_counter += 1;
296                        if self.stiffness_counter == 15 {
297                            self.status = Status::Error(Error::Stiffness {
298                                t: self.t,
299                                y: self.y.clone(),
300                            });
301                            return Err(Error::Stiffness {
302                                t: self.t,
303                                y: self.y.clone(),
304                            });
305                        }
306                    }
307                } else {
308                    self.non_stiffness_counter += 1;
309                    if self.non_stiffness_counter == 6 {
310                        self.stiffness_counter = 0;
311                    }
312                }
313            }
314
315            // Prepare dense output / interpolation
316            self.cont[0] = self.y.clone();
317            let ydiff = y_new.minus(&self.y);
318            self.cont[1] = ydiff.clone();
319            let mut bspl = ydiff.zeros_like();
320            bspl.add_scaled(self.h, &self.k[0]);
321            bspl.add_scaled(-T::one(), &ydiff);
322            self.cont[2] = bspl.clone();
323            let mut cont3 = ydiff;
324            cont3.add_scaled(-self.h, &self.dydt);
325            cont3.add_scaled(-T::one(), &bspl);
326            self.cont[3] = cont3;
327
328            // Dense output stages
329            if self.bi.is_some() {
330                if I > S {
331                    self.k[self.stages] = self.dydt.clone();
332                    for i in S + 1..I {
333                        let mut y_stage = self.y.clone();
334                        for j in 0..i {
335                            y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
336                        }
337
338                        let t_stage = self.t + self.c[i] * self.h;
339                        dde.lags(t_stage, &y_stage, &mut delays);
340                        if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
341                            self.status = Status::Error(e.clone());
342                            return Err(e);
343                        }
344                        dde.diff(t_stage, &y_stage, &y_delayed, &mut self.k[i]);
345                        evals.function += 1;
346                    }
347                }
348
349                // Dense output coefficients
350                for i in 4..self.order {
351                    self.cont[i].fill(T::zero());
352                    for j in 0..self.dense_stages {
353                        let bi = self.bi.as_ref().expect("dense output coefficients checked");
354                        self.cont[i].add_scaled(bi[i][j], &self.k[j]);
355                    }
356                    self.cont[i].scale_by(self.h);
357                }
358            }
359
360            // For interpolation
361            self.t_prev = self.t;
362            self.y_prev = self.y.clone();
363            self.dydt_prev = self.k[0].clone();
364            self.h_prev = self.h;
365
366            // Advance state
367            self.t = t_new;
368            self.y = y_new;
369            self.k[0] = self.dydt.clone();
370
371            // Append to history and prune
372            self.history
373                .push_back((self.t, self.y.clone(), self.dydt.clone()));
374            if let Some(max_delay) = self.max_delay {
375                let cutoff_time = self.t - max_delay;
376                while let Some((t_front, _, _)) = self.history.get(1) {
377                    if *t_front < cutoff_time {
378                        self.history.pop_front();
379                    } else {
380                        break;
381                    }
382                }
383            }
384            if let Status::RejectedStep = self.status {
385                self.status = Status::Solving;
386                scale = scale.min(T::one());
387            }
388        } else {
389            // Step rejected
390            self.status = Status::RejectedStep;
391        }
392
393        // Update step size
394        self.h *= scale;
395        // Enforce bounds
396        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
397        // Apply step size filter
398        self.h = (self.filter)(self.h);
399
400        Ok(evals)
401    }
402
403    fn t(&self) -> T {
404        self.t
405    }
406    fn y(&self) -> &Y {
407        &self.y
408    }
409    fn t_prev(&self) -> T {
410        self.t_prev
411    }
412    fn y_prev(&self) -> &Y {
413        &self.y_prev
414    }
415    fn h(&self) -> T {
416        self.h
417    }
418    fn set_h(&mut self, h: T) {
419        self.h = (self.filter)(h);
420    }
421    fn status(&self) -> &Status<T, Y> {
422        &self.status
423    }
424    fn set_status(&mut self, status: Status<T, Y>) {
425        self.status = status;
426    }
427}
428
429impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
430    ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
431{
432    fn lagvals<const L: usize, H>(
433        &mut self,
434        t_stage: T,
435        lags: &[T; L],
436        yd: &mut [Y; L],
437        phi: &H,
438    ) -> Result<(), Error<T, Y>>
439    where
440        H: Fn(T) -> Y,
441    {
442        for i in 0..L {
443            let t_delayed = t_stage - lags[i];
444
445            // Check if delayed time falls within the history period (t_delayed <= t0)
446            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
447                yd[i] = phi(t_delayed);
448            // If t_delayed is after t_prev then use interpolation function
449            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
450                if self.bi.is_some() {
451                    let theta = (t_delayed - self.t_prev) / self.h_prev;
452                    let one_minus_theta = T::one() - theta;
453
454                    // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
455                    let ilast = self.cont.len() - 1;
456                    let poly = (1..ilast)
457                        .rev()
458                        .fold(self.cont[ilast].clone(), |mut acc, i| {
459                            let factor = if i >= 4 {
460                                if (ilast - i) % 2 == 1 {
461                                    one_minus_theta
462                                } else {
463                                    theta
464                                }
465                            } else if i % 2 == 1 {
466                                one_minus_theta
467                            } else {
468                                theta
469                            };
470                            acc.scale_by(factor);
471                            acc.add_scaled(T::one(), &self.cont[i]);
472                            acc
473                        });
474
475                    // Final multiplication by theta for the outermost level
476                    let y_interp = self.cont[0].plus_scaled(theta, &poly);
477                    yd[i] = y_interp;
478                } else {
479                    yd[i] = cubic_hermite_interpolate(
480                        self.t_prev,
481                        self.t,
482                        &self.y_prev,
483                        &self.y,
484                        &self.dydt_prev,
485                        &self.dydt,
486                        t_delayed,
487                    );
488                }
489            // If t_delayed is before t_prev and after t0, we need to search in the history
490            } else {
491                // Search through history to find appropriate interpolation points
492                let mut found_interpolation = false;
493                let buffer = &self.history;
494                // Find two consecutive points that sandwich t_delayed using iterators
495                let mut buffer_iter = buffer.iter();
496                if let Some(mut prev_entry) = buffer_iter.next() {
497                    for curr_entry in buffer_iter {
498                        let (t_left, y_left, dydt_left) = prev_entry;
499                        let (t_right, y_right, dydt_right) = curr_entry;
500
501                        // Check if t_delayed is between these two points
502                        let is_between = if self.h.signum() > T::zero() {
503                            *t_left <= t_delayed && t_delayed <= *t_right
504                        } else {
505                            *t_right <= t_delayed && t_delayed <= *t_left
506                        };
507
508                        if is_between {
509                            yd[i] = cubic_hermite_interpolate(
510                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
511                                t_delayed,
512                            );
513                            found_interpolation = true;
514                            break;
515                        }
516                        prev_entry = curr_entry;
517                    }
518                }
519                // If not found in history, this indicates insufficient history in buffer
520                if !found_interpolation {
521                    return Err(Error::InsufficientHistory {
522                        t_delayed,
523                        t_prev: self.t_prev,
524                        t_curr: self.t,
525                    });
526                }
527            }
528        }
529        Ok(())
530    }
531}
532
533impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
534    for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
535{
536    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
537        // Check if interpolation is out of bounds
538        let dir = (self.t - self.t_prev).signum();
539        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
540            return Err(Error::OutOfBounds {
541                t_interp,
542                t_prev: self.t_prev,
543                t_curr: self.t,
544            });
545        }
546
547        // Evaluate the interpolation polynomial at the requested time
548        let theta = (t_interp - self.t_prev) / self.h_prev;
549        let one_minus_theta = T::one() - theta;
550
551        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
552        let ilast = self.cont.len() - 1;
553        let poly = (1..ilast)
554            .rev()
555            .fold(self.cont[ilast].clone(), |mut acc, i| {
556                let factor = if i >= 4 {
557                    if (ilast - i) % 2 == 1 {
558                        one_minus_theta
559                    } else {
560                        theta
561                    }
562                } else if i % 2 == 1 {
563                    one_minus_theta
564                } else {
565                    theta
566                };
567                acc.scale_by(factor);
568                acc.add_scaled(T::one(), &self.cont[i]);
569                acc
570            });
571
572        // Final multiplication by theta for the outermost level
573        let y_interp = self.cont[0].plus_scaled(theta, &poly);
574
575        Ok(y_interp)
576    }
577}