use num_traits::FromPrimitive;
use std::{cell::RefCell, rc::Rc};
use crate::{
error::DiffsolError, vector::Vector, AdjointContext, AdjointEquations, AugmentedOdeEquations,
AugmentedOdeEquationsImplicit, Bdf, BdfState, Checkpointing, DefaultDenseMatrix, DenseMatrix,
ExplicitRk, LinearSolver, MatrixRef, NewtonNonlinearSolver, NoLineSearch, OdeEquations,
OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsImplicitAdjoint,
OdeEquationsImplicitSens, OdeSolverMethod, OdeSolverState, RkState, Scalar, Sdirk,
SensEquations, Tableau, VectorRef,
};
pub struct InitialConditionSolverOptions<T: Scalar> {
pub use_linesearch: bool,
pub max_linesearch_iterations: usize,
pub max_newton_iterations: usize,
pub max_linear_solver_setups: usize,
pub step_reduction_factor: T,
pub armijo_constant: T,
}
impl<T: Scalar> Default for InitialConditionSolverOptions<T> {
fn default() -> Self {
Self {
use_linesearch: true,
max_linesearch_iterations: 10,
max_linear_solver_setups: 4,
max_newton_iterations: 10,
step_reduction_factor: T::from_f64(0.5).unwrap(),
armijo_constant: T::from_f64(1e-4).unwrap(),
}
}
}
pub struct OdeSolverOptions<T: Scalar> {
pub max_nonlinear_solver_iterations: usize,
pub max_error_test_failures: usize,
pub max_nonlinear_solver_failures: usize,
pub nonlinear_solver_tolerance: T,
pub min_timestep: T,
pub max_timestep_growth: Option<T>,
pub min_timestep_growth: Option<T>,
pub max_timestep_shrink: Option<T>,
pub min_timestep_shrink: Option<T>,
pub update_jacobian_after_steps: usize,
pub update_rhs_jacobian_after_steps: usize,
pub threshold_to_update_jacobian: T,
pub threshold_to_update_rhs_jacobian: T,
}
impl<T: Scalar> Default for OdeSolverOptions<T> {
fn default() -> Self {
Self {
max_nonlinear_solver_iterations: 10,
max_error_test_failures: 40,
max_nonlinear_solver_failures: 50,
nonlinear_solver_tolerance: T::from_f64(0.2).unwrap(),
min_timestep: T::from_f64(1e-13).unwrap(),
max_timestep_growth: None,
min_timestep_growth: None,
max_timestep_shrink: None,
min_timestep_shrink: None,
update_jacobian_after_steps: 20,
update_rhs_jacobian_after_steps: 50,
threshold_to_update_jacobian: T::from_f64(0.3).unwrap(),
threshold_to_update_rhs_jacobian: T::from_f64(0.2).unwrap(),
}
}
}
pub struct OdeSolverProblem<Eqn>
where
Eqn: OdeEquations,
{
pub eqn: Eqn,
pub rtol: Eqn::T,
pub atol: Eqn::V,
pub t0: Eqn::T,
pub h0: Eqn::T,
pub integrate_out: bool,
pub sens_rtol: Option<Eqn::T>,
pub sens_atol: Option<Eqn::V>,
pub out_rtol: Option<Eqn::T>,
pub out_atol: Option<Eqn::V>,
pub param_rtol: Option<Eqn::T>,
pub param_atol: Option<Eqn::V>,
pub ic_options: InitialConditionSolverOptions<Eqn::T>,
pub ode_options: OdeSolverOptions<Eqn::T>,
}
macro_rules! sdirk_solver_from_tableau {
($method:ident, $method_sens:ident, $method_solver:ident, $method_solver_sens:ident, $method_state_adjoint:ident, $method_solver_adjoint:ident, $method_solver_adjoint_from_state:ident, $tableau:ident) => {
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK solver instance with the given initial state.\n\n",
"This method uses the built-in ", stringify!($tableau), " Butcher tableau.\n\n",
"# Type Parameters\n",
"- `LS`: The linear solver type\n\n",
"# Arguments\n",
"- `state`: The initial state for the solver\n\n",
"# Returns\n",
"An SDIRK solver instance configured with the ", stringify!($tableau), " method")]
pub fn $method_solver<LS: LinearSolver<Eqn::M>>(
&self,
state: RkState<Eqn::V>,
) -> Result<Sdirk<'_, Eqn, LS>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
self.sdirk_solver(
state,
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK solver instance with forward sensitivities, given the initial state.\n\n",
"This method uses the built-in ", stringify!($tableau), " Butcher tableau and simultaneously solves\n",
"the state equations and forward sensitivity equations.\n\n",
"# Type Parameters\n",
"- `LS`: The linear solver type\n\n",
"# Arguments\n",
"- `state`: The initial state for the solver (including sensitivities)\n\n",
"# Returns\n",
"An SDIRK solver instance configured for forward sensitivity analysis using ", stringify!($tableau))]
pub fn $method_solver_sens<LS: LinearSolver<Eqn::M>>(
&self,
state: RkState<Eqn::V>,
) -> Result<
Sdirk<'_, Eqn, LS, <Eqn::V as DefaultDenseMatrix>::M, SensEquations<'_, Eqn>>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
{
self.sdirk_solver_sens(
state,
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK solver instance for adjoint sensitivity analysis.\n\n",
"This method creates a solver for the backward adjoint equations using the ", stringify!($tableau), " method.\n",
"Requires a checkpointer to provide the forward solution during the backward solve.\n\n",
"# Type Parameters\n",
"- `LS`: The linear solver type\n",
"- `S`: The forward solver method type used for checkpointing\n\n",
"# Arguments\n",
"- `checkpointer`: The checkpointing object containing the forward solution\n",
"- `nout_override`: Optional override for the number of output equations\n\n",
"# Returns\n",
"An SDIRK solver instance configured for adjoint sensitivity analysis using ", stringify!($tableau))]
pub fn $method_solver_adjoint<'a, LS: LinearSolver<Eqn::M>, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
checkpointer: Checkpointing<'a, Eqn, S>,
nout_override: Option<usize>,
) -> Result<
Sdirk<'a, Eqn, LS, <Eqn::V as DefaultDenseMatrix>::M, AdjointEquations<'a, Eqn, S>>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitAdjoint,
{
self.sdirk_solver_adjoint(
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
checkpointer,
nout_override,
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK adjoint initial state.")]
pub fn $method_state_adjoint<'a, LS: LinearSolver<Eqn::M>, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
adjoint_eqn: &mut AdjointEquations<'a, Eqn, S>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint,
{
self.sdirk_state_adjoint::<LS, _, _>(
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
adjoint_eqn,
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK adjoint solver instance from an existing state.")]
pub fn $method_solver_adjoint_from_state<
'a,
LS: LinearSolver<Eqn::M>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
state: RkState<Eqn::V>,
adjoint_eqn: AdjointEquations<'a, Eqn, S>,
) -> Result<
Sdirk<'a, Eqn, LS, <Eqn::V as DefaultDenseMatrix>::M, AdjointEquations<'a, Eqn, S>>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitAdjoint,
{
self.sdirk_solver_adjoint_from_state(
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
state,
adjoint_eqn,
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK solver instance with a consistent initial state.\n\n",
"This convenience method combines state creation and solver initialization using the\n",
"built-in ", stringify!($tableau), " Butcher tableau. It will create a consistent initial state,\n",
"which may require solving a nonlinear system if a mass matrix is present.\n\n",
"# Type Parameters\n",
"- `LS`: The linear solver type\n\n",
"# Returns\n",
"An SDIRK solver instance configured with the ", stringify!($tableau), " method and consistent initial state")]
pub fn $method<LS: LinearSolver<Eqn::M>>(&self) -> Result<Sdirk<'_, Eqn, LS>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
let tableau =
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone());
let state = self.rk_state_and_consistent::<LS, _>(&tableau)?;
self.sdirk_solver(state, tableau)
}
#[doc = concat!("Create a new ", stringify!($tableau), " SDIRK solver instance with forward sensitivities and consistent initial state.\n\n",
"This convenience method combines state creation and solver initialization for forward\n",
"sensitivity analysis using the built-in ", stringify!($tableau), " Butcher tableau. It will create\n",
"a consistent initial state, which may require solving a nonlinear system if a mass matrix is present.\n\n",
"# Type Parameters\n",
"- `LS`: The linear solver type\n\n",
"# Returns\n",
"An SDIRK solver instance configured for forward sensitivity analysis using ", stringify!($tableau))]
pub fn $method_sens<LS: LinearSolver<Eqn::M>>(
&self,
) -> Result<
Sdirk<'_, Eqn, LS, <Eqn::V as DefaultDenseMatrix>::M, SensEquations<'_, Eqn>>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
{
let tableau =
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone());
let state = self.rk_state_sens_and_consistent::<LS, _>(&tableau)?;
self.sdirk_solver_sens(state, tableau)
}
};
}
macro_rules! rk_solver_from_tableau {
($method:ident, $method_sens:ident, $method_solver:ident, $method_solver_sens:ident, $method_state_adjoint:ident, $method_solver_adjoint:ident, $method_solver_adjoint_from_state:ident, $tableau:ident) => {
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta solver instance with the given initial state.\n\n",
"This method uses the built-in ", stringify!($tableau), " Butcher tableau.\n\n",
"# Arguments\n",
"- `state`: The initial state for the solver\n\n",
"# Returns\n",
"An explicit RK solver instance configured with the ", stringify!($tableau), " method")]
pub fn $method_solver(
&self,
state: RkState<Eqn::V>,
) -> Result<ExplicitRk<'_, Eqn>, DiffsolError>
where
Eqn: OdeEquations,
{
self.explicit_rk_solver(
state,
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta solver instance with forward sensitivities, given the initial state.\n\n",
"This method uses the built-in ", stringify!($tableau), " Butcher tableau and simultaneously solves\n",
"the state equations and forward sensitivity equations.\n\n",
"# Arguments\n",
"- `state`: The initial state for the solver (including sensitivities)\n\n",
"# Returns\n",
"An explicit RK solver instance configured for forward sensitivity analysis using ", stringify!($tableau))]
pub fn $method_solver_sens(
&self,
state: RkState<Eqn::V>,
) -> Result<
ExplicitRk<'_, Eqn, <Eqn::V as DefaultDenseMatrix>::M, SensEquations<'_, Eqn>>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
{
self.explicit_rk_solver_sens(
state,
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta solver instance for adjoint sensitivity analysis.\n\n",
"This method creates a solver for the backward adjoint equations using the ", stringify!($tableau), " method.\n",
"Requires a checkpointer to provide the forward solution during the backward solve.\n\n",
"# Type Parameters\n",
"- `S`: The forward solver method type used for checkpointing (this can be auto-deduced fromt the `checkpointer`\n\n",
"# Arguments\n",
"- `checkpointer`: The checkpointing object containing the forward solution\n",
"- `nout_override`: Optional override for the number of output equations\n\n",
"# Returns\n",
"An explicit RK solver instance configured for adjoint sensitivity analysis using ", stringify!($tableau))]
pub fn $method_solver_adjoint<'a, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
checkpointer: Checkpointing<'a, Eqn, S>,
nout_override: Option<usize>,
) -> Result<
ExplicitRk<'a, Eqn, <Eqn::V as DefaultDenseMatrix>::M, AdjointEquations<'a, Eqn, S>>,
DiffsolError,
>
where
Eqn: OdeEquationsAdjoint,
{
self.explicit_rk_solver_adjoint(
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
checkpointer,
nout_override,
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta adjoint initial state.")]
pub fn $method_state_adjoint<'a, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
adjoint_eqn: &mut AdjointEquations<'a, Eqn, S>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsAdjoint,
{
self.explicit_rk_state_adjoint(
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
adjoint_eqn,
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta adjoint solver instance from an existing state.")]
pub fn $method_solver_adjoint_from_state<'a, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
state: RkState<Eqn::V>,
adjoint_eqn: AdjointEquations<'a, Eqn, S>,
) -> Result<
ExplicitRk<'a, Eqn, <Eqn::V as DefaultDenseMatrix>::M, AdjointEquations<'a, Eqn, S>>,
DiffsolError,
>
where
Eqn: OdeEquationsAdjoint,
{
self.explicit_rk_solver_adjoint_from_state(
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone()),
state,
adjoint_eqn,
)
}
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta solver instance with initial state.\n\n",
"This convenience method combines state creation and solver initialization using the\n",
"built-in ", stringify!($tableau), " Butcher tableau.\n\n",
"# Returns\n",
"An explicit RK solver instance configured with the ", stringify!($tableau), " method")]
pub fn $method(&self) -> Result<ExplicitRk<'_, Eqn>, DiffsolError>
where
Eqn: OdeEquations,
{
let tableau =
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone());
let state = self.rk_state(&tableau)?;
self.explicit_rk_solver(state, tableau)
}
#[doc = concat!("Create a new ", stringify!($tableau), " explicit Runge-Kutta solver instance with forward sensitivities.\n\n",
"This convenience method combines state creation and solver initialization for forward\n",
"sensitivity analysis using the built-in ", stringify!($tableau), " Butcher tableau.\n\n",
"# Returns\n",
"An explicit RK solver instance configured for forward sensitivity analysis using ", stringify!($tableau))]
pub fn $method_sens(
&self,
) -> Result<
ExplicitRk<'_, Eqn, <Eqn::V as DefaultDenseMatrix>::M, SensEquations<'_, Eqn>>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
{
let tableau =
Tableau::<<Eqn::V as DefaultDenseMatrix>::M>::$tableau(self.context().clone());
let state = self.rk_state_sens(&tableau)?;
self.explicit_rk_solver_sens(state, tableau)
}
};
}
impl<Eqn> OdeSolverProblem<Eqn>
where
Eqn: OdeEquations,
{
pub fn output_in_error_control(&self) -> bool {
self.integrate_out
&& self.eqn.out().is_some()
&& self.out_rtol.is_some()
&& self.out_atol.is_some()
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
eqn: Eqn,
rtol: Eqn::T,
atol: Eqn::V,
sens_rtol: Option<Eqn::T>,
sens_atol: Option<Eqn::V>,
out_rtol: Option<Eqn::T>,
out_atol: Option<Eqn::V>,
param_rtol: Option<Eqn::T>,
param_atol: Option<Eqn::V>,
t0: Eqn::T,
h0: Eqn::T,
integrate_out: bool,
ic_options: InitialConditionSolverOptions<Eqn::T>,
ode_options: OdeSolverOptions<Eqn::T>,
) -> Result<Self, DiffsolError> {
Ok(Self {
eqn,
rtol,
atol,
out_atol,
out_rtol,
param_atol,
param_rtol,
sens_atol,
sens_rtol,
t0,
h0,
integrate_out,
ic_options,
ode_options,
})
}
pub fn eqn(&self) -> &Eqn {
&self.eqn
}
pub fn eqn_mut(&mut self) -> &mut Eqn {
&mut self.eqn
}
pub fn context(&self) -> &Eqn::C {
self.eqn.context()
}
}
impl<Eqn> OdeSolverProblem<Eqn>
where
Eqn: OdeEquationsAdjoint,
{
pub fn adjoint_equations<'a, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
checkpointer: Checkpointing<'a, Eqn, S>,
nout_override: Option<usize>,
) -> AdjointEquations<'a, Eqn, S> {
let nout = nout_override.unwrap_or_else(|| self.eqn.nout());
let context = Rc::new(RefCell::new(AdjointContext::new(checkpointer, nout)));
AdjointEquations::new(self, context, self.integrate_out)
}
}
impl<Eqn> OdeSolverProblem<Eqn>
where
Eqn: OdeEquations,
Eqn::V: DefaultDenseMatrix<T = Eqn::T, C = Eqn::C>,
for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
{
pub fn bdf_state<LS: LinearSolver<Eqn::M>>(&self) -> Result<BdfState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
BdfState::new_and_consistent::<LS, Eqn>(self, 1)
}
pub fn bdf_state_sens<LS: LinearSolver<Eqn::M>>(&self) -> Result<BdfState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicitSens,
{
BdfState::new_with_sensitivities_and_consistent::<LS, Eqn>(self, 1)
}
#[allow(clippy::type_complexity)]
pub fn bdf_solver<LS: LinearSolver<Eqn::M>>(
&self,
state: BdfState<Eqn::V>,
) -> Result<Bdf<'_, Eqn, NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
let newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
Bdf::new(self, state, newton_solver)
}
#[allow(clippy::type_complexity)]
pub fn bdf<LS: LinearSolver<Eqn::M>>(
&self,
) -> Result<Bdf<'_, Eqn, NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
let state = self.bdf_state::<LS>()?;
self.bdf_solver(state)
}
#[allow(clippy::type_complexity)]
pub(crate) fn bdf_solver_aug<
LS: LinearSolver<Eqn::M>,
Aug: AugmentedOdeEquationsImplicit<Eqn>,
>(
&self,
state: BdfState<Eqn::V>,
aug_eqn: Aug,
) -> Result<
Bdf<
'_,
Eqn,
NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>,
<Eqn::V as DefaultDenseMatrix>::M,
Aug,
>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicit,
{
let newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
Bdf::new_augmented(state, self, aug_eqn, newton_solver)
}
#[allow(clippy::type_complexity)]
pub fn bdf_state_adjoint<'a, LS: LinearSolver<Eqn::M>, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
augmented_eqn: &mut AdjointEquations<'a, Eqn, S>,
) -> Result<BdfState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint,
{
let h = augmented_eqn.last_h();
let t = augmented_eqn.last_t();
let mut newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
let mut state = BdfState::new_without_initialise_augmented_at(self, augmented_eqn, t)?;
*state.as_mut().t = t;
if let Some(h) = h {
*state.as_mut().h = -h;
}
state.set_consistent(self, &mut newton_solver)?;
state.set_consistent_augmented(self, augmented_eqn, &mut newton_solver)?;
state.set_step_size(
state.h,
augmented_eqn.atol().unwrap(),
augmented_eqn.rtol().unwrap(),
augmented_eqn,
1,
);
Ok(state)
}
#[allow(clippy::type_complexity)]
pub fn bdf_solver_adjoint_from_state<
'a,
LS: LinearSolver<Eqn::M>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
state: BdfState<Eqn::V>,
mut augmented_eqn: AdjointEquations<'a, Eqn, S>,
) -> Result<
Bdf<
'a,
Eqn,
NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>,
<Eqn::V as DefaultDenseMatrix>::M,
AdjointEquations<'a, Eqn, S>,
>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitAdjoint,
{
let mut newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
let mut state = state;
state.set_consistent(self, &mut newton_solver)?;
state.set_consistent_augmented(self, &mut augmented_eqn, &mut newton_solver)?;
let newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
Bdf::new_augmented(state, self, augmented_eqn, newton_solver)
}
#[allow(clippy::type_complexity)]
pub fn bdf_solver_adjoint<'a, LS: LinearSolver<Eqn::M>, S: OdeSolverMethod<'a, Eqn>>(
&'a self,
checkpointer: Checkpointing<'a, Eqn, S>,
nout_override: Option<usize>,
) -> Result<
Bdf<
'a,
Eqn,
NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>,
<Eqn::V as DefaultDenseMatrix>::M,
AdjointEquations<'a, Eqn, S>,
>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitAdjoint,
{
let mut augmented_eqn = self.adjoint_equations(checkpointer, nout_override);
let state = self.bdf_state_adjoint::<LS, _>(&mut augmented_eqn)?;
self.bdf_solver_adjoint_from_state::<LS, _>(state, augmented_eqn)
}
#[allow(clippy::type_complexity)]
pub fn bdf_solver_sens<LS: LinearSolver<Eqn::M>>(
&self,
state: BdfState<Eqn::V>,
) -> Result<
Bdf<
'_,
Eqn,
NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>,
<Eqn::V as DefaultDenseMatrix>::M,
SensEquations<'_, Eqn>,
>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
{
let sens_eqn = SensEquations::new(self);
self.bdf_solver_aug(state, sens_eqn)
}
#[allow(clippy::type_complexity)]
pub fn bdf_sens<LS: LinearSolver<Eqn::M>>(
&self,
) -> Result<
Bdf<
'_,
Eqn,
NewtonNonlinearSolver<Eqn::M, LS, NoLineSearch>,
<Eqn::V as DefaultDenseMatrix>::M,
SensEquations<'_, Eqn>,
>,
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
{
let state = self.bdf_state_sens::<LS>()?;
self.bdf_solver_sens(state)
}
pub fn rk_state<DM: DenseMatrix>(
&self,
tableau: &Tableau<DM>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquations,
{
RkState::new(self, tableau.order())
}
pub fn rk_state_and_consistent<LS: LinearSolver<Eqn::M>, DM: DenseMatrix>(
&self,
tableau: &Tableau<DM>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
RkState::new_and_consistent::<LS, _>(self, tableau.order())
}
pub fn rk_state_sens<DM: DenseMatrix>(
&self,
tableau: &Tableau<DM>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicitSens,
{
RkState::new_with_sensitivities(self, tableau.order())
}
pub fn rk_state_sens_and_consistent<LS: LinearSolver<Eqn::M>, DM: DenseMatrix>(
&self,
tableau: &Tableau<DM>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicitSens,
{
RkState::new_with_sensitivities_and_consistent::<LS, _>(self, tableau.order())
}
pub fn sdirk_solver<
LS: LinearSolver<Eqn::M>,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
>(
&self,
state: RkState<Eqn::V>,
tableau: Tableau<DM>,
) -> Result<Sdirk<'_, Eqn, LS, DM>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
let linear_solver = LS::default();
Sdirk::new(self, state, tableau, linear_solver)
}
pub(crate) fn sdirk_solver_aug<
LS: LinearSolver<Eqn::M>,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
Aug: AugmentedOdeEquationsImplicit<Eqn>,
>(
&self,
state: RkState<Eqn::V>,
tableau: Tableau<DM>,
aug_eqn: Aug,
) -> Result<Sdirk<'_, Eqn, LS, DM, Aug>, DiffsolError>
where
Eqn: OdeEquationsImplicit,
{
Sdirk::new_augmented(self, state, tableau, LS::default(), aug_eqn)
}
pub(crate) fn sdirk_solver_adjoint<
'a,
LS: LinearSolver<Eqn::M>,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
tableau: Tableau<DM>,
checkpointer: Checkpointing<'a, Eqn, S>,
nout_override: Option<usize>,
) -> Result<Sdirk<'a, Eqn, LS, DM, AdjointEquations<'a, Eqn, S>>, DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint,
{
let mut augmented_eqn = self.adjoint_equations(checkpointer, nout_override);
let state = self.sdirk_state_adjoint::<LS, _, _>(tableau.clone(), &mut augmented_eqn)?;
self.sdirk_solver_adjoint_from_state::<LS, DM, _>(tableau, state, augmented_eqn)
}
pub(crate) fn sdirk_state_adjoint<
'a,
LS: LinearSolver<Eqn::M>,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
tableau: Tableau<DM>,
augmented_eqn: &mut AdjointEquations<'a, Eqn, S>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint,
{
let t = augmented_eqn.last_t();
let h = augmented_eqn.last_h();
let mut state = RkState::new_without_initialise_augmented_at(self, augmented_eqn, t)?;
*state.as_mut().t = t;
if let Some(h) = h {
*state.as_mut().h = -h;
}
let mut newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
state.set_consistent(self, &mut newton_solver)?;
state.set_consistent_augmented(self, augmented_eqn, &mut newton_solver)?;
state.set_step_size(
state.h,
augmented_eqn.atol().unwrap(),
augmented_eqn.rtol().unwrap(),
augmented_eqn,
tableau.order(),
);
Ok(state)
}
pub(crate) fn sdirk_solver_adjoint_from_state<
'a,
LS: LinearSolver<Eqn::M>,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
tableau: Tableau<DM>,
mut state: RkState<Eqn::V>,
mut augmented_eqn: AdjointEquations<'a, Eqn, S>,
) -> Result<Sdirk<'a, Eqn, LS, DM, AdjointEquations<'a, Eqn, S>>, DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint,
{
let mut newton_solver = NewtonNonlinearSolver::new(LS::default(), NoLineSearch);
state.set_consistent(self, &mut newton_solver)?;
state.set_consistent_augmented(self, &mut augmented_eqn, &mut newton_solver)?;
Sdirk::new_augmented(self, state, tableau, LS::default(), augmented_eqn)
}
pub fn sdirk_solver_sens<
LS: LinearSolver<Eqn::M>,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
>(
&self,
state: RkState<Eqn::V>,
tableau: Tableau<DM>,
) -> Result<Sdirk<'_, Eqn, LS, DM, SensEquations<'_, Eqn>>, DiffsolError>
where
Eqn: OdeEquationsImplicitSens,
{
let sens_eqn = SensEquations::new(self);
self.sdirk_solver_aug::<LS, DM, _>(state, tableau, sens_eqn)
}
sdirk_solver_from_tableau!(
tr_bdf2,
tr_bdf2_sens,
tr_bdf2_solver,
tr_bdf2_solver_sens,
tr_bdf2_state_adjoint,
tr_bdf2_solver_adjoint,
tr_bdf2_solver_adjoint_from_state,
tr_bdf2
);
sdirk_solver_from_tableau!(
esdirk34,
esdirk34_sens,
esdirk34_solver,
esdirk34_solver_sens,
esdirk34_state_adjoint,
esdirk34_solver_adjoint,
esdirk34_solver_adjoint_from_state,
esdirk34
);
pub fn explicit_rk_solver<DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>>(
&self,
state: RkState<Eqn::V>,
tableau: Tableau<DM>,
) -> Result<ExplicitRk<'_, Eqn, DM>, DiffsolError>
where
Eqn: OdeEquations,
{
ExplicitRk::new(self, state, tableau)
}
pub(crate) fn explicit_rk_solver_aug<
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
Aug: AugmentedOdeEquations<Eqn>,
>(
&self,
state: RkState<Eqn::V>,
tableau: Tableau<DM>,
aug_eqn: Aug,
) -> Result<ExplicitRk<'_, Eqn, DM, Aug>, DiffsolError>
where
Eqn: OdeEquations,
{
ExplicitRk::new_augmented(self, state, tableau, aug_eqn)
}
pub(crate) fn explicit_rk_solver_adjoint<
'a,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
tableau: Tableau<DM>,
checkpointer: Checkpointing<'a, Eqn, S>,
nout_override: Option<usize>,
) -> Result<ExplicitRk<'a, Eqn, DM, AdjointEquations<'a, Eqn, S>>, DiffsolError>
where
Eqn: OdeEquationsAdjoint,
{
let mut augmented_eqn = self.adjoint_equations(checkpointer, nout_override);
let state = self.explicit_rk_state_adjoint(tableau.clone(), &mut augmented_eqn)?;
self.explicit_rk_solver_adjoint_from_state(tableau, state, augmented_eqn)
}
pub(crate) fn explicit_rk_state_adjoint<
'a,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
tableau: Tableau<DM>,
augmented_eqn: &mut AdjointEquations<'a, Eqn, S>,
) -> Result<RkState<Eqn::V>, DiffsolError>
where
Eqn: OdeEquationsAdjoint,
{
let t = augmented_eqn.last_t();
let h = augmented_eqn.last_h();
let mut state = RkState::new_without_initialise_augmented_at(self, augmented_eqn, t)?;
*state.as_mut().t = t;
if let Some(h) = h {
*state.as_mut().h = -h;
}
state.set_step_size(
state.h,
augmented_eqn.atol().unwrap(),
augmented_eqn.rtol().unwrap(),
augmented_eqn,
tableau.order(),
);
Ok(state)
}
pub(crate) fn explicit_rk_solver_adjoint_from_state<
'a,
DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>,
S: OdeSolverMethod<'a, Eqn>,
>(
&'a self,
tableau: Tableau<DM>,
state: RkState<Eqn::V>,
augmented_eqn: AdjointEquations<'a, Eqn, S>,
) -> Result<ExplicitRk<'a, Eqn, DM, AdjointEquations<'a, Eqn, S>>, DiffsolError>
where
Eqn: OdeEquationsAdjoint,
{
ExplicitRk::new_augmented(self, state, tableau, augmented_eqn)
}
pub fn explicit_rk_solver_sens<DM: DenseMatrix<V = Eqn::V, T = Eqn::T, C = Eqn::C>>(
&self,
state: RkState<Eqn::V>,
tableau: Tableau<DM>,
) -> Result<ExplicitRk<'_, Eqn, DM, SensEquations<'_, Eqn>>, DiffsolError>
where
Eqn: OdeEquationsImplicitSens,
{
let sens_eqn = SensEquations::new(self);
self.explicit_rk_solver_aug::<DM, _>(state, tableau, sens_eqn)
}
rk_solver_from_tableau!(
tsit45,
tsit45_sens,
tsit45_solver,
tsit45_solver_sens,
tsit45_state_adjoint,
tsit45_solver_adjoint,
tsit45_solver_adjoint_from_state,
tsit45
);
}
#[derive(Debug, Clone)]
pub struct OdeSolverSolutionPoint<V: Vector> {
pub state: V,
pub t: V::T,
}
pub struct OdeSolverSolution<V: Vector> {
pub solution_points: Vec<OdeSolverSolutionPoint<V>>,
pub sens_solution_points: Option<Vec<Vec<OdeSolverSolutionPoint<V>>>>,
pub rtol: V::T,
pub atol: V,
pub negative_time: bool,
}
impl<V: Vector> OdeSolverSolution<V> {
pub fn push(&mut self, state: V, t: V::T) {
let index = self.get_index(t);
self.solution_points
.insert(index, OdeSolverSolutionPoint { state, t });
}
fn get_index(&self, t: V::T) -> usize {
if self.negative_time {
self.solution_points
.iter()
.position(|x| x.t < t)
.unwrap_or(self.solution_points.len())
} else {
self.solution_points
.iter()
.position(|x| x.t > t)
.unwrap_or(self.solution_points.len())
}
}
pub fn push_sens(&mut self, state: V, t: V::T, sens: &[V]) {
let index = self.get_index(t);
self.solution_points
.insert(index, OdeSolverSolutionPoint { state, t });
if self.sens_solution_points.is_none() {
self.sens_solution_points = Some(vec![vec![]; sens.len()]);
}
for (i, s) in sens.iter().enumerate() {
self.sens_solution_points.as_mut().unwrap()[i].insert(
index,
OdeSolverSolutionPoint {
state: s.clone(),
t,
},
);
}
}
}
impl<V: Vector> Default for OdeSolverSolution<V> {
fn default() -> Self {
Self {
solution_points: Vec::new(),
sens_solution_points: None,
rtol: V::T::from_f64(1e-6).unwrap(),
atol: V::from_element(1, V::T::from_f64(1e-6).unwrap(), V::C::default()),
negative_time: false,
}
}
}