use crate::{Method, StrError};
use russell_sparse::{Genie, LinSolParams};
#[derive(Clone, Copy, Debug)]
pub(crate) struct ParamsTol {
pub(crate) abs: f64,
pub(crate) rel: f64,
pub(crate) newton: f64,
}
#[derive(Clone, Copy, Debug)]
pub struct ParamsNewton {
pub n_iteration_max: usize,
pub use_numerical_jacobian: bool,
pub genie: Genie,
pub lin_sol_params: Option<LinSolParams>,
pub write_matrix_after_nstep_and_stop: Option<usize>,
}
#[derive(Clone, Copy, Debug)]
pub struct ParamsStep {
pub m_min: f64,
pub m_max: f64,
pub m_safety: f64,
pub m_first_reject: f64,
pub h_ini: f64,
pub n_step_max: usize,
pub rel_error_prev_min: f64,
}
#[derive(Clone, Copy, Debug)]
pub struct ParamsStiffness {
pub enabled: bool,
pub stop_with_error: bool,
pub save_results: bool,
pub ratified_after_nstep: usize,
pub ignored_after_nstep: usize,
pub skip_first_n_accepted_step: usize,
pub(crate) h_times_rho_max: f64,
}
#[derive(Clone, Copy, Debug)]
pub struct ParamsBwEuler {
pub use_modified_newton: bool,
}
#[derive(Clone, Copy, Debug)]
pub struct ParamsRadau5 {
pub zero_trial: bool,
pub theta_max: f64,
pub c1h: f64,
pub c2h: f64,
pub concurrent: bool,
pub use_pred_control: bool,
}
#[derive(Clone, Copy, Debug)]
pub struct ParamsERK {
pub lund_beta: f64,
pub lund_m: f64,
}
#[derive(Clone, Copy, Debug)]
pub struct Params {
pub(crate) method: Method,
pub(crate) tol: ParamsTol,
pub newton: ParamsNewton,
pub step: ParamsStep,
pub stiffness: ParamsStiffness,
pub bweuler: ParamsBwEuler,
pub radau5: ParamsRadau5,
pub erk: ParamsERK,
pub debug: bool,
}
impl ParamsTol {
pub(crate) fn new(method: Method) -> Self {
let radau5 = method == Method::Radau5;
let (abs, rel, newton) = calc_tolerances(radau5, 1e-4, 1e-4).unwrap();
ParamsTol { abs, rel, newton }
}
}
impl ParamsNewton {
pub(crate) fn new() -> Self {
ParamsNewton {
n_iteration_max: 7, use_numerical_jacobian: false,
genie: Genie::Umfpack,
lin_sol_params: None,
write_matrix_after_nstep_and_stop: None,
}
}
pub(crate) fn validate(&self) -> Result<(), StrError> {
if self.n_iteration_max < 1 {
return Err("parameter must satisfy: n_iteration_max ≥ 1");
}
Ok(())
}
}
impl ParamsStep {
pub(crate) fn new(method: Method) -> Self {
let (m_min, m_max, m_safety, rel_error_prev_min) = match method {
Method::Radau5 => (0.125, 5.0, 0.9, 1e-2), Method::DoPri5 => (0.2, 10.0, 0.9, 1e-4), Method::DoPri8 => (0.333, 6.0, 0.9, 1e-4), _ => (0.2, 10.0, 0.9, 1e-4),
};
ParamsStep {
m_min,
m_max,
m_safety,
m_first_reject: 0.1,
h_ini: 1e-4,
n_step_max: 100000, rel_error_prev_min,
}
}
pub(crate) fn validate(&self) -> Result<(), StrError> {
if self.m_min < 0.001 || self.m_min > 0.5 || self.m_min >= self.m_max {
return Err("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max");
}
if self.m_max < 0.01 || self.m_max > 20.0 {
return Err("parameter must satisfy: 0.01 ≤ m_max ≤ 20 and m_max > m_min");
}
if self.m_safety < 0.1 || self.m_safety > 1.0 {
return Err("parameter must satisfy: 0.1 ≤ m_safety ≤ 1");
}
if self.m_first_reject < 0.0 {
return Err("parameter must satisfy: m_first_rejection ≥ 0");
}
if self.h_ini < 1e-8 {
return Err("parameter must satisfy: h_ini ≥ 1e-8");
}
if self.n_step_max < 1 {
return Err("parameter must satisfy: n_step_max ≥ 1");
}
if self.rel_error_prev_min < 1e-8 {
return Err("parameter must satisfy: rel_error_prev_min ≥ 1e-8");
}
Ok(())
}
}
impl ParamsStiffness {
pub(crate) fn new(method: Method) -> Self {
let h_times_lambda_max = match method {
Method::DoPri5 => 3.25, Method::DoPri8 => 6.1, _ => f64::NEG_INFINITY, };
ParamsStiffness {
enabled: false,
stop_with_error: true,
save_results: false,
ratified_after_nstep: 15, ignored_after_nstep: 6, skip_first_n_accepted_step: 10,
h_times_rho_max: h_times_lambda_max,
}
}
pub fn get_h_times_rho_max(&self) -> f64 {
self.h_times_rho_max
}
}
impl ParamsBwEuler {
pub(crate) fn new() -> Self {
ParamsBwEuler {
use_modified_newton: false,
}
}
}
impl ParamsRadau5 {
pub(crate) fn new() -> Self {
ParamsRadau5 {
zero_trial: false,
theta_max: 1e-3, c1h: 1.0, c2h: 1.2, concurrent: true,
use_pred_control: true,
}
}
pub(crate) fn validate(&self) -> Result<(), StrError> {
if self.theta_max < 1e-7 {
return Err("parameter must satisfy: theta_max ≥ 1e-7");
}
if self.c1h < 0.5 || self.c1h > 1.5 || self.c1h >= self.c2h {
return Err("parameter must satisfy: 0.5 ≤ c1h ≤ 1.5 and c1h < c2h");
}
if self.c2h < 1.0 || self.c2h > 2.0 {
return Err("parameter must satisfy: 1 ≤ c2h ≤ 2 and c2h > c1h");
}
Ok(())
}
}
impl ParamsERK {
pub(crate) fn new(method: Method) -> Self {
let (lund_beta, lund_m) = match method {
Method::DoPri5 => (0.04, 0.75), Method::DoPri8 => (0.0, 0.2), _ => (0.0, 0.0),
};
ParamsERK { lund_beta, lund_m }
}
pub(crate) fn validate(&self) -> Result<(), StrError> {
if self.lund_beta < 0.0 || self.lund_beta > 0.1 {
return Err("parameter must satisfy: 0 ≤ lund_beta ≤ 0.1");
}
if self.lund_m < 0.0 || self.lund_m > 1.0 {
return Err("parameter must satisfy: 0 ≤ lund_m ≤ 1");
}
Ok(())
}
}
impl Params {
pub fn new(method: Method) -> Self {
Params {
method,
tol: ParamsTol::new(method),
newton: ParamsNewton::new(),
step: ParamsStep::new(method),
stiffness: ParamsStiffness::new(method),
bweuler: ParamsBwEuler::new(),
radau5: ParamsRadau5::new(),
erk: ParamsERK::new(method),
debug: false,
}
}
pub fn set_tolerances(&mut self, absolute: f64, relative: f64, newton: Option<f64>) -> Result<(), StrError> {
let radau5 = self.method == Method::Radau5;
let (abs, rel, newt) = calc_tolerances(radau5, absolute, relative)?;
self.tol.abs = abs;
self.tol.rel = rel;
self.tol.newton = if let Some(n) = newton { n } else { newt };
Ok(())
}
pub(crate) fn validate(&self) -> Result<(), StrError> {
self.newton.validate()?;
self.step.validate()?;
self.radau5.validate()?;
self.erk.validate()?;
Ok(())
}
}
fn calc_tolerances(radau5: bool, abs_tol: f64, rel_tol: f64) -> Result<(f64, f64, f64), StrError> {
if abs_tol <= 10.0 * f64::EPSILON {
return Err("the absolute tolerance must be > 10 · EPSILON");
}
if rel_tol <= 10.0 * f64::EPSILON {
return Err("the relative tolerance must be > 10 · EPSILON");
}
let mut abs_tol = abs_tol;
let mut rel_tol = rel_tol;
if radau5 {
const BETA: f64 = 2.0 / 3.0; let quot = abs_tol / rel_tol; rel_tol = 0.1 * f64::powf(rel_tol, BETA); abs_tol = rel_tol * quot; }
let tol_newton = f64::max(10.0 * f64::EPSILON / rel_tol, f64::min(0.03, f64::sqrt(rel_tol)));
Ok((abs_tol, rel_tol, tol_newton))
}
#[cfg(test)]
mod tests {
use super::*;
use russell_lab::approx_eq;
#[test]
fn derive_methods_work() {
let tol = ParamsTol::new(Method::Radau5);
let newton = ParamsNewton::new();
let step = ParamsStep::new(Method::Radau5);
let stiffness = ParamsStiffness::new(Method::Radau5);
let bweuler = ParamsBwEuler::new();
let radau5 = ParamsRadau5::new();
let erk = ParamsERK::new(Method::DoPri5);
let params = Params::new(Method::Radau5);
let clone_tol = tol.clone();
let clone_newton = newton.clone();
let clone_step = step.clone();
let clone_stiffness = stiffness.clone();
let clone_bweuler = bweuler.clone();
let clone_radau5 = radau5.clone();
let clone_erk = erk.clone();
let clone_params = params.clone();
assert_eq!(format!("{:?}", tol), format!("{:?}", clone_tol));
assert_eq!(format!("{:?}", newton), format!("{:?}", clone_newton));
assert_eq!(format!("{:?}", step), format!("{:?}", clone_step));
assert_eq!(format!("{:?}", stiffness), format!("{:?}", clone_stiffness));
assert_eq!(format!("{:?}", bweuler), format!("{:?}", clone_bweuler));
assert_eq!(format!("{:?}", radau5), format!("{:?}", clone_radau5));
assert_eq!(format!("{:?}", erk), format!("{:?}", clone_erk));
assert_eq!(format!("{:?}", params), format!("{:?}", clone_params));
}
#[test]
fn set_tolerances_captures_errors() {
for method in [Method::Radau5, Method::DoPri5] {
let mut params = Params::new(method);
assert_eq!(
params.set_tolerances(0.0, 1e-4, None).err(),
Some("the absolute tolerance must be > 10 · EPSILON")
);
assert_eq!(
params.set_tolerances(1e-4, 0.0, None).err(),
Some("the relative tolerance must be > 10 · EPSILON")
);
}
}
#[test]
fn set_tolerances_works() {
let mut params = Params::new(Method::Radau5);
params.set_tolerances(0.1, 0.1, None).unwrap();
approx_eq(params.tol.abs, 2.154434690031884E-02, 1e-17);
approx_eq(params.tol.rel, 2.154434690031884E-02, 1e-17);
assert_eq!(params.tol.newton, 0.03);
params.set_tolerances(0.1, 0.1, Some(0.05)).unwrap();
approx_eq(params.tol.abs, 2.154434690031884E-02, 1e-17);
approx_eq(params.tol.rel, 2.154434690031884E-02, 1e-17);
assert_eq!(params.tol.newton, 0.05);
let mut params = Params::new(Method::DoPri5);
params.set_tolerances(0.2, 0.3, None).unwrap();
assert_eq!(params.tol.abs, 0.2);
assert_eq!(params.tol.rel, 0.3);
}
#[test]
fn params_stiffness_returns_h_times_rho() {
let params = ParamsStiffness::new(Method::DoPri5);
assert_eq!(params.get_h_times_rho_max(), 3.25);
let params = ParamsStiffness::new(Method::DoPri8);
assert_eq!(params.get_h_times_rho_max(), 6.1);
let params = ParamsStiffness::new(Method::Radau5);
assert_eq!(params.get_h_times_rho_max(), f64::NEG_INFINITY);
}
#[test]
fn params_newton_validate_works() {
let mut params = ParamsNewton::new();
params.n_iteration_max = 0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: n_iteration_max ≥ 1")
);
params.n_iteration_max = 10;
assert_eq!(params.validate().is_err(), false);
}
#[test]
fn params_step_validate_works() {
let mut params = ParamsStep::new(Method::Radau5);
params.m_min = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max")
);
params.m_min = 0.6;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max")
);
params.m_min = 0.02;
params.m_max = 0.01;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max")
);
params.m_min = 0.001;
params.m_max = 0.005;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.01 ≤ m_max ≤ 20 and m_max > m_min")
);
params.m_max = 30.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.01 ≤ m_max ≤ 20 and m_max > m_min")
);
params.m_max = 10.0;
params.m_safety = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.1 ≤ m_safety ≤ 1")
);
params.m_safety = 1.2;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.1 ≤ m_safety ≤ 1")
);
params.m_safety = 0.9;
params.m_first_reject = -1.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: m_first_rejection ≥ 0")
);
params.m_first_reject = 0.0;
params.h_ini = 0.0;
assert_eq!(params.validate().err(), Some("parameter must satisfy: h_ini ≥ 1e-8"));
params.h_ini = 1e-4;
params.n_step_max = 0;
assert_eq!(params.validate().err(), Some("parameter must satisfy: n_step_max ≥ 1"));
params.n_step_max = 10;
params.rel_error_prev_min = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: rel_error_prev_min ≥ 1e-8")
);
params.rel_error_prev_min = 1e-6;
assert_eq!(params.validate().is_err(), false);
}
#[test]
fn params_radau5_validate_works() {
let mut params = ParamsRadau5::new();
params.theta_max = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: theta_max ≥ 1e-7")
);
params.theta_max = 1e-7;
params.c1h = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.5 ≤ c1h ≤ 1.5 and c1h < c2h")
);
params.c1h = 2.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.5 ≤ c1h ≤ 1.5 and c1h < c2h")
);
params.c1h = 1.3;
params.c2h = 1.2;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.5 ≤ c1h ≤ 1.5 and c1h < c2h")
);
params.c1h = 1.0;
params.c2h = 3.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 1 ≤ c2h ≤ 2 and c2h > c1h")
);
params.c2h = 1.2;
assert_eq!(params.validate().is_err(), false);
}
#[test]
fn params_erk_validate_works() {
let mut params = ParamsERK::new(Method::DoPri5);
params.lund_beta = -1.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0 ≤ lund_beta ≤ 0.1")
);
params.lund_beta = 0.2;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0 ≤ lund_beta ≤ 0.1")
);
params.lund_beta = 0.1;
params.lund_m = -1.0;
assert_eq!(params.validate().err(), Some("parameter must satisfy: 0 ≤ lund_m ≤ 1"));
params.lund_m = 1.1;
assert_eq!(params.validate().err(), Some("parameter must satisfy: 0 ≤ lund_m ≤ 1"));
params.lund_m = 0.75;
assert_eq!(params.validate().is_err(), false);
}
#[test]
fn params_validate_works() {
let mut params = Params::new(Method::Radau5);
params.newton.n_iteration_max = 0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: n_iteration_max ≥ 1")
);
params.newton.n_iteration_max = 10;
params.step.m_min = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0.001 ≤ m_min < 0.5 and m_min < m_max")
);
params.step.m_min = 0.001;
params.radau5.theta_max = 0.0;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: theta_max ≥ 1e-7")
);
params.radau5.theta_max = 1e-7;
params.erk.lund_beta = -0.1;
assert_eq!(
params.validate().err(),
Some("parameter must satisfy: 0 ≤ lund_beta ≤ 0.1")
);
params.erk.lund_beta = 0.1;
assert_eq!(params.validate().is_err(), false);
}
}