use crate::{
error::{DiffsolError, OdeSolverError},
ode_solver_error, AdjointEquations, AugmentedOdeEquations, AugmentedOdeSolverMethod,
CheckpointingPath, DefaultDenseMatrix, DefaultSolver, DenseMatrix, LinearSolver, Matrix,
MatrixCommon, MatrixOp, NonLinearOpAdjoint, NonLinearOpSensAdjoint, OdeEquations,
OdeEquationsImplicitAdjoint, OdeSolverMethod, OdeSolverState, OdeSolverStopReason, Op,
StateRef, Vector, VectorIndex,
};
use num_traits::{One, Zero};
use std::ops::{AddAssign, SubAssign};
pub trait AdjointOdeSolverMethod<'a, Eqn, Solver>:
AugmentedOdeSolverMethod<'a, Eqn, AdjointEquations<'a, Eqn, Solver>>
where
Eqn: OdeEquationsImplicitAdjoint + 'a,
Solver: OdeSolverMethod<'a, Eqn>,
{
fn apply_reset_with_adjoint(
&mut self,
root_idx: usize,
fwd_state_minus: StateRef<'_, Eqn::V>,
fwd_state_plus: StateRef<'_, Eqn::V>,
) -> Result<(), DiffsolError> {
let (reset, root) = {
let eqn = &self.problem().eqn;
(
eqn.reset().ok_or_else(|| {
ode_solver_error!(Other, "No reset operator configured for this problem")
})?,
eqn.root().ok_or_else(|| {
ode_solver_error!(Other, "No root operator configured for this problem")
})?,
)
};
let integrate_out = self.problem().integrate_out;
let (mut state, adj_eqn) = self
.state_and_augmented_eqn_mut()
.ok_or_else(|| ode_solver_error!(Other, "No augmented equations"))?;
state.state_mut_op_with_adjoint_and_reset(
adj_eqn.eqn(),
&reset,
&root,
root_idx,
fwd_state_minus,
fwd_state_plus,
integrate_out,
)
}
#[allow(clippy::type_complexity)]
fn solve_adjoint_backwards_pass(
mut self,
t_eval: &[Eqn::T],
dgdu_eval: &[&<Eqn::V as DefaultDenseMatrix>::M],
) -> Result<(Self::State, CheckpointingPath<Eqn, Solver::State>), DiffsolError>
where
Eqn::V: DefaultDenseMatrix,
Eqn::M: DefaultSolver,
{
let have_neqn = validate_adjoint_backwards_inputs(&self, t_eval, dgdu_eval)?;
let mut integrate_delta_g = if have_neqn > 0 && !dgdu_eval.is_empty() {
let integrate_delta_g =
IntegrateDeltaG::<_, <Eqn::M as DefaultSolver>::LS>::new(&self)?;
Some(integrate_delta_g)
} else {
None
};
let problem_t0 = self.problem().t0;
let solve_t1 = self.state().t;
let checkpointing_len = self.augmented_eqn().unwrap().checkpointing_len();
let (first_checkpoint_t, _) = self.augmented_eqn().unwrap().checkpointing_bounds(0);
let (_, last_checkpoint_t) = self
.augmented_eqn()
.unwrap()
.checkpointing_bounds(checkpointing_len - 1);
let path_starts_at_problem_t0 = problem_t0 == first_checkpoint_t;
if solve_t1 != last_checkpoint_t {
return Err(ode_solver_error!(
Other,
"Adjoint solver current time does not match the last checkpointing segment end time"
));
}
for segment_index in (0..checkpointing_len).rev() {
let (segment_first_t, segment_end_t) = self
.augmented_eqn()
.unwrap()
.checkpointing_bounds(segment_index);
solve_adjoint_backwards_segment(
&mut self,
segment_first_t,
segment_end_t,
segment_index + 1 < checkpointing_len,
t_eval,
dgdu_eval,
integrate_delta_g.as_mut(),
)?;
if segment_index > 0 {
let checkpointing = self
.augmented_eqn_mut()
.unwrap()
.pop_last_checkpointing()
.unwrap();
let fwd_state_plus = checkpointing.first_checkpoint();
let fwd_state_minus = self
.augmented_eqn()
.unwrap()
.checkpointing_last_state(segment_index - 1);
let root_idx = self
.augmented_eqn()
.unwrap()
.checkpointing_terminal_reset_root_idx(segment_index - 1)
.ok_or_else(|| {
ode_solver_error!(
Other,
"Missing reset root metadata between checkpointing segments"
)
})?;
self.apply_reset_with_adjoint(
root_idx,
fwd_state_minus.as_ref(),
fwd_state_plus.as_ref(),
)?;
}
}
let (mut state, aug_eqn) = self.into_state_and_eqn();
let aug_eqn = aug_eqn.unwrap();
if path_starts_at_problem_t0 {
let state_mut = state.as_mut();
aug_eqn.correct_sg_for_init(problem_t0, state_mut.s, state_mut.sg);
}
Ok((state, aug_eqn.into_checkpointing()))
}
}
fn validate_adjoint_backwards_inputs<'a, Eqn, Solver, AdjointSolver>(
solver: &AdjointSolver,
t_eval: &[Eqn::T],
dgdu_eval: &[&<Eqn::V as DefaultDenseMatrix>::M],
) -> Result<usize, DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint + 'a,
Eqn::V: DefaultDenseMatrix,
Solver: OdeSolverMethod<'a, Eqn>,
AdjointSolver: AdjointOdeSolverMethod<'a, Eqn, Solver>,
{
if solver.augmented_eqn().is_none() {
return Err(ode_solver_error!(Other, "No augmented equations"));
}
if t_eval.windows(2).any(|w| w[0] >= w[1]) {
return Err(ode_solver_error!(
Other,
"t_eval should be in increasing order"
));
}
let have_neqn = solver.augmented_eqn().unwrap().max_index();
if dgdu_eval.is_empty() {
let expected_neqn = solver.problem().eqn.out().map(|o| o.nout()).unwrap_or(0);
if have_neqn != expected_neqn {
return Err(ode_solver_error!(
Other,
format!("Number of augmented equations does not match number of model outputs: {} != {}", have_neqn, expected_neqn)
));
}
} else {
let expected_neqn = dgdu_eval.len();
if have_neqn != expected_neqn {
return Err(ode_solver_error!(
Other,
format!("Number of outputs in augmented equations does not match number of outputs in dgdu_eval: {} != {}", have_neqn, expected_neqn)
));
}
}
let nout = solver.problem().eqn.nout();
if dgdu_eval.iter().any(|dgdu| dgdu.nrows() != nout) {
return Err(ode_solver_error!(
Other,
"Number of outputs does not match number of rows in gradient"
));
}
if dgdu_eval.iter().any(|dgdu| dgdu.ncols() != t_eval.len()) {
return Err(ode_solver_error!(
Other,
"Number of solution timepoints does not match number of columns in gradient"
));
}
Ok(have_neqn)
}
fn solve_adjoint_backwards_segment<'a, Eqn, Solver, AdjointSolver>(
solver: &mut AdjointSolver,
solve_t0: Eqn::T,
solve_t1: Eqn::T,
exclude_t1: bool,
t_eval: &[Eqn::T],
dgdu_eval: &[&<Eqn::V as DefaultDenseMatrix>::M],
mut integrate_delta_g: Option<&mut IntegrateDeltaG<Eqn::M, <Eqn::M as DefaultSolver>::LS>>,
) -> Result<(), DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint + 'a,
Eqn::V: DefaultDenseMatrix,
Eqn::M: DefaultSolver,
Solver: OdeSolverMethod<'a, Eqn>,
AdjointSolver: AdjointOdeSolverMethod<'a, Eqn, Solver>,
{
for (i, t) in t_eval
.iter()
.enumerate()
.rev()
.filter(|(_, t)| **t <= solve_t1 && **t >= solve_t0)
.filter(|(_, t)| !(exclude_t1 && **t == solve_t1))
{
match solver.set_stop_time(*t) {
Ok(_) => while solver.step()? != OdeSolverStopReason::TstopReached {},
Err(DiffsolError::OdeSolverError(OdeSolverError::StopTimeAtCurrentTime)) => {}
e => e?,
}
if let Some(integrate_delta_g) = integrate_delta_g.as_deref_mut() {
let dudg_i = dgdu_eval.iter().map(|dgdu| dgdu.column(i));
integrate_delta_g.integrate_delta_g(solver, dudg_i)?;
}
}
match solver.set_stop_time(solve_t0) {
Ok(_) => while solver.step()? != OdeSolverStopReason::TstopReached {},
Err(DiffsolError::OdeSolverError(OdeSolverError::StopTimeAtCurrentTime)) => {}
e => e?,
}
Ok(())
}
struct BlockInfoSol<M: Matrix, LS: LinearSolver<M>> {
pub block: MatrixOp<M>,
pub src_indices: <M::V as Vector>::Index,
pub solver: LS,
}
struct BlockInfo<M: Matrix> {
pub block: MatrixOp<M>,
pub src_indices: <M::V as Vector>::Index,
}
struct PartitionInfo<I> {
pub algebraic_indices: I,
pub differential_indices: I,
}
struct IntegrateDeltaG<M: Matrix, LS: LinearSolver<M>> {
pub rhs_jac_aa: Option<BlockInfoSol<M, LS>>,
pub rhs_jac_ad: Option<BlockInfo<M>>,
pub mass_dd: Option<BlockInfoSol<M, LS>>,
pub partition: Option<PartitionInfo<<M::V as Vector>::Index>>,
pub tmp_algebraic: M::V,
pub tmp_differential: M::V,
pub tmp_differential2: M::V,
pub tmp_nparams: M::V,
pub tmp_nstates: M::V,
pub tmp_nstates2: M::V,
pub tmp_nout: M::V,
}
impl<M, LS> IntegrateDeltaG<M, LS>
where
M: Matrix,
LS: LinearSolver<M>,
{
fn new<'a, Eqn, Solver>(solver: &Solver) -> Result<Self, DiffsolError>
where
Eqn: OdeEquations<M = M, V = M::V, T = M::T, C = M::C> + 'a,
Solver: OdeSolverMethod<'a, Eqn>,
{
let eqn = &solver.problem().eqn;
let ctx = solver.problem().eqn.context();
let (partition, mass_dd, rhs_jac_aa, rhs_jac_ad) = if let Some(_mass) = eqn.mass() {
let mass_matrix = solver.mass().unwrap();
let (algebraic_indices, differential_indices) =
mass_matrix.partition_indices_by_zero_diagonal();
let [(dd, dd_idx), _, _, _] = mass_matrix.split(&algebraic_indices);
let mut mass_dd = BlockInfoSol {
block: MatrixOp::new(dd),
src_indices: dd_idx,
solver: LS::default(),
};
mass_dd.solver.set_problem(&mass_dd.block);
let (rhs_jac_aa, rhs_jac_ad) = if algebraic_indices.len() > 0 {
let jacobian = solver
.jacobian()
.ok_or(DiffsolError::from(OdeSolverError::JacobianNotAvailable))?;
let [_, (ad, ad_idx), _, (aa, aa_idx)] = jacobian.split(&algebraic_indices);
let mut rhs_jac_aa = BlockInfoSol {
block: MatrixOp::new(aa),
src_indices: aa_idx,
solver: LS::default(),
};
rhs_jac_aa.solver.set_problem(&rhs_jac_aa.block);
let rhs_jac_ad = BlockInfo {
block: MatrixOp::new(ad),
src_indices: ad_idx,
};
(Some(rhs_jac_aa), Some(rhs_jac_ad))
} else {
(None, None)
};
let partition = PartitionInfo {
algebraic_indices,
differential_indices,
};
(Some(partition), Some(mass_dd), rhs_jac_aa, rhs_jac_ad)
} else {
(None, None, None, None)
};
let nparams = eqn.rhs().nparams();
let nstates = eqn.rhs().nstates();
let nout = eqn.out().map(|o| o.nout()).unwrap_or(nstates);
let tmp_nstates = M::V::zeros(nstates, ctx.clone());
let tmp_nstates2 = M::V::zeros(nstates, ctx.clone());
let tmp_nparams = M::V::zeros(nparams, ctx.clone());
let tmp_nout = M::V::zeros(nout, ctx.clone());
let nalgebraic = partition
.as_ref()
.map(|p| p.algebraic_indices.len())
.unwrap_or(0);
let ndifferential = nstates - nalgebraic;
let tmp_algebraic = M::V::zeros(nalgebraic, ctx.clone());
let tmp_differential = M::V::zeros(ndifferential, ctx.clone());
let tmp_differential2 = M::V::zeros(ndifferential, ctx.clone());
Ok(Self {
rhs_jac_aa,
rhs_jac_ad,
mass_dd,
tmp_nparams,
tmp_algebraic,
partition,
tmp_nstates,
tmp_nout,
tmp_differential,
tmp_differential2,
tmp_nstates2,
})
}
fn integrate_delta_g<'a, 'b, Eqn, S1, Solver>(
&mut self,
solver: &mut Solver,
dgdus: impl Iterator<Item = <M::V as Vector>::View<'b>>,
) -> Result<(), DiffsolError>
where
Eqn: OdeEquationsImplicitAdjoint<M = M, V = M::V, T = M::T> + 'a,
Solver: AdjointOdeSolverMethod<'a, Eqn, S1>,
S1: OdeSolverMethod<'a, Eqn>,
{
let t = solver.state().t;
solver
.augmented_eqn()
.unwrap()
.interpolate_forward_state(t, &mut self.tmp_nstates)?;
if let Some(rhs_jac_aa) = self.rhs_jac_aa.as_mut() {
let jacobian = solver.jacobian().unwrap();
let rhs_jac_ad = self.rhs_jac_ad.as_mut().unwrap();
rhs_jac_ad
.block
.m_mut()
.gather(&jacobian, &rhs_jac_ad.src_indices);
rhs_jac_aa
.block
.m_mut()
.gather(&jacobian, &rhs_jac_aa.src_indices);
rhs_jac_aa.solver.set_linearisation(
&rhs_jac_aa.block,
&self.tmp_algebraic,
Eqn::T::zero(),
);
};
if let Some(mass_dd) = self.mass_dd.as_mut() {
let mass = solver.mass().unwrap();
mass_dd.block.m_mut().gather(&mass, &mass_dd.src_indices);
mass_dd.solver.set_linearisation(
&mass_dd.block,
&self.tmp_differential,
Eqn::T::zero(),
);
}
let out = solver.augmented_eqn().unwrap().eqn().out();
let sol_mdd_opt = self.mass_dd.as_ref().map(|m| &m.solver);
let sol_jaa_opt = self.rhs_jac_aa.as_ref().map(|m| &m.solver);
let state_mut = solver.state_mut();
for ((s_i, sg_i), dgdu) in state_mut
.s
.iter_mut()
.zip(state_mut.sg.iter_mut())
.zip(dgdus)
{
self.tmp_nout.copy_from_view(&dgdu);
if let (Some(out), Some(sol_mdd), Some(sol_jaa)) =
(out.as_ref(), sol_mdd_opt, sol_jaa_opt)
{
let p = self.partition.as_ref().unwrap();
out.jac_transpose_mul_inplace(
&self.tmp_nstates,
t,
&self.tmp_nout,
&mut self.tmp_nstates2,
);
self.tmp_differential
.gather(&self.tmp_nstates2, &p.differential_indices);
self.tmp_algebraic
.gather(&self.tmp_nstates2, &p.algebraic_indices);
sol_mdd.solve_in_place(&mut self.tmp_differential)?;
self.tmp_differential2.gather(s_i, &p.differential_indices);
self.tmp_differential2.sub_assign(&self.tmp_differential);
sol_jaa.solve_in_place(&mut self.tmp_algebraic)?;
let rhs_jac_ad = self.rhs_jac_ad.as_ref().unwrap().block.m();
rhs_jac_ad.gemv(
M::T::one(),
&self.tmp_algebraic,
M::T::one(),
&mut self.tmp_differential2,
);
self.tmp_differential2.scatter(&p.differential_indices, s_i);
} else if let (Some(out), Some(sol_mdd)) = (out.as_ref(), sol_mdd_opt) {
out.jac_transpose_mul_inplace(
&self.tmp_nstates,
t,
&self.tmp_nout,
&mut self.tmp_nstates2,
);
sol_mdd.solve_in_place(&mut self.tmp_nstates2)?;
s_i.sub_assign(&self.tmp_nstates2);
} else if let Some(out) = out.as_ref() {
out.jac_transpose_mul_inplace(
&self.tmp_nstates,
t,
&self.tmp_nout,
&mut self.tmp_nstates2,
);
s_i.sub_assign(&self.tmp_nstates2);
} else if let (Some(sol_mdd), Some(sol_jaa)) = (sol_mdd_opt, sol_jaa_opt) {
let p = self.partition.as_ref().unwrap();
self.tmp_differential
.gather(&self.tmp_nout, &p.differential_indices);
self.tmp_algebraic
.gather(&self.tmp_nout, &p.algebraic_indices);
sol_mdd.solve_in_place(&mut self.tmp_differential)?;
self.tmp_differential2.gather(s_i, &p.differential_indices);
self.tmp_differential2.add_assign(&self.tmp_differential);
sol_jaa.solve_in_place(&mut self.tmp_algebraic)?;
let rhs_jac_ad = self.rhs_jac_aa.as_ref().unwrap().block.m();
rhs_jac_ad.gemv(
M::T::one(),
&self.tmp_algebraic,
M::T::one(),
&mut self.tmp_differential2,
);
self.tmp_differential2.scatter(&p.differential_indices, s_i);
} else if let Some(sol_mdd) = sol_mdd_opt {
sol_mdd.solve_in_place(&mut self.tmp_nout)?;
s_i.add_assign(&self.tmp_nout);
} else {
s_i.add_assign(&self.tmp_nout);
}
if let Some(out) = out.as_ref() {
out.sens_transpose_mul_inplace(
&self.tmp_nstates,
t,
&self.tmp_nout,
&mut self.tmp_nparams,
);
sg_i.sub_assign(&self.tmp_nparams);
}
}
Ok(())
}
}