Skip to main content

numra_ode/
solver.rs

1//! ODE solver infrastructure.
2//!
3//! This module defines the common traits and types for ODE solvers.
4//!
5//! Author: Moussa Leblouba
6//! Date: 30 April 2026
7//! Modified: 2 May 2026
8
9use crate::dense::DenseOutput;
10use crate::error::SolverError;
11use crate::events::{Event, EventFunction};
12use crate::problem::OdeSystem;
13use core::fmt;
14use numra_core::Scalar;
15use std::sync::Arc;
16
17/// Solver options and tolerances.
18///
19/// Cloneable thanks to `Arc`-wrapped event functions.
20pub struct SolverOptions<S: Scalar> {
21    /// Relative tolerance
22    pub rtol: S,
23    /// Absolute tolerance (scalar)
24    pub atol: S,
25    /// Initial step size (None = auto)
26    pub h0: Option<S>,
27    /// Maximum step size
28    pub h_max: S,
29    /// Minimum step size
30    pub h_min: S,
31    /// Maximum number of steps
32    pub max_steps: usize,
33    /// Output grid in the integration direction. When `Some`, each solver
34    /// returns exactly these `(t, y)` pairs (Hermite cubic interpolated
35    /// from accepted step endpoints; endpoints are reproduced bit-exact).
36    /// When `None`, the natural adaptive step grid is returned.
37    pub t_eval: Option<Vec<S>>,
38    /// Enable dense output
39    pub dense_output: bool,
40    /// Maximum BDF order during adaptive order selection.
41    ///
42    /// `None` (default) = use the BDF solver's natural cap (order 5).
43    /// `Some(n)` clamps to `[1, 5]`. No effect on non-BDF solvers.
44    pub max_order: Option<usize>,
45    /// Minimum BDF order during adaptive order selection.
46    ///
47    /// `None` (default) = use the BDF solver's natural floor (order 1).
48    /// `Some(n)` clamps to `[1, 5]`. BDF always starts at order 1
49    /// (only single-step information is available at startup); the floor
50    /// is enforced during downward order adaptation, so combining
51    /// `max_order(n)` with `min_order(n)` of the same value pins the order
52    /// to `n` once it has risen there. No effect on non-BDF solvers.
53    pub min_order: Option<usize>,
54    /// Event functions for zero-crossing detection (Arc enables Clone)
55    pub events: Vec<Arc<dyn EventFunction<S>>>,
56}
57
58impl<S: Scalar> Clone for SolverOptions<S> {
59    fn clone(&self) -> Self {
60        Self {
61            rtol: self.rtol,
62            atol: self.atol,
63            h0: self.h0,
64            h_max: self.h_max,
65            h_min: self.h_min,
66            max_steps: self.max_steps,
67            t_eval: self.t_eval.clone(),
68            dense_output: self.dense_output,
69            max_order: self.max_order,
70            min_order: self.min_order,
71            events: self.events.clone(),
72        }
73    }
74}
75
76impl<S: Scalar> fmt::Debug for SolverOptions<S> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        f.debug_struct("SolverOptions")
79            .field("rtol", &self.rtol)
80            .field("atol", &self.atol)
81            .field("h0", &self.h0)
82            .field("h_max", &self.h_max)
83            .field("h_min", &self.h_min)
84            .field("max_steps", &self.max_steps)
85            .field("t_eval", &self.t_eval)
86            .field("dense_output", &self.dense_output)
87            .field("max_order", &self.max_order)
88            .field("min_order", &self.min_order)
89            .field("events", &format!("[{} event(s)]", self.events.len()))
90            .finish()
91    }
92}
93
94impl<S: Scalar> Default for SolverOptions<S> {
95    fn default() -> Self {
96        Self {
97            rtol: S::from_f64(1e-6),
98            atol: S::from_f64(1e-9),
99            h0: None,
100            h_max: S::INFINITY,
101            // Scale h_min with machine epsilon to support both f32 and f64.
102            // f64: 100 * EPSILON ~ 2.2e-14 (close to previous fixed 1e-14)
103            // f32: 100 * EPSILON ~ 1.2e-5  (meaningful for f32 precision)
104            // A fixed 1e-14 was below f32 machine epsilon (~1.2e-7), making it useless.
105            h_min: S::EPSILON * S::from_f64(100.0),
106            max_steps: 100_000,
107            t_eval: None,
108            dense_output: false,
109            max_order: None,
110            min_order: None,
111            events: Vec::new(),
112        }
113    }
114}
115
116impl<S: Scalar> SolverOptions<S> {
117    /// Set relative tolerance.
118    pub fn rtol(mut self, rtol: S) -> Self {
119        self.rtol = rtol;
120        self
121    }
122
123    /// Set absolute tolerance.
124    pub fn atol(mut self, atol: S) -> Self {
125        self.atol = atol;
126        self
127    }
128
129    /// Set initial step size.
130    pub fn h0(mut self, h0: S) -> Self {
131        self.h0 = Some(h0);
132        self
133    }
134
135    /// Set maximum step size.
136    pub fn h_max(mut self, h_max: S) -> Self {
137        self.h_max = h_max;
138        self
139    }
140
141    /// Set evaluation times.
142    pub fn t_eval(mut self, t_eval: Vec<S>) -> Self {
143        self.t_eval = Some(t_eval);
144        self
145    }
146
147    /// Enable dense output.
148    pub fn dense(mut self) -> Self {
149        self.dense_output = true;
150        self
151    }
152
153    /// Set maximum number of steps.
154    pub fn max_steps(mut self, max_steps: usize) -> Self {
155        self.max_steps = max_steps;
156        self
157    }
158
159    /// Set minimum step size.
160    pub fn h_min(mut self, h_min: S) -> Self {
161        self.h_min = h_min;
162        self
163    }
164
165    /// Cap the maximum BDF order during adaptive order selection.
166    ///
167    /// Useful for keeping BDF L-stable (`max_order(2)`) on problems that
168    /// need strict L-stability. No effect on non-BDF solvers. Values are
169    /// clamped to `[1, 5]` (BDF's algorithmic limit).
170    pub fn max_order(mut self, n: usize) -> Self {
171        self.max_order = Some(n);
172        self
173    }
174
175    /// Pin the minimum BDF order during adaptive order selection.
176    ///
177    /// Combined with `max_order(n)` of the same value, pins the BDF order
178    /// to `n` once adaptive selection reaches it (BDF always starts at
179    /// order 1). No effect on non-BDF solvers. Values are clamped to
180    /// `[1, 5]` (BDF's algorithmic limit).
181    pub fn min_order(mut self, n: usize) -> Self {
182        self.min_order = Some(n);
183        self
184    }
185
186    /// Add an event function for zero-crossing detection.
187    ///
188    /// Internally converts to `Arc` to enable `Clone` on `SolverOptions`.
189    pub fn event(mut self, event: Box<dyn EventFunction<S>>) -> Self {
190        self.events.push(Arc::from(event));
191        self
192    }
193}
194
195/// Solver statistics.
196#[derive(Clone, Debug, Default)]
197pub struct SolverStats {
198    /// Number of function evaluations
199    pub n_eval: usize,
200    /// Number of Jacobian evaluations
201    pub n_jac: usize,
202    /// Number of accepted steps
203    pub n_accept: usize,
204    /// Number of rejected steps
205    pub n_reject: usize,
206    /// Number of LU decompositions (for implicit methods)
207    pub n_lu: usize,
208}
209
210impl SolverStats {
211    pub fn new() -> Self {
212        Self::default()
213    }
214}
215
216/// Result of ODE integration.
217#[derive(Clone, Debug)]
218pub struct SolverResult<S: Scalar> {
219    /// Time points
220    pub t: Vec<S>,
221    /// Solution at each time point (row-major: y[i*dim + j] = y_j(t_i))
222    pub y: Vec<S>,
223    /// Dimension of the system
224    pub dim: usize,
225    /// Solver statistics
226    pub stats: SolverStats,
227    /// Was integration successful?
228    pub success: bool,
229    /// Message (error description if failed)
230    pub message: String,
231    /// Detected events during integration
232    pub events: Vec<Event<S>>,
233    /// Whether integration was terminated by a Stop event
234    pub terminated_by_event: bool,
235    /// Dense output for continuous interpolation (populated when `SolverOptions::dense()` was set).
236    pub dense_output: Option<DenseOutput<S>>,
237}
238
239impl<S: Scalar> SolverResult<S> {
240    /// Create a new successful result.
241    pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: SolverStats) -> Self {
242        Self {
243            t,
244            y,
245            dim,
246            stats,
247            success: true,
248            message: String::new(),
249            events: Vec::new(),
250            terminated_by_event: false,
251            dense_output: None,
252        }
253    }
254
255    /// Create a failed result.
256    pub fn failed(message: String, stats: SolverStats) -> Self {
257        Self {
258            t: Vec::new(),
259            y: Vec::new(),
260            dim: 0,
261            stats,
262            success: false,
263            message,
264            events: Vec::new(),
265            terminated_by_event: false,
266            dense_output: None,
267        }
268    }
269
270    /// Number of time points.
271    pub fn len(&self) -> usize {
272        self.t.len()
273    }
274
275    /// Is result empty?
276    pub fn is_empty(&self) -> bool {
277        self.t.is_empty()
278    }
279
280    /// Get final time.
281    pub fn t_final(&self) -> Option<S> {
282        self.t.last().copied()
283    }
284
285    /// Get final state.
286    pub fn y_final(&self) -> Option<Vec<S>> {
287        if self.t.is_empty() {
288            None
289        } else {
290            let start = (self.t.len() - 1) * self.dim;
291            Some(self.y[start..start + self.dim].to_vec())
292        }
293    }
294
295    /// Get state at index i.
296    pub fn y_at(&self, i: usize) -> &[S] {
297        let start = i * self.dim;
298        &self.y[start..start + self.dim]
299    }
300
301    /// Number of time steps in the solution.
302    pub fn n_steps(&self) -> usize {
303        self.y.len().checked_div(self.dim).unwrap_or(0)
304    }
305
306    /// Extract the j-th state variable as a time series.
307    ///
308    /// Returns `Some(Vec<S>)` containing `y_j(t_0), y_j(t_1), ..., y_j(t_N)`,
309    /// or `None` if `j >= self.dim`.
310    /// Useful for feeding a single component into FFT, statistics, or plotting.
311    pub fn component(&self, j: usize) -> Option<Vec<S>> {
312        if j >= self.dim {
313            return None;
314        }
315        Some(
316            (0..self.n_steps())
317                .map(|i| self.y[i * self.dim + j])
318                .collect(),
319        )
320    }
321
322    /// Iterate over (t, y) pairs.
323    pub fn iter(&self) -> impl Iterator<Item = (S, &[S])> {
324        self.t
325            .iter()
326            .enumerate()
327            .map(move |(i, &t)| (t, self.y_at(i)))
328    }
329}
330
331/// Trait for ODE solvers.
332pub trait Solver<S: Scalar> {
333    /// Solve the ODE problem.
334    fn solve<Sys: OdeSystem<S>>(
335        problem: &Sys,
336        t0: S,
337        tf: S,
338        y0: &[S],
339        options: &SolverOptions<S>,
340    ) -> Result<SolverResult<S>, SolverError>;
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_solver_options_default() {
349        let opts: SolverOptions<f64> = SolverOptions::default();
350        assert!((opts.rtol - 1e-6).abs() < 1e-10);
351        assert!((opts.atol - 1e-9).abs() < 1e-15);
352    }
353
354    #[test]
355    fn test_solver_options_builder() {
356        let opts: SolverOptions<f64> = SolverOptions::default().rtol(1e-8).atol(1e-10).h0(0.01);
357        assert!((opts.rtol - 1e-8).abs() < 1e-15);
358        assert!((opts.atol - 1e-10).abs() < 1e-15);
359        assert!((opts.h0.unwrap() - 0.01).abs() < 1e-15);
360    }
361
362    #[test]
363    fn test_solver_result() {
364        let t = vec![0.0, 0.5, 1.0];
365        let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0]; // 2D system
366        let result = SolverResult::new(t, y, 2, SolverStats::new());
367
368        assert_eq!(result.len(), 3);
369        assert!((result.t_final().unwrap() - 1.0).abs() < 1e-10);
370
371        let y_final = result.y_final().unwrap();
372        assert!((y_final[0] - 0.2).abs() < 1e-10);
373        assert!((y_final[1] - 1.0).abs() < 1e-10);
374
375        assert_eq!(result.y_at(0), &[1.0, 2.0]);
376        assert_eq!(result.y_at(1), &[0.5, 1.5]);
377    }
378
379    #[test]
380    fn test_n_steps() {
381        let t = vec![0.0, 0.5, 1.0];
382        let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
383        let result = SolverResult::new(t, y, 2, SolverStats::new());
384        assert_eq!(result.n_steps(), 3);
385
386        let empty = SolverResult::<f64>::failed("err".to_string(), SolverStats::new());
387        assert_eq!(empty.n_steps(), 0);
388    }
389
390    #[test]
391    fn test_component() {
392        let t = vec![0.0, 0.5, 1.0];
393        // 2D system: y0 = [1.0, 0.5, 0.2], y1 = [2.0, 1.5, 1.0]
394        let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
395        let result = SolverResult::new(t, y, 2, SolverStats::new());
396
397        let comp0 = result.component(0).unwrap();
398        assert_eq!(comp0, vec![1.0, 0.5, 0.2]);
399
400        let comp1 = result.component(1).unwrap();
401        assert_eq!(comp1, vec![2.0, 1.5, 1.0]);
402    }
403
404    #[test]
405    fn test_component_out_of_bounds() {
406        let t = vec![0.0];
407        let y = vec![1.0, 2.0];
408        let result = SolverResult::new(t, y, 2, SolverStats::new());
409        assert!(result.component(2).is_none());
410    }
411}