use crate::ode::ODEError;
#[derive(Debug, Clone)]
pub struct Rk4Config {
pub x0: f64,
pub y0: f64,
pub x_end: f64,
pub steps: usize,
}
impl Rk4Config {
#[must_use]
pub fn new(x0: f64, y0: f64, x_end: f64, steps: usize) -> Self {
Self {
x0,
y0,
x_end,
steps,
}
}
}
#[derive(Debug, Clone)]
pub struct Rk4Solution {
pub x_final: f64,
pub y_final: f64,
pub trajectory: Vec<(f64, f64)>,
}
pub fn rk4_solve<F>(f: F, config: Rk4Config) -> Result<Rk4Solution, ODEError>
where
F: Fn(f64, f64) -> f64,
{
if config.steps == 0 {
return Err(ODEError::CannotSolve(
"RK4 requires at least one step".to_string(),
));
}
let h = (config.x_end - config.x0) / config.steps as f64;
let mut x = config.x0;
let mut y = config.y0;
let mut trajectory = Vec::with_capacity(config.steps + 1);
trajectory.push((x, y));
for _ in 0..config.steps {
let k1 = f(x, y);
let k2 = f(x + 0.5 * h, y + 0.5 * h * k1);
let k3 = f(x + 0.5 * h, y + 0.5 * h * k2);
let k4 = f(x + h, y + h * k3);
y += h * (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0;
x += h;
trajectory.push((x, y));
}
Ok(Rk4Solution {
x_final: x,
y_final: y,
trajectory,
})
}
pub fn rk4_system_solve<F>(
f: F,
x0: f64,
y0: Vec<f64>,
x_end: f64,
steps: usize,
) -> Result<Vec<f64>, ODEError>
where
F: Fn(f64, &[f64]) -> Vec<f64>,
{
if steps == 0 {
return Err(ODEError::CannotSolve(
"RK4 requires at least one step".to_string(),
));
}
let n = y0.len();
let h = (x_end - x0) / steps as f64;
let mut x = x0;
let mut y = y0;
for _ in 0..steps {
let k1 = f(x, &y);
validate_system_size(&k1, n)?;
let y_mid1: Vec<f64> = y
.iter()
.zip(&k1)
.map(|(yi, ki)| yi + 0.5 * h * ki)
.collect();
let k2 = f(x + 0.5 * h, &y_mid1);
validate_system_size(&k2, n)?;
let y_mid2: Vec<f64> = y
.iter()
.zip(&k2)
.map(|(yi, ki)| yi + 0.5 * h * ki)
.collect();
let k3 = f(x + 0.5 * h, &y_mid2);
validate_system_size(&k3, n)?;
let y_end: Vec<f64> = y.iter().zip(&k3).map(|(yi, ki)| yi + h * ki).collect();
let k4 = f(x + h, &y_end);
validate_system_size(&k4, n)?;
y = y
.iter()
.enumerate()
.map(|(i, yi)| yi + h * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i]) / 6.0)
.collect();
x += h;
}
let _ = x; Ok(y)
}
fn validate_system_size(output: &[f64], expected: usize) -> Result<(), ODEError> {
if output.len() != expected {
return Err(ODEError::CannotSolve(format!(
"System function returned {} values, expected {}",
output.len(),
expected
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rk4_exponential_growth() {
let config = Rk4Config::new(0.0, 1.0, 1.0, 1000);
let sol = rk4_solve(|_x, y| y, config).unwrap();
assert!((sol.y_final - std::f64::consts::E).abs() < 1e-6);
}
#[test]
fn test_rk4_trajectory_length() {
let config = Rk4Config::new(0.0, 1.0, 1.0, 10);
let sol = rk4_solve(|_x, y| y, config).unwrap();
assert_eq!(sol.trajectory.len(), 11);
}
#[test]
fn test_rk4_zero_steps_error() {
let config = Rk4Config::new(0.0, 1.0, 1.0, 0);
let result = rk4_solve(|_x, y| y, config);
assert!(matches!(result, Err(ODEError::CannotSolve(_))));
}
#[test]
fn test_rk4_simple_linear() {
let config = Rk4Config::new(0.0, 0.0, 3.0, 300);
let sol = rk4_solve(|_x, _y| 2.0, config).unwrap();
assert!((sol.y_final - 6.0).abs() < 1e-10);
}
#[test]
fn test_rk4_system_solve() {
let y_final = rk4_system_solve(
|_x, u| vec![u[1], -u[0]],
0.0,
vec![1.0, 0.0],
std::f64::consts::PI,
10_000,
)
.unwrap();
assert!((y_final[0] - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_rk4_system_zero_steps_error() {
let result = rk4_system_solve(|_x, u| vec![u[0]], 0.0, vec![1.0], 1.0, 0);
assert!(matches!(result, Err(ODEError::CannotSolve(_))));
}
}