Skip to main content

scivex_optim/ode/
mod.rs

1//! ODE initial value problem solvers.
2//!
3//! Provides multiple methods for solving systems of ordinary differential
4//! equations of the form `dy/dt = f(t, y)`:
5//!
6//! - [`euler`] — Forward Euler (1st order, simple)
7//! - [`rk45`] — Dormand-Prince RK4(5) (adaptive, general-purpose)
8//! - [`bdf2`] — BDF-2 (implicit, for stiff systems)
9//! - [`solve_ivp`] — Unified entry point with method selection
10//!
11//! ## Example
12//!
13//! ```ignore
14//! use scivex_optim::ode::{solve_ivp, OdeMethod, OdeOptions};
15//!
16//! // dy/dt = -y, y(0) = 1  =>  y(t) = e^(-t)
17//! let result = solve_ivp(
18//!     |_t, y: &[f64]| vec![-y[0]],
19//!     [0.0, 1.0],
20//!     &[1.0],
21//!     OdeMethod::RK45,
22//!     &OdeOptions::default(),
23//! ).unwrap();
24//!
25//! println!("y(1) = {}", result.y.last().unwrap()[0]);
26//! ```
27
28mod bdf;
29mod euler;
30mod rk45;
31
32pub use bdf::bdf2;
33pub use euler::euler;
34pub use rk45::rk45;
35
36use scivex_core::Float;
37
38use crate::error::Result;
39
40/// Result of an ODE integration.
41///
42/// # Examples
43///
44/// ```
45/// # use scivex_optim::ode::{euler, OdeOptions};
46/// let result = euler(|_t, y: &[f64]| vec![-y[0]], [0.0, 1.0], &[1.0], &OdeOptions::default()).unwrap();
47/// assert!(result.success);
48/// assert!(!result.t.is_empty());
49/// ```
50#[cfg_attr(
51    feature = "serde-support",
52    derive(serde::Serialize, serde::Deserialize)
53)]
54#[derive(Debug, Clone)]
55pub struct OdeResult<T: Float> {
56    /// Time values at each accepted step.
57    pub t: Vec<T>,
58    /// Solution vectors at each accepted step. `y[i]` corresponds to `t[i]`.
59    pub y: Vec<Vec<T>>,
60    /// Total number of function evaluations.
61    pub n_evals: usize,
62    /// Total number of accepted steps.
63    pub n_steps: usize,
64    /// Whether the integration completed successfully.
65    pub success: bool,
66}
67
68/// Event function type: `fn(t, y) -> T`. Integration stops when the return value crosses zero.
69pub type EventFn<T> = Box<dyn Fn(T, &[T]) -> T>;
70
71/// Options for ODE solvers.
72///
73/// # Examples
74///
75/// ```
76/// # use scivex_optim::ode::OdeOptions;
77/// let opts = OdeOptions::<f64>::default();
78/// assert_eq!(opts.max_steps, 10_000);
79/// ```
80pub struct OdeOptions<T: Float> {
81    /// Absolute tolerance for adaptive methods.
82    pub atol: T,
83    /// Relative tolerance for adaptive methods.
84    pub rtol: T,
85    /// Initial step size. If `None`, a reasonable default is chosen.
86    pub first_step: Option<T>,
87    /// Maximum number of steps before giving up.
88    pub max_steps: usize,
89    /// Optional event function. Integration terminates when the return
90    /// value crosses zero.
91    pub event_fn: Option<EventFn<T>>,
92}
93
94impl<T: Float> Default for OdeOptions<T> {
95    fn default() -> Self {
96        Self {
97            atol: T::from_f64(1e-8),
98            rtol: T::from_f64(1e-6),
99            first_step: None,
100            max_steps: 10_000,
101            event_fn: None,
102        }
103    }
104}
105
106/// Available ODE solver methods.
107///
108/// # Examples
109///
110/// ```
111/// # use scivex_optim::ode::OdeMethod;
112/// let method = OdeMethod::RK45;
113/// assert_eq!(method, OdeMethod::RK45);
114/// ```
115#[cfg_attr(
116    feature = "serde-support",
117    derive(serde::Serialize, serde::Deserialize)
118)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum OdeMethod {
121    /// Forward Euler — first-order, fixed step. Simple but inaccurate.
122    Euler,
123    /// Dormand-Prince RK4(5) — adaptive, general-purpose. Best for non-stiff problems.
124    RK45,
125    /// BDF-2 — implicit, fixed step. Best for stiff problems.
126    BDF2,
127}
128
129/// Solve an initial value problem (IVP) for a system of ODEs.
130///
131/// # Arguments
132///
133/// * `f` — right-hand side function: `f(t, y) -> dy/dt`
134/// * `t_span` — integration interval `[t0, tf]`
135/// * `y0` — initial state vector
136/// * `method` — solver method to use
137/// * `options` — solver options (tolerances, step size, etc.)
138///
139/// # Returns
140///
141/// An [`OdeResult`] containing the time values and solution trajectory.
142///
143/// # Examples
144///
145/// ```
146/// # use scivex_optim::ode::{solve_ivp, OdeMethod, OdeOptions};
147/// // dy/dt = -y, y(0) = 1  →  y(t) = e^(-t)
148/// let result = solve_ivp(
149///     |_t: f64, y: &[f64]| vec![-y[0]],
150///     [0.0, 1.0],
151///     &[1.0],
152///     OdeMethod::RK45,
153///     &OdeOptions::default(),
154/// ).unwrap();
155/// let y_final = result.y.last().unwrap()[0];
156/// assert!((y_final - (-1.0_f64).exp()).abs() < 1e-6);
157/// ```
158pub fn solve_ivp<T, F>(
159    f: F,
160    t_span: [T; 2],
161    y0: &[T],
162    method: OdeMethod,
163    options: &OdeOptions<T>,
164) -> Result<OdeResult<T>>
165where
166    T: Float,
167    F: Fn(T, &[T]) -> Vec<T>,
168{
169    match method {
170        OdeMethod::Euler => euler::euler(f, t_span, y0, options),
171        OdeMethod::RK45 => rk45::rk45(f, t_span, y0, options),
172        OdeMethod::BDF2 => bdf::bdf2(f, t_span, y0, options),
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_solve_ivp_rk45() {
182        let result = solve_ivp(
183            |_t: f64, y: &[f64]| vec![-y[0]],
184            [0.0, 1.0],
185            &[1.0],
186            OdeMethod::RK45,
187            &OdeOptions::default(),
188        )
189        .unwrap();
190
191        let y_final = result.y.last().unwrap()[0];
192        let expected = (-1.0_f64).exp();
193        assert!((y_final - expected).abs() < 1e-6);
194    }
195
196    #[test]
197    fn test_solve_ivp_euler() {
198        let result = solve_ivp(
199            |_t: f64, y: &[f64]| vec![-y[0]],
200            [0.0, 1.0],
201            &[1.0],
202            OdeMethod::Euler,
203            &OdeOptions::default(),
204        )
205        .unwrap();
206
207        let y_final = result.y.last().unwrap()[0];
208        let expected = (-1.0_f64).exp();
209        assert!((y_final - expected).abs() < 0.02);
210    }
211
212    #[test]
213    fn test_solve_ivp_bdf2() {
214        let result = solve_ivp(
215            |_t: f64, y: &[f64]| vec![-y[0]],
216            [0.0, 1.0],
217            &[1.0],
218            OdeMethod::BDF2,
219            &OdeOptions::default(),
220        )
221        .unwrap();
222
223        let y_final = result.y.last().unwrap()[0];
224        let expected = (-1.0_f64).exp();
225        assert!((y_final - expected).abs() < 1e-3);
226    }
227
228    #[test]
229    fn test_event_detection() {
230        // dy/dt = 1, y(0) = -1. Event: y = 0 at t = 1.
231        let result = solve_ivp(
232            |_t: f64, _y: &[f64]| vec![1.0],
233            [0.0, 5.0],
234            &[-1.0],
235            OdeMethod::RK45,
236            &OdeOptions {
237                event_fn: Some(Box::new(|_t: f64, y: &[f64]| y[0])),
238                ..OdeOptions::default()
239            },
240        )
241        .unwrap();
242
243        // Should stop early, around t=1
244        let t_final = *result.t.last().unwrap();
245        assert!(
246            t_final < 2.0,
247            "Should have stopped early at event, t_final={t_final}"
248        );
249    }
250
251    #[test]
252    fn test_ode_result_trajectory() {
253        let result = solve_ivp(
254            |_t: f64, _y: &[f64]| vec![1.0],
255            [0.0, 1.0],
256            &[0.0],
257            OdeMethod::RK45,
258            &OdeOptions::default(),
259        )
260        .unwrap();
261
262        // Trajectory should be monotonically increasing
263        for i in 1..result.y.len() {
264            assert!(result.y[i][0] >= result.y[i - 1][0]);
265            assert!(result.t[i] > result.t[i - 1]);
266        }
267    }
268
269    #[test]
270    fn test_lotka_volterra() {
271        // Classic predator-prey: dx/dt = x - x*y, dy/dt = -y + x*y
272        // Oscillatory, conservative system
273        let result = solve_ivp(
274            |_t: f64, y: &[f64]| vec![y[0] - y[0] * y[1], -y[1] + y[0] * y[1]],
275            [0.0, 10.0],
276            &[1.0, 0.5],
277            OdeMethod::RK45,
278            &OdeOptions::default(),
279        )
280        .unwrap();
281
282        assert!(result.success);
283        // Both populations should remain positive
284        for y in &result.y {
285            assert!(y[0] > 0.0, "prey went negative");
286            assert!(y[1] > 0.0, "predator went negative");
287        }
288    }
289}