mathhook_core/calculus/ode/numerical/
runge_kutta.rs1use crate::calculus::ode::first_order::ODEError;
12
13pub fn rk4_method<F>(f: F, x0: f64, y0: f64, x_end: f64, step: f64) -> Vec<(f64, f64)>
46where
47 F: Fn(f64, f64) -> f64,
48{
49 if step <= 0.0 {
50 return vec![(x0, y0)];
51 }
52
53 let mut solution = Vec::new();
54
55 let mut x = x0;
56 let mut y = y0;
57 solution.push((x, y));
58
59 let direction = if x_end > x0 { 1.0 } else { -1.0 };
60 let h = direction * step;
61
62 loop {
63 if direction > 0.0 && x + h > x_end {
65 let final_h = x_end - x;
67 if final_h > 1e-10 {
68 let k1 = f(x, y);
69 let k2 = f(x + final_h / 2.0, y + final_h * k1 / 2.0);
70 let k3 = f(x + final_h / 2.0, y + final_h * k2 / 2.0);
71 let k4 = f(x + final_h, y + final_h * k3);
72 y += final_h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
73 x = x_end;
74 solution.push((x, y));
75 }
76 break;
77 } else if direction < 0.0 && x + h < x_end {
78 let final_h = x_end - x;
80 if final_h.abs() > 1e-10 {
81 let k1 = f(x, y);
82 let k2 = f(x + final_h / 2.0, y + final_h * k1 / 2.0);
83 let k3 = f(x + final_h / 2.0, y + final_h * k2 / 2.0);
84 let k4 = f(x + final_h, y + final_h * k3);
85 y += final_h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
86 x = x_end;
87 solution.push((x, y));
88 }
89 break;
90 }
91
92 let k1 = f(x, y);
94 let k2 = f(x + h / 2.0, y + h * k1 / 2.0);
95 let k3 = f(x + h / 2.0, y + h * k2 / 2.0);
96 let k4 = f(x + h, y + h * k3);
97
98 y += h / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4);
99 x += h;
100
101 solution.push((x, y));
102
103 if (direction > 0.0 && x >= x_end - 1e-10) || (direction < 0.0 && x <= x_end + 1e-10) {
105 break;
106 }
107 }
108
109 solution
110}
111
112pub fn solve_rk4<F>(
126 f: F,
127 x0: f64,
128 y0: f64,
129 x_end: f64,
130 step: f64,
131) -> Result<Vec<(f64, f64)>, ODEError>
132where
133 F: Fn(f64, f64) -> f64,
134{
135 if step <= 0.0 {
136 return Err(ODEError::InvalidInput {
137 message: "Step size must be positive".to_owned(),
138 });
139 }
140
141 if !x0.is_finite() || !y0.is_finite() || !x_end.is_finite() {
142 return Err(ODEError::InvalidInput {
143 message: "Initial values and endpoints must be finite".to_owned(),
144 });
145 }
146
147 Ok(rk4_method(f, x0, y0, x_end, step))
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_rk4_constant_derivative() {
156 let solution = rk4_method(|_x, _y| 2.0, 0.0, 0.0, 1.0, 0.1);
157
158 assert!(solution.len() >= 11);
159 assert_eq!(solution[0], (0.0, 0.0));
160
161 let (x_final, y_final) = solution.last().unwrap();
162 assert!((x_final - 1.0).abs() < 1e-10);
163 assert!((y_final - 2.0).abs() < 1e-6);
164 }
165
166 #[test]
167 fn test_rk4_linear_ode() {
168 let solution = rk4_method(|x, _y| x, 0.0, 0.0, 1.0, 0.1);
169
170 let (x_final, y_final) = solution.last().unwrap();
171 assert!((x_final - 1.0).abs() < 1e-10);
172 assert!((y_final - 0.5).abs() < 1e-6);
173 }
174
175 #[test]
176 fn test_rk4_exponential_growth() {
177 let solution = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
178
179 let (x_final, y_final) = solution.last().unwrap();
180 assert!((x_final - 1.0).abs() < 1e-10);
181
182 let expected = 1.0_f64.exp();
183 let relative_error = (y_final - expected).abs() / expected;
184 assert!(relative_error < 1e-4);
185 }
186
187 #[test]
188 fn test_rk4_trigonometric() {
189 let solution = rk4_method(|x, _y| x.cos(), 0.0, 0.0, std::f64::consts::PI, 0.1);
192
193 let (x_final, y_final) = solution.last().unwrap();
194
195 assert!(
197 (x_final - std::f64::consts::PI).abs() < 1e-10,
198 "Expected x_final ≈ {}, got {}",
199 std::f64::consts::PI,
200 x_final
201 );
202
203 let expected = std::f64::consts::PI.sin();
205 assert!(
206 (y_final - expected).abs() < 1e-4,
207 "Expected y_final ≈ {}, got {}",
208 expected,
209 y_final
210 );
211 }
212
213 #[test]
214 fn test_rk4_backward_integration() {
215 let solution = rk4_method(|x, _y| x, 1.0, 0.5, 0.0, 0.1);
216
217 assert!(solution.len() > 1);
218 let (x_first, y_first) = solution[0];
219 let (x_final, y_final) = solution.last().unwrap();
220
221 assert_eq!((x_first, y_first), (1.0, 0.5));
222 assert!((x_final - 0.0).abs() < 1e-10);
223 assert!((y_final - 0.0).abs() < 1e-6);
224 }
225
226 #[test]
227 fn test_rk4_zero_step_size() {
228 let solution = rk4_method(|x, _y| x, 0.0, 0.0, 1.0, 0.0);
229
230 assert_eq!(solution.len(), 1);
231 assert_eq!(solution[0], (0.0, 0.0));
232 }
233
234 #[test]
235 fn test_solve_rk4_invalid_input() {
236 let result = solve_rk4(|x, _y| x, 0.0, 0.0, 1.0, -0.1);
237 assert!(result.is_err());
238
239 let result = solve_rk4(|x, _y| x, f64::NAN, 0.0, 1.0, 0.1);
240 assert!(result.is_err());
241 }
242
243 #[test]
244 fn test_rk4_variable_step() {
245 let solution1 = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
249 let solution2 = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.05);
250
251 let (_, y1) = solution1.last().unwrap();
252 let (_, y2) = solution2.last().unwrap();
253
254 assert!(solution2.len() > solution1.len());
256
257 let expected = 1.0_f64.exp();
259
260 let error1 = (y1 - expected).abs();
262 let error2 = (y2 - expected).abs();
263
264 assert!(
265 error2 < error1,
266 "Smaller step should be more accurate: error(h=0.05)={} should be < error(h=0.1)={}",
267 error2,
268 error1
269 );
270 }
271
272 #[test]
273 fn test_rk4_better_than_euler() {
274 use crate::calculus::ode::numerical::euler::euler_method;
275
276 let rk4_sol = rk4_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
277 let euler_sol = euler_method(|_x, y| y, 0.0, 1.0, 1.0, 0.1);
278
279 let expected = 1.0_f64.exp();
280 let (_, y_rk4) = rk4_sol.last().unwrap();
281 let (_, y_euler) = euler_sol.last().unwrap();
282
283 let error_rk4 = (y_rk4 - expected).abs();
284 let error_euler = (y_euler - expected).abs();
285
286 assert!(error_rk4 < error_euler);
287 }
288}