Skip to main content

numra_ode/
auto.rs

1//! Automatic solver selection.
2//!
3//! Provides intelligent method selection based on problem characteristics.
4//!
5//! ## Usage
6//!
7//! ```rust
8//! use numra_ode::{OdeProblem, auto_solve, SolverOptions};
9//!
10//! let problem = OdeProblem::new(
11//!     |_t, y: &[f64], dydt: &mut [f64]| { dydt[0] = -y[0]; },
12//!     0.0, 1.0, vec![1.0],
13//! );
14//! let options = SolverOptions::default();
15//! let result = auto_solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
16//! assert!(result.success);
17//! ```
18//!
19//! ## Selection Strategy
20//!
21//! - **Non-stiff problems**: Uses Tsit5 (efficient, accurate, FSAL)
22//! - **Moderately stiff**: Uses Esdirk54 (L-stable, good efficiency)
23//! - **Very stiff**: Uses BDF or Radau5 (high stiffness handling)
24//! - **High accuracy**: Uses Vern8 (8th order accuracy)
25//!
26//! Author: Moussa Leblouba
27//! Date: 5 March 2026
28//! Modified: 2 May 2026
29
30use faer::{ComplexField, Conjugate, SimpleEntity};
31use numra_core::Scalar;
32
33use crate::bdf::Bdf;
34use crate::error::SolverError;
35use crate::esdirk::Esdirk54;
36use crate::problem::OdeSystem;
37use crate::radau5::Radau5;
38use crate::solver::{Solver, SolverOptions, SolverResult};
39use crate::tsit5::Tsit5;
40use crate::verner::{Vern6, Vern8};
41
42/// Problem stiffness classification.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum Stiffness {
45    /// Non-stiff problem (use explicit methods)
46    NonStiff,
47    /// Moderate stiffness (ESDIRK methods work well)
48    ModeratelyStiff,
49    /// Highly stiff (BDF or Radau methods recommended)
50    VeryStiff,
51    /// Unknown (will be detected automatically)
52    Unknown,
53}
54
55/// Accuracy requirements.
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum Accuracy {
58    /// Low accuracy (rtol ~ 1e-3)
59    Low,
60    /// Standard accuracy (rtol ~ 1e-6)
61    Standard,
62    /// High accuracy (rtol ~ 1e-10)
63    High,
64    /// Very high accuracy (rtol ~ 1e-12+)
65    VeryHigh,
66}
67
68/// Solver selection hints.
69#[derive(Clone, Debug, Default)]
70pub struct SolverHints {
71    /// Problem stiffness
72    pub stiffness: Option<Stiffness>,
73    /// Accuracy requirements
74    pub accuracy: Option<Accuracy>,
75    /// Prefer implicit methods (for conservation)
76    pub prefer_implicit: bool,
77    /// Enable stiffness detection
78    pub detect_stiffness: bool,
79}
80
81impl SolverHints {
82    /// Create default hints.
83    pub fn new() -> Self {
84        Self {
85            stiffness: None,
86            accuracy: None,
87            prefer_implicit: false,
88            detect_stiffness: true,
89        }
90    }
91
92    /// Set stiffness hint.
93    pub fn stiffness(mut self, stiffness: Stiffness) -> Self {
94        self.stiffness = Some(stiffness);
95        self
96    }
97
98    /// Set accuracy requirement.
99    pub fn accuracy(mut self, accuracy: Accuracy) -> Self {
100        self.accuracy = Some(accuracy);
101        self
102    }
103
104    /// Prefer implicit methods.
105    pub fn implicit(mut self) -> Self {
106        self.prefer_implicit = true;
107        self
108    }
109
110    /// Enable/disable stiffness detection.
111    pub fn detect_stiffness(mut self, detect: bool) -> Self {
112        self.detect_stiffness = detect;
113        self
114    }
115}
116
117/// Automatic solver selection.
118#[derive(Clone, Debug, Default)]
119pub struct Auto {
120    #[allow(dead_code)]
121    hints: SolverHints,
122}
123
124impl Auto {
125    /// Create auto-selector with default hints.
126    pub fn new() -> Self {
127        Self {
128            hints: SolverHints::new(),
129        }
130    }
131
132    /// Create auto-selector with custom hints.
133    pub fn with_hints(hints: SolverHints) -> Self {
134        Self { hints }
135    }
136
137    /// Determine accuracy level from options.
138    fn classify_accuracy<S: Scalar>(options: &SolverOptions<S>) -> Accuracy {
139        let rtol = options.rtol.to_f64();
140        if rtol >= 1e-3 {
141            Accuracy::Low
142        } else if rtol >= 1e-7 {
143            Accuracy::Standard
144        } else if rtol >= 1e-11 {
145            Accuracy::High
146        } else {
147            Accuracy::VeryHigh
148        }
149    }
150
151    /// Attempt stiffness detection.
152    fn detect_stiffness<S, Sys>(
153        problem: &Sys,
154        t: S,
155        y: &[S],
156        _options: &SolverOptions<S>,
157    ) -> Stiffness
158    where
159        S: Scalar,
160        Sys: OdeSystem<S>,
161    {
162        let dim = problem.dim();
163        if dim == 0 {
164            return Stiffness::Unknown;
165        }
166
167        // Compute Jacobian eigenvalues (approximate via power iteration)
168        let eps = S::from_f64(1e-8);
169        let mut f0 = vec![S::ZERO; dim];
170        let mut f1 = vec![S::ZERO; dim];
171        let _jv = vec![S::ZERO; dim];
172
173        problem.rhs(t, y, &mut f0);
174
175        // Simple stiffness indicator: ratio of max/min Jacobian elements
176        let mut max_jac = S::ZERO;
177        let mut min_jac = S::INFINITY;
178        let mut y_pert = y.to_vec();
179
180        for j in 0..dim.min(10) {
181            // Sample first 10 components for stiffness detection
182            let yj = y[j];
183            let h = eps * (S::ONE + yj.abs());
184            y_pert[j] = yj + h;
185            problem.rhs(t, &y_pert, &mut f1);
186            y_pert[j] = yj;
187
188            for i in 0..dim {
189                let jij = ((f1[i] - f0[i]) / h).abs();
190                if jij > S::from_f64(1e-15) {
191                    max_jac = max_jac.max(jij);
192                    min_jac = min_jac.min(jij);
193                }
194            }
195        }
196
197        // Stiffness ratio
198        if max_jac < S::from_f64(1e-10) {
199            return Stiffness::NonStiff;
200        }
201
202        let ratio = max_jac / min_jac.max(S::from_f64(1e-15));
203        let ratio_f64 = ratio.to_f64();
204
205        if ratio_f64 > 1e4 {
206            Stiffness::VeryStiff
207        } else if ratio_f64 > 100.0 {
208            Stiffness::ModeratelyStiff
209        } else {
210            Stiffness::NonStiff
211        }
212    }
213
214    /// Select and run appropriate solver.
215    pub fn solve_with_hints<S, Sys>(
216        problem: &Sys,
217        t0: S,
218        tf: S,
219        y0: &[S],
220        options: &SolverOptions<S>,
221        hints: &SolverHints,
222    ) -> Result<SolverResult<S>, SolverError>
223    where
224        S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
225        Sys: OdeSystem<S>,
226    {
227        // Determine accuracy
228        let accuracy = hints
229            .accuracy
230            .unwrap_or_else(|| Self::classify_accuracy(options));
231
232        // Determine stiffness
233        let stiffness = hints.stiffness.unwrap_or_else(|| {
234            if hints.detect_stiffness {
235                Self::detect_stiffness(problem, t0, y0, options)
236            } else {
237                Stiffness::Unknown
238            }
239        });
240
241        // Select solver based on characteristics
242        match (stiffness, accuracy, hints.prefer_implicit) {
243            // Non-stiff problems
244            (Stiffness::NonStiff, Accuracy::Low, false)
245            | (Stiffness::NonStiff, Accuracy::Standard, false) => {
246                Tsit5::solve(problem, t0, tf, y0, options)
247            }
248            (Stiffness::NonStiff, Accuracy::High, false) => {
249                Vern6::solve(problem, t0, tf, y0, options)
250            }
251            (Stiffness::NonStiff, Accuracy::VeryHigh, false) => {
252                Vern8::solve(problem, t0, tf, y0, options)
253            }
254
255            // Moderately stiff
256            (Stiffness::ModeratelyStiff, _, _) => Esdirk54::solve(problem, t0, tf, y0, options),
257
258            // Very stiff
259            (Stiffness::VeryStiff, Accuracy::Low, _)
260            | (Stiffness::VeryStiff, Accuracy::Standard, _) => {
261                Bdf::solve(problem, t0, tf, y0, options)
262            }
263            (Stiffness::VeryStiff, Accuracy::High, _)
264            | (Stiffness::VeryStiff, Accuracy::VeryHigh, _) => {
265                Radau5::solve(problem, t0, tf, y0, options)
266            }
267
268            // Prefer implicit
269            (_, _, true) => Esdirk54::solve(problem, t0, tf, y0, options),
270
271            // Unknown/default: try explicit first
272            (Stiffness::Unknown, _, _) => {
273                // Try Tsit5 first
274                match Tsit5::solve(problem, t0, tf, y0, options) {
275                    Ok(result) => {
276                        // Check if solution seems reasonable
277                        if result.stats.n_reject < result.stats.n_accept {
278                            return Ok(result);
279                        }
280                    }
281                    Err(_) => {}
282                }
283
284                // Fall back to implicit method
285                Esdirk54::solve(problem, t0, tf, y0, options)
286            }
287        }
288    }
289}
290
291impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Auto {
292    fn solve<Sys: OdeSystem<S>>(
293        problem: &Sys,
294        t0: S,
295        tf: S,
296        y0: &[S],
297        options: &SolverOptions<S>,
298    ) -> Result<SolverResult<S>, SolverError> {
299        let hints = SolverHints::new();
300        Self::solve_with_hints(problem, t0, tf, y0, options, &hints)
301    }
302}
303
304/// Convenience function for automatic solving.
305pub fn auto_solve<S, Sys>(
306    problem: &Sys,
307    t0: S,
308    tf: S,
309    y0: &[S],
310    options: &SolverOptions<S>,
311) -> Result<SolverResult<S>, SolverError>
312where
313    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
314    Sys: OdeSystem<S>,
315{
316    Auto::solve(problem, t0, tf, y0, options)
317}
318
319/// Convenience function for automatic solving with hints.
320pub fn auto_solve_with_hints<S, Sys>(
321    problem: &Sys,
322    t0: S,
323    tf: S,
324    y0: &[S],
325    options: &SolverOptions<S>,
326    hints: &SolverHints,
327) -> Result<SolverResult<S>, SolverError>
328where
329    S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
330    Sys: OdeSystem<S>,
331{
332    Auto::solve_with_hints(problem, t0, tf, y0, options, hints)
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::problem::OdeProblem;
339
340    #[test]
341    fn test_auto_nonstiff() {
342        let problem = OdeProblem::new(
343            |_t, y: &[f64], dydt: &mut [f64]| {
344                dydt[0] = -y[0];
345            },
346            0.0,
347            5.0,
348            vec![1.0],
349        );
350        let options = SolverOptions::default().rtol(1e-6);
351        let result = Auto::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
352
353        assert!(result.success);
354        let y_final = result.y_final().unwrap();
355        let expected = (-5.0_f64).exp();
356        assert!((y_final[0] - expected).abs() < 1e-4);
357    }
358
359    #[test]
360    fn test_auto_stiff() {
361        // Moderately stiff problem - use ESDIRK instead of BDF for now
362        // since BDF still needs Newton iteration improvements
363        let problem = OdeProblem::new(
364            |_t, y: &[f64], dydt: &mut [f64]| {
365                dydt[0] = -100.0 * y[0];
366            },
367            0.0,
368            0.1,
369            vec![1.0],
370        );
371        let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
372        // Use moderately stiff hint which selects ESDIRK (more robust than BDF currently)
373        let hints = SolverHints::new().stiffness(Stiffness::ModeratelyStiff);
374
375        let result = Auto::solve_with_hints(&problem, 0.0, 0.1, &[1.0], &options, &hints).unwrap();
376
377        assert!(result.success);
378        let y_final = result.y_final().unwrap();
379        let expected = (-10.0_f64).exp();
380        assert!(
381            (y_final[0] - expected).abs() < 0.05,
382            "Auto stiff: got {}, expected {}",
383            y_final[0],
384            expected
385        );
386    }
387
388    #[test]
389    fn test_auto_high_accuracy() {
390        let problem = OdeProblem::new(
391            |_t, y: &[f64], dydt: &mut [f64]| {
392                dydt[0] = y[1];
393                dydt[1] = -y[0];
394            },
395            0.0,
396            10.0,
397            vec![1.0, 0.0],
398        );
399        // Use moderate tolerances for reliable testing
400        let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
401        let hints = SolverHints::new().stiffness(Stiffness::NonStiff);
402
403        let result =
404            Auto::solve_with_hints(&problem, 0.0, 10.0, &[1.0, 0.0], &options, &hints).unwrap();
405
406        assert!(result.success);
407        let y_final = result.y_final().unwrap();
408        // Allow 0.1% error for moderate tolerances
409        assert!(
410            (y_final[0] - 10.0_f64.cos()).abs() < 1e-3,
411            "Auto high accuracy: got {}, expected {}",
412            y_final[0],
413            10.0_f64.cos()
414        );
415    }
416
417    #[test]
418    fn test_auto_detect_stiffness() {
419        // Non-stiff problem
420        let problem1 = OdeProblem::new(
421            |_t, y: &[f64], dydt: &mut [f64]| {
422                dydt[0] = -y[0];
423            },
424            0.0,
425            1.0,
426            vec![1.0],
427        );
428        let options = SolverOptions::default();
429        let stiffness1 = Auto::detect_stiffness(&problem1, 0.0, &[1.0], &options);
430        assert_eq!(stiffness1, Stiffness::NonStiff);
431
432        // Stiff problem
433        let problem2 = OdeProblem::new(
434            |_t, y: &[f64], dydt: &mut [f64]| {
435                dydt[0] = -1000.0 * y[0] + 0.01 * y[1];
436                dydt[1] = 0.01 * y[0] - y[1];
437            },
438            0.0,
439            1.0,
440            vec![1.0, 1.0],
441        );
442        let stiffness2 = Auto::detect_stiffness(&problem2, 0.0, &[1.0, 1.0], &options);
443        assert!(stiffness2 == Stiffness::VeryStiff || stiffness2 == Stiffness::ModeratelyStiff);
444    }
445
446    #[test]
447    fn test_accuracy_classification() {
448        let opts_low: SolverOptions<f64> = SolverOptions::default().rtol(1e-2);
449        let opts_std: SolverOptions<f64> = SolverOptions::default().rtol(1e-6);
450        let opts_high: SolverOptions<f64> = SolverOptions::default().rtol(1e-10);
451        let opts_vhigh: SolverOptions<f64> = SolverOptions::default().rtol(1e-13);
452
453        assert_eq!(Auto::classify_accuracy(&opts_low), Accuracy::Low);
454        assert_eq!(Auto::classify_accuracy(&opts_std), Accuracy::Standard);
455        assert_eq!(Auto::classify_accuracy(&opts_high), Accuracy::High);
456        assert_eq!(Auto::classify_accuracy(&opts_vhigh), Accuracy::VeryHigh);
457    }
458
459    #[test]
460    fn test_auto_convenience() {
461        let problem = OdeProblem::new(
462            |_t, y: &[f64], dydt: &mut [f64]| {
463                dydt[0] = -y[0];
464            },
465            0.0,
466            2.0,
467            vec![1.0],
468        );
469        let options = SolverOptions::default();
470
471        let result = auto_solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
472        assert!(result.success);
473    }
474}