mod bdf;
mod euler;
mod rk45;
pub use bdf::bdf2;
pub use euler::euler;
pub use rk45::rk45;
use scivex_core::Float;
use crate::error::Result;
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OdeResult<T: Float> {
pub t: Vec<T>,
pub y: Vec<Vec<T>>,
pub n_evals: usize,
pub n_steps: usize,
pub success: bool,
}
pub type EventFn<T> = Box<dyn Fn(T, &[T]) -> T>;
pub struct OdeOptions<T: Float> {
pub atol: T,
pub rtol: T,
pub first_step: Option<T>,
pub max_steps: usize,
pub event_fn: Option<EventFn<T>>,
}
impl<T: Float> Default for OdeOptions<T> {
fn default() -> Self {
Self {
atol: T::from_f64(1e-8),
rtol: T::from_f64(1e-6),
first_step: None,
max_steps: 10_000,
event_fn: None,
}
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OdeMethod {
Euler,
RK45,
BDF2,
}
pub fn solve_ivp<T, F>(
f: F,
t_span: [T; 2],
y0: &[T],
method: OdeMethod,
options: &OdeOptions<T>,
) -> Result<OdeResult<T>>
where
T: Float,
F: Fn(T, &[T]) -> Vec<T>,
{
match method {
OdeMethod::Euler => euler::euler(f, t_span, y0, options),
OdeMethod::RK45 => rk45::rk45(f, t_span, y0, options),
OdeMethod::BDF2 => bdf::bdf2(f, t_span, y0, options),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_solve_ivp_rk45() {
let result = solve_ivp(
|_t: f64, y: &[f64]| vec![-y[0]],
[0.0, 1.0],
&[1.0],
OdeMethod::RK45,
&OdeOptions::default(),
)
.unwrap();
let y_final = result.y.last().unwrap()[0];
let expected = (-1.0_f64).exp();
assert!((y_final - expected).abs() < 1e-6);
}
#[test]
fn test_solve_ivp_euler() {
let result = solve_ivp(
|_t: f64, y: &[f64]| vec![-y[0]],
[0.0, 1.0],
&[1.0],
OdeMethod::Euler,
&OdeOptions::default(),
)
.unwrap();
let y_final = result.y.last().unwrap()[0];
let expected = (-1.0_f64).exp();
assert!((y_final - expected).abs() < 0.02);
}
#[test]
fn test_solve_ivp_bdf2() {
let result = solve_ivp(
|_t: f64, y: &[f64]| vec![-y[0]],
[0.0, 1.0],
&[1.0],
OdeMethod::BDF2,
&OdeOptions::default(),
)
.unwrap();
let y_final = result.y.last().unwrap()[0];
let expected = (-1.0_f64).exp();
assert!((y_final - expected).abs() < 1e-3);
}
#[test]
fn test_event_detection() {
let result = solve_ivp(
|_t: f64, _y: &[f64]| vec![1.0],
[0.0, 5.0],
&[-1.0],
OdeMethod::RK45,
&OdeOptions {
event_fn: Some(Box::new(|_t: f64, y: &[f64]| y[0])),
..OdeOptions::default()
},
)
.unwrap();
let t_final = *result.t.last().unwrap();
assert!(
t_final < 2.0,
"Should have stopped early at event, t_final={t_final}"
);
}
#[test]
fn test_ode_result_trajectory() {
let result = solve_ivp(
|_t: f64, _y: &[f64]| vec![1.0],
[0.0, 1.0],
&[0.0],
OdeMethod::RK45,
&OdeOptions::default(),
)
.unwrap();
for i in 1..result.y.len() {
assert!(result.y[i][0] >= result.y[i - 1][0]);
assert!(result.t[i] > result.t[i - 1]);
}
}
#[test]
fn test_lotka_volterra() {
let result = solve_ivp(
|_t: f64, y: &[f64]| vec![y[0] - y[0] * y[1], -y[1] + y[0] * y[1]],
[0.0, 10.0],
&[1.0, 0.5],
OdeMethod::RK45,
&OdeOptions::default(),
)
.unwrap();
assert!(result.success);
for y in &result.y {
assert!(y[0] > 0.0, "prey went negative");
assert!(y[1] > 0.0, "predator went negative");
}
}
}