differential_equations/methods/erk/fixed/
delay.rs

1//! Fixed-step explicit Runge–Kutta methods for Delay Differential Equations (DDEs)
2
3use std::collections::VecDeque;
4
5use crate::{
6    dde::{DDE, DelayNumericalMethod},
7    error::Error,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    methods::{Delay, ExplicitRungeKutta, Fixed},
10    stats::Evals,
11    status::Status,
12    traits::{Real, State},
13    utils::validate_step_size_parameters,
14};
15
16impl<
17    const L: usize,
18    T: Real,
19    Y: State<T>,
20    H: Fn(T) -> Y,
21    const O: usize,
22    const S: usize,
23    const I: usize,
24> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
25{
26    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
27    where
28        F: DDE<L, T, Y>,
29    {
30        // Initialize solver state
31        let mut evals = Evals::new();
32
33        // DDE requires at least one lag
34        if L <= 0 {
35            return Err(Error::NoLags);
36        }
37        self.t0 = t0;
38        self.t = t0;
39        self.y = *y0;
40        self.t_prev = self.t;
41        self.y_prev = self.y;
42        self.status = Status::Initialized;
43        self.steps = 0;
44        self.history = VecDeque::new();
45
46        // Delay buffers
47        let mut delays = [T::zero(); L];
48        let mut y_delayed = [Y::zeros(); L];
49
50        // Evaluate initial delays and history
51        dde.lags(self.t, &self.y, &mut delays);
52        for i in 0..L {
53            let t_delayed = self.t - delays[i];
54            // Ensure delayed time is within history range
55            if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
56                return Err(Error::BadInput {
57                    msg: format!(
58                        "Initial delayed time {} is out of history range (t <= {}).",
59                        t_delayed, t0
60                    ),
61                });
62            }
63            y_delayed[i] = phi(t_delayed);
64        }
65
66        // Initial derivative
67        dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
68        evals.function += 1;
69        self.dydt_prev = self.dydt; // Store initial state in history
70        self.history.push_back((self.t, self.y, self.dydt));
71
72        // Initial step size
73        if self.h0 == T::zero() {
74            let duration = (tf - t0).abs();
75            let default_steps = T::from_usize(100).unwrap();
76            self.h0 = duration / default_steps;
77        }
78
79        // Validate and set initial step size h
80        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
81            Ok(h0) => self.h = h0,
82            Err(status) => return Err(status),
83        }
84        Ok(evals)
85    }
86
87    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
88    where
89        F: DDE<L, T, Y>,
90    {
91        let mut evals = Evals::new();
92
93        // Check maximum number of steps
94        if self.steps >= self.max_steps {
95            self.status = Status::Error(Error::MaxSteps {
96                t: self.t,
97                y: self.y,
98            });
99            return Err(Error::MaxSteps {
100                t: self.t,
101                y: self.y,
102            });
103        }
104        self.steps += 1;
105
106        // Step buffers
107        let mut delays = [T::zero(); L];
108        let mut y_delayed = [Y::zeros(); L];
109
110        // Store current derivative as k[0] for RK computations
111        // Seed k[0] with current derivative
112        self.k[0] = self.dydt;
113        let mut min_delay_abs = T::infinity();
114        // Predict y(t+h) to estimate delays at t+h
115        let y_pred_for_lags = self.y + self.k[0] * self.h;
116        dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
117        for i in 0..L {
118            min_delay_abs = min_delay_abs.min(delays[i].abs());
119        }
120
121        // Delay iteration count
122        let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
123            5
124        } else {
125            1
126        };
127
128        let mut y_next_candidate_iter = self.y; // Approximated y at t+h, refined in DDE iterations
129        let mut dydt_next_candidate_iter = Y::zeros(); // Derivative at t+h using y_next_candidate_iter
130        let mut y_prev_candidate_iter = self.y; // y_next_candidate_iter from previous DDE iteration
131        let mut dde_iteration_failed = false;
132
133        // DDE iteration loop
134        for iter_idx in 0..max_iter {
135            if iter_idx > 0 {
136                y_prev_candidate_iter = y_next_candidate_iter;
137            }
138
139            // Compute stages
140            for i in 1..self.stages {
141                let mut y_stage = self.y;
142                for j in 0..i {
143                    y_stage += self.k[j] * (self.a[i][j] * self.h);
144                }
145                // Delayed states for this stage
146                dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
147                if let Err(e) =
148                    self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
149                {
150                    self.status = Status::Error(e.clone());
151                    return Err(e);
152                }
153                dde.diff(
154                    self.t + self.c[i] * self.h,
155                    &y_stage,
156                    &y_delayed,
157                    &mut self.k[i],
158                );
159            }
160            evals.function += self.stages - 1;
161
162            // Combine stages
163            let mut y_next = self.y;
164            for i in 0..self.stages {
165                y_next += self.k[i] * (self.b[i] * self.h);
166            }
167
168            // Convergence check (if iterating)
169            if max_iter > 1 && iter_idx > 0 {
170                let mut dde_iteration_error = T::zero();
171                let n_dim = self.y.len();
172                for i_dim in 0..n_dim {
173                    let scale = T::from_f64(1e-10).unwrap()
174                        + y_prev_candidate_iter
175                            .get(i_dim)
176                            .abs()
177                            .max(y_next.get(i_dim).abs());
178                    if scale > T::zero() {
179                        let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
180                        dde_iteration_error += (diff_val / scale).powi(2);
181                    }
182                }
183                if n_dim > 0 {
184                    dde_iteration_error =
185                        (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
186                }
187
188                if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
189                    break;
190                }
191                if iter_idx == max_iter - 1 {
192                    dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
193                }
194            }
195            y_next_candidate_iter = y_next;
196
197            // Derivative at t+h for current candidate
198            dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
199            if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
200                self.status = Status::Error(e.clone());
201                return Err(e);
202            }
203            dde.diff(
204                self.t + self.h,
205                &y_next_candidate_iter,
206                &y_delayed,
207                &mut dydt_next_candidate_iter,
208            );
209            evals.function += 1;
210        }
211
212        // Iteration failed: reduce h and retry
213        if dde_iteration_failed {
214            let sign = self.h.signum();
215            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
216            if L > 0
217                && min_delay_abs > T::zero()
218                && self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
219            {
220                self.h = min_delay_abs * sign;
221            }
222            self.status = Status::RejectedStep;
223            return Ok(evals);
224        }
225
226        // Store current state before update for interpolation
227        self.t_prev = self.t;
228        self.y_prev = self.y;
229        self.dydt_prev = self.dydt;
230
231        // Advance state
232        self.t += self.h;
233        self.y = y_next_candidate_iter;
234
235        // Derivative for next step
236        if self.fsal {
237            self.dydt = self.k[S - 1];
238        } else {
239            dde.lags(self.t, &self.y, &mut delays);
240            if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
241                self.status = Status::Error(e.clone());
242                return Err(e);
243            }
244            dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
245            evals.function += 1;
246        }
247
248        // Dense output stages
249        if self.bi.is_some() {
250            for i in 0..(I - S) {
251                let mut y_stage_dense = self.y_prev;
252                for j in 0..self.stages + i {
253                    y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
254                }
255                let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
256                dde.lags(t_stage, &y_stage_dense, &mut delays);
257                if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
258                    self.status = Status::Error(e.clone());
259                    return Err(e);
260                }
261                dde.diff(
262                    self.t_prev + self.c[self.stages + i] * self.h,
263                    &y_stage_dense,
264                    &y_delayed,
265                    &mut self.k[self.stages + i],
266                );
267            }
268            evals.function += I - S;
269        }
270
271        // Append to history and prune
272        self.history.push_back((self.t, self.y, self.dydt));
273        if let Some(max_delay) = self.max_delay {
274            let cutoff_time = self.t - max_delay;
275            while let Some((t_front, _, _)) = self.history.get(1) {
276                if *t_front < cutoff_time {
277                    self.history.pop_front();
278                } else {
279                    break;
280                }
281            }
282        }
283
284        self.status = Status::Solving;
285        Ok(evals)
286    }
287
288    fn t(&self) -> T {
289        self.t
290    }
291    fn y(&self) -> &Y {
292        &self.y
293    }
294    fn t_prev(&self) -> T {
295        self.t_prev
296    }
297    fn y_prev(&self) -> &Y {
298        &self.y_prev
299    }
300    fn h(&self) -> T {
301        self.h
302    }
303    fn set_h(&mut self, h: T) {
304        self.h = h;
305    }
306    fn status(&self) -> &Status<T, Y> {
307        &self.status
308    }
309    fn set_status(&mut self, status: Status<T, Y>) {
310        self.status = status;
311    }
312}
313
314impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
315    ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
316{
317    pub fn lagvals<const L: usize, H>(
318        &mut self,
319        t_stage: T,
320        delays: &[T; L],
321        y_delayed: &mut [Y; L],
322        phi: &H,
323    ) -> Result<(), Error<T, Y>>
324    where
325        H: Fn(T) -> Y,
326    {
327        for i in 0..L {
328            let t_delayed = t_stage - delays[i];
329
330            // Check if delayed time falls within the history period (t_delayed <= t0)
331            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
332                y_delayed[i] = phi(t_delayed);
333            // If t_delayed is after t_prev then use interpolation function
334            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
335                if self.bi.is_some() {
336                    let s = (t_delayed - self.t_prev) / self.h_prev;
337
338                    let bi_coeffs = self.bi.as_ref().unwrap();
339
340                    let mut cont = [T::zero(); I];
341                    for i in 0..I {
342                        if i < cont.len() && i < bi_coeffs.len() {
343                            cont[i] = bi_coeffs[i][self.dense_stages - 1];
344                            for j in (0..self.dense_stages - 1).rev() {
345                                cont[i] = cont[i] * s + bi_coeffs[i][j];
346                            }
347                            cont[i] *= s;
348                        }
349                    }
350
351                    let mut y_interp = self.y_prev;
352                    for i in 0..I {
353                        if i < self.k.len() && i < cont.len() {
354                            y_interp += self.k[i] * (cont[i] * self.h_prev);
355                        }
356                    }
357                    y_delayed[i] = y_interp;
358                } else {
359                    y_delayed[i] = cubic_hermite_interpolate(
360                        self.t_prev,
361                        self.t,
362                        &self.y_prev,
363                        &self.y,
364                        &self.dydt_prev,
365                        &self.dydt,
366                        t_delayed,
367                    );
368                } // If t_delayed is before t_prev and after t0, we need to search in the history
369            } else {
370                // Search through history to find appropriate interpolation points
371                let mut found_interpolation = false;
372                let buffer = &self.history;
373                // Find two consecutive points that sandwich t_delayed using iterators
374                let mut buffer_iter = buffer.iter();
375                if let Some(mut prev_entry) = buffer_iter.next() {
376                    for curr_entry in buffer_iter {
377                        let (t_left, y_left, dydt_left) = prev_entry;
378                        let (t_right, y_right, dydt_right) = curr_entry;
379
380                        // Check if t_delayed is between these two points
381                        let is_between = if self.h.signum() > T::zero() {
382                            // Forward integration: t_left <= t_delayed <= t_right
383                            *t_left <= t_delayed && t_delayed <= *t_right
384                        } else {
385                            // Backward integration: t_right <= t_delayed <= t_left
386                            *t_right <= t_delayed && t_delayed <= *t_left
387                        };
388
389                        if is_between {
390                            // Use cubic Hermite interpolation between these points
391                            y_delayed[i] = cubic_hermite_interpolate(
392                                *t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
393                                t_delayed,
394                            );
395                            found_interpolation = true;
396                            break;
397                        }
398                        prev_entry = curr_entry;
399                    }
400                } // If not found in history, this indicates insufficient history in buffer
401                if !found_interpolation {
402                    return Err(Error::InsufficientHistory {
403                        t_delayed,
404                        t_prev: self.t_prev,
405                        t_curr: self.t,
406                    });
407                }
408            }
409        }
410        Ok(())
411    }
412}
413
414impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
415    for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
416{
417    /// Interpolates the solution at time `t_interp` within the last accepted step.
418    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
419        let dir = self.h.signum();
420        if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
421            return Err(Error::OutOfBounds {
422                t_interp,
423                t_prev: self.t_prev,
424                t_curr: self.t,
425            });
426        }
427
428        // If method has dense output coefficients, use them
429        if self.bi.is_some() {
430            let s = (t_interp - self.t_prev) / self.h_prev;
431
432            let bi = self.bi.as_ref().unwrap();
433
434            let mut cont = [T::zero(); I];
435            for i in 0..self.dense_stages {
436                cont[i] = bi[i][self.order - 1];
437                for j in (0..self.order - 1).rev() {
438                    cont[i] = cont[i] * s + bi[i][j];
439                }
440                cont[i] *= s;
441            }
442
443            let mut y_interp = self.y_prev;
444            for i in 0..I {
445                y_interp += self.k[i] * cont[i] * self.h_prev;
446            }
447
448            Ok(y_interp)
449        } else {
450            // Otherwise use cubic Hermite interpolation
451            let y_interp = cubic_hermite_interpolate(
452                self.t_prev,
453                self.t,
454                &self.y_prev,
455                &self.y,
456                &self.dydt_prev,
457                &self.dydt,
458                t_interp,
459            );
460
461            Ok(y_interp)
462        }
463    }
464}