Skip to main content

numerics_ode/
rk4.rs

1//! Classical fourth-order Runge-Kutta method (RK4).
2//!
3//! Uses the Butcher tableau:
4//! ```text
5//! k1 = f(x_n,       y_n)
6//! k2 = f(x_n + h/2, y_n + h·k1/2)
7//! k3 = f(x_n + h/2, y_n + h·k2/2)
8//! k4 = f(x_n + h,   y_n + h·k3)
9//! y_{n+1} = y_n + (h/6)(k1 + 2k2 + 2k3 + k4)
10//! ```
11//!
12//! Global truncation error: O(h⁴).
13
14/// Solve an IVP using the classical RK4 method.
15///
16/// # Arguments
17///
18/// * `f`     — Right-hand side `dy/dx = f(x, y)`.
19/// * `x0`    — Initial independent variable.
20/// * `y0`    — Initial value.
21/// * `x_end` — Terminal value.
22/// * `n`     — Number of steps (≥ 1).
23///
24/// # Returns
25///
26/// `(xs, ys)` — vectors of `x` and `y` at each step.
27pub fn solve(f: &dyn Fn(f64, f64) -> f64, x0: f64, y0: f64, x_end: f64, n: usize) -> (Vec<f64>, Vec<f64>) {
28    assert!(n >= 1, "number of steps must be ≥ 1");
29    let h = (x_end - x0) / n as f64;
30    let mut xs = Vec::with_capacity(n + 1);
31    let mut ys = Vec::with_capacity(n + 1);
32    xs.push(x0);
33    ys.push(y0);
34    let mut x = x0;
35    let mut y = y0;
36    for _ in 0..n {
37        let k1 = f(x, y);
38        let k2 = f(x + h / 2.0, y + h * k1 / 2.0);
39        let k3 = f(x + h / 2.0, y + h * k2 / 2.0);
40        let k4 = f(x + h, y + h * k3);
41        y += (h / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
42        x += h;
43        xs.push(x);
44        ys.push(y);
45    }
46    (xs, ys)
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52
53    #[test]
54    fn exponential_growth() {
55        let f = |_x: f64, y: f64| y;
56        let (_, ys) = solve(&f, 0.0, 1.0, 1.0, 10);
57        let exact = 1.0_f64.exp();
58        assert!((ys[10] - exact).abs() < 1e-5, "got {}, exact {}", ys[10], exact);
59    }
60
61    #[test]
62    fn exponential_growth_high_accuracy() {
63        let f = |_x: f64, y: f64| y;
64        let (_, ys) = solve(&f, 0.0, 1.0, 1.0, 100);
65        let exact = 1.0_f64.exp();
66        assert!((ys[100] - exact).abs() < 1e-8);
67    }
68
69    /// RK4 is fourth-order: error ∝ h⁴
70    #[test]
71    fn convergence_order_rk4() {
72        let f = |_x: f64, y: f64| y;
73        let exact = 1.0_f64.exp();
74        let e1 = (solve(&f, 0.0, 1.0, 1.0, 10).1[10] - exact).abs();
75        let e2 = (solve(&f, 0.0, 1.0, 1.0, 20).1[20] - exact).abs();
76        let ratio = e1 / e2;
77        // ratio should be ~16 (2⁴) for fourth-order method
78        assert!(ratio > 10.0 && ratio < 25.0, "convergence ratio = {ratio}, expected ~16");
79    }
80
81    #[test]
82    fn constant_rhs() {
83        let f = |_x, _y| 0.0;
84        let (_, ys) = solve(&f, 0.0, 42.0, 1.0, 10);
85        for &y in &ys {
86            assert!((y - 42.0).abs() < 1e-12);
87        }
88    }
89
90    #[test]
91    fn linear_rhs() {
92        let f = |_x, _y| 2.0;
93        let (_, ys) = solve(&f, 0.0, 0.0, 3.0, 100);
94        // exact: y = 2x, y(3) = 6
95        assert!((ys[100] - 6.0).abs() < 1e-10);
96    }
97
98    #[test]
99    fn quadratic_rhs() {
100        // dy/dx = 2x → y = x²
101        let f = |x: f64, _y: f64| 2.0 * x;
102        let (_, ys) = solve(&f, 0.0, 0.0, 5.0, 100);
103        assert!((ys[100] - 25.0).abs() < 1e-8);
104    }
105
106    #[test]
107    fn sinusoidal_rhs() {
108        // dy/dx = cos(x) → y = sin(x)
109        let f = |x: f64, _y: f64| x.cos();
110        let (_, ys) = solve(&f, 0.0, 0.0, std::f64::consts::PI / 2.0, 100);
111        assert!((ys[100] - 1.0).abs() < 1e-10);
112    }
113
114    #[test]
115    fn exponential_decay() {
116        let f = |_x: f64, y: f64| -y;
117        let (_, ys) = solve(&f, 0.0, 1.0, 2.0, 100);
118        assert!((ys[100] - (-2.0_f64).exp()).abs() < 1e-9);
119    }
120
121    #[test]
122    #[should_panic(expected = "number of steps must be ≥ 1")]
123    fn panics_on_zero_steps() {
124        let f = |_x, _y| 0.0;
125        solve(&f, 0.0, 1.0, 1.0, 0);
126    }
127
128    #[test]
129    fn single_step() {
130        let f = |_x, y| y;
131        let (xs, ys) = solve(&f, 0.0, 1.0, 1.0, 1);
132        assert_eq!(xs.len(), 2);
133        assert_eq!(ys.len(), 2);
134        let k1 = 1.0_f64;
135        let k2 = 1.5_f64; // f(0.5, 1.5) = 1.5
136        let k3 = 1.75_f64; // f(0.5, 1.75) = 1.75
137        let k4 = 2.75_f64; // f(1.0, 2.75) = 2.75
138        let expected = 1.0 + (1.0 / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
139        assert!((ys[1] - expected).abs() < 1e-12);
140    }
141
142    #[test]
143    fn negative_direction() {
144        let f = |_x, y| y;
145        let (_, ys) = solve(&f, 1.0, 1.0_f64.exp(), 0.0, 100);
146        assert!((ys[100] - 1.0).abs() < 1e-9);
147    }
148
149    #[test]
150    fn stiff_problem_moderate() {
151        // dy/dx = -15y, moderate stiffness
152        let f = |_x, y| -15.0 * y;
153        let (_, ys) = solve(&f, 0.0, 1.0, 1.0, 10_000);
154        let exact = (-15.0_f64).exp();
155        assert!((ys[10_000] - exact).abs() < 1e-3);
156    }
157}