use anyhow::anyhow;
use anyhow::Result;
use num_traits::abs;
use num_traits::One;
use num_traits::Pow;
use num_traits::Zero;
use std::ops::MulAssign;
use std::rc::Rc;
use crate::matrix::MatrixRef;
use crate::nonlinear_solver::convergence::Convergence;
use crate::nonlinear_solver::newton::newton_iteration;
use crate::vector::VectorRef;
use crate::LinearSolver;
use crate::NewtonNonlinearSolver;
use crate::OdeSolverStopReason;
use crate::RootFinder;
use crate::SensEquations;
use crate::Tableau;
use crate::{
nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale, solver::SolverProblem,
DenseMatrix, NonLinearOp, OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Op,
Scalar, Vector, VectorViewMut,
};
use super::bdf::BdfStatistics;
pub struct Sdirk<M, Eqn, LS>
where
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
LS: LinearSolver<SdirkCallable<Eqn>>,
Eqn: OdeEquations,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
{
tableau: Tableau<M>,
problem: Option<OdeSolverProblem<Eqn>>,
nonlinear_solver: NewtonNonlinearSolver<SdirkCallable<Eqn>, LS>,
state: Option<OdeSolverState<Eqn::V>>,
diff: M,
sdiff: Vec<M>,
gamma: Eqn::T,
is_sdirk: bool,
s_op: Option<SdirkCallable<SensEquations<Eqn>>>,
old_t: Eqn::T,
old_y: Eqn::V,
old_y_sens: Vec<Eqn::V>,
old_f: Eqn::V,
old_f_sens: Vec<Eqn::V>,
a_rows: Vec<Eqn::V>,
statistics: BdfStatistics<Eqn::T>,
root_finder: Option<RootFinder<Eqn::V>>,
tstop: Option<Eqn::T>,
is_state_mutated: bool,
}
impl<M, Eqn, LS> Sdirk<M, Eqn, LS>
where
LS: LinearSolver<SdirkCallable<Eqn>>,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Eqn: OdeEquations,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
{
const NEWTON_MAXITER: usize = 10;
const MIN_FACTOR: f64 = 0.2;
const MAX_FACTOR: f64 = 10.0;
const MIN_TIMESTEP: f64 = 1e-13;
pub fn new(tableau: Tableau<M>, linear_solver: LS) -> Self {
let mut nonlinear_solver = NewtonNonlinearSolver::new(linear_solver);
nonlinear_solver.set_max_iter(Self::NEWTON_MAXITER);
let s = tableau.s();
for i in 0..s {
for j in (i + 1)..s {
assert_eq!(
tableau.a()[(i, j)],
Eqn::T::zero(),
"Invalid tableau, expected a(i, j) = 0 for i > j"
);
}
}
let gamma = tableau.a()[(1, 1)];
for i in 1..tableau.s() {
assert_eq!(
tableau.a()[(i, i)],
gamma,
"Invalid tableau, expected a(i, i) = gamma = {} for i = 1..s-1",
gamma
);
}
let zero = Eqn::T::zero();
if tableau.a()[(0, 0)] != zero && tableau.a()[(0, 0)] != gamma {
panic!("Invalid tableau, expected a(0, 0) = 0 or a(0, 0) = gamma");
}
let is_sdirk = tableau.a()[(0, 0)] == gamma;
let mut a_rows = Vec::with_capacity(s);
for i in 0..s {
let mut row = Vec::with_capacity(i);
for j in 0..i {
row.push(tableau.a()[(i, j)]);
}
a_rows.push(Eqn::V::from_vec(row));
}
for i in 0..s {
assert_eq!(
tableau.a()[(s - 1, i)],
tableau.b()[i],
"Invalid tableau, expected a(s-1, i) = b(i)"
);
}
assert_eq!(
tableau.c()[s - 1],
Eqn::T::one(),
"Invalid tableau, expected c(s-1) = 1"
);
if !is_sdirk {
assert_eq!(
tableau.c()[0],
Eqn::T::zero(),
"Invalid tableau, expected c(0) = 0 for esdirk methods"
);
}
let n = 1;
let s = tableau.s();
let diff = M::zeros(n, s);
let old_t = Eqn::T::zero();
let old_y = <Eqn::V as Vector>::zeros(n);
let old_f = <Eqn::V as Vector>::zeros(n);
let statistics = BdfStatistics::default();
let old_f_sens = Vec::new();
let sdiff = Vec::new();
let old_y_sens = Vec::new();
Self {
old_y_sens,
old_f_sens,
sdiff,
tableau,
nonlinear_solver,
state: None,
diff,
problem: None,
s_op: None,
gamma,
is_sdirk,
old_t,
old_y,
a_rows,
old_f,
statistics,
root_finder: None,
tstop: None,
is_state_mutated: false,
}
}
pub fn get_statistics(&self) -> &BdfStatistics<Eqn::T> {
&self.statistics
}
fn handle_tstop(&mut self, tstop: Eqn::T) -> Result<Option<OdeSolverStopReason<Eqn::T>>> {
let state = self.state.as_mut().unwrap();
let troundoff = Eqn::T::from(100.0) * Eqn::T::EPSILON * (abs(state.t) + abs(state.h));
if abs(state.t - tstop) <= troundoff {
self.tstop = None;
return Ok(Some(OdeSolverStopReason::TstopReached));
} else if tstop < state.t - troundoff {
return Err(anyhow::anyhow!(
"tstop = {} is less than current time t = {}",
tstop,
state.t
));
}
if state.t + state.h > tstop + troundoff {
let factor = (tstop - state.t) / state.h;
state.h *= factor;
self.nonlinear_solver.problem().f.set_h(state.h);
}
Ok(None)
}
fn predict_stage(i: usize, diff: &M, dy: &mut Eqn::V, tableau: &Tableau<M>) {
if i == 0 {
dy.fill(Eqn::T::zero());
} else if i == 1 {
dy.copy_from_view(&diff.column(i - 1));
} else {
let c =
(tableau.c()[i] - tableau.c()[i - 2]) / (tableau.c()[i - 1] - tableau.c()[i - 2]);
dy.copy_from_view(&diff.column(i - 1));
dy.axpy_v(-c, &diff.column(i - 2), Eqn::T::one() + c);
}
}
fn solve_for_sensitivities(&mut self, i: usize, t: Eqn::T) -> Result<()> {
{
self.problem()
.as_ref()
.unwrap()
.eqn_sens
.as_ref()
.unwrap()
.rhs()
.update_state(&self.old_y, &self.old_f, t);
}
let ls =
|x: &mut Eqn::V| -> Result<()> { self.nonlinear_solver.solve_linearised_in_place(x) };
let op = self.s_op.as_ref().unwrap();
op.set_h(self.state.as_ref().unwrap().h);
let fun = |x: &Eqn::V, y: &mut Eqn::V| op.call_inplace(x, t, y);
let rtol = self.problem().as_ref().unwrap().rtol;
let atol = self.problem().as_ref().unwrap().atol.clone();
let maxiter = self.nonlinear_solver.max_iter();
let mut convergence = Convergence::new(rtol, atol, maxiter);
let nparams = self.problem().as_ref().unwrap().eqn.rhs().nparams();
for j in 0..nparams {
let s0 = &self.state.as_ref().unwrap().s[j];
op.set_phi(&self.sdiff[j].columns(0, i), s0, &self.a_rows[i]);
op.eqn().as_ref().rhs().set_param_index(j);
let ds = &mut self.old_f_sens[j];
Self::predict_stage(i, &self.sdiff[j], ds, &self.tableau);
{
let niter = newton_iteration(ds, fun, ls, &mut convergence)?;
self.old_y_sens[j].copy_from(&op.get_last_f_eval());
self.statistics.number_of_nonlinear_solver_iterations += niter;
}
}
Ok(())
}
fn interpolate_from_diff(y0: &Eqn::V, beta_f: &Eqn::V, diff: &M) -> Eqn::V {
let mut ret = y0.clone();
diff.gemv(Eqn::T::one(), beta_f, Eqn::T::one(), &mut ret);
ret
}
fn interpolate_beta_function(theta: Eqn::T, beta: &M) -> Eqn::V {
let poly_order = beta.ncols();
let s_star = beta.nrows();
let mut thetav = Vec::with_capacity(poly_order);
thetav.push(theta);
for i in 1..poly_order {
thetav.push(theta * thetav[i - 1]);
}
let thetav = Eqn::V::from_vec(thetav);
let mut beta_f = <Eqn::V as Vector>::zeros(s_star);
beta.gemv(Eqn::T::one(), &thetav, Eqn::T::zero(), &mut beta_f);
beta_f
}
fn interpolate_hermite(theta: Eqn::T, u0: &Eqn::V, u1: &Eqn::V, diff: &M) -> Eqn::V {
let hf0 = diff.column(0);
let hf1 = diff.column(diff.ncols() - 1);
u0 * scale(Eqn::T::from(1.0) - theta)
+ u1 * scale(theta)
+ ((u1 - u0) * scale(Eqn::T::from(1.0) - Eqn::T::from(2.0) * theta)
+ hf0 * scale(theta - Eqn::T::from(1.0))
+ hf1 * scale(theta))
* scale(theta * (theta - Eqn::T::from(1.0)))
}
}
impl<M, Eqn, LS> OdeSolverMethod<Eqn> for Sdirk<M, Eqn, LS>
where
LS: LinearSolver<SdirkCallable<Eqn>>,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Eqn: OdeEquations,
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
{
fn problem(&self) -> Option<&OdeSolverProblem<Eqn>> {
self.problem.as_ref()
}
fn order(&self) -> usize {
self.tableau.order()
}
fn take_state(&mut self) -> Option<OdeSolverState<Eqn::V>> {
Option::take(&mut self.state)
}
fn set_problem(&mut self, state: OdeSolverState<<Eqn>::V>, problem: &OdeSolverProblem<Eqn>) {
let callable = Rc::new(SdirkCallable::new(problem, self.gamma));
callable.set_h(state.h);
let nonlinear_problem = SolverProblem::new_from_ode_problem(callable, problem);
self.nonlinear_solver.set_problem(&nonlinear_problem);
self.statistics = BdfStatistics::default();
self.statistics.initial_step_size = state.h;
let nstates = state.y.len();
let nparams = problem.eqn.rhs().nparams();
if problem.eqn_sens.is_some() {
self.sdiff = vec![M::zeros(nstates, self.tableau.s()); nparams];
self.old_f_sens = vec![<Eqn::V as Vector>::zeros(nstates); nparams];
self.old_y_sens = vec![<Eqn::V as Vector>::zeros(nstates); nparams];
self.s_op = Some(SdirkCallable::from_eqn(
problem.eqn_sens.as_ref().unwrap().clone(),
self.gamma,
));
}
self.diff = M::zeros(nstates, self.tableau.s());
self.old_f = state.dy.clone();
self.old_t = state.t;
self.old_y = state.y.clone();
self.state = Some(state);
self.problem = Some(problem.clone());
if let Some(root_fn) = problem.eqn.root() {
let state = self.state.as_ref().unwrap();
self.root_finder = Some(RootFinder::new(root_fn.nout()));
self.root_finder
.as_ref()
.unwrap()
.init(root_fn.as_ref(), &state.y, state.t);
}
}
fn step(&mut self) -> Result<OdeSolverStopReason<Eqn::T>> {
if self.state.is_none() {
return Err(anyhow!("State not set"));
}
let n = self.state.as_ref().unwrap().y.len();
let start = if self.is_sdirk { 0 } else { 1 };
let mut updated_jacobian = false;
let mut second_step_attempt = false;
let mut error = <Eqn::V as Vector>::zeros(n);
let mut t1: Eqn::T;
'step: loop {
let t0 = self.state.as_ref().unwrap().t;
let h = self.state.as_ref().unwrap().h;
if start == 1 {
{
let mut hf = self.diff.column_mut(0);
hf.copy_from(&self.state.as_ref().unwrap().dy);
hf *= scale(h);
}
if self.problem().as_ref().unwrap().eqn_sens.is_some() {
for (diff, dy) in self
.sdiff
.iter_mut()
.zip(self.state.as_ref().unwrap().ds.iter())
{
let mut hf = diff.column_mut(0);
hf.copy_from(dy);
hf *= scale(h);
}
}
}
for i in start..self.tableau.s() {
let t = t0 + self.tableau.c()[i] * h;
self.nonlinear_solver.problem().f.set_phi(
&self.diff.columns(0, i),
&self.state.as_ref().unwrap().y,
&self.a_rows[i],
);
Self::predict_stage(i, &self.diff, &mut self.old_f, &self.tableau);
if i == start && second_step_attempt {
self.nonlinear_solver.reset_jacobian(&self.old_f, t);
}
second_step_attempt = true;
let mut solve_result = self.nonlinear_solver.solve_in_place(&mut self.old_f, t);
self.statistics.number_of_nonlinear_solver_iterations +=
self.nonlinear_solver.niter();
if solve_result.is_ok() {
self.old_y
.copy_from(&self.nonlinear_solver.problem().f.get_last_f_eval());
if self.problem().as_ref().unwrap().eqn_sens.is_some() {
solve_result = self.solve_for_sensitivities(i, t);
}
}
if solve_result.is_err() {
if !updated_jacobian {
self.nonlinear_solver.problem().f.set_jacobian_is_stale();
updated_jacobian = true;
self.statistics.number_of_nonlinear_solver_fails += 1;
} else {
let state = self.state.as_mut().unwrap();
self.statistics.number_of_nonlinear_solver_fails += 1;
state.h *= Eqn::T::from(0.3);
if state.h < Eqn::T::from(Self::MIN_TIMESTEP) {
return Err(anyhow::anyhow!("Step size too small at t = {}", state.t));
}
self.nonlinear_solver.problem().f.set_h(state.h);
}
continue 'step;
};
self.diff.column_mut(i).copy_from(&self.old_f);
if self.problem().as_ref().unwrap().eqn_sens.is_some() {
for (diff, old_f_sens) in self.sdiff.iter_mut().zip(self.old_f_sens.iter()) {
diff.column_mut(i).copy_from(old_f_sens);
}
}
}
self.diff
.gemv(Eqn::T::one(), self.tableau.d(), Eqn::T::zero(), &mut error);
self.nonlinear_solver
.solve_linearised_in_place(&mut error)?;
let atol = self.problem().as_ref().unwrap().atol.as_ref();
let rtol = self.problem().as_ref().unwrap().rtol;
let mut error_norm = error.squared_norm(&self.old_y, atol, rtol);
if self.problem().as_ref().unwrap().eqn_sens.is_some()
&& self.problem().as_ref().unwrap().sens_error_control
{
for i in 0..self.sdiff.len() {
self.sdiff[i].gemv(Eqn::T::one(), self.tableau.d(), Eqn::T::zero(), &mut error);
self.nonlinear_solver
.solve_linearised_in_place(&mut error)?;
let sens_error_norm = error.squared_norm(&self.old_y_sens[i], atol, rtol);
error_norm += sens_error_norm;
}
error_norm /= Eqn::T::from((self.sdiff.len() + 1) as f64);
}
let maxiter = self.nonlinear_solver.max_iter() as f64;
let niter = self.nonlinear_solver.niter() as f64;
let safety = Eqn::T::from(0.9 * (2.0 * maxiter + 1.0) / (2.0 * maxiter + niter));
let order = self.tableau.order() as f64;
let mut factor = safety * error_norm.pow(Eqn::T::from(-0.5 / (order + 1.0)));
if factor < Eqn::T::from(Self::MIN_FACTOR) {
factor = Eqn::T::from(Self::MIN_FACTOR);
}
if factor > Eqn::T::from(Self::MAX_FACTOR) {
factor = Eqn::T::from(Self::MAX_FACTOR);
}
let state = self.state.as_mut().unwrap();
t1 = state.t + state.h;
state.h *= factor;
if state.h < Eqn::T::from(Self::MIN_TIMESTEP) {
return Err(anyhow::anyhow!("Step size too small at t = {}", state.t));
}
self.nonlinear_solver.problem().f.set_h(state.h);
if error_norm <= Eqn::T::from(1.0) {
break 'step;
}
self.statistics.number_of_error_test_failures += 1;
}
self.nonlinear_solver.reset_jacobian(&self.old_f, t1);
let state = self.state.as_mut().unwrap();
let dt = t1 - state.t;
self.old_t = state.t;
state.t = t1;
self.old_f.mul_assign(scale(Eqn::T::one() / dt));
std::mem::swap(&mut self.old_f, &mut state.dy);
std::mem::swap(&mut self.old_y, &mut state.y);
for i in 0..self.sdiff.len() {
self.old_f_sens[i].mul_assign(scale(Eqn::T::one() / dt));
std::mem::swap(&mut self.old_f_sens[i], &mut state.ds[i]);
std::mem::swap(&mut self.old_y_sens[i], &mut state.s[i]);
}
self.is_state_mutated = false;
self.statistics.number_of_linear_solver_setups =
self.nonlinear_solver.problem().f.number_of_jac_evals();
self.statistics.number_of_steps += 1;
self.statistics.final_step_size = self.state.as_ref().unwrap().h;
if let Some(root_fn) = self.problem.as_ref().unwrap().eqn.root() {
let ret = self.root_finder.as_ref().unwrap().check_root(
&|t| self.interpolate(t),
root_fn.as_ref(),
&self.state.as_ref().unwrap().y,
self.state.as_ref().unwrap().t,
);
if let Some(root) = ret {
return Ok(OdeSolverStopReason::RootFound(root));
}
}
if let Some(tstop) = self.tstop {
if let Some(reason) = self.handle_tstop(tstop).unwrap() {
return Ok(reason);
}
}
Ok(OdeSolverStopReason::InternalTimestep)
}
fn set_stop_time(&mut self, tstop: <Eqn as OdeEquations>::T) -> Result<()> {
self.tstop = Some(tstop);
if let Some(OdeSolverStopReason::TstopReached) = self.handle_tstop(tstop)? {
self.tstop = None;
return Err(anyhow::anyhow!(
"Stop time is at or before current time t = {}",
self.state.as_ref().unwrap().t
));
}
Ok(())
}
fn interpolate_sens(
&self,
t: <Eqn as OdeEquations>::T,
) -> Result<Vec<<Eqn as OdeEquations>::V>> {
if self.state.is_none() {
return Err(anyhow!("State not set"));
}
let state = self.state.as_ref().unwrap();
if self.is_state_mutated {
if t == state.t {
return Ok(state.s.clone());
} else {
return Err(anyhow::anyhow!("Interpolation time is not within the current step. Step size is zero after calling state_mut()"));
}
}
if t > state.t || t < self.old_t {
return Err(anyhow::anyhow!(
"Interpolation time is not within the current step"
));
}
let dt = state.t - self.old_t;
let theta = if dt == Eqn::T::zero() {
Eqn::T::one()
} else {
(t - self.old_t) / dt
};
if let Some(beta) = self.tableau.beta() {
let beta_f = Self::interpolate_beta_function(theta, beta);
let ret = self
.old_y_sens
.iter()
.zip(self.sdiff.iter())
.map(|(y, diff)| Self::interpolate_from_diff(y, &beta_f, diff))
.collect();
Ok(ret)
} else {
let ret = self
.old_y_sens
.iter()
.zip(state.s.iter())
.zip(self.sdiff.iter())
.map(|((s0, s1), diff)| Self::interpolate_hermite(theta, s0, s1, diff))
.collect();
Ok(ret)
}
}
fn interpolate(&self, t: <Eqn>::T) -> anyhow::Result<<Eqn>::V> {
if self.state.is_none() {
return Err(anyhow!("State not set"));
}
let state = self.state.as_ref().unwrap();
if self.is_state_mutated {
if t == state.t {
return Ok(state.y.clone());
} else {
return Err(anyhow::anyhow!("Interpolation time is not within the current step. Step size is zero after calling state_mut()"));
}
}
if t > state.t || t < self.old_t {
return Err(anyhow::anyhow!(
"Interpolation time is not within the current step"
));
}
let dt = state.t - self.old_t;
let theta = if dt == Eqn::T::zero() {
Eqn::T::one()
} else {
(t - self.old_t) / dt
};
if let Some(beta) = self.tableau.beta() {
let beta_f = Self::interpolate_beta_function(theta, beta);
let ret = Self::interpolate_from_diff(&self.old_y, &beta_f, &self.diff);
Ok(ret)
} else {
let ret = Self::interpolate_hermite(theta, &self.old_y, &state.y, &self.diff);
Ok(ret)
}
}
fn state(&self) -> Option<&OdeSolverState<Eqn::V>> {
self.state.as_ref()
}
fn state_mut(&mut self) -> Option<&mut OdeSolverState<Eqn::V>> {
self.is_state_mutated = true;
self.state.as_mut()
}
}
#[cfg(test)]
mod test {
use crate::{
ode_solver::{
test_models::{
exponential_decay::{
exponential_decay_problem, exponential_decay_problem_sens,
exponential_decay_problem_with_root,
},
robertson::robertson,
robertson_ode::robertson_ode,
robertson_sens::robertson_sens,
},
tests::{
test_interpolate, test_no_set_problem, test_ode_solver, test_state_mut,
test_state_mut_on_problem,
},
},
NalgebraLU, OdeEquations, Op, Sdirk, Tableau,
};
use num_traits::abs;
type M = nalgebra::DMatrix<f64>;
#[test]
fn sdirk_no_set_problem() {
let tableau = Tableau::<M>::tr_bdf2();
test_no_set_problem::<M, _>(Sdirk::<M, _, _>::new(tableau, NalgebraLU::default()));
}
#[test]
fn sdirk_state_mut() {
let tableau = Tableau::<M>::tr_bdf2();
test_state_mut::<M, _>(Sdirk::<M, _, _>::new(tableau, NalgebraLU::default()));
}
#[test]
fn sdirk_test_interpolate() {
let tableau = Tableau::<M>::tr_bdf2();
test_interpolate::<M, _>(Sdirk::<M, _, _>::new(tableau, NalgebraLU::default()));
}
#[test]
fn sdirk_test_state_mut_exponential_decay() {
let (p, soln) = exponential_decay_problem::<M>(false);
let tableau = Tableau::<M>::tr_bdf2();
let s = Sdirk::<M, _, _>::new(tableau, NalgebraLU::default());
test_state_mut_on_problem(s, p, soln);
}
#[test]
fn test_tr_bdf2_nalgebra_exponential_decay() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = exponential_decay_problem::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 30
number_of_steps: 29
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 116
number_of_nonlinear_solver_fails: 0
initial_step_size: 0.005848035476425734
final_step_size: 0.3808530346209797
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 118
number_of_jac_muls: 2
number_of_matrix_evals: 1
"###);
}
#[test]
fn test_tr_bdf2_nalgebra_exponential_decay_sens() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = exponential_decay_problem_sens::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 59
number_of_steps: 58
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 464
number_of_nonlinear_solver_fails: 0
initial_step_size: 0.005848035476425734
final_step_size: 0.22851673033949357
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 234
number_of_jac_muls: 235
number_of_matrix_evals: 1
"###);
}
#[test]
fn test_esdirk34_nalgebra_exponential_decay() {
let tableau = Tableau::<M>::esdirk34();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = exponential_decay_problem::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 14
number_of_steps: 13
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 78
number_of_nonlinear_solver_fails: 0
initial_step_size: 0.02114742526881128
final_step_size: 0.9531112013867072
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 80
number_of_jac_muls: 2
number_of_matrix_evals: 1
"###);
}
#[test]
fn test_esdirk34_nalgebra_exponential_decay_sens() {
let tableau = Tableau::<M>::esdirk34();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = exponential_decay_problem_sens::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 23
number_of_steps: 22
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 264
number_of_nonlinear_solver_fails: 0
initial_step_size: 0.02114742526881128
final_step_size: 0.5893196907333161
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 134
number_of_jac_muls: 135
number_of_matrix_evals: 1
"###);
}
#[test]
fn test_tr_bdf2_nalgebra_robertson() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = robertson::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 429
number_of_steps: 410
number_of_error_test_failures: 6
number_of_nonlinear_solver_iterations: 3032
number_of_nonlinear_solver_fails: 12
initial_step_size: 0.0005245814253712257
final_step_size: 38234484245.73098
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 2993
number_of_jac_muls: 42
number_of_matrix_evals: 14
"###);
}
#[test]
fn test_tr_bdf2_nalgebra_robertson_sens() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = robertson_sens::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 914
number_of_steps: 891
number_of_error_test_failures: 7
number_of_nonlinear_solver_iterations: 17062
number_of_nonlinear_solver_fails: 15
initial_step_size: 0.0005245814253712257
final_step_size: 16695887030.215992
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 4809
number_of_jac_muls: 12347
number_of_matrix_evals: 20
"###);
}
#[test]
fn test_esdirk34_nalgebra_robertson() {
let tableau = Tableau::<M>::esdirk34();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = robertson::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 289
number_of_steps: 266
number_of_error_test_failures: 3
number_of_nonlinear_solver_iterations: 2889
number_of_nonlinear_solver_fails: 19
initial_step_size: 0.0034662483959892352
final_step_size: 47734821046.576515
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 2848
number_of_jac_muls: 57
number_of_matrix_evals: 19
"###);
}
#[test]
fn test_esdirk34_nalgebra_robertson_sens() {
let tableau = Tableau::<M>::esdirk34();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = robertson_sens::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 489
number_of_steps: 461
number_of_error_test_failures: 3
number_of_nonlinear_solver_iterations: 13777
number_of_nonlinear_solver_fails: 24
initial_step_size: 0.0034662483959892352
final_step_size: 23926664695.166664
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 3989
number_of_jac_muls: 9909
number_of_matrix_evals: 26
"###);
}
#[test]
fn test_tr_bdf2_nalgebra_robertson_ode() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = robertson_ode::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, false);
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
---
number_of_linear_solver_setups: 243
number_of_steps: 230
number_of_error_test_failures: 0
number_of_nonlinear_solver_iterations: 2383
number_of_nonlinear_solver_fails: 12
initial_step_size: 0.00046734995811969143
final_step_size: 59513072650.62326
"###);
insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###"
---
number_of_calls: 2322
number_of_jac_muls: 39
number_of_matrix_evals: 13
"###);
}
#[test]
fn test_tstop_tr_bdf2() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = exponential_decay_problem::<M>(false);
test_ode_solver(&mut s, &problem, soln, None, true);
}
#[test]
fn test_root_finder_tr_bdf2() {
let tableau = Tableau::<M>::tr_bdf2();
let mut s = Sdirk::new(tableau, NalgebraLU::default());
let (problem, soln) = exponential_decay_problem_with_root::<M>(false);
let y = test_ode_solver(&mut s, &problem, soln, None, false);
assert!(abs(y[0] - 0.6) < 1e-6, "y[0] = {}", y[0]);
}
}