mathhook_core/calculus/ode/numerical/
runge_kutta.rs

1//! Runge-Kutta 4th order method for numerical ODE solving
2//!
3//! Implements the classic RK4 method for first-order ODEs: y' = f(x, y)
4//! Formula:
5//!   k1 = f(x_n, y_n)
6//!   k2 = f(x_n + h/2, y_n + h*k1/2)
7//!   k3 = f(x_n + h/2, y_n + h*k2/2)
8//!   k4 = f(x_n + h, y_n + h*k3)
9//!   y_{n+1} = y_n + h/6 * (k1 + 2*k2 + 2*k3 + k4)
10
11use crate::calculus::ode::first_order::ODEError;
12
13/// Solves a first-order ODE using Runge-Kutta 4th order method
14///
15/// # Arguments
16///
17/// * `f` - The derivative function f(x, y) where dy/dx = f(x, y)
18/// * `x0` - Initial x value
19/// * `y0` - Initial y value
20/// * `x_end` - Final x value
21/// * `step` - Step size h
22///
23/// # Returns
24///
25/// Vector of (x, y) solution points
26///
27/// # Examples
28///
29/// ```
30/// use mathhook_core::calculus::ode::numerical::runge_kutta::rk4_method;
31///
32/// let solution = rk4_method(
33///     |x, _y| x,
34///     0.0,
35///     0.0,
36///     1.0,
37///     0.1
38/// );
39///
40/// assert!(solution.len() > 0);
41/// assert_eq!(solution[0], (0.0, 0.0));
42/// let (_, y_final) = solution.last().unwrap();
43/// assert!((y_final - 0.5).abs() < 0.001);
44/// ```
45pub fn rk4_method<F>(f: F, x0: f64, y0: f64, x_end: f64, step: f64) -> Vec<(f64, f64)>
46where
47    F: Fn(f64, f64) -> f64,
48{
49    if step <= 0.0 {
50        return vec![(x0, y0)];
51    }
52
53    let mut solution = Vec::new();
54
55    let mut x = x0;
56    let mut y = y0;
57    solution.push((x, y));
58
59    let direction = if x_end > x0 { 1.0 } else { -1.0 };
60    let h = direction * step;
61
62    loop {
63        // Check if we would overshoot the endpoint
64        if direction > 0.0 && x + h > x_end {
65            // Take a final partial step to exactly reach x_end
66            let final_h = x_end - x;
67            if final_h > 1e-10 {
68                let k1 = f(x, y);
69                let k2 = f(x + final_h / 2.0, y + final_h * k1 / 2.0);
70                let k3 = f(x + final_h / 2.0, y + final_h * k2 / 2.0);
71                let k4 = f(x + final_h, y + final_h * k3);
72                y += final_h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
73                x = x_end;
74                solution.push((x, y));
75            }
76            break;
77        } else if direction < 0.0 && x + h < x_end {
78            // Take a final partial step to exactly reach x_end (backward)
79            let final_h = x_end - x;
80            if final_h.abs() > 1e-10 {
81                let k1 = f(x, y);
82                let k2 = f(x + final_h / 2.0, y + final_h * k1 / 2.0);
83                let k3 = f(x + final_h / 2.0, y + final_h * k2 / 2.0);
84                let k4 = f(x + final_h, y + final_h * k3);
85                y += final_h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
86                x = x_end;
87                solution.push((x, y));
88            }
89            break;
90        }
91
92        // Normal RK4 step
93        let k1 = f(x, y);
94        let k2 = f(x + h / 2.0, y + h * k1 / 2.0);
95        let k3 = f(x + h / 2.0, y + h * k2 / 2.0);
96        let k4 = f(x + h, y + h * k3);
97
98        y += h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
99        x += h;
100
101        solution.push((x, y));
102
103        // Check if we've reached the endpoint (within floating point tolerance)
104        if (direction > 0.0 && x >= x_end - 1e-10) || (direction < 0.0 && x <= x_end + 1e-10) {
105            break;
106        }
107    }
108
109    solution
110}
111
112/// Solves a first-order ODE using RK4 method with Result type
113///
114/// # Arguments
115///
116/// * `f` - The derivative function f(x, y)
117/// * `x0` - Initial x value
118/// * `y0` - Initial y value
119/// * `x_end` - Final x value
120/// * `step` - Step size
121///
122/// # Returns
123///
124/// Result containing vector of (x, y) solution points
125pub fn solve_rk4<F>(
126    f: F,
127    x0: f64,
128    y0: f64,
129    x_end: f64,
130    step: f64,
131) -> Result<Vec<(f64, f64)>, ODEError>
132where
133    F: Fn(f64, f64) -> f64,
134{
135    if step <= 0.0 {
136        return Err(ODEError::InvalidInput {
137            message: "Step size must be positive".to_owned(),
138        });
139    }
140
141    if !x0.is_finite() || !y0.is_finite() || !x_end.is_finite() {
142        return Err(ODEError::InvalidInput {
143            message: "Initial values and endpoints must be finite".to_owned(),
144        });
145    }
146
147    Ok(rk4_method(f, x0, y0, x_end, step))
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_rk4_constant_derivative() {
156        let solution = rk4_method(|_x, _y| 2.0, 0.0, 0.0, 1.0, 0.1);
157
158        assert!(solution.len() >= 11);
159        assert_eq!(solution[0], (0.0, 0.0));
160
161        let (x_final, y_final) = solution.last().unwrap();
162        assert!((x_final - 1.0).abs() < 1e-10);
163        assert!((y_final - 2.0).abs() < 1e-6);
164    }
165
166    #[test]
167    fn test_rk4_linear_ode() {
168        let solution = rk4_method(|x, _y| x, 0.0, 0.0, 1.0, 0.1);
169
170        let (x_final, y_final) = solution.last().unwrap();
171        assert!((x_final - 1.0).abs() < 1e-10);
172        assert!((y_final - 0.5).abs() < 1e-6);
173    }
174
175    #[test]
176    fn test_rk4_exponential_growth() {
177        let solution = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
178
179        let (x_final, y_final) = solution.last().unwrap();
180        assert!((x_final - 1.0).abs() < 1e-10);
181
182        let expected = 1.0_f64.exp();
183        let relative_error = (y_final - expected).abs() / expected;
184        assert!(relative_error < 1e-4);
185    }
186
187    #[test]
188    fn test_rk4_trigonometric() {
189        // dy/dx = cos(x), y(0) = 0 => y = sin(x)
190        // At x = π, y = sin(π) = 0
191        let solution = rk4_method(|x, _y| x.cos(), 0.0, 0.0, std::f64::consts::PI, 0.1);
192
193        let (x_final, y_final) = solution.last().unwrap();
194
195        // Verify we actually reach π (not overshoot)
196        assert!(
197            (x_final - std::f64::consts::PI).abs() < 1e-10,
198            "Expected x_final ≈ {}, got {}",
199            std::f64::consts::PI,
200            x_final
201        );
202
203        // Verify y(π) ≈ sin(π) = 0
204        let expected = std::f64::consts::PI.sin();
205        assert!(
206            (y_final - expected).abs() < 1e-4,
207            "Expected y_final ≈ {}, got {}",
208            expected,
209            y_final
210        );
211    }
212
213    #[test]
214    fn test_rk4_backward_integration() {
215        let solution = rk4_method(|x, _y| x, 1.0, 0.5, 0.0, 0.1);
216
217        assert!(solution.len() > 1);
218        let (x_first, y_first) = solution[0];
219        let (x_final, y_final) = solution.last().unwrap();
220
221        assert_eq!((x_first, y_first), (1.0, 0.5));
222        assert!((x_final - 0.0).abs() < 1e-10);
223        assert!((y_final - 0.0).abs() < 1e-6);
224    }
225
226    #[test]
227    fn test_rk4_zero_step_size() {
228        let solution = rk4_method(|x, _y| x, 0.0, 0.0, 1.0, 0.0);
229
230        assert_eq!(solution.len(), 1);
231        assert_eq!(solution[0], (0.0, 0.0));
232    }
233
234    #[test]
235    fn test_solve_rk4_invalid_input() {
236        let result = solve_rk4(|x, _y| x, 0.0, 0.0, 1.0, -0.1);
237        assert!(result.is_err());
238
239        let result = solve_rk4(|x, _y| x, f64::NAN, 0.0, 1.0, 0.1);
240        assert!(result.is_err());
241    }
242
243    #[test]
244    fn test_rk4_variable_step() {
245        // dy/dx = y, y(0) = 1 => y = e^x
246        // This is a more sensitive test case than dy/dx = x
247        // because RK4's truncation error is more visible
248        let solution1 = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
249        let solution2 = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.05);
250
251        let (_, y1) = solution1.last().unwrap();
252        let (_, y2) = solution2.last().unwrap();
253
254        // More steps with smaller step size
255        assert!(solution2.len() > solution1.len());
256
257        // Expected value: e^1
258        let expected = 1.0_f64.exp();
259
260        // Smaller step size should give more accurate result
261        let error1 = (y1 - expected).abs();
262        let error2 = (y2 - expected).abs();
263
264        assert!(
265            error2 < error1,
266            "Smaller step should be more accurate: error(h=0.05)={} should be < error(h=0.1)={}",
267            error2,
268            error1
269        );
270    }
271
272    #[test]
273    fn test_rk4_better_than_euler() {
274        use crate::calculus::ode::numerical::euler::euler_method;
275
276        let rk4_sol = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
277        let euler_sol = euler_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
278
279        let expected = 1.0_f64.exp();
280        let (_, y_rk4) = rk4_sol.last().unwrap();
281        let (_, y_euler) = euler_sol.last().unwrap();
282
283        let error_rk4 = (y_rk4 - expected).abs();
284        let error_euler = (y_euler - expected).abs();
285
286        assert!(error_rk4 < error_euler);
287    }
288}