Skip to main content

numerics_ode/
dormand_prince.rs

1//! Dormand-Prince method (DOPRI5) — embedded RK4(5) with adaptive step-size control.
2//!
3//! Uses the 7-stage FSAL (First Same As Last) Dormand-Prince coefficients.
4//! Step size is adapted based on the local error estimate from the embedded 4th/5th order pair.
5
6/// Result of an adaptive ODE solve.
7#[derive(Debug, Clone)]
8pub struct AdaptiveResult {
9    /// x values at accepted steps.
10    pub xs: Vec<f64>,
11    /// y values at accepted steps.
12    pub ys: Vec<f64>,
13    /// Number of rejected steps.
14    pub rejected_steps: usize,
15    /// Number of accepted steps.
16    pub accepted_steps: usize,
17}
18
19/// Solve an IVP using the Dormand-Prince RK4(5) method with adaptive step-size control.
20///
21/// # Arguments
22///
23/// * `f`       — Right-hand side `dy/dx = f(x, y)`.
24/// * `x0`      — Initial x.
25/// * `y0`      — Initial y.
26/// * `x_end`   — Terminal x.
27/// * `h_init`  — Initial step size guess (absolute value used).
28/// * `tol`     — Local error tolerance.
29/// * `h_min`   — Minimum step size.
30/// * `h_max`   — Maximum step size.
31/// * `max_steps` — Safety limit on total step attempts.
32///
33/// # Returns
34///
35/// `AdaptiveResult` containing solution points and statistics.
36#[allow(clippy::too_many_arguments)]
37pub fn solve_adaptive(
38    f: &dyn Fn(f64, f64) -> f64,
39    x0: f64,
40    y0: f64,
41    x_end: f64,
42    h_init: f64,
43    tol: f64,
44    h_min: f64,
45    h_max: f64,
46    max_steps: usize,
47) -> AdaptiveResult {
48    // Dormand-Prince Butcher tableau coefficients
49    const A: [[f64; 5]; 5] = [
50        [1.0 / 5.0, 0.0, 0.0, 0.0, 0.0],
51        [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0],
52        [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0],
53        [19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0, 0.0],
54        [9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0],
55    ];
56    const B5: [f64; 6] = [35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0];
57    const E: [f64; 7] = [
58        71.0 / 57600.0,
59        0.0,
60        -71.0 / 16695.0,
61        71.0 / 1920.0,
62        -17253.0 / 339200.0,
63        22.0 / 525.0,
64        -1.0 / 40.0,
65    ];
66
67    let mut xs = vec![x0];
68    let mut ys = vec![y0];
69    let mut x = x0;
70    let mut y = y0;
71    // signed step: positive for forward, negative for backward
72    let sign = if x_end >= x0 { 1.0 } else { -1.0 };
73    let mut h = sign * h_init.abs();
74    let h_min_s = sign * h_min.abs();
75    let h_max_s = sign * h_max.abs();
76    let mut accepted = 0usize;
77    let mut rejected = 0usize;
78    let mut total = 0usize;
79
80    loop {
81        if total >= max_steps {
82            break;
83        }
84        // Don't overshoot
85        let remaining = x_end - x;
86        if sign * (remaining - h) < 0.0 {
87            h = remaining;
88        }
89        if h.abs() < h_min_s.abs() {
90            h = h_min_s;
91        }
92
93        let k1 = f(x, y);
94        let k2 = f(x + h * A[0][0], y + h * A[0][0] * k1);
95        let k3 = f(
96            x + h * (A[1][0] + A[1][1]),
97            y + h * (A[1][0] * k1 + A[1][1] * k2),
98        );
99        let k4 = f(
100            x + h * (A[2][0] + A[2][1] + A[2][2]),
101            y + h * (A[2][0] * k1 + A[2][1] * k2 + A[2][2] * k3),
102        );
103        let k5 = f(
104            x + h * (A[3][0] + A[3][1] + A[3][2] + A[3][3]),
105            y + h * (A[3][0] * k1 + A[3][1] * k2 + A[3][2] * k3 + A[3][3] * k4),
106        );
107        let k6 = f(
108            x + h * (A[4][0] + A[4][1] + A[4][2] + A[4][3] + A[4][4]),
109            y + h * (A[4][0] * k1 + A[4][1] * k2 + A[4][2] * k3 + A[4][3] * k4 + A[4][4] * k5),
110        );
111
112        let y_new = y + h * (B5[0] * k1 + B5[1] * k2 + B5[2] * k3 + B5[3] * k4 + B5[4] * k5 + B5[5] * k6);
113        let k7 = f(x + h, y_new);
114        let err = h * (E[0] * k1 + E[1] * k2 + E[2] * k3 + E[3] * k4 + E[4] * k5 + E[5] * k6 + E[6] * k7);
115        let err_norm = err.abs();
116        total += 1;
117
118        if err_norm <= tol || h.abs() <= h_min_s.abs() {
119            x += h;
120            y = y_new;
121            xs.push(x);
122            ys.push(y);
123            accepted += 1;
124
125            if err_norm > 0.0 {
126                let scale = 0.9 * (tol / err_norm).powf(0.2);
127                h *= scale.min(5.0);
128            } else {
129                h *= 2.0;
130            }
131            if h.abs() > h_max_s.abs() {
132                h = h_max_s;
133            }
134
135            if (x_end - x).abs() < 1e-14 * x_end.abs().max(1.0) {
136                break;
137            }
138        } else {
139            rejected += 1;
140            let scale = 0.9 * (tol / err_norm).powf(0.2);
141            h *= scale.max(0.1);
142            if h.abs() < h_min_s.abs() {
143                h = h_min_s;
144            }
145        }
146    }
147
148    AdaptiveResult {
149        xs,
150        ys,
151        rejected_steps: rejected,
152        accepted_steps: accepted,
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn exponential_growth() {
162        let f = |_x: f64, y: f64| y;
163        let res = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-8, 1e-10, 0.5, 10_000);
164        let exact = 1.0_f64.exp();
165        let last = *res.ys.last().unwrap();
166        assert!((last - exact).abs() < 1e-6, "got {last}, exact {exact}");
167    }
168
169    #[test]
170    fn constant_rhs() {
171        let f = |_x: f64, _y: f64| 0.0;
172        let res = solve_adaptive(&f, 0.0, 42.0, 1.0, 0.1, 1e-6, 1e-10, 0.5, 1000);
173        let last = *res.ys.last().unwrap();
174        assert!((last - 42.0).abs() < 1e-12);
175    }
176
177    #[test]
178    fn linear_rhs() {
179        let f = |_x: f64, _y: f64| 5.0;
180        let res = solve_adaptive(&f, 0.0, 0.0, 2.0, 0.1, 1e-8, 1e-10, 0.5, 1000);
181        let last = *res.ys.last().unwrap();
182        assert!((last - 10.0).abs() < 1e-8);
183    }
184
185    #[test]
186    fn sinusoidal_rhs() {
187        let f = |x: f64, _y: f64| x.cos();
188        let res = solve_adaptive(&f, 0.0, 0.0, std::f64::consts::PI / 2.0, 0.1, 1e-10, 1e-12, 0.2, 10_000);
189        let last = *res.ys.last().unwrap();
190        assert!((last - 1.0).abs() < 1e-8);
191    }
192
193    #[test]
194    fn exponential_decay() {
195        let f = |_x: f64, y: f64| -y;
196        let res = solve_adaptive(&f, 0.0, 1.0, 2.0, 0.1, 1e-8, 1e-10, 0.5, 10_000);
197        let last = *res.ys.last().unwrap();
198        assert!((last - (-2.0_f64).exp()).abs() < 1e-6);
199    }
200
201    #[test]
202    fn adaptive_reduces_steps_for_easy_problems() {
203        let f = |_x: f64, _y: f64| 0.0;
204        let res = solve_adaptive(&f, 0.0, 1.0, 10.0, 0.1, 1e-6, 1e-10, 1.0, 10_000);
205        assert!(res.accepted_steps < 20, "took {} steps for trivial problem", res.accepted_steps);
206    }
207
208    #[test]
209    fn quadratic_rhs() {
210        let f = |x: f64, _y: f64| 2.0 * x;
211        let res = solve_adaptive(&f, 0.0, 0.0, 3.0, 0.1, 1e-10, 1e-12, 0.5, 10_000);
212        let last = *res.ys.last().unwrap();
213        assert!((last - 9.0).abs() < 1e-8);
214    }
215
216    #[test]
217    fn result_contains_initial_point() {
218        let f = |_x: f64, y: f64| y;
219        let res = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-6, 1e-10, 0.5, 1000);
220        assert_eq!(res.xs[0], 0.0);
221        assert_eq!(res.ys[0], 1.0);
222    }
223
224    #[test]
225    fn tol_affects_accuracy() {
226        let f = |_x: f64, y: f64| y;
227        let exact = 1.0_f64.exp();
228        let res_loose = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-4, 1e-10, 0.5, 10_000);
229        let res_tight = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-12, 1e-14, 0.5, 10_000);
230        let e_loose = (res_loose.ys.last().unwrap() - exact).abs();
231        let e_tight = (res_tight.ys.last().unwrap() - exact).abs();
232        assert!(e_tight < e_loose, "tight tol should be more accurate: {e_tight} vs {e_loose}");
233    }
234
235    #[test]
236    fn stiff_problem() {
237        let f = |_x: f64, y: f64| -50.0 * y;
238        let res = solve_adaptive(&f, 0.0, 1.0, 0.1, 0.001, 1e-6, 1e-8, 0.01, 100_000);
239        let last = *res.ys.last().unwrap();
240        let exact = (-5.0_f64).exp();
241        assert!((last - exact).abs() < 1e-4, "got {last}, exact {exact}");
242    }
243
244    #[test]
245    fn negative_direction() {
246        let f = |_x: f64, y: f64| y;
247        let res = solve_adaptive(&f, 1.0, 1.0_f64.exp(), 0.0, 0.1, 1e-8, 1e-12, 0.5, 10_000);
248        let last = *res.ys.last().unwrap();
249        assert!((last - 1.0).abs() < 1e-5, "got {last}");
250    }
251}