Skip to main content

differential_equations/methods/erk/fixed/
delay.rs

1//! Fixed-step explicit Runge–Kutta methods for Delay Differential Equations (DDEs)
2use std::collections::VecDeque;
3
4use crate::{
5    dde::{DDE, DelayNumericalMethod},
6    error::Error,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    methods::{Delay, ExplicitRungeKutta, Fixed},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<
16    const L: usize,
17    T: Real,
18    Y: State<T>,
19    H: Fn(T) -> Y,
20    const O: usize,
21    const S: usize,
22    const I: usize,
23> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
24{
25    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
26    where
27        F: DDE<L, T, Y> + ?Sized,
28    {
29        // Initialize solver state
30        let mut evals = Evals::new();
31
32        // DDE requires at least one lag
33        if L == 0 {
34            return Err(Error::NoLags);
35        }
36        self.t0 = t0;
37        self.t = t0;
38        self.y = y0.clone();
39        self.dydt = y0.zeros_like();
40        self.y_prev = y0.clone();
41        self.dydt_prev = y0.zeros_like();
42        self.k = core::array::from_fn(|_| y0.zeros_like());
43        self.cont = core::array::from_fn(|_| y0.zeros_like());
44        self.t_prev = self.t;
45        self.y_prev = self.y.clone();
46        self.status = Status::Initialized;
47        self.steps = 0;
48        self.history = VecDeque::new();
49
50        // Delay buffers
51        let mut delays = [T::zero(); L];
52        let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
53
54        // Evaluate initial delays and history
55        dde.lags(self.t, &self.y, &mut delays);
56        for i in 0..L {
57            let t_delayed = self.t - delays[i];
58            // Ensure delayed time is within history range
59            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
60                return Err(Error::BadInput {
61                    msg: format!(
62                        "Initial delayed time {} is out of history range (t <= {}).",
63                        t_delayed, t0
64                    ),
65                });
66            }
67            y_delayed[i] = phi(t_delayed);
68        }
69
70        // Initial derivative
71        dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
72        evals.function += 1;
73        self.dydt_prev = self.dydt.clone(); // Store initial state in history
74        self.history
75            .push_back((self.t, self.y.clone(), self.dydt.clone()));
76
77        // Initial step size
78        if self.h0 == T::zero() {
79            let duration = (tf - t0).abs();
80            let default_steps = T::from_usize(100).unwrap();
81            self.h0 = duration / default_steps;
82        }
83
84        // Validate and set initial step size h
85        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
86            Ok(h0) => self.h = h0,
87            Err(status) => return Err(status),
88        }
89        Ok(evals)
90    }
91
92    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
93    where
94        F: DDE<L, T, Y> + ?Sized,
95    {
96        let mut evals = Evals::new();
97
98        // Check maximum number of steps
99        if self.steps >= self.max_steps {
100            self.status = Status::Error(Error::MaxSteps {
101                t: self.t,
102                y: self.y.clone(),
103            });
104            return Err(Error::MaxSteps {
105                t: self.t,
106                y: self.y.clone(),
107            });
108        }
109        self.steps += 1;
110
111        // Step buffers
112        let mut delays = [T::zero(); L];
113        let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
114
115        // Store current derivative as k[0] for RK computations
116        // Seed k[0] with current derivative
117        self.k[0] = self.dydt.clone();
118        let mut min_delay_abs = T::infinity();
119        // Predict y(t+h) to estimate delays at t+h
120        let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
121        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
122        for i in 0..L {
123            min_delay_abs = min_delay_abs.min(delays[i].abs());
124        }
125
126        // Delay iteration count
127        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
128            5
129        } else {
130            1
131        };
132
133        let mut y_next_candidate_iter = self.y.clone(); // Approximated y at t+h, refined in DDE iterations
134        let mut dydt_next_candidate_iter = self.y.zeros_like(); // Derivative at t+h using y_next_candidate_iter
135        let mut y_prev_candidate_iter = self.y.clone(); // y_next_candidate_iter from previous DDE iteration
136        let mut dde_iteration_failed = false;
137
138        // DDE iteration loop
139        for iter_idx in 0..max_iter {
140            if iter_idx > 0 {
141                y_prev_candidate_iter = y_next_candidate_iter.clone();
142            }
143
144            // Compute stages
145            for i in 1..self.stages {
146                let mut y_stage = self.y.clone();
147                for j in 0..i {
148                    y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
149                }
150                // Delayed states for this stage
151                dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
152                if let Err(e) =
153                    self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
154                {
155                    self.status = Status::Error(e.clone());
156                    return Err(e);
157                }
158                dde.diff(
159                    self.t + self.c[i] * self.h,
160                    &y_stage,
161                    &y_delayed,
162                    &mut self.k[i],
163                );
164            }
165            evals.function += self.stages - 1;
166
167            // Combine stages
168            let mut y_next = self.y.clone();
169            for i in 0..self.stages {
170                y_next.add_scaled(self.b[i] * self.h, &self.k[i]);
171            }
172
173            // Convergence check (if iterating)
174            if max_iter > 1 && iter_idx > 0 {
175                let n_dim = self.y.len();
176                let mut dde_iteration_error = T::zero();
177                for i_dim in 0..n_dim {
178                    let scale = T::from_f64(1e-10).unwrap()
179                        + y_prev_candidate_iter
180                            .get_component(i_dim)
181                            .abs()
182                            .max(y_next.get_component(i_dim).abs());
183                    if scale > T::zero() {
184                        let diff_val = y_next.get_component(i_dim)
185                            - y_prev_candidate_iter.get_component(i_dim);
186                        let val = diff_val / scale;
187                        dde_iteration_error += val * val;
188                    }
189                }
190                if n_dim > 0 {
191                    dde_iteration_error =
192                        (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
193                }
194
195                if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
196                    break;
197                }
198                if iter_idx == max_iter - 1 {
199                    dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
200                }
201            }
202            y_next_candidate_iter = y_next.clone();
203
204            // Derivative at t+h for current candidate
205            dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
206            if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
207                self.status = Status::Error(e.clone());
208                return Err(e);
209            }
210            dde.diff(
211                self.t + self.h,
212                &y_next_candidate_iter,
213                &y_delayed,
214                &mut dydt_next_candidate_iter,
215            );
216            evals.function += 1;
217        }
218
219        // Iteration failed: reduce h and retry
220        if dde_iteration_failed {
221            let sign = self.h.signum();
222            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
223            if L > 0
224                && min_delay_abs > T::zero()
225                && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
226            {
227                self.h = min_delay_abs * sign;
228            }
229            self.status = Status::RejectedStep;
230            return Ok(evals);
231        }
232
233        // Store current state before update for interpolation
234        self.t_prev = self.t;
235        self.y_prev = self.y.clone();
236        self.dydt_prev = self.dydt.clone();
237
238        // Advance state
239        self.t += self.h;
240        self.y = y_next_candidate_iter;
241
242        // Derivative for next step
243        if self.fsal {
244            self.dydt = self.k[S - 1].clone();
245        } else {
246            dde.lags(self.t, &self.y, &mut delays);
247            if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
248                self.status = Status::Error(e.clone());
249                return Err(e);
250            }
251            dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
252            evals.function += 1;
253        }
254
255        // Dense output stages
256        if self.bi.is_some() {
257            for i in 0..(I - S) {
258                let mut y_stage_dense = self.y_prev.clone();
259                for j in 0..self.stages + i {
260                    y_stage_dense.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
261                }
262                let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
263                dde.lags(t_stage, &y_stage_dense, &mut delays);
264                if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
265                    self.status = Status::Error(e.clone());
266                    return Err(e);
267                }
268                dde.diff(
269                    self.t_prev + self.c[self.stages + i] * self.h,
270                    &y_stage_dense,
271                    &y_delayed,
272                    &mut self.k[self.stages + i],
273                );
274            }
275            evals.function += I - S;
276        }
277
278        // Append to history and prune
279        self.history
280            .push_back((self.t, self.y.clone(), self.dydt.clone()));
281        if let Some(max_delay) = self.max_delay {
282            let cutoff_time = self.t - max_delay;
283            while let Some((t_front, _, _)) = self.history.get(1) {
284                if *t_front < cutoff_time {
285                    self.history.pop_front();
286                } else {
287                    break;
288                }
289            }
290        }
291
292        self.status = Status::Solving;
293        Ok(evals)
294    }
295
296    fn t(&self) -> T {
297        self.t
298    }
299    fn y(&self) -> &Y {
300        &self.y
301    }
302    fn t_prev(&self) -> T {
303        self.t_prev
304    }
305    fn y_prev(&self) -> &Y {
306        &self.y_prev
307    }
308    fn h(&self) -> T {
309        self.h
310    }
311    fn set_h(&mut self, h: T) {
312        self.h = h;
313    }
314    fn status(&self) -> &Status<T, Y> {
315        &self.status
316    }
317    fn set_status(&mut self, status: Status<T, Y>) {
318        self.status = status;
319    }
320}
321
322impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
323    ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
324{
325    pub fn lagvals<const L: usize, H>(
326        &mut self,
327        t_stage: T,
328        delays: &[T; L],
329        y_delayed: &mut [Y; L],
330        phi: &H,
331    ) -> Result<(), Error<T, Y>>
332    where
333        H: Fn(T) -> Y,
334    {
335        for i in 0..L {
336            let t_delayed = t_stage - delays[i];
337
338            // Check if delayed time falls within the history period (t_delayed <= t0)
339            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
340                y_delayed[i] = phi(t_delayed);
341            // If t_delayed is after t_prev then use interpolation function
342            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
343                if let Some(bi_coeffs) = self.bi.as_ref() {
344                    let s = (t_delayed - self.t_prev) / self.h_prev;
345
346                    let mut cont = [T::zero(); I];
347                    for i in 0..I {
348                        if i < cont.len() && i < bi_coeffs.len() {
349                            cont[i] = bi_coeffs[i][self.dense_stages - 1];
350                            for j in (0..self.dense_stages - 1).rev() {
351                                cont[i] = cont[i] * s + bi_coeffs[i][j];
352                            }
353                            cont[i] *= s;
354                        }
355                    }
356
357                    let mut y_interp = self.y_prev.clone();
358                    for i in 0..I {
359                        if i < self.k.len() && i < cont.len() {
360                            y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
361                        }
362                    }
363                    y_delayed[i] = y_interp;
364                } else {
365                    y_delayed[i] = cubic_hermite_interpolate(
366                        self.t_prev,
367                        self.t,
368                        &self.y_prev,
369                        &self.y,
370                        &self.dydt_prev,
371                        &self.dydt,
372                        t_delayed,
373                    );
374                } // If t_delayed is before t_prev and after t0, we need to search in the history
375            } else {
376                // Search through history to find appropriate interpolation points
377                let mut found_interpolation = false;
378                let buffer = &self.history;
379                // Find two consecutive points that sandwich t_delayed using iterators
380                let mut buffer_iter = buffer.iter();
381                if let Some(mut prev_entry) = buffer_iter.next() {
382                    for curr_entry in buffer_iter {
383                        let (t_left, y_left, dydt_left) = prev_entry;
384                        let (t_right, y_right, dydt_right) = curr_entry;
385
386                        // Check if t_delayed is between these two points
387                        let is_between = if self.h.signum() > T::zero() {
388                            // Forward integration: t_left <= t_delayed <= t_right
389                            *t_left <= t_delayed && t_delayed <= *t_right
390                        } else {
391                            // Backward integration: t_right <= t_delayed <= t_left
392                            *t_right <= t_delayed && t_delayed <= *t_left
393                        };
394
395                        if is_between {
396                            // Use cubic Hermite interpolation between these points
397                            y_delayed[i] = cubic_hermite_interpolate(
398                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
399                                t_delayed,
400                            );
401                            found_interpolation = true;
402                            break;
403                        }
404                        prev_entry = curr_entry;
405                    }
406                } // If not found in history, this indicates insufficient history in buffer
407                if !found_interpolation {
408                    return Err(Error::InsufficientHistory {
409                        t_delayed,
410                        t_prev: self.t_prev,
411                        t_curr: self.t,
412                    });
413                }
414            }
415        }
416        Ok(())
417    }
418}
419
420impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
421    for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
422{
423    /// Interpolates the solution at time `t_interp` within the last accepted step.
424    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
425        let dir = self.h.signum();
426        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
427            return Err(Error::OutOfBounds {
428                t_interp,
429                t_prev: self.t_prev,
430                t_curr: self.t,
431            });
432        }
433
434        // If method has dense output coefficients, use them
435        if let Some(bi) = self.bi.as_ref() {
436            let s = (t_interp - self.t_prev) / self.h_prev;
437
438            let mut cont = [T::zero(); I];
439            for i in 0..self.dense_stages {
440                cont[i] = bi[i][self.order - 1];
441                for j in (0..self.order - 1).rev() {
442                    cont[i] = cont[i] * s + bi[i][j];
443                }
444                cont[i] *= s;
445            }
446
447            let mut y_interp = self.y_prev.clone();
448            for i in 0..I {
449                y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
450            }
451
452            Ok(y_interp)
453        } else {
454            // Otherwise use cubic Hermite interpolation
455            let y_interp = cubic_hermite_interpolate(
456                self.t_prev,
457                self.t,
458                &self.y_prev,
459                &self.y,
460                &self.dydt_prev,
461                &self.dydt,
462                t_interp,
463            );
464
465            Ok(y_interp)
466        }
467    }
468}