differential_equations/methods/erk/dormandprince/
delay.rs

1//! Dormand–Prince explicit Runge–Kutta methods for Delay Differential Equations (DDEs)
2
3use std::collections::VecDeque;
4
5use crate::{
6    dde::{DDE, DelayNumericalMethod},
7    error::Error,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    methods::{Delay, DormandPrince, ExplicitRungeKutta, h_init::InitialStepSize},
10    stats::Evals,
11    status::Status,
12    traits::{Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17    const L: usize,
18    T: Real,
19    Y: State<T>,
20    H: Fn(T) -> Y,
21    const O: usize,
22    const S: usize,
23    const I: usize,
24> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
25{
26    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
27    where
28        F: DDE<L, T, Y>,
29    {
30        let mut evals = Evals::new();
31
32        // DDE requires at least one lag
33        if L <= 0 {
34            return Err(Error::NoLags);
35        }
36
37        // Initialize solver state
38        self.t0 = t0;
39        self.t = t0;
40        self.y = *y0;
41        self.t_prev = self.t;
42        self.y_prev = self.y;
43        self.status = Status::Initialized;
44        self.steps = 0;
45        self.stiffness_counter = 0;
46        self.non_stiffness_counter = 0;
47        self.history = VecDeque::new();
48
49        // Delay buffers
50        let mut delays = [T::zero(); L];
51        let mut y_delayed = [Y::zeros(); L];
52
53        // Evaluate initial delays and history
54        dde.lags(self.t, &self.y, &mut delays);
55        for i in 0..L {
56            let t_delayed = self.t - delays[i];
57            // Ensure delayed time is within history range
58            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
59                return Err(Error::BadInput {
60                    msg: format!("Delayed time {} is beyond initial time {}", t_delayed, t0),
61                });
62            }
63            y_delayed[i] = phi(t_delayed);
64        }
65
66        // Initial derivative
67        dde.diff(self.t, &self.y, &y_delayed, &mut self.k[0]);
68        self.dydt = self.k[0];
69        evals.function += 1;
70        self.dydt_prev = self.dydt;
71
72        // Seed history
73        self.history.push_back((self.t, self.y, self.dydt));
74
75        // Initial step size
76        if self.h0 == T::zero() {
77            self.h0 = InitialStepSize::<Delay>::compute(
78                dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
79                &self.k[0], &mut evals,
80            );
81        }
82
83        // Validate and set initial step size h
84        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
85            Ok(h0) => self.h = h0,
86            Err(status) => return Err(status),
87        }
88        Ok(evals)
89    }
90
91    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
92    where
93        F: DDE<L, T, Y>,
94    {
95        let mut evals = Evals::new();
96
97        // Validate step size
98        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
99            self.status = Status::Error(Error::StepSize {
100                t: self.t,
101                y: self.y,
102            });
103            return Err(Error::StepSize {
104                t: self.t,
105                y: self.y,
106            });
107        }
108
109        // Check maximum number of steps
110        if self.steps >= self.max_steps {
111            self.status = Status::Error(Error::MaxSteps {
112                t: self.t,
113                y: self.y,
114            });
115            return Err(Error::MaxSteps {
116                t: self.t,
117                y: self.y,
118            });
119        }
120        self.steps += 1;
121
122        // Step buffers
123        let mut delays = [T::zero(); L];
124        let mut y_delayed = [Y::zeros(); L];
125
126        // Decide if delay iteration is needed
127        let mut min_delay_abs = T::infinity();
128        // Predict y(t+h) to estimate delays at t+h
129        let y_pred_for_lags = self.y + self.k[0] * self.h;
130        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
131        for i in 0..L {
132            min_delay_abs = min_delay_abs.min(delays[i].abs());
133        }
134
135        // Delay iteration count
136        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
137            5
138        } else {
139            1
140        };
141        let mut y_next_est = self.y;
142        let mut y_next_est_prev = self.y;
143        let mut dde_iter_failed = false;
144        let mut err_norm: T = T::zero();
145        let mut y_last_stage = Y::zeros();
146
147        // DDE iteration loop
148        for it in 0..max_iter {
149            if it > 0 {
150                y_next_est_prev = y_next_est;
151            }
152
153            // Compute stages
154            let mut y_stage = Y::zeros();
155            for i in 1..self.stages {
156                y_stage = Y::zeros();
157                for j in 0..i {
158                    y_stage += self.k[j] * self.a[i][j];
159                }
160                y_stage = self.y + y_stage * self.h;
161
162                // Delayed states for this stage
163                dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
164                if let Err(e) =
165                    self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
166                {
167                    self.status = Status::Error(e.clone());
168                    return Err(e);
169                }
170                dde.diff(
171                    self.t + self.c[i] * self.h,
172                    &y_stage,
173                    &y_delayed,
174                    &mut self.k[i],
175                );
176            }
177            evals.function += self.stages - 1;
178
179            // Keep last stage for stiffness detection
180            y_last_stage = y_stage;
181
182            // RK combination
183            let mut yseg = Y::zeros();
184            for i in 0..self.stages {
185                yseg += self.k[i] * self.b[i];
186            }
187
188            let y_new = self.y + yseg * self.h;
189
190            // Dormand–Prince error estimation
191            let er = self.er.unwrap();
192            let n = self.y.len();
193            let mut err_val = T::zero();
194            let mut err2 = T::zero();
195            let mut erri;
196            for i in 0..n {
197                // Calculate the error scale
198                let sk = self.atol[i] + self.rtol[i] * self.y.get(i).abs().max(y_new.get(i).abs());
199
200                // Primary error term
201                erri = T::zero();
202                for j in 0..self.stages {
203                    erri += er[j] * self.k[j].get(i);
204                }
205                err_val += (erri / sk).powi(2);
206
207                // Optional secondary error term
208                if let Some(bh) = &self.bh {
209                    erri = yseg.get(i);
210                    for j in 0..self.stages {
211                        erri -= bh[j] * self.k[j].get(i);
212                    }
213                    err2 += (erri / sk).powi(2);
214                }
215            }
216            let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
217            if deno <= T::zero() {
218                deno = T::one();
219            }
220            err_norm =
221                self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
222
223            // Convergence check (if iterating)
224            if max_iter > 1 && it > 0 {
225                let mut dde_iteration_error = T::zero();
226                let n_dim = self.y.len();
227                for i_dim in 0..n_dim {
228                    let scale = self.atol[i_dim]
229                        + self.rtol[i_dim]
230                            * y_next_est_prev.get(i_dim).abs().max(y_new.get(i_dim).abs());
231                    if scale > T::zero() {
232                        let diff_val = y_new.get(i_dim) - y_next_est_prev.get(i_dim);
233                        dde_iteration_error += (diff_val / scale).powi(2);
234                    }
235                }
236                if n_dim > 0 {
237                    dde_iteration_error =
238                        (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
239                }
240
241                if dde_iteration_error <= self.rtol.average() * T::from_f64(0.1).unwrap() {
242                    break;
243                }
244                if it == max_iter - 1 {
245                    dde_iter_failed =
246                        dde_iteration_error > self.rtol.average() * T::from_f64(0.1).unwrap();
247                }
248            }
249            y_next_est = y_new;
250        }
251
252        // Iteration failed: reduce h and retry
253        if dde_iter_failed {
254            let sign = self.h.signum();
255            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
256            if L > 0
257                && min_delay_abs > T::zero()
258                && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
259            {
260                self.h = min_delay_abs * sign;
261            }
262            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
263            self.status = Status::RejectedStep;
264            return Ok(evals);
265        }
266
267        // Step size control
268        let order = T::from_usize(self.order).unwrap();
269        let error_exponent = T::one() / order;
270        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
271
272        // Clamp scale factor
273        scale = scale.max(self.min_scale).min(self.max_scale);
274
275        // Accept/reject
276        if err_norm <= T::one() {
277            let y_new = y_next_est;
278            let t_new = self.t + self.h;
279
280            // Derivative at new point
281            dde.lags(t_new, &y_new, &mut delays);
282            if let Err(e) = self.lagvals(t_new, &delays, &mut y_delayed, phi) {
283                self.status = Status::Error(e.clone());
284                return Err(e);
285            }
286            dde.diff(t_new, &y_new, &y_delayed, &mut self.dydt);
287            evals.function += 1;
288            // Stiffness detection (every 100 steps)
289            let n_stiff_threshold = 100;
290            if self.steps % n_stiff_threshold == 0 {
291                let mut stdnum = T::zero();
292                let mut stden = T::zero();
293                let sqr = {
294                    let mut yseg = Y::zeros();
295                    for i in 0..self.stages {
296                        yseg += self.k[i] * self.b[i];
297                    }
298                    yseg - self.k[S - 1]
299                };
300                for i in 0..sqr.len() {
301                    stdnum += sqr.get(i).powi(2);
302                }
303                let sqr = self.dydt - y_last_stage;
304                for i in 0..sqr.len() {
305                    stden += sqr.get(i).powi(2);
306                }
307
308                if stden > T::zero() {
309                    let h_lamb = self.h * (stdnum / stden).sqrt();
310                    if h_lamb > T::from_f64(6.1).unwrap() {
311                        self.non_stiffness_counter = 0;
312                        self.stiffness_counter += 1;
313                        if self.stiffness_counter == 15 {
314                            self.status = Status::Error(Error::Stiffness {
315                                t: self.t,
316                                y: self.y,
317                            });
318                            return Err(Error::Stiffness {
319                                t: self.t,
320                                y: self.y,
321                            });
322                        }
323                    }
324                } else {
325                    self.non_stiffness_counter += 1;
326                    if self.non_stiffness_counter == 6 {
327                        self.stiffness_counter = 0;
328                    }
329                }
330            }
331
332            // Prepare dense output / interpolation
333            self.cont[0] = self.y;
334            let ydiff = y_new - self.y;
335            self.cont[1] = ydiff;
336            let bspl = self.k[0] * self.h - ydiff;
337            self.cont[2] = bspl;
338            self.cont[3] = ydiff - self.dydt * self.h - bspl;
339
340            // Dense output stages
341            if let Some(bi) = &self.bi {
342                if I > S {
343                    self.k[self.stages] = self.dydt;
344                    for i in S + 1..I {
345                        let mut y_stage = Y::zeros();
346                        for j in 0..i {
347                            y_stage += self.k[j] * self.a[i][j];
348                        }
349                        y_stage = self.y + y_stage * self.h;
350
351                        dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
352                        for lag_idx in 0..L {
353                            let t_delayed = (self.t + self.c[i] * self.h) - delays[lag_idx];
354
355                            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
356                                y_delayed[lag_idx] = phi(t_delayed);
357                            } else if (t_delayed - self.t_prev) * self.h.signum()
358                                > T::default_epsilon()
359                            {
360                                if self.bi.is_some() {
361                                    let theta = (t_delayed - self.t_prev) / self.h_prev;
362                                    let one_minus_theta = T::one() - theta;
363                                    let ilast = self.cont.len() - 1;
364                                    let poly =
365                                        (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {
366                                            let factor = if cont_i >= 4 {
367                                                if (ilast - cont_i) % 2 == 1 {
368                                                    one_minus_theta
369                                                } else {
370                                                    theta
371                                                }
372                                            } else if cont_i % 2 == 1 {
373                                                one_minus_theta
374                                            } else {
375                                                theta
376                                            };
377                                            acc * factor + self.cont[cont_i]
378                                        });
379                                    y_delayed[lag_idx] = self.cont[0] + poly * theta;
380                                } else {
381                                    y_delayed[lag_idx] = cubic_hermite_interpolate(
382                                        self.t_prev,
383                                        self.t,
384                                        &self.y_prev,
385                                        &self.y,
386                                        &self.dydt_prev,
387                                        &self.dydt,
388                                        t_delayed,
389                                    );
390                                }
391                            } else {
392                                let mut found_interpolation = false;
393                                let buffer = &self.history;
394                                let mut buffer_iter = buffer.iter();
395                                if let Some(mut prev_entry) = buffer_iter.next() {
396                                    for curr_entry in buffer_iter {
397                                        let (t_left, y_left, dydt_left) = prev_entry;
398                                        let (t_right, y_right, dydt_right) = curr_entry;
399
400                                        let is_between = if self.h.signum() > T::zero() {
401                                            *t_left <= t_delayed && t_delayed <= *t_right
402                                        } else {
403                                            *t_right <= t_delayed && t_delayed <= *t_left
404                                        };
405
406                                        if is_between {
407                                            y_delayed[lag_idx] = cubic_hermite_interpolate(
408                                                *t_left, *t_right, y_left, y_right, dydt_left,
409                                                dydt_right, t_delayed,
410                                            );
411                                            found_interpolation = true;
412                                            break;
413                                        }
414                                        prev_entry = curr_entry;
415                                    }
416                                }
417                                if !found_interpolation {
418                                    return Err(Error::InsufficientHistory {
419                                        t_delayed,
420                                        t_prev: self.t_prev,
421                                        t_curr: self.t,
422                                    });
423                                }
424                            }
425                        }
426                        dde.diff(
427                            self.t + self.c[i] * self.h,
428                            &y_stage,
429                            &y_delayed,
430                            &mut self.k[i],
431                        );
432                        evals.function += 1;
433                    }
434                }
435
436                // Dense output coefficients
437                for i in 4..self.order {
438                    self.cont[i] = Y::zeros();
439                    for j in 0..self.dense_stages {
440                        self.cont[i] += self.k[j] * bi[i][j];
441                    }
442                    self.cont[i] = self.cont[i] * self.h;
443                }
444            }
445
446            // For interpolation
447            self.t_prev = self.t;
448            self.y_prev = self.y;
449            self.dydt_prev = self.k[0];
450            self.h_prev = self.h;
451
452            // Advance state
453            self.t = t_new;
454            self.y = y_new;
455            self.k[0] = self.dydt;
456
457            // Append to history and prune
458            self.history.push_back((self.t, self.y, self.dydt));
459            if let Some(max_delay) = self.max_delay {
460                let cutoff_time = self.t - max_delay;
461                while let Some((t_front, _, _)) = self.history.get(1) {
462                    if *t_front < cutoff_time {
463                        self.history.pop_front();
464                    } else {
465                        break;
466                    }
467                }
468            }
469            if let Status::RejectedStep = self.status {
470                self.status = Status::Solving;
471                scale = scale.min(T::one());
472            }
473        } else {
474            // Step rejected
475            self.status = Status::RejectedStep;
476        }
477
478        // Update step size
479        self.h *= scale;
480        // Enforce bounds
481        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
482
483        Ok(evals)
484    }
485
486    fn t(&self) -> T {
487        self.t
488    }
489    fn y(&self) -> &Y {
490        &self.y
491    }
492    fn t_prev(&self) -> T {
493        self.t_prev
494    }
495    fn y_prev(&self) -> &Y {
496        &self.y_prev
497    }
498    fn h(&self) -> T {
499        self.h
500    }
501    fn set_h(&mut self, h: T) {
502        self.h = h;
503    }
504    fn status(&self) -> &Status<T, Y> {
505        &self.status
506    }
507    fn set_status(&mut self, status: Status<T, Y>) {
508        self.status = status;
509    }
510}
511
512impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
513    ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
514{
515    fn lagvals<const L: usize, H>(
516        &mut self,
517        t_stage: T,
518        lags: &[T; L],
519        yd: &mut [Y; L],
520        phi: &H,
521    ) -> Result<(), Error<T, Y>>
522    where
523        H: Fn(T) -> Y,
524    {
525        for i in 0..L {
526            let t_delayed = t_stage - lags[i];
527
528            // Check if delayed time falls within the history period (t_delayed <= t0)
529            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
530                yd[i] = phi(t_delayed);
531            // If t_delayed is after t_prev then use interpolation function
532            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
533                if self.bi.is_some() {
534                    let theta = (t_delayed - self.t_prev) / self.h_prev;
535                    let one_minus_theta = T::one() - theta;
536
537                    // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
538                    let ilast = self.cont.len() - 1;
539                    let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
540                        let factor = if i >= 4 {
541                            if (ilast - i) % 2 == 1 {
542                                one_minus_theta
543                            } else {
544                                theta
545                            }
546                        } else if i % 2 == 1 {
547                            one_minus_theta
548                        } else {
549                            theta
550                        };
551                        acc * factor + self.cont[i]
552                    });
553
554                    // Final multiplication by theta for the outermost level
555                    let y_interp = self.cont[0] + poly * theta;
556                    yd[i] = y_interp;
557                } else {
558                    yd[i] = cubic_hermite_interpolate(
559                        self.t_prev,
560                        self.t,
561                        &self.y_prev,
562                        &self.y,
563                        &self.dydt_prev,
564                        &self.dydt,
565                        t_delayed,
566                    );
567                }
568            // If t_delayed is before t_prev and after t0, we need to search in the history
569            } else {
570                // Search through history to find appropriate interpolation points
571                let mut found_interpolation = false;
572                let buffer = &self.history;
573                // Find two consecutive points that sandwich t_delayed using iterators
574                let mut buffer_iter = buffer.iter();
575                if let Some(mut prev_entry) = buffer_iter.next() {
576                    for curr_entry in buffer_iter {
577                        let (t_left, y_left, dydt_left) = prev_entry;
578                        let (t_right, y_right, dydt_right) = curr_entry;
579
580                        // Check if t_delayed is between these two points
581                        let is_between = if self.h.signum() > T::zero() {
582                            *t_left <= t_delayed && t_delayed <= *t_right
583                        } else {
584                            *t_right <= t_delayed && t_delayed <= *t_left
585                        };
586
587                        if is_between {
588                            yd[i] = cubic_hermite_interpolate(
589                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
590                                t_delayed,
591                            );
592                            found_interpolation = true;
593                            break;
594                        }
595                        prev_entry = curr_entry;
596                    }
597                }
598                // If not found in history, this indicates insufficient history in buffer
599                if !found_interpolation {
600                    return Err(Error::InsufficientHistory {
601                        t_delayed,
602                        t_prev: self.t_prev,
603                        t_curr: self.t,
604                    });
605                }
606            }
607        }
608        Ok(())
609    }
610}
611
612impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
613    for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
614{
615    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
616        // Check if interpolation is out of bounds
617        let dir = (self.t - self.t_prev).signum();
618        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
619            return Err(Error::OutOfBounds {
620                t_interp,
621                t_prev: self.t_prev,
622                t_curr: self.t,
623            });
624        }
625
626        // Evaluate the interpolation polynomial at the requested time
627        let theta = (t_interp - self.t_prev) / self.h_prev;
628        let one_minus_theta = T::one() - theta;
629
630        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
631        let ilast = self.cont.len() - 1;
632        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
633            let factor = if i >= 4 {
634                if (ilast - i) % 2 == 1 {
635                    one_minus_theta
636                } else {
637                    theta
638                }
639            } else if i % 2 == 1 {
640                one_minus_theta
641            } else {
642                theta
643            };
644            acc * factor + self.cont[i]
645        });
646
647        // Final multiplication by theta for the outermost level
648        let y_interp = self.cont[0] + poly * theta;
649
650        Ok(y_interp)
651    }
652}