Skip to main content

numra_ode/
auto.rs

1//! Automatic solver selection.
2//!
3//! Provides intelligent method selection based on problem characteristics.
4//!
5//! ## Usage
6//!
7//! ```rust
8//! use numra_ode::{OdeProblem, auto_solve, SolverOptions};
9//!
10//! let problem = OdeProblem::new(
11//!     |_t, y: &[f64], dydt: &mut [f64]| { dydt[0] = -y[0]; },
12//!     0.0, 1.0, vec![1.0],
13//! );
14//! let options = SolverOptions::default();
15//! let result = auto_solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
16//! assert!(result.success);
17//! ```
18//!
19//! ## Selection Strategy
20//!
21//! - **Non-stiff problems**: Uses Tsit5 (efficient, accurate, FSAL)
22//! - **Moderately stiff**: Uses Esdirk54 (L-stable, good efficiency)
23//! - **Very stiff**: Uses BDF or Radau5 (high stiffness handling)
24//! - **High accuracy**: Uses Vern8 (8th order accuracy)
25//!
26//! Author: Moussa Leblouba
27//! Date: 5 March 2026
28//! Modified: 2 May 2026
29
30use faer::{ComplexField, Conjugate, SimpleEntity};
31use numra_core::Scalar;
32
33use crate::bdf::Bdf;
34use crate::error::SolverError;
35use crate::esdirk::Esdirk54;
36use crate::problem::OdeSystem;
37use crate::radau5::Radau5;
38use crate::solver::{Solver, SolverOptions, SolverResult};
39use crate::tsit5::Tsit5;
40use crate::verner::{Vern6, Vern8};
41
42/// Problem stiffness classification.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum Stiffness {
45    /// Non-stiff problem (use explicit methods)
46    NonStiff,
47    /// Moderate stiffness (ESDIRK methods work well)
48    ModeratelyStiff,
49    /// Highly stiff (BDF or Radau methods recommended)
50    VeryStiff,
51    /// Unknown (will be detected automatically)
52    Unknown,
53}
54
55/// Accuracy requirements.
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum Accuracy {
58    /// Low accuracy (rtol ~ 1e-3)
59    Low,
60    /// Standard accuracy (rtol ~ 1e-6)
61    Standard,
62    /// High accuracy (rtol ~ 1e-10)
63    High,
64    /// Very high accuracy (rtol ~ 1e-12+)
65    VeryHigh,
66}
67
68/// Solver selection hints.
69#[derive(Clone, Debug, Default)]
70pub struct SolverHints {
71    /// Problem stiffness
72    pub stiffness: Option<Stiffness>,
73    /// Accuracy requirements
74    pub accuracy: Option<Accuracy>,
75    /// Prefer implicit methods (for conservation)
76    pub prefer_implicit: bool,
77    /// Enable stiffness detection
78    pub detect_stiffness: bool,
79}
80
81impl SolverHints {
82    /// Create default hints.
83    pub fn new() -> Self {
84        Self {
85            stiffness: None,
86            accuracy: None,
87            prefer_implicit: false,
88            detect_stiffness: true,
89        }
90    }
91
92    /// Set stiffness hint.
93    pub fn stiffness(mut self, stiffness: Stiffness) -> Self {
94        self.stiffness = Some(stiffness);
95        self
96    }
97
98    /// Set accuracy requirement.
99    pub fn accuracy(mut self, accuracy: Accuracy) -> Self {
100        self.accuracy = Some(accuracy);
101        self
102    }
103
104    /// Prefer implicit methods.
105    pub fn implicit(mut self) -> Self {
106        self.prefer_implicit = true;
107        self
108    }
109
110    /// Enable/disable stiffness detection.
111    pub fn detect_stiffness(mut self, detect: bool) -> Self {
112        self.detect_stiffness = detect;
113        self
114    }
115}
116
117/// Determine accuracy level from options.
118fn classify_accuracy<S: Scalar>(options: &SolverOptions<S>) -> Accuracy {
119    let rtol = options.rtol.to_f64();
120    if rtol >= 1e-3 {
121        Accuracy::Low
122    } else if rtol >= 1e-7 {
123        Accuracy::Standard
124    } else if rtol >= 1e-11 {
125        Accuracy::High
126    } else {
127        Accuracy::VeryHigh
128    }
129}
130
131/// Attempt stiffness detection.
132fn detect_stiffness<S, Sys>(problem: &Sys, t: S, y: &[S], _options: &SolverOptions<S>) -> Stiffness
133where
134    S: Scalar,
135    Sys: OdeSystem<S>,
136{
137    let dim = problem.dim();
138    if dim == 0 {
139        return Stiffness::Unknown;
140    }
141
142    // Compute Jacobian eigenvalues (approximate via power iteration)
143    let h_factor = S::EPSILON.sqrt();
144    let mut f0 = vec![S::ZERO; dim];
145    let mut f1 = vec![S::ZERO; dim];
146    let _jv = vec![S::ZERO; dim];
147
148    problem.rhs(t, y, &mut f0);
149
150    // Simple stiffness indicator: ratio of max/min Jacobian elements
151    let mut max_jac = S::ZERO;
152    let mut min_jac = S::INFINITY;
153    let mut y_pert = y.to_vec();
154
155    for j in 0..dim.min(10) {
156        // Sample first 10 components for stiffness detection
157        let yj = y[j];
158        let h = h_factor * (S::ONE + yj.abs());
159        y_pert[j] = yj + h;
160        problem.rhs(t, &y_pert, &mut f1);
161        y_pert[j] = yj;
162
163        for i in 0..dim {
164            let jij = ((f1[i] - f0[i]) / h).abs();
165            if jij > S::from_f64(1e-15) {
166                max_jac = max_jac.max(jij);
167                min_jac = min_jac.min(jij);
168            }
169        }
170    }
171
172    // Stiffness ratio
173    if max_jac < S::from_f64(1e-10) {
174        return Stiffness::NonStiff;
175    }
176
177    let ratio = max_jac / min_jac.max(S::from_f64(1e-15));
178    let ratio_f64 = ratio.to_f64();
179
180    if ratio_f64 > 1e4 {
181        Stiffness::VeryStiff
182    } else if ratio_f64 > 100.0 {
183        Stiffness::ModeratelyStiff
184    } else {
185        Stiffness::NonStiff
186    }
187}
188
189/// Convenience function for automatic solving.
190pub fn auto_solve<S, Sys>(
191    problem: &Sys,
192    t0: S,
193    tf: S,
194    y0: &[S],
195    options: &SolverOptions<S>,
196) -> Result<SolverResult<S>, SolverError>
197where
198    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
199    Sys: OdeSystem<S>,
200{
201    auto_solve_with_hints(problem, t0, tf, y0, options, &SolverHints::new())
202}
203
204/// Convenience function for automatic solving with hints.
205pub fn auto_solve_with_hints<S, Sys>(
206    problem: &Sys,
207    t0: S,
208    tf: S,
209    y0: &[S],
210    options: &SolverOptions<S>,
211    hints: &SolverHints,
212) -> Result<SolverResult<S>, SolverError>
213where
214    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
215    Sys: OdeSystem<S>,
216{
217    // Determine accuracy
218    let accuracy = hints.accuracy.unwrap_or_else(|| classify_accuracy(options));
219
220    // Determine stiffness
221    let stiffness = hints.stiffness.unwrap_or_else(|| {
222        if hints.detect_stiffness {
223            detect_stiffness(problem, t0, y0, options)
224        } else {
225            Stiffness::Unknown
226        }
227    });
228
229    // Select solver based on characteristics
230    match (stiffness, accuracy, hints.prefer_implicit) {
231        // Non-stiff problems
232        (Stiffness::NonStiff, Accuracy::Low, false)
233        | (Stiffness::NonStiff, Accuracy::Standard, false) => {
234            Tsit5::solve(problem, t0, tf, y0, options)
235        }
236        (Stiffness::NonStiff, Accuracy::High, false) => Vern6::solve(problem, t0, tf, y0, options),
237        (Stiffness::NonStiff, Accuracy::VeryHigh, false) => {
238            Vern8::solve(problem, t0, tf, y0, options)
239        }
240
241        // Moderately stiff
242        (Stiffness::ModeratelyStiff, _, _) => Esdirk54::solve(problem, t0, tf, y0, options),
243
244        // Very stiff
245        (Stiffness::VeryStiff, Accuracy::Low, _)
246        | (Stiffness::VeryStiff, Accuracy::Standard, _) => Bdf::solve(problem, t0, tf, y0, options),
247        (Stiffness::VeryStiff, Accuracy::High, _)
248        | (Stiffness::VeryStiff, Accuracy::VeryHigh, _) => {
249            Radau5::solve(problem, t0, tf, y0, options)
250        }
251
252        // Prefer implicit
253        (_, _, true) => Esdirk54::solve(problem, t0, tf, y0, options),
254
255        // Unknown/default: try explicit first
256        (Stiffness::Unknown, _, _) => {
257            // Try Tsit5 first
258            if let Ok(result) = Tsit5::solve(problem, t0, tf, y0, options) {
259                // Check if solution seems reasonable
260                if result.stats.n_reject < result.stats.n_accept {
261                    return Ok(result);
262                }
263            }
264
265            // Fall back to implicit method
266            Esdirk54::solve(problem, t0, tf, y0, options)
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::problem::OdeProblem;
275
276    #[test]
277    fn test_auto_nonstiff() {
278        let problem = OdeProblem::new(
279            |_t, y: &[f64], dydt: &mut [f64]| {
280                dydt[0] = -y[0];
281            },
282            0.0,
283            5.0,
284            vec![1.0],
285        );
286        let options = SolverOptions::default().rtol(1e-6);
287        let result = auto_solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
288
289        assert!(result.success);
290        let y_final = result.y_final().unwrap();
291        let expected = (-5.0_f64).exp();
292        assert!((y_final[0] - expected).abs() < 1e-4);
293    }
294
295    #[test]
296    fn test_auto_stiff() {
297        // Moderately stiff problem - use ESDIRK instead of BDF for now
298        // since BDF still needs Newton iteration improvements
299        let problem = OdeProblem::new(
300            |_t, y: &[f64], dydt: &mut [f64]| {
301                dydt[0] = -100.0 * y[0];
302            },
303            0.0,
304            0.1,
305            vec![1.0],
306        );
307        let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
308        // Use moderately stiff hint which selects ESDIRK (more robust than BDF currently)
309        let hints = SolverHints::new().stiffness(Stiffness::ModeratelyStiff);
310
311        let result = auto_solve_with_hints(&problem, 0.0, 0.1, &[1.0], &options, &hints).unwrap();
312
313        assert!(result.success);
314        let y_final = result.y_final().unwrap();
315        let expected = (-10.0_f64).exp();
316        assert!(
317            (y_final[0] - expected).abs() < 0.05,
318            "stiff: got {}, expected {}",
319            y_final[0],
320            expected
321        );
322    }
323
324    #[test]
325    fn test_auto_high_accuracy() {
326        let problem = OdeProblem::new(
327            |_t, y: &[f64], dydt: &mut [f64]| {
328                dydt[0] = y[1];
329                dydt[1] = -y[0];
330            },
331            0.0,
332            10.0,
333            vec![1.0, 0.0],
334        );
335        // Use moderate tolerances for reliable testing
336        let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
337        let hints = SolverHints::new().stiffness(Stiffness::NonStiff);
338
339        let result =
340            auto_solve_with_hints(&problem, 0.0, 10.0, &[1.0, 0.0], &options, &hints).unwrap();
341
342        assert!(result.success);
343        let y_final = result.y_final().unwrap();
344        // Allow 0.1% error for moderate tolerances
345        assert!(
346            (y_final[0] - 10.0_f64.cos()).abs() < 1e-3,
347            "high accuracy: got {}, expected {}",
348            y_final[0],
349            10.0_f64.cos()
350        );
351    }
352
353    #[test]
354    fn test_auto_detect_stiffness() {
355        // Non-stiff problem
356        let problem1 = OdeProblem::new(
357            |_t, y: &[f64], dydt: &mut [f64]| {
358                dydt[0] = -y[0];
359            },
360            0.0,
361            1.0,
362            vec![1.0],
363        );
364        let options = SolverOptions::default();
365        let stiffness1 = detect_stiffness(&problem1, 0.0, &[1.0], &options);
366        assert_eq!(stiffness1, Stiffness::NonStiff);
367
368        // Stiff problem
369        let problem2 = OdeProblem::new(
370            |_t, y: &[f64], dydt: &mut [f64]| {
371                dydt[0] = -1000.0 * y[0] + 0.01 * y[1];
372                dydt[1] = 0.01 * y[0] - y[1];
373            },
374            0.0,
375            1.0,
376            vec![1.0, 1.0],
377        );
378        let stiffness2 = detect_stiffness(&problem2, 0.0, &[1.0, 1.0], &options);
379        assert!(stiffness2 == Stiffness::VeryStiff || stiffness2 == Stiffness::ModeratelyStiff);
380    }
381
382    #[test]
383    fn test_accuracy_classification() {
384        let opts_low: SolverOptions<f64> = SolverOptions::default().rtol(1e-2);
385        let opts_std: SolverOptions<f64> = SolverOptions::default().rtol(1e-6);
386        let opts_high: SolverOptions<f64> = SolverOptions::default().rtol(1e-10);
387        let opts_vhigh: SolverOptions<f64> = SolverOptions::default().rtol(1e-13);
388
389        assert_eq!(classify_accuracy(&opts_low), Accuracy::Low);
390        assert_eq!(classify_accuracy(&opts_std), Accuracy::Standard);
391        assert_eq!(classify_accuracy(&opts_high), Accuracy::High);
392        assert_eq!(classify_accuracy(&opts_vhigh), Accuracy::VeryHigh);
393    }
394
395    #[test]
396    fn test_auto_convenience() {
397        let problem = OdeProblem::new(
398            |_t, y: &[f64], dydt: &mut [f64]| {
399                dydt[0] = -y[0];
400            },
401            0.0,
402            2.0,
403            vec![1.0],
404        );
405        let options = SolverOptions::default();
406
407        let result = auto_solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
408        assert!(result.success);
409    }
410}