use faer::{ComplexField, Conjugate, SimpleEntity};
use numra_core::Scalar;
use crate::bdf::Bdf;
use crate::error::SolverError;
use crate::esdirk::Esdirk54;
use crate::problem::OdeSystem;
use crate::radau5::Radau5;
use crate::solver::{Solver, SolverOptions, SolverResult};
use crate::tsit5::Tsit5;
use crate::verner::{Vern6, Vern8};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Stiffness {
NonStiff,
ModeratelyStiff,
VeryStiff,
Unknown,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Accuracy {
Low,
Standard,
High,
VeryHigh,
}
#[derive(Clone, Debug, Default)]
pub struct SolverHints {
pub stiffness: Option<Stiffness>,
pub accuracy: Option<Accuracy>,
pub prefer_implicit: bool,
pub detect_stiffness: bool,
}
impl SolverHints {
pub fn new() -> Self {
Self {
stiffness: None,
accuracy: None,
prefer_implicit: false,
detect_stiffness: true,
}
}
pub fn stiffness(mut self, stiffness: Stiffness) -> Self {
self.stiffness = Some(stiffness);
self
}
pub fn accuracy(mut self, accuracy: Accuracy) -> Self {
self.accuracy = Some(accuracy);
self
}
pub fn implicit(mut self) -> Self {
self.prefer_implicit = true;
self
}
pub fn detect_stiffness(mut self, detect: bool) -> Self {
self.detect_stiffness = detect;
self
}
}
fn classify_accuracy<S: Scalar>(options: &SolverOptions<S>) -> Accuracy {
let rtol = options.rtol.to_f64();
if rtol >= 1e-3 {
Accuracy::Low
} else if rtol >= 1e-7 {
Accuracy::Standard
} else if rtol >= 1e-11 {
Accuracy::High
} else {
Accuracy::VeryHigh
}
}
fn detect_stiffness<S, Sys>(problem: &Sys, t: S, y: &[S], _options: &SolverOptions<S>) -> Stiffness
where
S: Scalar,
Sys: OdeSystem<S>,
{
let dim = problem.dim();
if dim == 0 {
return Stiffness::Unknown;
}
let h_factor = S::EPSILON.sqrt();
let mut f0 = vec![S::ZERO; dim];
let mut f1 = vec![S::ZERO; dim];
let _jv = vec![S::ZERO; dim];
problem.rhs(t, y, &mut f0);
let mut max_jac = S::ZERO;
let mut min_jac = S::INFINITY;
let mut y_pert = y.to_vec();
for j in 0..dim.min(10) {
let yj = y[j];
let h = h_factor * (S::ONE + yj.abs());
y_pert[j] = yj + h;
problem.rhs(t, &y_pert, &mut f1);
y_pert[j] = yj;
for i in 0..dim {
let jij = ((f1[i] - f0[i]) / h).abs();
if jij > S::from_f64(1e-15) {
max_jac = max_jac.max(jij);
min_jac = min_jac.min(jij);
}
}
}
if max_jac < S::from_f64(1e-10) {
return Stiffness::NonStiff;
}
let ratio = max_jac / min_jac.max(S::from_f64(1e-15));
let ratio_f64 = ratio.to_f64();
if ratio_f64 > 1e4 {
Stiffness::VeryStiff
} else if ratio_f64 > 100.0 {
Stiffness::ModeratelyStiff
} else {
Stiffness::NonStiff
}
}
pub fn auto_solve<S, Sys>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
) -> Result<SolverResult<S>, SolverError>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
auto_solve_with_hints(problem, t0, tf, y0, options, &SolverHints::new())
}
pub fn auto_solve_with_hints<S, Sys>(
problem: &Sys,
t0: S,
tf: S,
y0: &[S],
options: &SolverOptions<S>,
hints: &SolverHints,
) -> Result<SolverResult<S>, SolverError>
where
S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
Sys: OdeSystem<S>,
{
let accuracy = hints.accuracy.unwrap_or_else(|| classify_accuracy(options));
let stiffness = hints.stiffness.unwrap_or_else(|| {
if hints.detect_stiffness {
detect_stiffness(problem, t0, y0, options)
} else {
Stiffness::Unknown
}
});
match (stiffness, accuracy, hints.prefer_implicit) {
(Stiffness::NonStiff, Accuracy::Low, false)
| (Stiffness::NonStiff, Accuracy::Standard, false) => {
Tsit5::solve(problem, t0, tf, y0, options)
}
(Stiffness::NonStiff, Accuracy::High, false) => Vern6::solve(problem, t0, tf, y0, options),
(Stiffness::NonStiff, Accuracy::VeryHigh, false) => {
Vern8::solve(problem, t0, tf, y0, options)
}
(Stiffness::ModeratelyStiff, _, _) => Esdirk54::solve(problem, t0, tf, y0, options),
(Stiffness::VeryStiff, Accuracy::Low, _)
| (Stiffness::VeryStiff, Accuracy::Standard, _) => Bdf::solve(problem, t0, tf, y0, options),
(Stiffness::VeryStiff, Accuracy::High, _)
| (Stiffness::VeryStiff, Accuracy::VeryHigh, _) => {
Radau5::solve(problem, t0, tf, y0, options)
}
(_, _, true) => Esdirk54::solve(problem, t0, tf, y0, options),
(Stiffness::Unknown, _, _) => {
if let Ok(result) = Tsit5::solve(problem, t0, tf, y0, options) {
if result.stats.n_reject < result.stats.n_accept {
return Ok(result);
}
}
Esdirk54::solve(problem, t0, tf, y0, options)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::problem::OdeProblem;
#[test]
fn test_auto_nonstiff() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
5.0,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-6);
let result = auto_solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let expected = (-5.0_f64).exp();
assert!((y_final[0] - expected).abs() < 1e-4);
}
#[test]
fn test_auto_stiff() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -100.0 * y[0];
},
0.0,
0.1,
vec![1.0],
);
let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
let hints = SolverHints::new().stiffness(Stiffness::ModeratelyStiff);
let result = auto_solve_with_hints(&problem, 0.0, 0.1, &[1.0], &options, &hints).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
let expected = (-10.0_f64).exp();
assert!(
(y_final[0] - expected).abs() < 0.05,
"stiff: got {}, expected {}",
y_final[0],
expected
);
}
#[test]
fn test_auto_high_accuracy() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = y[1];
dydt[1] = -y[0];
},
0.0,
10.0,
vec![1.0, 0.0],
);
let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
let hints = SolverHints::new().stiffness(Stiffness::NonStiff);
let result =
auto_solve_with_hints(&problem, 0.0, 10.0, &[1.0, 0.0], &options, &hints).unwrap();
assert!(result.success);
let y_final = result.y_final().unwrap();
assert!(
(y_final[0] - 10.0_f64.cos()).abs() < 1e-3,
"high accuracy: got {}, expected {}",
y_final[0],
10.0_f64.cos()
);
}
#[test]
fn test_auto_detect_stiffness() {
let problem1 = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
1.0,
vec![1.0],
);
let options = SolverOptions::default();
let stiffness1 = detect_stiffness(&problem1, 0.0, &[1.0], &options);
assert_eq!(stiffness1, Stiffness::NonStiff);
let problem2 = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -1000.0 * y[0] + 0.01 * y[1];
dydt[1] = 0.01 * y[0] - y[1];
},
0.0,
1.0,
vec![1.0, 1.0],
);
let stiffness2 = detect_stiffness(&problem2, 0.0, &[1.0, 1.0], &options);
assert!(stiffness2 == Stiffness::VeryStiff || stiffness2 == Stiffness::ModeratelyStiff);
}
#[test]
fn test_accuracy_classification() {
let opts_low: SolverOptions<f64> = SolverOptions::default().rtol(1e-2);
let opts_std: SolverOptions<f64> = SolverOptions::default().rtol(1e-6);
let opts_high: SolverOptions<f64> = SolverOptions::default().rtol(1e-10);
let opts_vhigh: SolverOptions<f64> = SolverOptions::default().rtol(1e-13);
assert_eq!(classify_accuracy(&opts_low), Accuracy::Low);
assert_eq!(classify_accuracy(&opts_std), Accuracy::Standard);
assert_eq!(classify_accuracy(&opts_high), Accuracy::High);
assert_eq!(classify_accuracy(&opts_vhigh), Accuracy::VeryHigh);
}
#[test]
fn test_auto_convenience() {
let problem = OdeProblem::new(
|_t, y: &[f64], dydt: &mut [f64]| {
dydt[0] = -y[0];
},
0.0,
2.0,
vec![1.0],
);
let options = SolverOptions::default();
let result = auto_solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
assert!(result.success);
}
}