differential_equations/methods/erk/fixed/
delay.rs

1//! Fixed-step explicit Runge–Kutta methods for Delay Differential Equations (DDEs)
2
3use std::collections::VecDeque;
4
5use crate::{
6    dde::{DDE, DelayNumericalMethod},
7    error::Error,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    methods::{Delay, ExplicitRungeKutta, Fixed},
10    stats::Evals,
11    status::Status,
12    traits::{CallBackData, Real, State},
13    utils::validate_step_size_parameters,
14};
15
16impl<
17    const L: usize,
18    T: Real,
19    Y: State<T>,
20    H: Fn(T) -> Y,
21    D: CallBackData,
22    const O: usize,
23    const S: usize,
24    const I: usize,
25> DelayNumericalMethod<L, T, Y, H, D> for ExplicitRungeKutta<Delay, Fixed, T, Y, D, O, S, I>
26{
27    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
28    where
29        F: DDE<L, T, Y, D>,
30    {
31        // Initialize solver state
32        let mut evals = Evals::new();
33
34        // DDE requires at least one lag
35        if L <= 0 {
36            return Err(Error::NoLags);
37        }
38        self.t0 = t0;
39        self.t = t0;
40        self.y = *y0;
41        self.t_prev = self.t;
42        self.y_prev = self.y;
43        self.status = Status::Initialized;
44        self.steps = 0;
45        self.history = VecDeque::new();
46
47        // Delay buffers
48        let mut delays = [T::zero(); L];
49        let mut y_delayed = [Y::zeros(); L];
50
51        // Evaluate initial delays and history
52        dde.lags(self.t, &self.y, &mut delays);
53        for i in 0..L {
54            let t_delayed = self.t - delays[i];
55            // Ensure delayed time is within history range
56            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
57                return Err(Error::BadInput {
58                    msg: format!(
59                        "Initial delayed time {} is out of history range (t <= {}).",
60                        t_delayed, t0
61                    ),
62                });
63            }
64            y_delayed[i] = phi(t_delayed);
65        }
66
67        // Initial derivative
68        dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
69        evals.function += 1;
70        self.dydt_prev = self.dydt; // Store initial state in history
71        self.history.push_back((self.t, self.y, self.dydt));
72
73        // Initial step size
74        if self.h0 == T::zero() {
75            let duration = (tf - t0).abs();
76            let default_steps = T::from_usize(100).unwrap();
77            self.h0 = duration / default_steps;
78        }
79
80        // Validate and set initial step size h
81        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
82            Ok(h0) => self.h = h0,
83            Err(status) => return Err(status),
84        }
85        Ok(evals)
86    }
87
88    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
89    where
90        F: DDE<L, T, Y, D>,
91    {
92        let mut evals = Evals::new();
93
94        // Check maximum number of steps
95        if self.steps >= self.max_steps {
96            self.status = Status::Error(Error::MaxSteps {
97                t: self.t,
98                y: self.y,
99            });
100            return Err(Error::MaxSteps {
101                t: self.t,
102                y: self.y,
103            });
104        }
105        self.steps += 1;
106
107        // Step buffers
108        let mut delays = [T::zero(); L];
109        let mut y_delayed = [Y::zeros(); L];
110
111        // Store current derivative as k[0] for RK computations
112        // Seed k[0] with current derivative
113        self.k[0] = self.dydt;
114        let mut min_delay_abs = T::infinity();
115        // Predict y(t+h) to estimate delays at t+h
116        let y_pred_for_lags = self.y + self.k[0] * self.h;
117        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
118        for i in 0..L {
119            min_delay_abs = min_delay_abs.min(delays[i].abs());
120        }
121
122        // Delay iteration count
123        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
124            5
125        } else {
126            1
127        };
128
129        let mut y_next_candidate_iter = self.y; // Approximated y at t+h, refined in DDE iterations
130        let mut dydt_next_candidate_iter = Y::zeros(); // Derivative at t+h using y_next_candidate_iter
131        let mut y_prev_candidate_iter = self.y; // y_next_candidate_iter from previous DDE iteration
132        let mut dde_iteration_failed = false;
133
134        // DDE iteration loop
135        for iter_idx in 0..max_iter {
136            if iter_idx > 0 {
137                y_prev_candidate_iter = y_next_candidate_iter;
138            }
139
140            // Compute stages
141            for i in 1..self.stages {
142                let mut y_stage = self.y;
143                for j in 0..i {
144                    y_stage += self.k[j] * (self.a[i][j] * self.h);
145                }
146                // Delayed states for this stage
147                dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
148                if let Err(e) =
149                    self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
150                {
151                    self.status = Status::Error(e.clone());
152                    return Err(e);
153                }
154                dde.diff(
155                    self.t + self.c[i] * self.h,
156                    &y_stage,
157                    &y_delayed,
158                    &mut self.k[i],
159                );
160            }
161            evals.function += self.stages - 1;
162
163            // Combine stages
164            let mut y_next = self.y;
165            for i in 0..self.stages {
166                y_next += self.k[i] * (self.b[i] * self.h);
167            }
168
169            // Convergence check (if iterating)
170            if max_iter > 1 && iter_idx > 0 {
171                let mut dde_iteration_error = T::zero();
172                let n_dim = self.y.len();
173                for i_dim in 0..n_dim {
174                    let scale = T::from_f64(1e-10).unwrap()
175                        + y_prev_candidate_iter
176                            .get(i_dim)
177                            .abs()
178                            .max(y_next.get(i_dim).abs());
179                    if scale > T::zero() {
180                        let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
181                        dde_iteration_error += (diff_val / scale).powi(2);
182                    }
183                }
184                if n_dim > 0 {
185                    dde_iteration_error =
186                        (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
187                }
188
189                if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
190                    break;
191                }
192                if iter_idx == max_iter - 1 {
193                    dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
194                }
195            }
196            y_next_candidate_iter = y_next;
197
198            // Derivative at t+h for current candidate
199            dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
200            if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
201                self.status = Status::Error(e.clone());
202                return Err(e);
203            }
204            dde.diff(
205                self.t + self.h,
206                &y_next_candidate_iter,
207                &y_delayed,
208                &mut dydt_next_candidate_iter,
209            );
210            evals.function += 1;
211        }
212
213        // Iteration failed: reduce h and retry
214        if dde_iteration_failed {
215            let sign = self.h.signum();
216            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
217            if L > 0
218                && min_delay_abs > T::zero()
219                && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
220            {
221                self.h = min_delay_abs * sign;
222            }
223            self.status = Status::RejectedStep;
224            return Ok(evals);
225        }
226
227        // Store current state before update for interpolation
228        self.t_prev = self.t;
229        self.y_prev = self.y;
230        self.dydt_prev = self.dydt;
231
232        // Advance state
233        self.t += self.h;
234        self.y = y_next_candidate_iter;
235
236        // Derivative for next step
237        if self.fsal {
238            self.dydt = self.k[S - 1];
239        } else {
240            dde.lags(self.t, &self.y, &mut delays);
241            if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
242                self.status = Status::Error(e.clone());
243                return Err(e);
244            }
245            dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
246            evals.function += 1;
247        }
248
249        // Dense output stages
250        if self.bi.is_some() {
251            for i in 0..(I - S) {
252                let mut y_stage_dense = self.y_prev;
253                for j in 0..self.stages + i {
254                    y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
255                }
256                let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
257                dde.lags(t_stage, &y_stage_dense, &mut delays);
258                if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
259                    self.status = Status::Error(e.clone());
260                    return Err(e);
261                }
262                dde.diff(
263                    self.t_prev + self.c[self.stages + i] * self.h,
264                    &y_stage_dense,
265                    &y_delayed,
266                    &mut self.k[self.stages + i],
267                );
268            }
269            evals.function += I - S;
270        }
271
272        // Append to history and prune
273        self.history.push_back((self.t, self.y, self.dydt));
274        if let Some(max_delay) = self.max_delay {
275            let cutoff_time = self.t - max_delay;
276            while let Some((t_front, _, _)) = self.history.get(1) {
277                if *t_front < cutoff_time {
278                    self.history.pop_front();
279                } else {
280                    break;
281                }
282            }
283        }
284
285        self.status = Status::Solving;
286        Ok(evals)
287    }
288
289    fn t(&self) -> T {
290        self.t
291    }
292    fn y(&self) -> &Y {
293        &self.y
294    }
295    fn t_prev(&self) -> T {
296        self.t_prev
297    }
298    fn y_prev(&self) -> &Y {
299        &self.y_prev
300    }
301    fn h(&self) -> T {
302        self.h
303    }
304    fn set_h(&mut self, h: T) {
305        self.h = h;
306    }
307    fn status(&self) -> &Status<T, Y, D> {
308        &self.status
309    }
310    fn set_status(&mut self, status: Status<T, Y, D>) {
311        self.status = status;
312    }
313}
314
315impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
316    ExplicitRungeKutta<Delay, Fixed, T, Y, D, O, S, I>
317{
318    pub fn lagvals<const L: usize, H>(
319        &mut self,
320        t_stage: T,
321        delays: &[T; L],
322        y_delayed: &mut [Y; L],
323        phi: &H,
324    ) -> Result<(), Error<T, Y>>
325    where
326        H: Fn(T) -> Y,
327    {
328        for i in 0..L {
329            let t_delayed = t_stage - delays[i];
330
331            // Check if delayed time falls within the history period (t_delayed <= t0)
332            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
333                y_delayed[i] = phi(t_delayed);
334            // If t_delayed is after t_prev then use interpolation function
335            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
336                if self.bi.is_some() {
337                    let s = (t_delayed - self.t_prev) / self.h_prev;
338
339                    let bi_coeffs = self.bi.as_ref().unwrap();
340
341                    let mut cont = [T::zero(); I];
342                    for i in 0..I {
343                        if i < cont.len() && i < bi_coeffs.len() {
344                            cont[i] = bi_coeffs[i][self.dense_stages - 1];
345                            for j in (0..self.dense_stages - 1).rev() {
346                                cont[i] = cont[i] * s + bi_coeffs[i][j];
347                            }
348                            cont[i] *= s;
349                        }
350                    }
351
352                    let mut y_interp = self.y_prev;
353                    for i in 0..I {
354                        if i < self.k.len() && i < cont.len() {
355                            y_interp += self.k[i] * (cont[i] * self.h_prev);
356                        }
357                    }
358                    y_delayed[i] = y_interp;
359                } else {
360                    y_delayed[i] = cubic_hermite_interpolate(
361                        self.t_prev,
362                        self.t,
363                        &self.y_prev,
364                        &self.y,
365                        &self.dydt_prev,
366                        &self.dydt,
367                        t_delayed,
368                    );
369                } // If t_delayed is before t_prev and after t0, we need to search in the history
370            } else {
371                // Search through history to find appropriate interpolation points
372                let mut found_interpolation = false;
373                let buffer = &self.history;
374                // Find two consecutive points that sandwich t_delayed using iterators
375                let mut buffer_iter = buffer.iter();
376                if let Some(mut prev_entry) = buffer_iter.next() {
377                    for curr_entry in buffer_iter {
378                        let (t_left, y_left, dydt_left) = prev_entry;
379                        let (t_right, y_right, dydt_right) = curr_entry;
380
381                        // Check if t_delayed is between these two points
382                        let is_between = if self.h.signum() > T::zero() {
383                            // Forward integration: t_left <= t_delayed <= t_right
384                            *t_left <= t_delayed && t_delayed <= *t_right
385                        } else {
386                            // Backward integration: t_right <= t_delayed <= t_left
387                            *t_right <= t_delayed && t_delayed <= *t_left
388                        };
389
390                        if is_between {
391                            // Use cubic Hermite interpolation between these points
392                            y_delayed[i] = cubic_hermite_interpolate(
393                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
394                                t_delayed,
395                            );
396                            found_interpolation = true;
397                            break;
398                        }
399                        prev_entry = curr_entry;
400                    }
401                } // If not found in history, this indicates insufficient history in buffer
402                if !found_interpolation {
403                    return Err(Error::InsufficientHistory {
404                        t_delayed,
405                        t_prev: self.t_prev,
406                        t_curr: self.t,
407                    });
408                }
409            }
410        }
411        Ok(())
412    }
413}
414
415impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
416    Interpolation<T, Y> for ExplicitRungeKutta<Delay, Fixed, T, Y, D, O, S, I>
417{
418    /// Interpolates the solution at time `t_interp` within the last accepted step.
419    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
420        let dir = self.h.signum();
421        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
422            return Err(Error::OutOfBounds {
423                t_interp,
424                t_prev: self.t_prev,
425                t_curr: self.t,
426            });
427        }
428
429        // If method has dense output coefficients, use them
430        if self.bi.is_some() {
431            let s = (t_interp - self.t_prev) / self.h_prev;
432
433            let bi = self.bi.as_ref().unwrap();
434
435            let mut cont = [T::zero(); I];
436            for i in 0..self.dense_stages {
437                cont[i] = bi[i][self.order - 1];
438                for j in (0..self.order - 1).rev() {
439                    cont[i] = cont[i] * s + bi[i][j];
440                }
441                cont[i] *= s;
442            }
443
444            let mut y_interp = self.y_prev;
445            for i in 0..I {
446                y_interp += self.k[i] * cont[i] * self.h_prev;
447            }
448
449            Ok(y_interp)
450        } else {
451            // Otherwise use cubic Hermite interpolation
452            let y_interp = cubic_hermite_interpolate(
453                self.t_prev,
454                self.t,
455                &self.y_prev,
456                &self.y,
457                &self.dydt_prev,
458                &self.dydt,
459                t_interp,
460            );
461
462            Ok(y_interp)
463        }
464    }
465}