Skip to main content

numra_ode/
solver.rs

1//! ODE solver infrastructure.
2//!
3//! This module defines the common traits and types for ODE solvers.
4//!
5//! Author: Moussa Leblouba
6//! Date: 30 April 2026
7//! Modified: 2 May 2026
8
9use crate::dense::DenseOutput;
10use crate::error::SolverError;
11use crate::events::{Event, EventFunction};
12use crate::problem::OdeSystem;
13use core::fmt;
14use numra_core::Scalar;
15use std::sync::Arc;
16
17/// Solver options and tolerances.
18///
19/// Cloneable thanks to `Arc`-wrapped event functions.
20pub struct SolverOptions<S: Scalar> {
21    /// Relative tolerance
22    pub rtol: S,
23    /// Absolute tolerance (scalar)
24    pub atol: S,
25    /// Initial step size (None = auto)
26    pub h0: Option<S>,
27    /// Maximum step size
28    pub h_max: S,
29    /// Minimum step size
30    pub h_min: S,
31    /// Maximum number of steps
32    pub max_steps: usize,
33    /// Output grid in the integration direction. When `Some`, each solver
34    /// returns exactly these `(t, y)` pairs (Hermite cubic interpolated
35    /// from accepted step endpoints; endpoints are reproduced bit-exact).
36    /// When `None`, the natural adaptive step grid is returned.
37    pub t_eval: Option<Vec<S>>,
38    /// Enable dense output
39    pub dense_output: bool,
40    /// Event functions for zero-crossing detection (Arc enables Clone)
41    pub events: Vec<Arc<dyn EventFunction<S>>>,
42}
43
44impl<S: Scalar> Clone for SolverOptions<S> {
45    fn clone(&self) -> Self {
46        Self {
47            rtol: self.rtol,
48            atol: self.atol,
49            h0: self.h0,
50            h_max: self.h_max,
51            h_min: self.h_min,
52            max_steps: self.max_steps,
53            t_eval: self.t_eval.clone(),
54            dense_output: self.dense_output,
55            events: self.events.clone(),
56        }
57    }
58}
59
60impl<S: Scalar> fmt::Debug for SolverOptions<S> {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        f.debug_struct("SolverOptions")
63            .field("rtol", &self.rtol)
64            .field("atol", &self.atol)
65            .field("h0", &self.h0)
66            .field("h_max", &self.h_max)
67            .field("h_min", &self.h_min)
68            .field("max_steps", &self.max_steps)
69            .field("t_eval", &self.t_eval)
70            .field("dense_output", &self.dense_output)
71            .field("events", &format!("[{} event(s)]", self.events.len()))
72            .finish()
73    }
74}
75
76impl<S: Scalar> Default for SolverOptions<S> {
77    fn default() -> Self {
78        Self {
79            rtol: S::from_f64(1e-6),
80            atol: S::from_f64(1e-9),
81            h0: None,
82            h_max: S::INFINITY,
83            // Scale h_min with machine epsilon to support both f32 and f64.
84            // f64: 100 * EPSILON ~ 2.2e-14 (close to previous fixed 1e-14)
85            // f32: 100 * EPSILON ~ 1.2e-5  (meaningful for f32 precision)
86            // A fixed 1e-14 was below f32 machine epsilon (~1.2e-7), making it useless.
87            h_min: S::EPSILON * S::from_f64(100.0),
88            max_steps: 100_000,
89            t_eval: None,
90            dense_output: false,
91            events: Vec::new(),
92        }
93    }
94}
95
96impl<S: Scalar> SolverOptions<S> {
97    /// Set relative tolerance.
98    pub fn rtol(mut self, rtol: S) -> Self {
99        self.rtol = rtol;
100        self
101    }
102
103    /// Set absolute tolerance.
104    pub fn atol(mut self, atol: S) -> Self {
105        self.atol = atol;
106        self
107    }
108
109    /// Set initial step size.
110    pub fn h0(mut self, h0: S) -> Self {
111        self.h0 = Some(h0);
112        self
113    }
114
115    /// Set maximum step size.
116    pub fn h_max(mut self, h_max: S) -> Self {
117        self.h_max = h_max;
118        self
119    }
120
121    /// Set evaluation times.
122    pub fn t_eval(mut self, t_eval: Vec<S>) -> Self {
123        self.t_eval = Some(t_eval);
124        self
125    }
126
127    /// Enable dense output.
128    pub fn dense(mut self) -> Self {
129        self.dense_output = true;
130        self
131    }
132
133    /// Set maximum number of steps.
134    pub fn max_steps(mut self, max_steps: usize) -> Self {
135        self.max_steps = max_steps;
136        self
137    }
138
139    /// Set minimum step size.
140    pub fn h_min(mut self, h_min: S) -> Self {
141        self.h_min = h_min;
142        self
143    }
144
145    /// Add an event function for zero-crossing detection.
146    ///
147    /// Internally converts to `Arc` to enable `Clone` on `SolverOptions`.
148    pub fn event(mut self, event: Box<dyn EventFunction<S>>) -> Self {
149        self.events.push(Arc::from(event));
150        self
151    }
152}
153
154/// Solver statistics.
155#[derive(Clone, Debug, Default)]
156pub struct SolverStats {
157    /// Number of function evaluations
158    pub n_eval: usize,
159    /// Number of Jacobian evaluations
160    pub n_jac: usize,
161    /// Number of accepted steps
162    pub n_accept: usize,
163    /// Number of rejected steps
164    pub n_reject: usize,
165    /// Number of LU decompositions (for implicit methods)
166    pub n_lu: usize,
167}
168
169impl SolverStats {
170    pub fn new() -> Self {
171        Self::default()
172    }
173}
174
175/// Result of ODE integration.
176#[derive(Clone, Debug)]
177pub struct SolverResult<S: Scalar> {
178    /// Time points
179    pub t: Vec<S>,
180    /// Solution at each time point (row-major: y[i*dim + j] = y_j(t_i))
181    pub y: Vec<S>,
182    /// Dimension of the system
183    pub dim: usize,
184    /// Solver statistics
185    pub stats: SolverStats,
186    /// Was integration successful?
187    pub success: bool,
188    /// Message (error description if failed)
189    pub message: String,
190    /// Detected events during integration
191    pub events: Vec<Event<S>>,
192    /// Whether integration was terminated by a Stop event
193    pub terminated_by_event: bool,
194    /// Dense output for continuous interpolation (populated when `SolverOptions::dense()` was set).
195    pub dense_output: Option<DenseOutput<S>>,
196}
197
198impl<S: Scalar> SolverResult<S> {
199    /// Create a new successful result.
200    pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: SolverStats) -> Self {
201        Self {
202            t,
203            y,
204            dim,
205            stats,
206            success: true,
207            message: String::new(),
208            events: Vec::new(),
209            terminated_by_event: false,
210            dense_output: None,
211        }
212    }
213
214    /// Create a failed result.
215    pub fn failed(message: String, stats: SolverStats) -> Self {
216        Self {
217            t: Vec::new(),
218            y: Vec::new(),
219            dim: 0,
220            stats,
221            success: false,
222            message,
223            events: Vec::new(),
224            terminated_by_event: false,
225            dense_output: None,
226        }
227    }
228
229    /// Number of time points.
230    pub fn len(&self) -> usize {
231        self.t.len()
232    }
233
234    /// Is result empty?
235    pub fn is_empty(&self) -> bool {
236        self.t.is_empty()
237    }
238
239    /// Get final time.
240    pub fn t_final(&self) -> Option<S> {
241        self.t.last().copied()
242    }
243
244    /// Get final state.
245    pub fn y_final(&self) -> Option<Vec<S>> {
246        if self.t.is_empty() {
247            None
248        } else {
249            let start = (self.t.len() - 1) * self.dim;
250            Some(self.y[start..start + self.dim].to_vec())
251        }
252    }
253
254    /// Get state at index i.
255    pub fn y_at(&self, i: usize) -> &[S] {
256        let start = i * self.dim;
257        &self.y[start..start + self.dim]
258    }
259
260    /// Number of time steps in the solution.
261    pub fn n_steps(&self) -> usize {
262        self.y.len().checked_div(self.dim).unwrap_or(0)
263    }
264
265    /// Extract the j-th state variable as a time series.
266    ///
267    /// Returns `Some(Vec<S>)` containing `y_j(t_0), y_j(t_1), ..., y_j(t_N)`,
268    /// or `None` if `j >= self.dim`.
269    /// Useful for feeding a single component into FFT, statistics, or plotting.
270    pub fn component(&self, j: usize) -> Option<Vec<S>> {
271        if j >= self.dim {
272            return None;
273        }
274        Some(
275            (0..self.n_steps())
276                .map(|i| self.y[i * self.dim + j])
277                .collect(),
278        )
279    }
280
281    /// Iterate over (t, y) pairs.
282    pub fn iter(&self) -> impl Iterator<Item = (S, &[S])> {
283        self.t
284            .iter()
285            .enumerate()
286            .map(move |(i, &t)| (t, self.y_at(i)))
287    }
288}
289
290/// Trait for ODE solvers.
291pub trait Solver<S: Scalar> {
292    /// Solve the ODE problem.
293    fn solve<Sys: OdeSystem<S>>(
294        problem: &Sys,
295        t0: S,
296        tf: S,
297        y0: &[S],
298        options: &SolverOptions<S>,
299    ) -> Result<SolverResult<S>, SolverError>;
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_solver_options_default() {
308        let opts: SolverOptions<f64> = SolverOptions::default();
309        assert!((opts.rtol - 1e-6).abs() < 1e-10);
310        assert!((opts.atol - 1e-9).abs() < 1e-15);
311    }
312
313    #[test]
314    fn test_solver_options_builder() {
315        let opts: SolverOptions<f64> = SolverOptions::default().rtol(1e-8).atol(1e-10).h0(0.01);
316        assert!((opts.rtol - 1e-8).abs() < 1e-15);
317        assert!((opts.atol - 1e-10).abs() < 1e-15);
318        assert!((opts.h0.unwrap() - 0.01).abs() < 1e-15);
319    }
320
321    #[test]
322    fn test_solver_result() {
323        let t = vec![0.0, 0.5, 1.0];
324        let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0]; // 2D system
325        let result = SolverResult::new(t, y, 2, SolverStats::new());
326
327        assert_eq!(result.len(), 3);
328        assert!((result.t_final().unwrap() - 1.0).abs() < 1e-10);
329
330        let y_final = result.y_final().unwrap();
331        assert!((y_final[0] - 0.2).abs() < 1e-10);
332        assert!((y_final[1] - 1.0).abs() < 1e-10);
333
334        assert_eq!(result.y_at(0), &[1.0, 2.0]);
335        assert_eq!(result.y_at(1), &[0.5, 1.5]);
336    }
337
338    #[test]
339    fn test_n_steps() {
340        let t = vec![0.0, 0.5, 1.0];
341        let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
342        let result = SolverResult::new(t, y, 2, SolverStats::new());
343        assert_eq!(result.n_steps(), 3);
344
345        let empty = SolverResult::<f64>::failed("err".to_string(), SolverStats::new());
346        assert_eq!(empty.n_steps(), 0);
347    }
348
349    #[test]
350    fn test_component() {
351        let t = vec![0.0, 0.5, 1.0];
352        // 2D system: y0 = [1.0, 0.5, 0.2], y1 = [2.0, 1.5, 1.0]
353        let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
354        let result = SolverResult::new(t, y, 2, SolverStats::new());
355
356        let comp0 = result.component(0).unwrap();
357        assert_eq!(comp0, vec![1.0, 0.5, 0.2]);
358
359        let comp1 = result.component(1).unwrap();
360        assert_eq!(comp1, vec![2.0, 1.5, 1.0]);
361    }
362
363    #[test]
364    fn test_component_out_of_bounds() {
365        let t = vec![0.0];
366        let y = vec![1.0, 2.0];
367        let result = SolverResult::new(t, y, 2, SolverStats::new());
368        assert!(result.component(2).is_none());
369    }
370}