differential_equations/methods/erk/dormandprince/
delay.rs

1//! Dormand–Prince explicit Runge–Kutta methods for Delay Differential Equations (DDEs)
2
3use std::collections::VecDeque;
4
5use crate::{
6    dde::{DDE, DelayNumericalMethod},
7    error::Error,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    methods::{Delay, DormandPrince, ExplicitRungeKutta, h_init::InitialStepSize},
10    stats::Evals,
11    status::Status,
12    traits::{CallBackData, Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<
17    const L: usize,
18    T: Real,
19    Y: State<T>,
20    H: Fn(T) -> Y,
21    D: CallBackData,
22    const O: usize,
23    const S: usize,
24    const I: usize,
25> DelayNumericalMethod<L, T, Y, H, D>
26    for ExplicitRungeKutta<Delay, DormandPrince, T, Y, D, O, S, I>
27{
28    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
29    where
30        F: DDE<L, T, Y, D>,
31    {
32        let mut evals = Evals::new();
33
34        // DDE requires at least one lag
35        if L <= 0 {
36            return Err(Error::NoLags);
37        }
38
39        // Initialize solver state
40        self.t0 = t0;
41        self.t = t0;
42        self.y = *y0;
43        self.t_prev = self.t;
44        self.y_prev = self.y;
45        self.status = Status::Initialized;
46        self.steps = 0;
47        self.stiffness_counter = 0;
48        self.non_stiffness_counter = 0;
49        self.history = VecDeque::new();
50
51        // Delay buffers
52        let mut delays = [T::zero(); L];
53        let mut y_delayed = [Y::zeros(); L];
54
55        // Evaluate initial delays and history
56        dde.lags(self.t, &self.y, &mut delays);
57        for i in 0..L {
58            let t_delayed = self.t - delays[i];
59            // Ensure delayed time is within history range
60            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
61                return Err(Error::BadInput {
62                    msg: format!("Delayed time {} is beyond initial time {}", t_delayed, t0),
63                });
64            }
65            y_delayed[i] = phi(t_delayed);
66        }
67
68        // Initial derivative
69        dde.diff(self.t, &self.y, &y_delayed, &mut self.k[0]);
70        self.dydt = self.k[0];
71        evals.function += 1;
72        self.dydt_prev = self.dydt;
73
74        // Seed history
75        self.history.push_back((self.t, self.y, self.dydt));
76
77        // Initial step size
78        if self.h0 == T::zero() {
79            self.h0 = InitialStepSize::<Delay>::compute(
80                dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi,
81                &self.k[0], &mut evals,
82            );
83        }
84
85        // Validate and set initial step size h
86        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
87            Ok(h0) => self.h = h0,
88            Err(status) => return Err(status),
89        }
90        Ok(evals)
91    }
92
93    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
94    where
95        F: DDE<L, T, Y, D>,
96    {
97        let mut evals = Evals::new();
98
99        // Validate step size
100        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
101            self.status = Status::Error(Error::StepSize {
102                t: self.t,
103                y: self.y,
104            });
105            return Err(Error::StepSize {
106                t: self.t,
107                y: self.y,
108            });
109        }
110
111        // Check maximum number of steps
112        if self.steps >= self.max_steps {
113            self.status = Status::Error(Error::MaxSteps {
114                t: self.t,
115                y: self.y,
116            });
117            return Err(Error::MaxSteps {
118                t: self.t,
119                y: self.y,
120            });
121        }
122        self.steps += 1;
123
124        // Step buffers
125        let mut delays = [T::zero(); L];
126        let mut y_delayed = [Y::zeros(); L];
127
128        // Decide if delay iteration is needed
129        let mut min_delay_abs = T::infinity();
130        // Predict y(t+h) to estimate delays at t+h
131        let y_pred_for_lags = self.y + self.k[0] * self.h;
132        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
133        for i in 0..L {
134            min_delay_abs = min_delay_abs.min(delays[i].abs());
135        }
136
137        // Delay iteration count
138        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
139            5
140        } else {
141            1
142        };
143        let mut y_next_est = self.y;
144        let mut y_next_est_prev = self.y;
145        let mut dde_iter_failed = false;
146        let mut err_norm: T = T::zero();
147        let mut y_last_stage = Y::zeros();
148
149        // DDE iteration loop
150        for it in 0..max_iter {
151            if it > 0 {
152                y_next_est_prev = y_next_est;
153            }
154
155            // Compute stages
156            let mut y_stage = Y::zeros();
157            for i in 1..self.stages {
158                y_stage = Y::zeros();
159                for j in 0..i {
160                    y_stage += self.k[j] * self.a[i][j];
161                }
162                y_stage = self.y + y_stage * self.h;
163
164                // Delayed states for this stage
165                dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
166                if let Err(e) =
167                    self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
168                {
169                    self.status = Status::Error(e.clone());
170                    return Err(e);
171                }
172                dde.diff(
173                    self.t + self.c[i] * self.h,
174                    &y_stage,
175                    &y_delayed,
176                    &mut self.k[i],
177                );
178            }
179            evals.function += self.stages - 1;
180
181            // Keep last stage for stiffness detection
182            y_last_stage = y_stage;
183
184            // RK combination
185            let mut yseg = Y::zeros();
186            for i in 0..self.stages {
187                yseg += self.k[i] * self.b[i];
188            }
189
190            let y_new = self.y + yseg * self.h;
191
192            // Dormand–Prince error estimation
193            let er = self.er.unwrap();
194            let n = self.y.len();
195            let mut err_val = T::zero();
196            let mut err2 = T::zero();
197            let mut erri;
198            for i in 0..n {
199                // Calculate the error scale
200                let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
201
202                // Primary error term
203                erri = T::zero();
204                for j in 0..self.stages {
205                    erri += er[j] * self.k[j].get(i);
206                }
207                err_val += (erri / sk).powi(2);
208
209                // Optional secondary error term
210                if let Some(bh) = &self.bh {
211                    erri = yseg.get(i);
212                    for j in 0..self.stages {
213                        erri -= bh[j] * self.k[j].get(i);
214                    }
215                    err2 += (erri / sk).powi(2);
216                }
217            }
218            let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
219            if deno <= T::zero() {
220                deno = T::one();
221            }
222            err_norm =
223                self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
224
225            // Convergence check (if iterating)
226            if max_iter > 1 && it > 0 {
227                let mut dde_iteration_error = T::zero();
228                let n_dim = self.y.len();
229                for i_dim in 0..n_dim {
230                    let scale = self.atol
231                        + self.rtol * y_next_est_prev.get(i_dim).abs().max(y_new.get(i_dim).abs());
232                    if scale > T::zero() {
233                        let diff_val = y_new.get(i_dim) - y_next_est_prev.get(i_dim);
234                        dde_iteration_error += (diff_val / scale).powi(2);
235                    }
236                }
237                if n_dim > 0 {
238                    dde_iteration_error =
239                        (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
240                }
241
242                if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
243                    break;
244                }
245                if it == max_iter - 1 {
246                    dde_iter_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
247                }
248            }
249            y_next_est = y_new;
250        }
251
252        // Iteration failed: reduce h and retry
253        if dde_iter_failed {
254            let sign = self.h.signum();
255            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
256            if L > 0
257                && min_delay_abs > T::zero()
258                && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
259            {
260                self.h = min_delay_abs * sign;
261            }
262            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
263            self.status = Status::RejectedStep;
264            return Ok(evals);
265        }
266
267        // Step size control
268        let order = T::from_usize(self.order).unwrap();
269        let error_exponent = T::one() / order;
270        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
271
272        // Clamp scale factor
273        scale = scale.max(self.min_scale).min(self.max_scale);
274
275        // Accept/reject
276        if err_norm <= T::one() {
277            let y_new = y_next_est;
278            let t_new = self.t + self.h;
279
280            // Derivative at new point
281            dde.lags(t_new, &y_new, &mut delays);
282            if let Err(e) = self.lagvals(t_new, &delays, &mut y_delayed, phi) {
283                self.status = Status::Error(e.clone());
284                return Err(e);
285            }
286            dde.diff(t_new, &y_new, &y_delayed, &mut self.dydt);
287            evals.function += 1;
288            // Stiffness detection (every 100 steps)
289            let n_stiff_threshold = 100;
290            if self.steps % n_stiff_threshold == 0 {
291                let mut stdnum = T::zero();
292                let mut stden = T::zero();
293                let sqr = {
294                    let mut yseg = Y::zeros();
295                    for i in 0..self.stages {
296                        yseg += self.k[i] * self.b[i];
297                    }
298                    yseg - self.k[S - 1]
299                };
300                for i in 0..sqr.len() {
301                    stdnum += sqr.get(i).powi(2);
302                }
303                let sqr = self.dydt - y_last_stage;
304                for i in 0..sqr.len() {
305                    stden += sqr.get(i).powi(2);
306                }
307
308                if stden > T::zero() {
309                    let h_lamb = self.h * (stdnum / stden).sqrt();
310                    if h_lamb > T::from_f64(6.1).unwrap() {
311                        self.non_stiffness_counter = 0;
312                        self.stiffness_counter += 1;
313                        if self.stiffness_counter == 15 {
314                            self.status = Status::Error(Error::Stiffness {
315                                t: self.t,
316                                y: self.y,
317                            });
318                            return Err(Error::Stiffness {
319                                t: self.t,
320                                y: self.y,
321                            });
322                        }
323                    }
324                } else {
325                    self.non_stiffness_counter += 1;
326                    if self.non_stiffness_counter == 6 {
327                        self.stiffness_counter = 0;
328                    }
329                }
330            }
331
332            // Prepare dense output / interpolation
333            self.cont[0] = self.y;
334            let ydiff = y_new - self.y;
335            self.cont[1] = ydiff;
336            let bspl = self.k[0] * self.h - ydiff;
337            self.cont[2] = bspl;
338            self.cont[3] = ydiff - self.dydt * self.h - bspl;
339
340            // Dense output stages
341            if let Some(bi) = &self.bi {
342                if I > S {
343                    self.k[self.stages] = self.dydt;
344                    for i in S + 1..I {
345                        let mut y_stage = Y::zeros();
346                        for j in 0..i {
347                            y_stage += self.k[j] * self.a[i][j];
348                        }
349                        y_stage = self.y + y_stage * self.h;
350
351                        dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
352                        for lag_idx in 0..L {
353                            let t_delayed = (self.t + self.c[i] * self.h) - delays[lag_idx];
354
355                            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
356                                y_delayed[lag_idx] = phi(t_delayed);
357                            } else if (t_delayed - self.t_prev) * self.h.signum()
358                                > T::default_epsilon()
359                            {
360                                if self.bi.is_some() {
361                                    let theta = (t_delayed - self.t_prev) / self.h_prev;
362                                    let one_minus_theta = T::one() - theta;
363                                    let ilast = self.cont.len() - 1;
364                                    let poly =
365                                        (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {
366                                            let factor = if cont_i >= 4 {
367                                                if (ilast - cont_i) % 2 == 1 {
368                                                    one_minus_theta
369                                                } else {
370                                                    theta
371                                                }
372                                            } else if cont_i % 2 == 1 {
373                                                one_minus_theta
374                                            } else {
375                                                theta
376                                            };
377                                            acc * factor + self.cont[cont_i]
378                                        });
379                                    y_delayed[lag_idx] = self.cont[0] + poly * theta;
380                                } else {
381                                    y_delayed[lag_idx] = cubic_hermite_interpolate(
382                                        self.t_prev,
383                                        self.t,
384                                        &self.y_prev,
385                                        &self.y,
386                                        &self.dydt_prev,
387                                        &self.dydt,
388                                        t_delayed,
389                                    );
390                                }
391                            } else {
392                                let mut found_interpolation = false;
393                                let buffer = &self.history;
394                                let mut buffer_iter = buffer.iter();
395                                if let Some(mut prev_entry) = buffer_iter.next() {
396                                    for curr_entry in buffer_iter {
397                                        let (t_left, y_left, dydt_left) = prev_entry;
398                                        let (t_right, y_right, dydt_right) = curr_entry;
399
400                                        let is_between = if self.h.signum() > T::zero() {
401                                            *t_left <= t_delayed && t_delayed <= *t_right
402                                        } else {
403                                            *t_right <= t_delayed && t_delayed <= *t_left
404                                        };
405
406                                        if is_between {
407                                            y_delayed[lag_idx] = cubic_hermite_interpolate(
408                                                *t_left, *t_right, y_left, y_right, dydt_left,
409                                                dydt_right, t_delayed,
410                                            );
411                                            found_interpolation = true;
412                                            break;
413                                        }
414                                        prev_entry = curr_entry;
415                                    }
416                                }
417                                if !found_interpolation {
418                                    return Err(Error::InsufficientHistory {
419                                        t_delayed,
420                                        t_prev: self.t_prev,
421                                        t_curr: self.t,
422                                    });
423                                }
424                            }
425                        }
426                        dde.diff(
427                            self.t + self.c[i] * self.h,
428                            &y_stage,
429                            &y_delayed,
430                            &mut self.k[i],
431                        );
432                        evals.function += 1;
433                    }
434                }
435
436                // Dense output coefficients
437                for i in 4..self.order {
438                    self.cont[i] = Y::zeros();
439                    for j in 0..self.dense_stages {
440                        self.cont[i] += self.k[j] * bi[i][j];
441                    }
442                    self.cont[i] = self.cont[i] * self.h;
443                }
444            }
445
446            // For interpolation
447            self.t_prev = self.t;
448            self.y_prev = self.y;
449            self.dydt_prev = self.k[0];
450            self.h_prev = self.h;
451
452            // Advance state
453            self.t = t_new;
454            self.y = y_new;
455            self.k[0] = self.dydt;
456
457            // Append to history and prune
458            self.history.push_back((self.t, self.y, self.dydt));
459            if let Some(max_delay) = self.max_delay {
460                let cutoff_time = self.t - max_delay;
461                while let Some((t_front, _, _)) = self.history.get(1) {
462                    if *t_front < cutoff_time {
463                        self.history.pop_front();
464                    } else {
465                        break;
466                    }
467                }
468            }
469            if let Status::RejectedStep = self.status {
470                self.status = Status::Solving;
471                scale = scale.min(T::one());
472            }
473        } else {
474            // Step rejected
475            self.status = Status::RejectedStep;
476        }
477
478        // Update step size
479        self.h *= scale;
480        // Enforce bounds
481        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
482
483        Ok(evals)
484    }
485
486    fn t(&self) -> T {
487        self.t
488    }
489    fn y(&self) -> &Y {
490        &self.y
491    }
492    fn t_prev(&self) -> T {
493        self.t_prev
494    }
495    fn y_prev(&self) -> &Y {
496        &self.y_prev
497    }
498    fn h(&self) -> T {
499        self.h
500    }
501    fn set_h(&mut self, h: T) {
502        self.h = h;
503    }
504    fn status(&self) -> &Status<T, Y, D> {
505        &self.status
506    }
507    fn set_status(&mut self, status: Status<T, Y, D>) {
508        self.status = status;
509    }
510}
511
512impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
513    ExplicitRungeKutta<Delay, DormandPrince, T, Y, D, O, S, I>
514{
515    fn lagvals<const L: usize, H>(
516        &mut self,
517        t_stage: T,
518        lags: &[T; L],
519        yd: &mut [Y; L],
520        phi: &H,
521    ) -> Result<(), Error<T, Y>>
522    where
523        H: Fn(T) -> Y,
524    {
525        for i in 0..L {
526            let t_delayed = t_stage - lags[i];
527
528            // Check if delayed time falls within the history period (t_delayed <= t0)
529            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
530                yd[i] = phi(t_delayed);
531            // If t_delayed is after t_prev then use interpolation function
532            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
533                if self.bi.is_some() {
534                    let theta = (t_delayed - self.t_prev) / self.h_prev;
535                    let one_minus_theta = T::one() - theta;
536
537                    // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
538                    let ilast = self.cont.len() - 1;
539                    let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
540                        let factor = if i >= 4 {
541                            if (ilast - i) % 2 == 1 {
542                                one_minus_theta
543                            } else {
544                                theta
545                            }
546                        } else if i % 2 == 1 {
547                            one_minus_theta
548                        } else {
549                            theta
550                        };
551                        acc * factor + self.cont[i]
552                    });
553
554                    // Final multiplication by theta for the outermost level
555                    let y_interp = self.cont[0] + poly * theta;
556                    yd[i] = y_interp;
557                } else {
558                    yd[i] = cubic_hermite_interpolate(
559                        self.t_prev,
560                        self.t,
561                        &self.y_prev,
562                        &self.y,
563                        &self.dydt_prev,
564                        &self.dydt,
565                        t_delayed,
566                    );
567                }
568            // If t_delayed is before t_prev and after t0, we need to search in the history
569            } else {
570                // Search through history to find appropriate interpolation points
571                let mut found_interpolation = false;
572                let buffer = &self.history;
573                // Find two consecutive points that sandwich t_delayed using iterators
574                let mut buffer_iter = buffer.iter();
575                if let Some(mut prev_entry) = buffer_iter.next() {
576                    for curr_entry in buffer_iter {
577                        let (t_left, y_left, dydt_left) = prev_entry;
578                        let (t_right, y_right, dydt_right) = curr_entry;
579
580                        // Check if t_delayed is between these two points
581                        let is_between = if self.h.signum() > T::zero() {
582                            *t_left <= t_delayed && t_delayed <= *t_right
583                        } else {
584                            *t_right <= t_delayed && t_delayed <= *t_left
585                        };
586
587                        if is_between {
588                            yd[i] = cubic_hermite_interpolate(
589                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
590                                t_delayed,
591                            );
592                            found_interpolation = true;
593                            break;
594                        }
595                        prev_entry = curr_entry;
596                    }
597                }
598                // If not found in history, this indicates insufficient history in buffer
599                if !found_interpolation {
600                    return Err(Error::InsufficientHistory {
601                        t_delayed,
602                        t_prev: self.t_prev,
603                        t_curr: self.t,
604                    });
605                }
606            }
607        }
608        Ok(())
609    }
610}
611
612impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
613    Interpolation<T, Y> for ExplicitRungeKutta<Delay, DormandPrince, T, Y, D, O, S, I>
614{
615    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
616        // Check if interpolation is out of bounds
617        let dir = (self.t - self.t_prev).signum();
618        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
619            return Err(Error::OutOfBounds {
620                t_interp,
621                t_prev: self.t_prev,
622                t_curr: self.t,
623            });
624        }
625
626        // Evaluate the interpolation polynomial at the requested time
627        let theta = (t_interp - self.t_prev) / self.h_prev;
628        let one_minus_theta = T::one() - theta;
629
630        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
631        let ilast = self.cont.len() - 1;
632        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {
633            let factor = if i >= 4 {
634                if (ilast - i) % 2 == 1 {
635                    one_minus_theta
636                } else {
637                    theta
638                }
639            } else if i % 2 == 1 {
640                one_minus_theta
641            } else {
642                theta
643            };
644            acc * factor + self.cont[i]
645        });
646
647        // Final multiplication by theta for the outermost level
648        let y_interp = self.cont[0] + poly * theta;
649
650        Ok(y_interp)
651    }
652}