differential_equations/methods/erk/dormandprince/
delay.rs

1//! Dormand-Prince Runge-Kutta methods for DDEs
2
3use super::{ExplicitRungeKutta, Delay, DormandPrince};
4use crate::{
5    Error, Status,
6    methods::h_init::InitialStepSize,
7    alias::Evals,
8    interpolate::{Interpolation, cubic_hermite_interpolate},
9    dde::{DelayNumericalMethod, DDE},
10    traits::{CallBackData, Real, State},
11    utils::{constrain_step_size, validate_step_size_parameters},
12};
13use std::collections::VecDeque;
14
15impl<const L: usize, T: Real, V: State<T>, H: Fn(T) -> V, D: CallBackData, const O: usize, const S: usize, const I: usize> DelayNumericalMethod<L, T, V, H, D> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
16    fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &V, phi: &H) -> Result<Evals, Error<T, V>>
17    where
18        F: DDE<L, T, V, D>,
19    {
20        let mut evals = Evals::new();
21
22        // Initialize solver state
23        self.t0 = t0;
24        self.t = t0;
25        self.y = *y0;
26        self.t_prev = self.t;
27        self.y_prev = self.y;
28        self.status = Status::Initialized;        self.steps = 0;
29        self.stiffness_counter = 0;
30        self.non_stiffness_counter = 0;
31        self.history = VecDeque::new();
32
33        // Initialize arrays for lags and delayed states
34        let mut lags = [T::zero(); L];
35        let mut yd = [V::zeros(); L];
36
37        // Evaluate initial lags and delayed states
38        if L > 0 {
39            dde.lags(self.t, &self.y, &mut lags);
40            for i in 0..L {
41                if lags[i] <= T::zero() {
42                    return Err(Error::BadInput {
43                        msg: "All lags must be positive.".to_string(),
44                    });
45                }
46                let t_delayed = self.t - lags[i];
47                // Ensure delayed time is within history range (t_delayed <= t0)
48                if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
49                    return Err(Error::BadInput {
50                        msg: format!(
51                            "Delayed time {} is beyond initial time {}",
52                            t_delayed, t0
53                        ),
54                    });
55                }
56                yd[i] = phi(t_delayed);
57            }
58        }
59
60        // Calculate initial derivative
61        dde.diff(self.t, &self.y, &yd, &mut self.k[0]);
62        self.dydt = self.k[0];
63        evals.fcn += 1;
64        self.dydt_prev = self.dydt;
65
66        // Store initial state in history
67        self.history.push_back((self.t, self.y, self.dydt));
68
69        // Calculate initial step size h0 if not provided
70        if self.h0 == T::zero() {
71            // Use Dormand-Prince specific step size calculation  
72            self.h0 = InitialStepSize::<Delay>::compute(dde, t0, tf, y0, self.order, self.rtol, self.atol, self.h_min, self.h_max, phi, &self.k[0], &mut evals);
73        }
74
75        // Validate and set initial step size h
76        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
77            Ok(h0) => self.h = h0,
78            Err(status) => return Err(status),
79        }
80        Ok(evals)
81    }
82
83    fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, V>>
84    where
85        F: DDE<L, T, V, D>,
86    {
87        let mut evals = Evals::new();
88
89        // Validate step size
90        if self.h.abs() < T::default_epsilon() {
91            self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
92            return Err(Error::StepSize { t: self.t, y: self.y });
93        }
94
95        // Check maximum number of steps
96        if self.steps >= self.max_steps {
97            self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
98            return Err(Error::MaxSteps { t: self.t, y: self.y });
99        }
100        self.steps += 1;
101
102        // Initialize variables for the step
103        let mut lags = [T::zero(); L];
104        let mut yd = [V::zeros(); L];
105
106        // DDE: Determine if iterative approach for lag handling is needed
107        let mut min_lag_abs = T::infinity();
108        if L > 0 {
109            // Predict y at t+h using Euler step to estimate lags at t+h
110            let y_pred_for_lags = self.y + self.k[0] * self.h;
111            dde.lags(self.t + self.h, &y_pred_for_lags, &mut lags);
112            for i in 0..L {
113                min_lag_abs = min_lag_abs.min(lags[i].abs());
114            }
115        }
116
117        // If lag values have to be extrapolated, we need to iterate for convergence
118        let max_iter: usize = if L > 0 && min_lag_abs < self.h.abs() && min_lag_abs > T::zero() {
119            5
120        } else {
121            1
122        };        let mut y_next_candidate_iter = self.y; // Approximated y at t+h, refined in DDE iterations
123        let mut y_prev_candidate_iter = self.y; // y_next_candidate_iter from previous DDE iteration
124        let mut dde_iteration_failed = false;
125        let mut err: T = T::zero(); // Error norm for step size control
126        let mut ysti = V::zeros(); // Store last stage for stiffness detection
127
128        // DDE iteration loop (for handling implicit lags or just one pass for explicit)
129        for iter_idx in 0..max_iter {
130            if iter_idx > 0 {
131                y_prev_candidate_iter = y_next_candidate_iter;
132            }
133
134            // Compute Runge-Kutta stages
135            let mut y_stage = V::zeros();
136            for i in 1..self.stages {
137                y_stage = V::zeros();
138                for j in 0..i {
139                    y_stage += self.k[j] * self.a[i][j];
140                }
141                y_stage = self.y + y_stage * self.h;
142
143                // Evaluate delayed states for the current stage
144                if L > 0 {
145                    dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
146                    self.lagvals(self.t + self.c[i] * self.h, &lags, &mut yd, phi);
147                }
148                dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
149            }
150            evals.fcn += self.stages - 1; // k[0] was already available
151
152            // Store the last stage for stiffness detection
153            ysti = y_stage;
154
155            // Calculate the line segment for the new y value
156            let mut yseg = V::zeros();
157            for i in 0..self.stages {
158                yseg += self.k[i] * self.b[i];
159            }
160
161            // Calculate the new y value using the line segment
162            let y_new = self.y + yseg * self.h;
163
164            // Dormand-Prince error estimation
165            let er = self.er.unwrap();
166            let n = self.y.len();
167            let mut err_val = T::zero();
168            let mut err2 = T::zero();
169            let mut erri;
170            for i in 0..n {
171                // Calculate the error scale
172                let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
173
174                // Primary error term
175                erri = T::zero();
176                for j in 0..self.stages {
177                    erri += er[j] * self.k[j].get(i);
178                }
179                err_val += (erri / sk).powi(2);
180
181                // Optional secondary error term
182                if let Some(bh) = &self.bh {
183                    erri = yseg.get(i);
184                    for j in 0..self.stages {
185                        erri -= bh[j] * self.k[j].get(i);
186                    }
187                    err2 += (erri / sk).powi(2);
188                }
189            }
190            let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
191            if deno <= T::zero() {
192                deno = T::one();
193            }
194            err = self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
195
196            // DDE iteration convergence check (if max_iter > 1)
197            if max_iter > 1 && iter_idx > 0 {
198                let mut dde_iteration_error = T::zero();
199                let n_dim = self.y.len();
200                for i_dim in 0..n_dim {
201                    let scale = self.atol + self.rtol * y_prev_candidate_iter.get(i_dim).abs().max(y_new.get(i_dim).abs());
202                    if scale > T::zero() {
203                        let diff_val = y_new.get(i_dim) - y_prev_candidate_iter.get(i_dim);
204                        dde_iteration_error += (diff_val / scale).powi(2);
205                    }
206                }
207                if n_dim > 0 {
208                    dde_iteration_error = (dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
209                }
210
211                if dde_iteration_error <= self.rtol * T::from_f64(0.1).unwrap() {
212                    break; // DDE iteration converged
213                }
214                if iter_idx == max_iter - 1 { // Last iteration
215                    dde_iteration_failed = dde_iteration_error > self.rtol * T::from_f64(0.1).unwrap();
216                }
217            }
218            y_next_candidate_iter = y_new; // Update candidate solution for t+h
219
220            // Store ysti for potential stiffness detection
221            if iter_idx == max_iter - 1 || max_iter == 1 {
222                // Keep ysti from the final iteration for stiffness detection
223            }
224        } // End of DDE iteration loop
225
226        // Handle DDE iteration failure: reduce step size and retry
227        if dde_iteration_failed {
228            let sign = self.h.signum();
229            self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
230            // Ensure step size is not smaller than a fraction of the minimum lag, if applicable
231            if L > 0 && min_lag_abs > T::zero() && self.h.abs() < T::from_f64(2.0).unwrap() * min_lag_abs {
232                self.h = min_lag_abs * sign;
233            }
234            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
235            self.status = Status::RejectedStep;
236            return Ok(evals); // Return to retry step with smaller h
237        }
238
239        // Step acceptance/rejection logic
240        if err <= T::one() { // Step accepted
241            let y_new = y_next_candidate_iter;
242            let t_new = self.t + self.h;
243
244            // Calculate the new derivative at the new point
245            if L > 0 {
246                dde.lags(t_new, &y_new, &mut lags);
247                self.lagvals(t_new, &lags, &mut yd, phi);
248            }
249            dde.diff(t_new, &y_new, &yd, &mut self.dydt);
250            evals.fcn += 1;            // Stiffness detection (every 100 steps)
251            let n_stiff_threshold = 100;
252            if self.steps % n_stiff_threshold == 0 {
253                let mut stdnum = T::zero();
254                let mut stden = T::zero();
255                let sqr = {
256                    let mut yseg = V::zeros();
257                    for i in 0..self.stages {
258                        yseg += self.k[i] * self.b[i];
259                    }
260                    yseg - self.k[S-1]
261                };
262                for i in 0..sqr.len() {
263                    stdnum += sqr.get(i).powi(2);
264                }
265                let sqr = self.dydt - ysti;
266                for i in 0..sqr.len() {
267                    stden += sqr.get(i).powi(2);
268                }
269
270                if stden > T::zero() {
271                    let h_lamb = self.h * (stdnum / stden).sqrt();
272                    if h_lamb > T::from_f64(6.1).unwrap() {
273                        self.non_stiffness_counter = 0;
274                        self.stiffness_counter += 1;
275                        if self.stiffness_counter == 15 {
276                            self.status = Status::Error(Error::Stiffness {
277                                t: self.t,
278                                y: self.y,
279                            });
280                            return Err(Error::Stiffness {
281                                t: self.t,
282                                y: self.y,
283                            });
284                        }
285                    }
286                } else {
287                    self.non_stiffness_counter += 1;
288                    if self.non_stiffness_counter == 6 {
289                        self.stiffness_counter = 0;
290                    }
291                }
292            }
293
294            // Preparation for dense output / interpolation
295            self.cont[0] = self.y;
296            let ydiff = y_new - self.y;
297            self.cont[1] = ydiff;
298            let bspl = self.k[0] * self.h - ydiff;
299            self.cont[2] = bspl;
300            self.cont[3] = ydiff - self.dydt * self.h - bspl;
301
302            // If method has dense output stages, compute them
303            if let Some(bi) = &self.bi {
304                // Compute extra stages for dense output
305                if I > S {
306                    // First dense output coefficient, k{i=order+1}, is the derivative at the new point
307                    self.k[self.stages] = self.dydt;                    for i in S+1..I {
308                        let mut y_stage = V::zeros();
309                        for j in 0..i {
310                            y_stage += self.k[j] * self.a[i][j];
311                        }
312                        y_stage = self.y + y_stage * self.h;
313
314                        if L > 0 {
315                            dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut lags);
316                            // Manually inline the lagvals logic to avoid borrowing conflicts
317                            for lag_idx in 0..L {
318                                let t_delayed = (self.t + self.c[i] * self.h) - lags[lag_idx];
319                                
320                                // Check if delayed time falls within the history period (t_delayed <= t0)
321                                if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
322                                    yd[lag_idx] = phi(t_delayed);
323                                // If t_delayed is after t_prev then use interpolation function
324                                } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
325                                    if self.bi.is_some() {
326                                        let s = (t_delayed - self.t_prev) / self.h_prev;
327                                        let s1 = T::one() - s;        
328                                        let ilast = self.cont.len() - 1;
329                                        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, cont_i| {            
330                                            let factor = if cont_i >= 4 {
331                                                if (ilast - cont_i) % 2 == 1 { s1 } else { s }
332                                            } else {
333                                                if cont_i % 2 == 1 { s1 } else { s }
334                                            };
335                                            acc * factor + self.cont[cont_i]
336                                        });
337                                        yd[lag_idx] = self.cont[0] + poly * s;
338                                    } else {
339                                        yd[lag_idx] = cubic_hermite_interpolate(
340                                            self.t_prev, 
341                                            self.t, 
342                                            &self.y_prev, 
343                                            &self.y, 
344                                            &self.dydt_prev, 
345                                            &self.dydt, 
346                                            t_delayed
347                                        );
348                                    }
349                                } else {
350                                    // Search through history to find appropriate interpolation points
351                                    let mut found_interpolation = false;
352                                    let buffer = &self.history;
353                                    let mut buffer_iter = buffer.iter();
354                                    if let Some(mut prev_entry) = buffer_iter.next() {
355                                        for curr_entry in buffer_iter {
356                                            let (t_left, y_left, dydt_left) = prev_entry;
357                                            let (t_right, y_right, dydt_right) = curr_entry;
358                                            
359                                            let is_between = if self.h.signum() > T::zero() {
360                                                *t_left <= t_delayed && t_delayed <= *t_right
361                                            } else {
362                                                *t_right <= t_delayed && t_delayed <= *t_left
363                                            };
364                                            
365                                            if is_between {
366                                                yd[lag_idx] = cubic_hermite_interpolate(
367                                                    *t_left,
368                                                    *t_right,
369                                                    y_left,
370                                                    y_right,
371                                                    dydt_left,
372                                                    dydt_right,
373                                                    t_delayed
374                                                );
375                                                found_interpolation = true;
376                                                break;
377                                            }
378                                            prev_entry = curr_entry;
379                                        }
380                                    }
381                                    if !found_interpolation {
382                                        panic!("Insufficient history for t_delayed = {} (t_prev = {}, t = {})", t_delayed, self.t_prev, self.t);
383                                    }
384                                }
385                            }
386                        }
387                        dde.diff(self.t + self.c[i] * self.h, &y_stage, &yd, &mut self.k[i]);
388                        evals.fcn += 1;
389                    }
390                }
391
392                // Compute dense output coefficients
393                for i in 4..self.order {
394                    self.cont[i] = V::zeros();
395                    for j in 0..self.dense_stages {
396                        self.cont[i] += self.k[j] * bi[i][j];
397                    }
398                    self.cont[i] = self.cont[i] * self.h;
399                }
400            }
401
402            // For interpolation
403            self.t_prev = self.t;
404            self.y_prev = self.y;
405            self.dydt_prev = self.k[0];
406            self.h_prev = self.h;
407
408            // Update state to t + h
409            self.t = t_new;
410            self.y = y_new;
411            self.k[0] = self.dydt;
412
413            // Update continuous output buffer and remove old entries if max_delay is set
414            self.history.push_back((self.t, self.y, self.dydt));
415            if let Some(max_delay) = self.max_delay {
416                let cutoff_time = self.t - max_delay;
417                while let Some((t_front, _, _)) = self.history.get(1){
418                    if *t_front < cutoff_time {
419                        self.history.pop_front();
420                    } else {
421                        break;
422                    }
423                }
424            }            // Check if previous step is rejected
425            if let Status::RejectedStep = self.status {
426                self.status = Status::Solving;
427            }
428        } else {
429            // Step Rejected
430            self.status = Status::RejectedStep;
431        }
432
433        // Calculate new step size for adaptive methods
434        let order = T::from_usize(self.order).unwrap();
435        let err_order = T::one() / order;
436
437        // Step size controller
438        let scale = self.safety_factor * err.powf(-err_order);
439        let scale = scale.max(self.min_scale).min(self.max_scale);
440        self.h *= scale;
441
442        // Ensure step size is within bounds
443        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
444        
445        Ok(evals)
446    }
447
448    fn t(&self) -> T { self.t }
449    fn y(&self) -> &V { &self.y }
450    fn t_prev(&self) -> T { self.t_prev }
451    fn y_prev(&self) -> &V { &self.y_prev }
452    fn h(&self) -> T { self.h }
453    fn set_h(&mut self, h: T) { self.h = h; }
454    fn status(&self) -> &Status<T, V, D> { &self.status }
455    fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
456}
457
458impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {    
459    fn lagvals<const L: usize, H>(&mut self, t_stage: T, lags: &[T; L], yd: &mut [V; L], phi: &H) 
460    where 
461        H: Fn(T) -> V,
462    {
463        for i in 0..L {
464            let t_delayed = t_stage - lags[i];
465            
466            // Check if delayed time falls within the history period (t_delayed <= t0)
467            if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
468                yd[i] = phi(t_delayed);
469            // If t_delayed is after t_prev then use interpolation function
470            } else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
471                if self.bi.is_some() {
472                    let s = (t_delayed - self.t_prev) / self.h_prev;
473                    
474                    // Evaluate the interpolation polynomial at the requested time
475                    let s1 = T::one() - s;        
476                    
477                    // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
478                    let ilast = self.cont.len() - 1;
479                    let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {            
480                        let factor = if i >= 4 {
481                            // For the higher-order part (conpar), alternate s and s1 based on index parity
482                            if (ilast - i) % 2 == 1 { s1 } else { s }
483                        } else {
484                            // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
485                            if i % 2 == 1 { s1 } else { s }
486                        };
487                        acc * factor + self.cont[i]
488                    });
489                    
490                    // Final multiplication by s for the outermost level
491                    let y_interp = self.cont[0] + poly * s;
492                    yd[i] = y_interp;
493                } else {
494                    yd[i] = cubic_hermite_interpolate(
495                        self.t_prev, 
496                        self.t, 
497                        &self.y_prev, 
498                        &self.y, 
499                        &self.dydt_prev, 
500                        &self.dydt, 
501                        t_delayed
502                    );
503                }
504            // If t_delayed is before t_prev and after t0, we need to search in the history
505            } else {
506                // Search through history to find appropriate interpolation points
507                let mut found_interpolation = false;
508                let buffer = &self.history;
509                // Find two consecutive points that sandwich t_delayed using iterators
510                let mut buffer_iter = buffer.iter();
511                if let Some(mut prev_entry) = buffer_iter.next() {
512                    for curr_entry in buffer_iter {
513                        let (t_left, y_left, dydt_left) = prev_entry;
514                        let (t_right, y_right, dydt_right) = curr_entry;
515                        
516                        // Check if t_delayed is between these two points
517                        let is_between = if self.h.signum() > T::zero() {
518                            *t_left <= t_delayed && t_delayed <= *t_right
519                        } else {
520                            *t_right <= t_delayed && t_delayed <= *t_left
521                        };
522                        
523                        if is_between {
524                            yd[i] = cubic_hermite_interpolate(
525                                *t_left,
526                                *t_right,
527                                y_left,
528                                y_right,
529                                dydt_left,
530                                dydt_right,
531                                t_delayed
532                            );
533                            found_interpolation = true;
534                            break;
535                        }
536                        prev_entry = curr_entry;
537                    }
538                }
539                // If not found in history, this indicates insufficient history in buffer
540                if !found_interpolation {
541                    // Debug: show buffer contents
542                    let buffer = &self.history;
543                    println!("Buffer contents ({} entries):", buffer.len());
544                    for (idx, (t_buf, _, _)) in buffer.iter().enumerate() {
545                        println!("  [{}]: t = {}", idx, t_buf);
546                    }
547                    panic!("Insufficient history in history for t_delayed = {} (t_prev = {}, t = {}). Buffer may need to retain more points or there's a logic error in determining interpolation intervals.", t_delayed, self.t_prev, self.t);
548                }
549            }
550        }
551    }
552}
553
554impl<T: Real, V: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize> Interpolation<T, V> for ExplicitRungeKutta<Delay, DormandPrince, T, V, D, O, S, I> {
555    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {        
556        // Check if interpolation is out of bounds
557        let posneg = (self.t - self.t_prev).signum();
558        if (t_interp - self.t_prev) * posneg < T::zero() || (t_interp - self.t) * posneg > T::zero() {
559            return Err(Error::OutOfBounds {
560                t_interp,
561                t_prev: self.t_prev,
562                t_curr: self.t,
563            });
564        }        
565        
566        // Evaluate the interpolation polynomial at the requested time
567        let s = (t_interp - self.t_prev) / self.h_prev;
568        let s1 = T::one() - s;        
569        
570        // Functional implementation of: cont[0] + (cont[1] + (cont[2] + (cont[3] + conpar*s1)*s)*s1)*s
571        let ilast = self.cont.len() - 1;
572        let poly = (1..ilast).rev().fold(self.cont[ilast], |acc, i| {            
573            let factor = if i >= 4 {
574                // For the higher-order part (conpar), alternate s and s1 based on index parity
575                if (ilast - i) % 2 == 1 { s1 } else { s }
576            } else {
577                // For the main polynomial part, pattern is [s1, s, s1] for indices [3, 2, 1]
578                if i % 2 == 1 { s1 } else { s }
579            };
580            acc * factor + self.cont[i]
581        });
582        
583        // Final multiplication by s for the outermost level
584        let y_interp = self.cont[0] + poly * s;
585
586        Ok(y_interp)
587    }
588}