differential_equations/methods/erk/dormandprince/
delay.rs

1//! Dormand-Prince Runge-Kutta methods for DDEs
2
3use super::{ExplicitRungeKutta, Delay, DormandPrince};
4use crate::{
5    Error, Status,
6    methods::h_init::InitialStepSize,
7    alias::Evals,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    dde::{DelayNumericalMethod, DDE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
16    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17    where
18        F: DDE<L, T, V, D>,
19    {
20        let mut evals = Evals::new();
21
22        // Initialize solver state
23        self.t0 = t0;
24        self.t = t0;
25        self.y = *y0;
26        self.t_prev = self.t;
27        self.y_prev = self.y;
28        self.status = Status::Initialized;        
29        self.steps = 0;
30        self.stiffness_counter = 0;
31        self.non_stiffness_counter = 0;
32        self.history = VecDeque::new();
33
34        // Initialize arrays for lags and delayed states
35        let mut lags = [T::zero(); L];
36        let mut yd = [V::zeros(); L];
37
38        // Evaluate initial lags and delayed states
39        if L > 0 {
40            dde.lags(self.t, &self.y, &mut lags);
41            for i in 0..L {
42                if lags[i] <= T::zero() {
43                    return Err(Error::BadInput {
44                        msg: "All lags must be positive.".to_string(),
45                    });
46                }
47                let t_delayed = self.t - lags[i];
48                // Ensure delayed time is within history range (t_delayed <= t0)
49                if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
50                    return Err(Error::BadInput {
51                        msg: format!(
52                            "Delayed time {} is beyond initial time {}",
53                            t_delayed, t0
54                        ),
55                    });
56                }
57                yd[i] = phi(t_delayed);
58            }
59        }
60
61        // Calculate initial derivative
62        dde.diff(self.t, &self.y, &yd, &mut self.k[0]);
63        self.dydt = self.k[0];
64        evals.fcn += 1;
65        self.dydt_prev = self.dydt;
66
67        // Store initial state in history
68        self.history.push_back((self.t, self.y, self.dydt));
69
70        // Calculate initial step size h0 if not provided
71        if self.h0 == T::zero() {
72            // Use Dormand-Prince specific step size calculation  
73            self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
74        }
75
76        // Validate and set initial step size h
77        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
78            Ok(h0) => self.h = h0,
79            Err(status) => return Err(status),
80        }
81        Ok(evals)
82    }
83
84    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
85    where
86        F: DDE<L, T, V, D>,
87    {
88        let mut evals = Evals::new();
89
90        // Validate step size
91        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
92            self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
93            return Err(Error::StepSize { t: self.t, y: self.y });
94        }
95
96        // Check maximum number of steps
97        if self.steps >= self.max_steps {
98            self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
99            return Err(Error::MaxSteps { t: self.t, y: self.y });
100        }
101        self.steps += 1;
102
103        // Initialize variables for the step
104        let mut lags = [T::zero(); L];
105        let mut yd = [V::zeros(); L];
106
107        // DDE: Determine if iterative approach for lag handling is needed
108        let mut min_lag_abs = T::infinity();
109        if L > 0 {
110            // Predict y at t+h using Euler step to estimate lags at t+h
111            let y_pred_for_lags = self.y + self.k[0] * self.h;
112            dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
113            for i in 0..L {
114                min_lag_abs = min_lag_abs.min(lags[i].abs());
115            }
116        }
117
118        // If lag values have to be extrapolated, we need to iterate for convergence
119        let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
120            5
121        } else {
122            1
123        };        let mut y_next_candidate_iter = self.y; // Approximated y at t+h, refined in DDE iterations
124        let mut y_prev_candidate_iter = self.y; // y_next_candidate_iter from previous DDE iteration
125        let mut dde_iteration_failed = false;
126        let mut err: T = T::zero(); // Error norm for step size control
127        let mut ysti = V::zeros(); // Store last stage for stiffness detection
128
129        // DDE iteration loop (for handling implicit lags or just one pass for explicit)
130        for iter_idx in 0..max_iter {
131            if iter_idx > 0 {
132                y_prev_candidate_iter = y_next_candidate_iter;
133            }
134
135            // Compute Runge-Kutta stages
136            let mut y_stage = V::zeros();
137            for i in 1..self.stages {
138                y_stage = V::zeros();
139                for j in 0..i {
140                    y_stage += self.k[j] * self.a[i][j];
141                }
142                y_stage = self.y + y_stage * self.h;
143
144                // Evaluate delayed states for the current stage
145                if L > 0 {
146                    dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
147                    self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
148                }
149                dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
150            }
151            evals.fcn += self.stages - 1; // k[0] was already available
152
153            // Store the last stage for stiffness detection
154            ysti = y_stage;
155
156            // Calculate the line segment for the new y value
157            let mut yseg = V::zeros();
158            for i in 0..self.stages {
159                yseg += self.k[i] * self.b[i];
160            }
161
162            // Calculate the new y value using the line segment
163            let y_new = self.y + yseg * self.h;
164
165            // Dormand-Prince error estimation
166            let er = self.er.unwrap();
167            let n = self.y.len();
168            let mut err_val = T::zero();
169            let mut err2 = T::zero();
170            let mut erri;
171            for i in 0..n {
172                // Calculate the error scale
173                let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
174
175                // Primary error term
176                erri = T::zero();
177                for j in 0..self.stages {
178                    erri += er[j] * self.k[j].get(i);
179                }
180                err_val += (erri / sk).powi(2);
181
182                // Optional secondary error term
183                if let Some(bh) = &self.bh {
184                    erri = yseg.get(i);
185                    for j in 0..self.stages {
186                        erri -= bh[j] * self.k[j].get(i);
187                    }
188                    err2 += (erri / sk).powi(2);
189                }
190            }
191            let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
192            if deno <= T::zero() {
193                deno = T::one();
194            }
195            err = self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
196
197            // DDE iteration convergence check (if max_iter > 1)
198            if max_iter > 1 && iter_idx > 0 {
199                let mut dde_iteration_error = T::zero();
200                let n_dim = self.y.len();
201                for i_dim in 0..n_dim {
202                    let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_new.get(i_dim).abs());
203                    if scale > T::zero() {
204                        let diff_val = y_new.get(i_dim) - y_prev_candidate_iter.get(i_dim);
205                        dde_iteration_error += (diff_val / scale).powi(2);
206                    }
207                }
208                if n_dim > 0 {
209                    dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
210                }
211
212                if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
213                    break; // DDE iteration converged
214                }
215                if iter_idx == max_iter - 1 { // Last iteration
216                    dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
217                }
218            }
219            y_next_candidate_iter = y_new; // Update candidate solution for t+h
220
221            // Store ysti for potential stiffness detection
222            if iter_idx == max_iter - 1 || max_iter == 1 {
223                // Keep ysti from the final iteration for stiffness detection
224            }
225        } // End of DDE iteration loop
226
227        // Handle DDE iteration failure: reduce step size and retry
228        if dde_iteration_failed {
229            let sign = self.h.signum();
230            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
231            // Ensure step size is not smaller than a fraction of the minimum lag, if applicable
232            if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
233                self.h = min_lag_abs * sign;
234            }
235            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
236            self.status = Status::RejectedStep;
237            return Ok(evals); // Return to retry step with smaller h
238        }
239
240        // Step size scale factor
241        let order = T::from_usize(self.order).unwrap();
242        let error_exponent = T::one() / order;
243        let mut scale = self.safety_factor * err.powf(-error_exponent);
244        
245        // Clamp scale factor to prevent extreme step size changes
246        scale = scale.max(self.min_scale).min(self.max_scale);
247
248        // Step acceptance/rejection logic
249        if err <= T::one() { // Step accepted
250            let y_new = y_next_candidate_iter;
251            let t_new = self.t + self.h;
252
253            // Calculate the new derivative at the new point
254            if L > 0 {
255                dde.lags(t_new, &y_new, &mut lags);
256                self.lagvals(t_new, &lags, &mut yd, phi);
257            }
258            dde.diff(t_new, &y_new, &yd, &mut self.dydt);
259            evals.fcn += 1;            // Stiffness detection (every 100 steps)
260            let n_stiff_threshold = 100;
261            if self.steps % n_stiff_threshold == 0 {
262                let mut stdnum = T::zero();
263                let mut stden = T::zero();
264                let sqr = {
265                    let mut yseg = V::zeros();
266                    for i in 0..self.stages {
267                        yseg += self.k[i] * self.b[i];
268                    }
269                    yseg - self.k[S-1]
270                };
271                for i in 0..sqr.len() {
272                    stdnum += sqr.get(i).powi(2);
273                }
274                let sqr = self.dydt - ysti;
275                for i in 0..sqr.len() {
276                    stden += sqr.get(i).powi(2);
277                }
278
279                if stden > T::zero() {
280                    let h_lamb = self.h * (stdnum / stden).sqrt();
281                    if h_lamb > T::from_f64(6.1).unwrap() {
282                        self.non_stiffness_counter = 0;
283                        self.stiffness_counter += 1;
284                        if self.stiffness_counter == 15 {
285                            self.status = Status::Error(Error::Stiffness {
286                                t: self.t,
287                                y: self.y,
288                            });
289                            return Err(Error::Stiffness {
290                                t: self.t,
291                                y: self.y,
292                            });
293                        }
294                    }
295                } else {
296                    self.non_stiffness_counter += 1;
297                    if self.non_stiffness_counter == 6 {
298                        self.stiffness_counter = 0;
299                    }
300                }
301            }
302
303            // Preparation for dense output / interpolation
304            self.cont[0] = self.y;
305            let ydiff = y_new - self.y;
306            self.cont[1] = ydiff;
307            let bspl = self.k[0] * self.h - ydiff;
308            self.cont[2] = bspl;
309            self.cont[3] = ydiff - self.dydt * self.h - bspl;
310
311            // If method has dense output stages, compute them
312            if let Some(bi) = &self.bi {
313                // Compute extra stages for dense output
314                if I > S {
315                    // First dense output coefficient, k{i=order+1}, is the derivative at the new point
316                    self.k[self.stages] = self.dydt;                    for i in S+1..I {
317                        let mut y_stage = V::zeros();
318                        for j in 0..i {
319                            y_stage += self.k[j] * self.a[i][j];
320                        }
321                        y_stage = self.y + y_stage * self.h;
322
323                        if L > 0 {
324                            dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
325                            // Manually inline the lagvals logic to avoid borrowing conflicts
326                            for lag_idx in 0..L {
327                                let t_delayed = (self.t + self.c[i] * self.h) - lags[lag_idx];
328                                
329                                // Check if delayed time falls within the history period (t_delayed <= t0)
330                                if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
331                                    yd[lag_idx] = phi(t_delayed);
332                                // If t_delayed is after t_prev then use interpolation function
333                                } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
334                                    if self.bi.is_some() {
335                                        let s = (t_delayed - self.t_prev) / self.h_prev;
336                                        let s1 = T::one() - s;        
337                                        let ilast = self.cont.len() - 1;
338                                        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {            
339                                            let factor = if cont_i >= 4 {
340                                                if (ilast - cont_i) % 2 == 1 { s1 } else { s }
341                                            } else {
342                                                if cont_i % 2 == 1 { s1 } else { s }
343                                            };
344                                            acc * factor + self.cont[cont_i]
345                                        });
346                                        yd[lag_idx] = self.cont[0] + poly * s;
347                                    } else {
348                                        yd[lag_idx] = cubic_hermite_interpolate(
349                                            self.t_prev, 
350                                            self.t, 
351                                            &self.y_prev, 
352                                            &self.y, 
353                                            &self.dydt_prev, 
354                                            &self.dydt, 
355                                            t_delayed
356                                        );
357                                    }
358                                } else {
359                                    // Search through history to find appropriate interpolation points
360                                    let mut found_interpolation = false;
361                                    let buffer = &self.history;
362                                    let mut buffer_iter = buffer.iter();
363                                    if let Some(mut prev_entry) = buffer_iter.next() {
364                                        for curr_entry in buffer_iter {
365                                            let (t_left, y_left, dydt_left) = prev_entry;
366                                            let (t_right, y_right, dydt_right) = curr_entry;
367                                            
368                                            let is_between = if self.h.signum() > T::zero() {
369                                                *t_left <= t_delayed && t_delayed <= *t_right
370                                            } else {
371                                                *t_right <= t_delayed && t_delayed <= *t_left
372                                            };
373                                            
374                                            if is_between {
375                                                yd[lag_idx] = cubic_hermite_interpolate(
376                                                    *t_left,
377                                                    *t_right,
378                                                    y_left,
379                                                    y_right,
380                                                    dydt_left,
381                                                    dydt_right,
382                                                    t_delayed
383                                                );
384                                                found_interpolation = true;
385                                                break;
386                                            }
387                                            prev_entry = curr_entry;
388                                        }
389                                    }
390                                    if !found_interpolation {
391                                        panic!("Insufficient history for t_delayed = {} (t_prev = {}, t = {})", t_delayed, self.t_prev, self.t);
392                                    }
393                                }
394                            }
395                        }
396                        dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
397                        evals.fcn += 1;
398                    }
399                }
400
401                // Compute dense output coefficients
402                for i in 4..self.order {
403                    self.cont[i] = V::zeros();
404                    for j in 0..self.dense_stages {
405                        self.cont[i] += self.k[j] * bi[i][j];
406                    }
407                    self.cont[i] = self.cont[i] * self.h;
408                }
409            }
410
411            // For interpolation
412            self.t_prev = self.t;
413            self.y_prev = self.y;
414            self.dydt_prev = self.k[0];
415            self.h_prev = self.h;
416
417            // Update state to t + h
418            self.t = t_new;
419            self.y = y_new;
420            self.k[0] = self.dydt;
421
422            // Update continuous output buffer and remove old entries if max_delay is set
423            self.history.push_back((self.t, self.y, self.dydt));
424            if let Some(max_delay) = self.max_delay {
425                let cutoff_time = self.t - max_delay;
426                while let Some((t_front, _, _)) = self.history.get(1){
427                    if *t_front < cutoff_time {
428                        self.history.pop_front();
429                    } else {
430                        break;
431                    }
432                }
433            }            // Check if previous step is rejected
434            if let Status::RejectedStep = self.status {
435                self.status = Status::Solving;
436
437                // Limit step size growth to avoid oscillations between accepted and rejected steps
438                scale = scale.min(T::one());
439            }
440        } else {
441            // Step Rejected
442            self.status = Status::RejectedStep;
443        }
444
445        // Update step size
446        self.h *= scale;
447
448        // Ensure step size is within bounds
449        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
450        
451        Ok(evals)
452    }
453
454    fn t(&self) -> T { self.t }
455    fn y(&self) -> &V { &self.y }
456    fn t_prev(&self) -> T { self.t_prev }
457    fn y_prev(&self) -> &V { &self.y_prev }
458    fn h(&self) -> T { self.h }
459    fn set_h(&mut self, h: T) { self.h = h; }
460    fn status(&self) -> &Status<T, V, D> { &self.status }
461    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
462}
463
464impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {    
465    fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H) 
466    where 
467        H: Fn(T) -> V,
468    {
469        for i in 0..L {
470            let t_delayed = t_stage - lags[i];
471            
472            // Check if delayed time falls within the history period (t_delayed <= t0)
473            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
474                yd[i] = phi(t_delayed);
475            // If t_delayed is after t_prev then use interpolation function
476            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
477                if self.bi.is_some() {
478                    let s = (t_delayed - self.t_prev) / self.h_prev;
479                    
480                    // Evaluate the interpolation polynomial at the requested time
481                    let s1 = T::one() - s;        
482                    
483                    // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
484                    let ilast = self.cont.len() - 1;
485                    let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {            
486                        let factor = if i >= 4 {
487                            // For the higher-order part (conpar), alternate s and s1 based on index parity
488                            if (ilast - i) % 2 == 1 { s1 } else { s }
489                        } else {
490                            // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
491                            if i % 2 == 1 { s1 } else { s }
492                        };
493                        acc * factor + self.cont[i]
494                    });
495                    
496                    // Final multiplication by s for the outermost level
497                    let y_interp = self.cont[0] + poly * s;
498                    yd[i] = y_interp;
499                } else {
500                    yd[i] = cubic_hermite_interpolate(
501                        self.t_prev, 
502                        self.t, 
503                        &self.y_prev, 
504                        &self.y, 
505                        &self.dydt_prev, 
506                        &self.dydt, 
507                        t_delayed
508                    );
509                }
510            // If t_delayed is before t_prev and after t0, we need to search in the history
511            } else {
512                // Search through history to find appropriate interpolation points
513                let mut found_interpolation = false;
514                let buffer = &self.history;
515                // Find two consecutive points that sandwich t_delayed using iterators
516                let mut buffer_iter = buffer.iter();
517                if let Some(mut prev_entry) = buffer_iter.next() {
518                    for curr_entry in buffer_iter {
519                        let (t_left, y_left, dydt_left) = prev_entry;
520                        let (t_right, y_right, dydt_right) = curr_entry;
521                        
522                        // Check if t_delayed is between these two points
523                        let is_between = if self.h.signum() > T::zero() {
524                            *t_left <= t_delayed && t_delayed <= *t_right
525                        } else {
526                            *t_right <= t_delayed && t_delayed <= *t_left
527                        };
528                        
529                        if is_between {
530                            yd[i] = cubic_hermite_interpolate(
531                                *t_left,
532                                *t_right,
533                                y_left,
534                                y_right,
535                                dydt_left,
536                                dydt_right,
537                                t_delayed
538                            );
539                            found_interpolation = true;
540                            break;
541                        }
542                        prev_entry = curr_entry;
543                    }
544                }
545                // If not found in history, this indicates insufficient history in buffer
546                if !found_interpolation {
547                    // Debug: show buffer contents
548                    let buffer = &self.history;
549                    println!("Buffer contents ({} entries):", buffer.len());
550                    for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
551                        println!("  [{}]: t = {}", idx, t_buf);
552                    }
553                    panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
554                }
555            }
556        }
557    }
558}
559
560impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
561    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {        
562        // Check if interpolation is out of bounds
563        let posneg = (self.t - self.t_prev).signum();
564        if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
565            return Err(Error::OutOfBounds {
566                t_interp,
567                t_prev: self.t_prev,
568                t_curr: self.t,
569            });
570        }        
571        
572        // Evaluate the interpolation polynomial at the requested time
573        let s = (t_interp - self.t_prev) / self.h_prev;
574        let s1 = T::one() - s;        
575        
576        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
577        let ilast = self.cont.len() - 1;
578        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {            
579            let factor = if i >= 4 {
580                // For the higher-order part (conpar), alternate s and s1 based on index parity
581                if (ilast - i) % 2 == 1 { s1 } else { s }
582            } else {
583                // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
584                if i % 2 == 1 { s1 } else { s }
585            };
586            acc * factor + self.cont[i]
587        });
588        
589        // Final multiplication by s for the outermost level
590        let y_interp = self.cont[0] + poly * s;
591
592        Ok(y_interp)
593    }
594}