differential_equations/methods/erk/fixed/
delay.rs

1//! Fixed Runge-Kutta methods for DDEs
2
3use super::{ExplicitRungeKutta, Delay, Fixed};
4use crate::{
5    Error, Status,
6    alias::Evals,
7    interpolate::{Interpolation, cubic_hermite_interpolate},
8    dde::{DelayNumericalMethod, DDE},
9    traits::{CallBackData, Real, State},
10    utils::validate_step_size_parameters,
11};
12use std::collections::VecDeque;
13
14impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, Fixed, T, V, D, O, S, I> {
15    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
16    where
17        F: DDE<L, T, V, D>,
18    {        
19        // Initialize solver state
20        let mut evals = Evals::new();
21        self.t0 = t0;
22        self.t = t0;
23        self.y = *y0;
24        self.t_prev = self.t;
25        self.y_prev = self.y;
26        self.status = Status::Initialized;
27        self.steps = 0;
28        self.history = VecDeque::new();
29
30        // Initialize arrays for lags and delayed states
31        let mut lags = [T::zero(); L];
32        let mut yd = [V::zeros(); L];
33
34        // Evaluate initial lags and delayed states
35        if L > 0 {
36            dde.lags(self.t, &self.y, &mut lags);
37            for i in 0..L {
38                if lags[i] <= T::zero() {
39                    return Err(Error::BadInput {
40                        msg: "All lags must be positive.".to_string(),
41                    });
42                }
43                let t_delayed = self.t - lags[i];
44                // Ensure delayed time is within history range (t_delayed <= t0)
45                if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
46                    return Err(Error::BadInput {
47                        msg: format!(
48                            "Initial delayed time {} is out of history range (t <= {}).",
49                            t_delayed, t0
50                        ),
51                    });
52                }
53                yd[i] = phi(t_delayed);
54            }
55        }
56
57        // Calculate initial derivative
58        dde.diff(self.t, &self.y, &yd, &mut self.dydt);
59        evals.fcn += 1;
60        self.dydt_prev = self.dydt;        // Store initial state in history
61        self.history.push_back((self.t, self.y, self.dydt));
62
63        // Calculate initial step size h0 if not provided
64        if self.h0 == T::zero() {
65            // Simple default step size for fixed-step methods
66            let duration = (tf - t0).abs();
67            let default_steps = T::from_usize(100).unwrap();
68            self.h0 = duration / default_steps;
69        }
70
71        // Validate and set initial step size h
72        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
73            Ok(h0) => self.h = h0,
74            Err(status) => return Err(status),
75        }
76        Ok(evals)
77    }
78
79    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
80    where
81        F: DDE<L, T, V, D>,
82    {
83        let mut evals = Evals::new();
84
85        // Check maximum number of steps
86        if self.steps >= self.max_steps {
87            self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
88            return Err(Error::MaxSteps { t: self.t, y: self.y });
89        }
90        self.steps += 1;
91
92        // Initialize variables for the step
93        let mut lags = [T::zero(); L];
94        let mut yd = [V::zeros(); L];
95
96        // Store current derivative as k[0] for RK computations
97        self.k[0] = self.dydt;        // DDE: Determine if iterative approach for lag handling is needed
98        let mut min_lag_abs = T::infinity();
99        if L > 0 {
100            // Predict y at t+h using Euler step to estimate lags at t+h
101            let y_pred_for_lags = self.y + self.k[0] * self.h;
102            dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
103            for i in 0..L {
104                min_lag_abs = min_lag_abs.min(lags[i].abs());
105            }
106        }
107
108        // If lag values have to be extrapolated, we need to iterate for convergence
109        let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
110            5
111        } else {
112            1
113        };
114
115        let mut y_next_candidate_iter = self.y; // Approximated y at t+h, refined in DDE iterations
116        let mut dydt_next_candidate_iter = V::zeros(); // Derivative at t+h using y_next_candidate_iter
117        let mut y_prev_candidate_iter = self.y; // y_next_candidate_iter from previous DDE iteration
118        let mut dde_iteration_failed = false;
119
120        // DDE iteration loop (for handling implicit lags or just one pass for explicit)
121        for iter_idx in 0..max_iter {
122            if iter_idx > 0 {
123                y_prev_candidate_iter = y_next_candidate_iter;
124            }
125
126            // Compute Runge-Kutta stages
127            for i in 1..self.stages {
128                let mut y_stage = self.y;
129                for j in 0..i {
130                    y_stage += self.k[j] * (self.a[i][j] * self.h);
131                }
132                // Evaluate delayed states for the current stage
133                if L > 0 {
134                    dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
135                    self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
136                }
137                dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
138            }
139            evals.fcn += self.stages - 1; // k[0] was already available
140
141            // Compute solution
142            let mut y_next = self.y;
143            for i in 0..self.stages {
144                y_next += self.k[i] * (self.b[i] * self.h);
145            }
146
147            // DDE iteration convergence check (if max_iter > 1)
148            if max_iter > 1 && iter_idx > 0 {
149                let mut dde_iteration_error = T::zero();
150                let n_dim = self.y.len();
151                for i_dim in 0..n_dim {
152                    let scale = T::from_f64(1e-10).unwrap() + y_prev_candidate_iter.get(i_dim).abs().max(y_next.get(i_dim).abs());
153                    if scale > T::zero() {
154                        let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
155                        dde_iteration_error += (diff_val / scale).powi(2);
156                    }
157                }
158                if n_dim > 0 {
159                    dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
160                }
161
162                if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
163                    break; // DDE iteration converged
164                }
165                if iter_idx == max_iter - 1 { // Last iteration
166                    dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
167                }
168            }
169            y_next_candidate_iter = y_next; // Update candidate solution for t+h
170
171            // Compute derivative at t+h with the current candidate y_next_candidate_iter
172            if L > 0 {
173                dde.lags(self.t + self.h, &y_next_candidate_iter, &mut lags);
174                self.lagvals(self.t + self.h, &lags, &mut yd, phi);
175            }
176            dde.diff(self.t + self.h, &y_next_candidate_iter, &yd, &mut dydt_next_candidate_iter);
177            evals.fcn += 1;
178        } // End of DDE iteration loop
179
180        // Handle DDE iteration failure: reduce step size and retry
181        if dde_iteration_failed {
182            let sign = self.h.signum();
183            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
184            // Ensure step size is not smaller than a fraction of the minimum lag, if applicable
185            if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
186                self.h = min_lag_abs * sign; // Or some factor of min_lag_abs
187            }
188            self.status = Status::RejectedStep; // Indicate step rejection due to DDE iteration
189            return Ok(evals); // Return to retry step with smaller h
190        }
191
192        // Store current state before update for interpolation
193        self.t_prev = self.t;
194        self.y_prev = self.y;
195        self.dydt_prev = self.dydt;
196
197        // Update state to t + h
198        self.t += self.h;
199        self.y = y_next_candidate_iter;
200        
201        // Calculate new derivative for next step
202        if self.fsal {
203            // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
204            self.dydt = self.k[S - 1];
205        } else {
206            // Otherwise, compute the new derivative
207            if L > 0 {
208                dde.lags(self.t, &self.y, &mut lags);
209                self.lagvals(self.t, &lags, &mut yd, phi);
210            }
211            dde.diff(self.t, &self.y, &yd, &mut self.dydt);
212            evals.fcn += 1;
213        }
214
215        // Compute additional stages for dense output if available
216        if self.bi.is_some() {
217            for i in 0..(I - S) { // I is total stages, S is main method stages
218                let mut y_stage_dense = self.y_prev; // Use previous state as base
219                // Sum up contributions from previous k values for this dense stage
220                for j in 0..self.stages + i { // self.stages is S
221                    y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
222                }
223                // Evaluate lags and derivative for the dense stage
224                if L > 0 {
225                    dde.lags(self.t_prev + self.c[self.stages + i] * self.h, &y_stage_dense, &mut lags);
226                    self.lagvals(self.t_prev + self.c[self.stages + i] * self.h, &lags, &mut yd, phi);
227                }
228                dde.diff(self.t_prev + self.c[self.stages + i] * self.h, &y_stage_dense, &yd, &mut self.k[self.stages + i]);
229            }
230            evals.fcn += I - S; // Account for function evaluations for dense stages
231        }
232
233        // Update continuous output buffer and remove old entries if max_delay is set
234        self.history.push_back((self.t, self.y, self.dydt));
235        if let Some(max_delay) = self.max_delay {
236            let cutoff_time = self.t - max_delay;
237            while let Some((t_front, _, _)) = self.history.get(1){
238                if *t_front < cutoff_time {
239                    self.history.pop_front();
240                } else {
241                    break; // Stop pruning when we reach the cutoff time
242                }
243            }
244        }
245
246        self.status = Status::Solving;
247        Ok(evals)
248    }
249
250    fn t(&self) -> T { self.t }
251    fn y(&self) -> &V { &self.y }
252    fn t_prev(&self) -> T { self.t_prev }
253    fn y_prev(&self) -> &V { &self.y_prev }
254    fn h(&self) -> T { self.h }
255    fn set_h(&mut self, h: T) { self.h = h; }
256    fn status(&self) -> &Status<T, V, D> { &self.status }
257    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
258}
259
260impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, Fixed, T, V, D, O, S, I> {    
261    pub fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H) 
262    where 
263        H: Fn(T) -> V,
264    {
265        for i in 0..L {
266            let t_delayed = t_stage - lags[i];
267            
268            // Check if delayed time falls within the history period (t_delayed <= t0)
269            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
270                yd[i] = phi(t_delayed);
271            // If t_delayed is after t_prev then use interpolation function
272            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
273                if self.bi.is_some() {
274                    let s = (t_delayed - self.t_prev) / self.h;
275                    
276                    let bi_coeffs = self.bi.as_ref().unwrap();
277
278                    let mut cont = [T::zero(); I];
279                    for i in 0..I {
280                        if i < cont.len() && i < bi_coeffs.len() {
281                            cont[i] = bi_coeffs[i][self.dense_stages - 1];
282                            for j in (0..self.dense_stages - 1).rev() {
283                                cont[i] = cont[i] * s + bi_coeffs[i][j];
284                            }
285                            cont[i] *= s;
286                        }
287                    }
288
289                    let mut y_interp = self.y_prev;
290                    for i in 0..I {
291                        if i < self.k.len() && i < cont.len() {
292                            y_interp += self.k[i] * (cont[i] * self.h);
293                        }
294                    }
295                    yd[i] = y_interp;
296                } else {
297                    yd[i] = cubic_hermite_interpolate(
298                        self.t_prev, 
299                        self.t, 
300                        &self.y_prev, 
301                        &self.y, 
302                        &self.dydt_prev, 
303                        &self.dydt, 
304                        t_delayed
305                    );
306                }            // If t_delayed is before t_prev and after t0, we need to search in the history
307            } else {                // Search through history to find appropriate interpolation points
308                let mut found_interpolation = false;
309                let buffer = &self.history;
310                // Find two consecutive points that sandwich t_delayed using iterators
311                let mut buffer_iter = buffer.iter();
312                if let Some(mut prev_entry) = buffer_iter.next() {
313                    for curr_entry in buffer_iter {
314                        let (t_left, y_left, dydt_left) = prev_entry;
315                        let (t_right, y_right, dydt_right) = curr_entry;
316                        
317                        // Check if t_delayed is between these two points
318                        let is_between = if self.h.signum() > T::zero() {
319                            // Forward integration: t_left <= t_delayed <= t_right
320                            *t_left <= t_delayed && t_delayed <= *t_right
321                        } else {
322                            // Backward integration: t_right <= t_delayed <= t_left
323                            *t_right <= t_delayed && t_delayed <= *t_left
324                        };
325                        
326                        if is_between {
327                            // Use cubic Hermite interpolation between these points
328                            yd[i] = cubic_hermite_interpolate(
329                                *t_left,
330                                *t_right,
331                                y_left,
332                                y_right,
333                                dydt_left,
334                                dydt_right,
335                                t_delayed
336                            );
337                            found_interpolation = true;
338                            break;
339                        }
340                        prev_entry = curr_entry;
341                    }
342                }// If not found in history, this indicates insufficient history in buffer
343                if !found_interpolation {
344                    // Debug: show buffer contents
345                    let buffer = &self.history;
346                    println!("Buffer contents ({} entries):", buffer.len());
347                    for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
348                        if idx < 5 || idx >= buffer.len() - 5 {
349                            println!("  [{}] t = {}", idx, t_buf);
350                        } else if idx == 5 {
351                            println!("  ... ({} more entries) ...", buffer.len() - 10);
352                        }
353                    }
354                    panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
355                }
356            }
357        }
358    }
359}
360
361impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, Fixed, T, V, D, O, S, I> {
362    /// Interpolates the solution at a given time `t_interp`.
363    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
364        let posneg = self.h.signum();
365        if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
366            return Err(Error::OutOfBounds {
367                t_interp,
368                t_prev: self.t_prev,
369                t_curr: self.t,
370            });
371        }
372
373        // If method has dense output coefficients, use them
374        if self.bi.is_some() {
375            // Calculate the normalized distance within the step [0, 1]
376            let s = (t_interp - self.t_prev) / self.h_prev;
377            
378            // Get the interpolation coefficients
379            let bi = self.bi.as_ref().unwrap();
380
381            let mut cont = [T::zero(); I];
382            // Compute the interpolation coefficients using Horner's method
383            for i in 0..self.dense_stages {
384                // Start with the highest-order term
385                cont[i] = bi[i][self.order - 1];
386
387                // Apply Horner's method
388                for j in (0..self.order - 1).rev() {
389                    cont[i] = cont[i] * s + bi[i][j];
390                }
391
392                // Multiply by s
393                cont[i] *= s;
394            }
395
396            // Compute the interpolated value
397            let mut y_interp = self.y_prev;
398            for i in 0..I {
399                y_interp += self.k[i] * cont[i] * self.h_prev;
400            }
401
402            Ok(y_interp)
403        } else {
404            // Otherwise use cubic Hermite interpolation
405            let y_interp = cubic_hermite_interpolate(
406                self.t_prev, 
407                self.t, 
408                &self.y_prev, 
409                &self.y, 
410                &self.dydt_prev, 
411                &self.dydt, 
412                t_interp
413            );
414
415            Ok(y_interp)
416        }
417    }
418
419}