use ndarray::{Array1, ArrayView1};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OdeMethod {
Euler,
Heun,
}
pub fn integrate_fixed(
method: OdeMethod,
x0: &Array1<f32>,
t0: f32,
dt: f32,
steps: usize,
mut f: impl FnMut(&ArrayView1<f32>, f32) -> crate::Result<Array1<f32>>,
) -> crate::Result<Array1<f32>> {
if steps < 1 {
return Err(crate::Error::Domain("steps must be >= 1"));
}
if !dt.is_finite() {
return Err(crate::Error::Domain("dt must be finite"));
}
let mut x = x0.clone();
let mut t = t0;
match method {
OdeMethod::Euler => {
for _ in 0..steps {
let v = f(&x.view(), t)?;
for i in 0..x.len() {
x[i] += dt * v[i];
}
t += dt;
}
}
OdeMethod::Heun => {
for _ in 0..steps {
let v0 = f(&x.view(), t)?;
let mut x_pred = x.clone();
for i in 0..x.len() {
x_pred[i] += dt * v0[i];
}
let v1 = f(&x_pred.view(), t + dt)?;
for i in 0..x.len() {
x[i] += 0.5 * dt * (v0[i] + v1[i]);
}
t += dt;
}
}
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn heun_is_more_accurate_than_euler_on_dx_dt_eq_minus_x() {
let x0 = Array1::from_vec(vec![1.0f32]);
let exact = (-1.0f32).exp();
let steps = 20usize;
let dt = 1.0f32 / (steps as f32);
let euler = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
})
.unwrap();
let heun = integrate_fixed(OdeMethod::Heun, &x0, 0.0, dt, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
})
.unwrap();
let err_euler = (euler[0] - exact).abs();
let err_heun = (heun[0] - exact).abs();
assert!(
err_heun < err_euler,
"expected Heun to be more accurate: err_heun={err_heun} err_euler={err_euler}"
);
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 64,
.. ProptestConfig::default()
})]
#[test]
fn prop_constant_field_is_exact_for_euler_and_heun(
len in 1usize..16,
steps in 1usize..200,
dt in 1e-3f32..1.0f32,
t0 in -2.0f32..2.0f32,
x0 in prop::collection::vec(-10.0f32..10.0f32, 16),
c in prop::collection::vec(-10.0f32..10.0f32, 16),
) {
let x0 = Array1::from_vec(x0[..len].to_vec());
let c = Array1::from_vec(c[..len].to_vec());
let expected = {
let mut out = x0.clone();
let scale = dt * (steps as f32);
for i in 0..len {
out[i] += scale * c[i];
}
out
};
let euler = integrate_fixed(OdeMethod::Euler, &x0, t0, dt, steps, |_x, _t| Ok(c.clone())).unwrap();
let heun = integrate_fixed(OdeMethod::Heun, &x0, t0, dt, steps, |_x, _t| Ok(c.clone())).unwrap();
for i in 0..len {
let tol = 2e-2 + 1e-6 * expected[i].abs();
prop_assert!((euler[i] - expected[i]).abs() <= tol, "euler mismatch at {i}");
prop_assert!((heun[i] - expected[i]).abs() <= tol, "heun mismatch at {i}");
}
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 64,
.. ProptestConfig::default()
})]
#[test]
fn prop_error_decreases_with_more_steps_for_dx_dt_eq_minus_x(
steps in 5usize..80,
) {
let x0 = Array1::from_vec(vec![1.0f32]);
let exact = (-1.0f32).exp();
let dt1 = 1.0f32 / (steps as f32);
let dt2 = 1.0f32 / ((2 * steps) as f32);
let e1 = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt1, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
}).unwrap();
let e2 = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt2, 2 * steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
}).unwrap();
let h1 = integrate_fixed(OdeMethod::Heun, &x0, 0.0, dt1, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
}).unwrap();
let h2 = integrate_fixed(OdeMethod::Heun, &x0, 0.0, dt2, 2 * steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
}).unwrap();
let err_e1 = (e1[0] - exact).abs();
let err_e2 = (e2[0] - exact).abs();
let err_h1 = (h1[0] - exact).abs();
let err_h2 = (h2[0] - exact).abs();
prop_assert!(err_e2 <= err_e1 + 1e-6, "euler error did not decrease: {err_e1} -> {err_e2}");
prop_assert!(err_h2 <= err_h1 + 1e-6, "heun error did not decrease: {err_h1} -> {err_h2}");
prop_assert!(err_h1 <= err_e1 + 1e-6, "expected Heun <= Euler at steps={steps}");
}
}
#[test]
fn euler_on_exponential_decay_converges_to_exact() {
let x0 = Array1::from_vec(vec![1.0f32]);
let exact = (-1.0f32).exp();
let steps = 1000usize;
let dt = 1.0f32 / (steps as f32);
let result = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
})
.unwrap();
let err = (result[0] - exact).abs();
assert!(
err < 1e-3,
"Euler with 1000 steps should be within 1e-3 of exact: got {}, exact {}, err {}",
result[0],
exact,
err
);
}
#[test]
fn heun_on_exponential_decay_is_very_accurate() {
let x0 = Array1::from_vec(vec![1.0f32]);
let exact = (-1.0f32).exp();
let steps = 100usize;
let dt = 1.0f32 / (steps as f32);
let result = integrate_fixed(OdeMethod::Heun, &x0, 0.0, dt, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[0]]))
})
.unwrap();
let err = (result[0] - exact).abs();
assert!(
err < 1e-5,
"Heun with 100 steps should be within 1e-5 of exact: got {}, exact {}, err {}",
result[0],
exact,
err
);
}
#[test]
fn euler_and_heun_on_2d_rotation_preserve_radius() {
let x0 = Array1::from_vec(vec![1.0f32, 0.0]);
let r0 = 1.0f32;
let steps = 200usize;
let total_t = std::f32::consts::PI; let dt = total_t / (steps as f32);
let euler = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[1], x[0]]))
})
.unwrap();
let heun = integrate_fixed(OdeMethod::Heun, &x0, 0.0, dt, steps, |x, _t| {
Ok(Array1::from_vec(vec![-x[1], x[0]]))
})
.unwrap();
let r_euler = (euler[0] * euler[0] + euler[1] * euler[1]).sqrt();
let r_heun = (heun[0] * heun[0] + heun[1] * heun[1]).sqrt();
let err_euler = (r_euler - r0).abs();
let err_heun = (r_heun - r0).abs();
assert!(
err_heun < err_euler,
"Heun should preserve radius better: err_heun={err_heun} err_euler={err_euler}"
);
assert!(
err_heun < 0.01,
"Heun radius error should be < 0.01, got {err_heun}"
);
}
#[test]
fn single_step_euler_matches_manual() {
let x0 = Array1::from_vec(vec![2.0f32, 3.0]);
let dt = 0.1f32;
let result = integrate_fixed(OdeMethod::Euler, &x0, 0.0, dt, 1, |x, _t| {
Ok(Array1::from_vec(vec![x[1], -x[0]]))
})
.unwrap();
assert!((result[0] - 2.3).abs() < 1e-6);
assert!((result[1] - 2.8).abs() < 1e-6);
}
}