differential_equations/methods/erk/adaptive/
delay.rs

1//! Adaptive Runge-Kutta methods for DDEs
2
3use super::{ExplicitRungeKutta, Delay, Adaptive};
4use crate::{
5    Error, Status,
6    methods::h_init::InitialStepSize,
7    alias::Evals,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    dde::{DelayNumericalMethod, DDE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
16    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17    where
18        F: DDE<L, T, V, D>,
19    {
20        let mut evals = Evals::new();
21
22        // Initialize solver state
23        self.t0 = t0;
24        self.t = t0;
25        self.y = *y0;
26        self.t_prev = self.t;
27        self.y_prev = self.y;
28        self.status = Status::Initialized;
29        self.steps = 0;
30        self.stiffness_counter = 0;
31        self.history = VecDeque::new();
32
33        // Initialize arrays for lags and delayed states
34        let mut lags = [T::zero(); L];
35        let mut yd = [V::zeros(); L];
36
37        // Evaluate initial lags and delayed states
38        if L > 0 {
39            dde.lags(self.t, &self.y, &mut lags);
40            for i in 0..L {
41                if lags[i] <= T::zero() {
42                    return Err(Error::BadInput {
43                        msg: "All lags must be positive.".to_string(),
44                    });
45                }
46                let t_delayed = self.t - lags[i];
47                // Ensure delayed time is within history range (t_delayed <= t0)
48                if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
49                    return Err(Error::BadInput {
50                        msg: format!(
51                            "Initial delayed time {} is out of history range (t <= {}).",
52                            t_delayed, t0
53                        ),
54                    });
55                }
56                yd[i] = phi(t_delayed);
57            }
58        }
59
60        // Calculate initial derivative
61        dde.diff(self.t, &self.y, &yd, &mut self.dydt);
62        evals.fcn += 1;
63        self.dydt_prev = self.dydt;        // Store initial state in history
64        self.history.push_back((self.t, self.y, self.dydt));
65
66        // Calculate initial step size h0 if not provided
67        if self.h0 == T::zero() {
68            // Adaptive step method
69            self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
70            evals.fcn += 2; // h_init performs 2 function evaluations
71        }
72
73        // Validate and set initial step size h
74        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
75            Ok(h0) => self.h = h0,
76            Err(status) => return Err(status),
77        }
78        Ok(evals)
79    }
80
81    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
82    where
83        F: DDE<L, T, V, D>,
84    {
85        let mut evals = Evals::new();
86
87        // Validate step size
88        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
89            self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
90            return Err(Error::StepSize { t: self.t, y: self.y });
91        }
92
93        // Check maximum number of steps
94        if self.steps >= self.max_steps {
95            self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
96            return Err(Error::MaxSteps { t: self.t, y: self.y });
97        }
98        self.steps += 1;
99
100        // Initialize variables for the step
101        let mut lags = [T::zero(); L];
102        let mut yd = [V::zeros(); L];
103
104        // Store current derivative as k[0] for RK computations
105        self.k[0] = self.dydt;
106
107        // DDE: Determine if iterative approach for lag handling is needed
108        let mut min_lag_abs = T::infinity();
109        if L > 0 {
110            // Predict y at t+h using Euler step to estimate lags at t+h
111            let y_pred_for_lags = self.y + self.k[0] * self.h;
112            dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
113            for i in 0..L {
114                min_lag_abs = min_lag_abs.min(lags[i].abs());
115            }
116        }
117
118        // If lag values have to be extrapolated, we need to iterate for convergence
119        let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
120            5
121        } else {
122            1
123        };
124
125        let mut y_next_candidate_iter = self.y; // Approximated y at t+h, refined in DDE iterations
126        let mut dydt_next_candidate_iter = V::zeros(); // Derivative at t+h using y_next_candidate_iter
127        let mut y_prev_candidate_iter = self.y; // y_next_candidate_iter from previous DDE iteration
128        let mut dde_iteration_failed = false;
129        let mut err_norm: T = T::zero(); // Error norm for step size control
130
131        // DDE iteration loop (for handling implicit lags or just one pass for explicit)
132        for iter_idx in 0..max_iter {
133            if iter_idx > 0 {
134                y_prev_candidate_iter = y_next_candidate_iter;
135            }
136
137            // Compute Runge-Kutta stages
138            for i in 1..self.stages {
139                let mut y_stage = self.y;
140                for j in 0..i {
141                    y_stage += self.k[j] * (self.a[i][j] * self.h);
142                }
143                // Evaluate delayed states for the current stage
144                if L > 0 {
145                    dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
146                    self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
147                }
148                dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
149            }
150            evals.fcn += self.stages - 1; // k[0] was already available
151
152            // Adaptive methods: compute high and low order solutions for error estimation
153            let mut y_high = self.y; // Higher order solution
154            for i in 0..self.stages {
155                y_high += self.k[i] * (self.b[i] * self.h);
156            }
157            let mut y_low = self.y; // Lower order solution (for error estimation)
158            if let Some(bh_coeffs) = &self.bh {
159                for i in 0..self.stages {
160                    y_low += self.k[i] * (bh_coeffs[i] * self.h);
161                }
162            }
163            let err_vec: V = y_high - y_low; // Error vector
164
165            // Calculate error norm (||err||)
166            err_norm = T::zero();
167            for n in 0..self.y.len() {
168                let tol = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
169                err_norm = err_norm.max((err_vec.get(n) / tol).abs());
170            }
171
172            // DDE iteration convergence check (if max_iter > 1)
173            if max_iter > 1 && iter_idx > 0 {
174                let mut dde_iteration_error = T::zero();
175                let n_dim = self.y.len();
176                for i_dim in 0..n_dim {
177                    let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_high.get(i_dim).abs());
178                    if scale > T::zero() {
179                        let diff_val = y_high.get(i_dim) - y_prev_candidate_iter.get(i_dim);
180                        dde_iteration_error += (diff_val / scale).powi(2);
181                    }
182                }
183                if n_dim > 0 {
184                    dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
185                }
186
187                if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
188                    break; // DDE iteration converged
189                }
190                if iter_idx == max_iter - 1 { // Last iteration
191                    dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
192                }
193            }
194            y_next_candidate_iter = y_high; // Update candidate solution for t+h
195
196            // Compute derivative at t+h with the current candidate y_next_candidate_iter
197            if L > 0 {
198                dde.lags(self.t + self.h, &y_next_candidate_iter, &mut lags);
199                self.lagvals(self.t + self.h, &lags, &mut yd, phi);
200            }
201            dde.diff(self.t + self.h, &y_next_candidate_iter, &yd, &mut dydt_next_candidate_iter);
202            evals.fcn += 1;
203        } // End of DDE iteration loop
204
205        // Handle DDE iteration failure: reduce step size and retry
206        if dde_iteration_failed {
207            let sign = self.h.signum();
208            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
209            // Ensure step size is not smaller than a fraction of the minimum lag, if applicable
210            if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
211                self.h = min_lag_abs * sign; // Or some factor of min_lag_abs
212            }
213            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
214            self.status = Status::RejectedStep; // Indicate step rejection due to DDE iteration
215            // self.k[0] = self.dydt; // k[0] is already self.dydt, no need to reset
216            return Ok(evals); // Return to retry step with smaller h
217        }
218
219        // Step size controller for adaptive methods (based on err_norm)
220        let order_t = T::from_usize(self.order).unwrap();
221        let err_order_inv = T::one() / order_t;
222        let mut scale_factor = self.safety_factor * err_norm.powf(-err_order_inv);
223        scale_factor = scale_factor.max(self.min_scale).min(self.max_scale);
224        
225        let h_new = self.h * scale_factor;
226
227        // Step acceptance/rejection logic
228        if err_norm <= T::one() { // Step accepted
229            self.t_prev = self.t;
230            self.y_prev = self.y;
231            self.dydt_prev = self.dydt; // Derivative at t_prev
232            self.h_prev = self.h; // Store accepted step size
233
234            if let Status::RejectedStep = self.status { // If previous step was rejected
235                self.stiffness_counter = 0;
236            }
237            self.status = Status::Solving;
238
239            // Compute additional stages for dense output if available
240            if self.bi.is_some() {
241                for i in 0..(I - S) { // I is total stages, S is main method stages
242                    let mut y_stage_dense = self.y; // Base for dense stage calculation
243                    // Sum up contributions from previous k values for this dense stage
244                    for j in 0..self.stages + i { // self.stages is S
245                        y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
246                    }
247                    // Evaluate lags and derivative for the dense stage
248                    if L > 0 {
249                        dde.lags(self.t + self.c[self.stages + i] * self.h, &y_stage_dense, &mut lags);
250                        self.lagvals(self.t + self.c[self.stages + i] * self.h, &lags, &mut yd, phi);
251                    }
252                    dde.diff(self.t + self.c[self.stages + i] * self.h, &y_stage_dense, &yd, &mut self.k[self.stages + i]);
253                }
254                evals.fcn += I - S; // Account for function evaluations for dense stages
255            }
256
257            // Update state to t + h
258            self.t += self.h;
259            self.y = y_next_candidate_iter;
260            // Update derivative at the new state (self.t, self.y)
261            // Compute the derivative for the next step
262            if self.fsal {
263                // If FSAL (First Same As Last) is enabled, we can reuse the last derivative
264                self.dydt = self.k[S - 1];
265            } else {
266                if L > 0 {
267                    dde.lags(self.t, &self.y, &mut lags);
268                    self.lagvals(self.t, &lags, &mut yd, phi);
269                }
270                // Otherwise, compute the new derivative
271                dde.diff(self.t, &self.y, &yd, &mut self.dydt);
272                evals.fcn += 1;
273            }
274
275            // Update continuous output buffer and remove old entries if max_delay is set
276            self.history.push_back((self.t, self.y, self.dydt));
277            if let Some(max_delay) = self.max_delay {
278                let cutoff_time = self.t - max_delay;
279                while let Some((t_front, _, _)) = self.history.get(1){
280                    if *t_front < cutoff_time {
281                        self.history.pop_front();
282                    } else {
283                        break; // Stop pruning when we reach the cutoff time
284                    }
285                }
286            }
287
288            self.h = constrain_step_size(h_new, self.h_min, self.h_max); // Set next step size
289        } else { // Step rejected
290            self.status = Status::RejectedStep;
291            self.stiffness_counter += 1;
292
293            // Check for excessive rejections (potential stiffness)
294            if self.stiffness_counter >= self.max_rejects {
295                self.status = Status::Error(Error::Stiffness { t: self.t, y: self.y });
296                return Err(Error::Stiffness { t: self.t, y: self.y });
297            }
298            // Reduce step size for next attempt
299            self.h = constrain_step_size(h_new, self.h_min, self.h_max);
300        }
301        Ok(evals)
302    }
303
304    fn t(&self) -> T { self.t }
305    fn y(&self) -> &V { &self.y }
306    fn t_prev(&self) -> T { self.t_prev }
307    fn y_prev(&self) -> &V { &self.y_prev }
308    fn h(&self) -> T { self.h }
309    fn set_h(&mut self, h: T) { self.h = h; }
310    fn status(&self) -> &Status<T, V, D> { &self.status }
311    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
312}
313
314impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {    
315    fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H) 
316    where 
317        H: Fn(T) -> V,
318    {
319        for i in 0..L {
320            let t_delayed = t_stage - lags[i];
321            
322            // Check if delayed time falls within the history period (t_delayed <= t0)
323            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
324                yd[i] = phi(t_delayed);
325            // If t_delayed is after t_prev then use interpolation function
326            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
327                if self.bi.is_some() {
328                    let s = (t_delayed - self.t_prev) / self.h_prev;
329                    
330                    let bi_coeffs = self.bi.as_ref().unwrap();
331
332                    let mut cont = [T::zero(); I];
333                    for i in 0..I {
334                        if i < self.cont.len() && i < bi_coeffs.len() {
335                            cont[i] = bi_coeffs[i][self.dense_stages - 1];
336                            for j in (0..self.dense_stages - 1).rev() {
337                                cont[i] = cont[i] * s + bi_coeffs[i][j];
338                            }
339                            cont[i] *= s;
340                        }
341                    }
342
343                    let mut y_interp = self.y_prev;
344                    for i in 0..I {
345                        if i < self.k.len() && i < self.cont.len() {
346                            y_interp += self.k[i] * (cont[i] * self.h_prev);
347                        }
348                    }
349                    yd[i] = y_interp;
350                } else {
351                    yd[i] = cubic_hermite_interpolate(
352                        self.t_prev, 
353                        self.t, 
354                        &self.y_prev, 
355                        &self.y, 
356                        &self.dydt_prev, 
357                        &self.dydt, 
358                        t_delayed
359                    );
360                }            // If t_delayed is before t_prev and after t0, we need to search in the history
361            } else {                // Search through history to find appropriate interpolation points
362                let mut found_interpolation = false;
363                let buffer = &self.history;
364                // Find two consecutive points that sandwich t_delayed using iterators
365                let mut buffer_iter = buffer.iter();
366                if let Some(mut prev_entry) = buffer_iter.next() {
367                    for curr_entry in buffer_iter {
368                        let (t_left, y_left, dydt_left) = prev_entry;
369                        let (t_right, y_right, dydt_right) = curr_entry;
370                        
371                        // Check if t_delayed is between these two points
372                        let is_between = if self.h.signum() > T::zero() {
373                            // Forward integration: t_left <= t_delayed <= t_right
374                            *t_left <= t_delayed && t_delayed <= *t_right
375                        } else {
376                            // Backward integration: t_right <= t_delayed <= t_left
377                            *t_right <= t_delayed && t_delayed <= *t_left
378                        };
379                        
380                        if is_between {
381                            // Use cubic Hermite interpolation between these points
382                            yd[i] = cubic_hermite_interpolate(
383                                *t_left,
384                                *t_right,
385                                y_left,
386                                y_right,
387                                dydt_left,
388                                dydt_right,
389                                t_delayed
390                            );
391                            found_interpolation = true;
392                            break;
393                        }
394                        prev_entry = curr_entry;
395                    }
396                }// If not found in history, this indicates insufficient history in buffer
397                if !found_interpolation {
398                    // Debug: show buffer contents
399                    let buffer = &self.history;
400                    println!("Buffer contents ({} entries):", buffer.len());
401                    for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
402                        if idx < 5 || idx >= buffer.len() - 5 {
403                            println!("  [{}] t = {}", idx, t_buf);
404                        } else if idx == 5 {
405                            println!("  ... ({} more entries) ...", buffer.len() - 10);
406                        }
407                    }
408                    panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
409                }
410            }
411        }
412    }
413}
414
415impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, Adaptive, T, V, D, O, S, I> {
416    /// Interpolates the solution at a given time `t_interp`.
417    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
418        let posneg = (self.t - self.t_prev).signum();
419        if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
420            return Err(Error::OutOfBounds {
421                t_interp,
422                t_prev: self.t_prev,
423                t_curr: self.t,
424            });
425        }
426
427        // If method has dense output coefficients, use them
428        if self.bi.is_some() {
429            // Calculate the normalized distance within the step [0, 1]
430            let s = (t_interp - self.t_prev) / self.h_prev;
431            
432            // Get the interpolation coefficients
433            let bi = self.bi.as_ref().unwrap();
434
435            let mut cont = [T::zero(); I];
436            // Compute the interpolation coefficients using Horner's method
437            for i in 0..self.dense_stages {
438                // Start with the highest-order term
439                cont[i] = bi[i][self.order - 1];
440
441                // Apply Horner's method
442                for j in (0..self.order - 1).rev() {
443                    cont[i] = cont[i] * s + bi[i][j];
444                }
445
446                // Multiply by s
447                cont[i] *= s;
448            }
449
450            // Compute the interpolated value
451            let mut y_interp = self.y_prev;
452            for i in 0..I {
453                y_interp += self.k[i] * cont[i] * self.h_prev;
454            }
455
456            Ok(y_interp)
457        } else {
458            // Otherwise use cubic Hermite interpolation
459            let y_interp = cubic_hermite_interpolate(
460                self.t_prev, 
461                self.t, 
462                &self.y_prev, 
463                &self.y, 
464                &self.dydt_prev, 
465                &self.dydt, 
466                t_interp
467            );
468
469            Ok(y_interp)
470        }
471    }
472}