differential_equations/ode/methods/runge_kutta/explicit/
dopri5.rs

1//! DOPRI5 ODENumericalMethod for Ordinary Differential Equations.
2
3use crate::{
4    Error, Status,
5    alias::Evals,
6    interpolate::Interpolation,
7    ode::{ODENumericalMethod, ODE, methods::h_init},
8    traits::{CallBackData, Real, State},
9    utils::{constrain_step_size, validate_step_size_parameters},
10};
11
12/// Dormand Prince 5(4) Method for solving ordinary differential equations.
13/// 5th order Dormand Prince method with embedded 4th order error estimation and
14/// dense output interpolation.
15///
16/// # Example
17/// ```
18/// use differential_equations::prelude::*;
19/// use nalgebra::{SVector, vector};
20///
21/// let mut dopri5 = DOPRI5::new()
22///    .rtol(1e-6)
23///    .atol(1e-6);
24///
25/// let t0 = 0.0;
26/// let tf = 10.0;
27/// let y0 = vector![1.0, 0.0];
28/// struct Example;
29/// impl ODE<f64, SVector<f64, 2>> for Example {
30///    fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
31///       dydt[0] = y[1];
32///       dydt[1] = -y[0];
33///   }
34/// }
35/// let solution = ODEProblem::new(Example, t0, tf, y0).solve(&mut dopri5).unwrap();
36///
37/// let (t, y) = solution.last().unwrap();
38/// println!("Solution: ({}, {})", t, y);
39/// ```
40///
41/// # Settings
42/// * `rtol`   - Relative tolerance for the solver.
43/// * `atol`   - Absolute tolerance for the solver.
44/// * `h0`     - Initial step size.
45/// * `h_max`   - Maximum step size for the solver.
46/// * `max_steps` - Maximum number of steps for the solver.
47/// * `n_stiff` - Number of steps to check for stiffness.
48/// * `safe`   - Safety factor for step size prediction.
49/// * `fac1`   - Parameter for step size selection.
50/// * `fac2`   - Parameter for step size selection.
51/// * `beta`   - Beta for stabilized step size control.
52///
53/// # Default Settings
54/// * `rtol`   - 1e-3
55/// * `atol`   - 1e-6
56/// * `h0`     - None (Calculated by solver if None)
57/// * `h_max`   - None (Calculated by tf - t0 if None)
58/// * `h_min`   - 0.0
59/// * `max_steps` - 100_000
60/// * `n_stiff` - 1000
61/// * `safe`   - 0.9
62/// * `fac1`   - 0.2
63/// * `fac2`   - 10.0
64/// * `beta`   - 0.04
65///
66pub struct DOPRI5<T: Real, V: State<T>, D: CallBackData> {
67    // Initial Conditions
68    pub h0: T, // Initial Step Size
69
70    // Current iteration
71    t: T,
72    y: V,
73    h: T,
74
75    // Tolerances
76    pub rtol: T,
77    pub atol: T,
78
79    // Settings
80    pub h_max: T,
81    pub h_min: T,
82    pub max_steps: usize,
83    pub n_stiff: usize,
84
85    // DOPRI5 Specific Settings
86    pub safe: T,
87    pub fac1: T,
88    pub fac2: T,
89    pub beta: T,
90
91    // Derived Settings
92    expo1: T,
93    facc1: T,
94    facc2: T,
95    facold: T,
96    fac11: T,
97    fac: T,
98
99    // Iteration Tracking
100    status: Status<T, V, D>,
101    steps: usize,      // Number of Steps
102    n_accepted: usize, // Number of Accepted Steps
103
104    // Stiffness Detection
105    h_lamb: T,
106    non_stiff_counter: usize,
107    stiffness_counter: usize,
108
109    // Butcher tableau coefficients (converted to type T)
110    a: [[T; 7]; 7],
111    b: [T; 7],
112    c: [T; 7],
113    er: [T; 7],
114
115    // Dense output coefficients
116    d: [T; 7],
117
118    // Derivatives - using array instead of individually numbered variables
119    k: [V; 7], // k[0] is derivative at t, others are stage derivatives
120
121    // For Interpolation - using array instead of individually numbered variables
122    y_old: V,     // State at Previous Step
123    t_old: T,     // Time of Previous Step
124    h_old: T,     // Step Size of Previous Step
125    cont: [V; 5], // Interpolation coefficients
126}
127
128impl<T: Real, V: State<T>, D: CallBackData> ODENumericalMethod<T, V, D> for DOPRI5<T, V, D> {
129    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
130    where
131        F: ODE<T, V, D>,
132    {
133        let mut evals = Evals::new();
134
135        // Set Current State as Initial State
136        self.t = t0;
137        self.y = *y0;
138
139        // Calculate derivative at t0
140        ode.diff(t0, y0, &mut self.k[0]);
141        evals.fcn += 1; // Increment function evaluations for initial derivative calculation
142
143        // Initialize Previous State
144        self.t_old = self.t;
145        self.y_old = self.y;
146
147        // Calculate Initial Step
148        if self.h0 == T::zero() {
149            self.h0 = h_init(
150                ode, t0, tf, y0, 5, self.rtol, self.atol, self.h_min, self.h_max,
151            );
152            evals.fcn += 1; // Increment function evaluations for initial step size calculation
153
154            // Adjust h0 to be within bounds
155            let posneg = (tf - t0).signum();
156            if self.h0.abs() < self.h_min.abs() {
157                self.h0 = self.h_min.abs() * posneg;
158            } else if self.h0.abs() > self.h_max.abs() {
159                self.h0 = self.h_max.abs() * posneg;
160            }
161        }
162
163        // Check if h0 is within bounds, and h_min and h_max are valid
164        match validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf) {
165            Ok(h0) => self.h = h0,
166            Err(status) => return Err(status),
167        }
168
169        // Make sure iteration variables are reset
170        self.h_lamb = T::zero();
171        self.non_stiff_counter = 0;
172        self.stiffness_counter = 0;
173
174        // ODENumericalMethod is ready to go
175        self.status = Status::Initialized;
176
177        Ok(evals)
178    }
179
180    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
181    where
182        F: ODE<T, V, D>,
183    {
184        let mut evals = Evals::new();
185
186        // Check if Max Steps Reached
187        if self.steps >= self.max_steps {
188            self.status = Status::Error(Error::MaxSteps {
189                t: self.t,
190                y: self.y,
191            });
192            return Err(Error::MaxSteps {
193                t: self.t,
194                y: self.y,
195            });
196        }
197
198        // Check if Step Size is too smaller then machine default_epsilon
199        if self.h.abs() < T::default_epsilon() {
200            self.status = Status::Error(Error::StepSize {
201                t: self.t,
202                y: self.y,
203            });
204            return Err(Error::StepSize {
205                t: self.t,
206                y: self.y,
207            });
208        }
209
210        // The six stages
211        ode.diff(
212            self.t + self.c[1] * self.h,
213            &(self.y + self.k[0] * (self.a[1][0] * self.h)),
214            &mut self.k[1],
215        );
216        ode.diff(
217            self.t + self.c[2] * self.h,
218            &(self.y + self.k[0] * (self.a[2][0] * self.h) + self.k[1] * (self.a[2][1] * self.h)),
219            &mut self.k[2],
220        );
221        ode.diff(
222            self.t + self.c[3] * self.h,
223            &(self.y
224                + self.k[0] * (self.a[3][0] * self.h)
225                + self.k[1] * (self.a[3][1] * self.h)
226                + self.k[2] * (self.a[3][2] * self.h)),
227            &mut self.k[3],
228        );
229        ode.diff(
230            self.t + self.c[4] * self.h,
231            &(self.y
232                + self.k[0] * (self.a[4][0] * self.h)
233                + self.k[1] * (self.a[4][1] * self.h)
234                + self.k[2] * (self.a[4][2] * self.h)
235                + self.k[3] * (self.a[4][3] * self.h)),
236            &mut self.k[4],
237        );
238        ode.diff(
239            self.t + self.c[5] * self.h,
240            &(self.y
241                + self.k[0] * (self.a[5][0] * self.h)
242                + self.k[1] * (self.a[5][1] * self.h)
243                + self.k[2] * (self.a[5][2] * self.h)
244                + self.k[3] * (self.a[5][3] * self.h)
245                + self.k[4] * (self.a[5][4] * self.h)),
246            &mut self.k[5],
247        );
248
249        let ysti = self.y
250            + self.k[0] * (self.a[6][0] * self.h)
251            + self.k[2] * (self.a[6][2] * self.h)
252            + self.k[3] * (self.a[6][3] * self.h)
253            + self.k[4] * (self.a[6][4] * self.h)
254            + self.k[5] * (self.a[6][5] * self.h);
255
256        let t_new = self.t + self.h;
257        ode.diff(t_new, &ysti, &mut self.k[6]);
258
259        let y_new = self.y
260            + self.k[0] * (self.b[0] * self.h)
261            + self.k[2] * (self.b[2] * self.h)
262            + self.k[3] * (self.b[3] * self.h)
263            + self.k[4] * (self.b[4] * self.h)
264            + self.k[5] * (self.b[5] * self.h)
265            + self.k[6] * (self.b[6] * self.h);
266
267        ode.diff(t_new, &y_new, &mut self.k[1]);
268
269        evals.fcn += 7; // Increment function evaluations for all derivatives
270
271        // Calculate error using embedded method
272        let mut err = T::zero();
273
274        let n = self.y.len();
275        for i in 0..n {
276            let sk = self.atol + self.rtol * self.y.get(i).abs().max(y_new.get(i).abs());
277            let erri = self.h
278                * (self.er[0] * self.k[0].get(i)
279                    + self.er[2] * self.k[2].get(i)
280                    + self.er[3] * self.k[3].get(i)
281                    + self.er[4] * self.k[4].get(i)
282                    + self.er[5] * self.k[5].get(i)
283                    + self.er[6] * self.k[6].get(i));
284            err += (erri / sk).powi(2);
285        }
286        err = (err / T::from_usize(n).unwrap()).sqrt();
287
288        // Computation of h_new
289        self.fac11 = err.powf(self.expo1);
290        // Lund-stabilization
291        self.fac = self.fac11 / self.facold.powf(self.beta);
292        // Requirement that fac1 <= h_new/h <= fac2
293        self.fac = self.facc2.max(self.facc1.min(self.fac / self.safe));
294        let mut h_new = self.h / self.fac;
295
296        if err <= T::one() {
297            // Step Accepted
298            self.facold = err.max(T::from_f64(1.0e-4).unwrap());
299            self.n_accepted += 1;
300
301            // stiffness detection
302            if self.n_accepted % self.n_stiff == 0 || self.stiffness_counter > 0 {
303                let mut stnum = T::zero();
304                let mut stden = T::zero();
305
306                for i in 0..n {
307                    let stnum_i = self.k[1].get(i) - self.k[6].get(i);
308                    stnum += stnum_i * stnum_i;
309
310                    let stden_i = y_new.get(i) - ysti.get(i);
311                    stden += stden_i * stden_i;
312                }
313
314                if stden > T::zero() {
315                    self.h_lamb = self.h * (stnum / stden).sqrt();
316                }
317
318                if self.h_lamb > T::from_f64(3.25).unwrap() {
319                    self.non_stiff_counter = 0;
320                    self.stiffness_counter += 1;
321                    if self.stiffness_counter == 15 {
322                        // Early Exit Stiffness Detected
323                        self.status = Status::Error(Error::Stiffness {
324                            t: self.t,
325                            y: self.y,
326                        });
327                        return Err(Error::Stiffness {
328                            t: self.t,
329                            y: self.y,
330                        });
331                    }
332                } else {
333                    self.non_stiff_counter += 1;
334                    if self.non_stiff_counter == 6 {
335                        self.stiffness_counter = 0;
336                    }
337                }
338            }
339
340            // Prepare for dense output / interpolation
341            // Store data for interpolation
342            let ydiff = y_new - self.y;
343            let bspl = self.k[0] * self.h - ydiff;
344
345            self.cont[0] = self.y;
346            self.cont[1] = ydiff;
347            self.cont[2] = bspl;
348            self.cont[3] = ydiff - self.k[1] * self.h - bspl;
349
350            // Compute the dense output coefficient
351            self.cont[4] = (self.k[0] * self.d[0]
352                + self.k[2] * self.d[2]
353                + self.k[3] * self.d[3]
354                + self.k[4] * self.d[4]
355                + self.k[5] * self.d[5]
356                + self.k[6] * self.d[6])
357                * self.h;
358
359            // For Interpolation
360            self.y_old = self.y;
361            self.t_old = self.t;
362            self.h_old = self.h;
363
364            // Update State
365            self.k[0] = self.k[1];
366            self.y = y_new;
367            self.t = t_new;
368
369            // Check if previous step rejected
370            if let Status::RejectedStep = self.status {
371                h_new = self.h.min(h_new);
372                self.status = Status::Solving;
373            }
374        } else {
375            // Step Rejected
376            h_new = self.h / self.facc1.min(self.fac11 / self.safe);
377            self.status = Status::RejectedStep;
378        }
379
380        // Step Complete
381        self.h = constrain_step_size(h_new, self.h_min, self.h_max);
382        Ok(evals)
383    }
384
385    fn t(&self) -> T {
386        self.t
387    }
388
389    fn y(&self) -> &V {
390        &self.y
391    }
392
393    fn t_prev(&self) -> T {
394        self.t_old
395    }
396
397    fn y_prev(&self) -> &V {
398        &self.y_old
399    }
400
401    fn h(&self) -> T {
402        self.h
403    }
404
405    fn set_h(&mut self, h: T) {
406        self.h = h;
407    }
408
409    fn status(&self) -> &Status<T, V, D> {
410        &self.status
411    }
412
413    fn set_status(&mut self, status: Status<T, V, D>) {
414        self.status = status;
415    }
416}
417
418impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for DOPRI5<T, V, D> {
419    fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
420        // Check if interpolation is out of bounds
421        if t_interp < self.t_old || t_interp > self.t {
422            return Err(Error::OutOfBounds {
423                t_interp,
424                t_prev: self.t_old,
425                t_curr: self.t,
426            });
427        }
428
429        // Evaluate the interpolation polynomial at the requested time
430        let s = (t_interp - self.t_old) / self.h_old;
431        let s1 = T::one() - s;
432
433        // Use the provided dense output formula
434        let y_interp = self.cont[0]
435            + (self.cont[1] + (self.cont[2] + (self.cont[3] + self.cont[4] * s1) * s) * s1) * s;
436
437        Ok(y_interp)
438    }
439}
440
441impl<T: Real, V: State<T>, D: CallBackData> DOPRI5<T, V, D> {
442    /// Creates a new DOPRI5 ODENumericalMethod.
443    ///
444    /// # Returns
445    /// * DOPRI5 Struct ready to go for solving.
446    ///  
447    pub fn new() -> Self {
448        DOPRI5 {
449            ..Default::default()
450        }
451    }
452
453    // Builder Functions
454    pub fn rtol(mut self, rtol: T) -> Self {
455        self.rtol = rtol;
456        self
457    }
458
459    pub fn atol(mut self, atol: T) -> Self {
460        self.atol = atol;
461        self
462    }
463
464    pub fn h0(mut self, h0: T) -> Self {
465        self.h0 = h0;
466        self
467    }
468
469    pub fn h_max(mut self, h_max: T) -> Self {
470        self.h_max = h_max;
471        self
472    }
473
474    pub fn h_min(mut self, h_min: T) -> Self {
475        self.h_min = h_min;
476        self
477    }
478
479    pub fn max_steps(mut self, max_steps: usize) -> Self {
480        self.max_steps = max_steps;
481        self
482    }
483
484    pub fn n_stiff(mut self, n_stiff: usize) -> Self {
485        self.n_stiff = n_stiff;
486        self
487    }
488
489    pub fn safe(mut self, safe: T) -> Self {
490        self.safe = safe;
491        self
492    }
493
494    pub fn beta(mut self, beta: T) -> Self {
495        self.beta = beta;
496        self
497    }
498
499    pub fn fac1(mut self, fac1: T) -> Self {
500        self.fac1 = fac1;
501        self
502    }
503
504    pub fn fac2(mut self, fac2: T) -> Self {
505        self.fac2 = fac2;
506        self
507    }
508
509    pub fn expo1(mut self, expo1: T) -> Self {
510        self.expo1 = expo1;
511        self
512    }
513
514    pub fn facc1(mut self, facc1: T) -> Self {
515        self.facc1 = facc1;
516        self
517    }
518
519    pub fn facc2(mut self, facc2: T) -> Self {
520        self.facc2 = facc2;
521        self
522    }
523}
524
525impl<T: Real, V: State<T>, D: CallBackData> Default for DOPRI5<T, V, D> {
526    fn default() -> Self {
527        // Convert coefficient arrays from f64 to type T
528        let a = DOPRI5_A.map(|row| row.map(|x| T::from_f64(x).unwrap()));
529        let b = DOPRI5_B.map(|x| T::from_f64(x).unwrap());
530        let c = DOPRI5_C.map(|x| T::from_f64(x).unwrap());
531        let er = DOPRI5_E.map(|x| T::from_f64(x).unwrap());
532        let d = DOPRI5_D.map(|x| T::from_f64(x).unwrap());
533
534        // Create arrays of zeros for k and cont matrices
535        let k_zeros = [V::zeros(); 7];
536        let cont_zeros = [V::zeros(); 5];
537
538        DOPRI5 {
539            // State Variables
540            t: T::zero(),
541            y: V::zeros(),
542            h: T::zero(),
543
544            // Settings
545            h0: T::zero(),
546            rtol: T::from_f64(1e-3).unwrap(),
547            atol: T::from_f64(1e-6).unwrap(),
548            h_max: T::infinity(),
549            h_min: T::zero(),
550            max_steps: 100_000,
551            n_stiff: 1000,
552            safe: T::from_f64(0.9).unwrap(),
553            fac1: T::from_f64(0.2).unwrap(),
554            fac2: T::from_f64(10.0).unwrap(),
555            beta: T::from_f64(0.04).unwrap(),
556            expo1: T::from_f64(1.0 / 5.0).unwrap(),
557            facc1: T::from_f64(1.0 / 0.2).unwrap(),
558            facc2: T::from_f64(1.0 / 10.0).unwrap(),
559            facold: T::from_f64(1.0e-4).unwrap(),
560            fac11: T::zero(),
561            fac: T::zero(),
562
563            // Butcher Tableau Coefficients
564            a,
565            b,
566            c,
567            er,
568            d,
569
570            // Status and Counters
571            status: Status::Uninitialized,
572            h_lamb: T::zero(),
573            non_stiff_counter: 0,
574            stiffness_counter: 0,
575            steps: 0,
576            n_accepted: 0,
577
578            // Coefficents and temporary storage
579            k: k_zeros,
580            y_old: V::zeros(),
581            t_old: T::zero(),
582            h_old: T::zero(),
583            cont: cont_zeros,
584        }
585    }
586}
587
588// DOPRI5 Butcher Tableau
589
590// A matrix (7x7, lower triangular)
591const DOPRI5_A: [[f64; 7]; 7] = [
592    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
593    [0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
594    [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0, 0.0, 0.0],
595    [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0, 0.0, 0.0],
596    [
597        19372.0 / 6561.0,
598        -25360.0 / 2187.0,
599        64448.0 / 6561.0,
600        -212.0 / 729.0,
601        0.0,
602        0.0,
603        0.0,
604    ],
605    [
606        9017.0 / 3168.0,
607        -355.0 / 33.0,
608        46732.0 / 5247.0,
609        49.0 / 176.0,
610        -5103.0 / 18656.0,
611        0.0,
612        0.0,
613    ],
614    [
615        35.0 / 384.0,
616        0.0,
617        500.0 / 1113.0,
618        125.0 / 192.0,
619        -2187.0 / 6784.0,
620        11.0 / 84.0,
621        0.0,
622    ],
623];
624
625// C coefficients (nodes)
626const DOPRI5_C: [f64; 7] = [
627    0.0,       // C1
628    0.2,       // C2
629    0.3,       // C3
630    0.8,       // C4
631    8.0 / 9.0, // C5
632    1.0,       // C6
633    1.0,       // C7
634];
635
636// B coefficients (weights for main method)
637const DOPRI5_B: [f64; 7] = [
638    35.0 / 384.0,     // B1
639    0.0,              // B2
640    500.0 / 1113.0,   // B3
641    125.0 / 192.0,    // B4
642    -2187.0 / 6784.0, // B5
643    11.0 / 84.0,      // B6
644    0.0,              // B7
645];
646
647// Error estimation coefficients
648const DOPRI5_E: [f64; 7] = [
649    71.0 / 57600.0,      // E1
650    0.0,                 // E2
651    -71.0 / 16695.0,     // E3
652    71.0 / 1920.0,       // E4
653    -17253.0 / 339200.0, // E5
654    22.0 / 525.0,        // E6
655    -1.0 / 40.0,         // E7
656];
657
658// Dense output coefficients
659const DOPRI5_D: [f64; 7] = [
660    -12715105075.0 / 11282082432.0,  // D1
661    0.0,                             // D2
662    87487479700.0 / 32700410799.0,   // D3
663    -10690763975.0 / 1880347072.0,   // D4
664    701980252875.0 / 199316789632.0, // D5
665    -1453857185.0 / 822651844.0,     // D6
666    69997945.0 / 29380423.0,         // D7
667];