use crate::{
error::DiffsolError,
error::OdeSolverError,
ode_solver::solution::{Solution, SolutionMode},
ode_solver_error, AugmentedOdeSolverMethod, Context, DefaultDenseMatrix, DenseMatrix,
MatrixCommon, NonLinearOp, NonLinearOpJacobian, NonLinearOpSens, OdeEquationsImplicitSens,
OdeSolverStopReason, Op, SensEquations, Vector, VectorViewMut,
};
use num_traits::{One, Zero};
use std::ops::AddAssign;
pub trait SensitivitiesOdeSolverMethod<'a, Eqn>:
AugmentedOdeSolverMethod<'a, Eqn, SensEquations<'a, Eqn>>
where
Eqn: OdeEquationsImplicitSens + 'a,
{
fn solve_soln_sensitivities(mut self, soln: &mut Solution<Eqn::V>) -> Result<Self, DiffsolError>
where
Eqn::V: DefaultDenseMatrix,
Self: Sized,
{
if self.problem().integrate_out {
return Err(ode_solver_error!(
Other,
"Cannot integrate out when solving for sensitivities"
));
}
let start_col = match soln.mode {
SolutionMode::Tevals(start_col) => start_col,
SolutionMode::Tfinal(_) => {
return Err(ode_solver_error!(
Other,
"solve_soln_sensitivities requires Solution::new_dense"
));
}
};
let ctx = self.problem().context().clone();
let nrows = self
.problem()
.eqn
.out()
.map(|out| out.nout())
.unwrap_or_else(|| self.problem().eqn.rhs().nout());
let nstates = self.problem().eqn.rhs().nstates();
let nparams = self.problem().eqn.rhs().nparams();
let nout = self.problem().eqn.out().map(|out| out.nout()).unwrap_or(0);
let nout_params = self
.problem()
.eqn
.out()
.map(|out| out.nparams())
.unwrap_or(0);
soln.ensure_sens_allocation(&ctx, nrows, nout, nout_params, nstates, nparams)?;
let (stop_reason, col) = solve_dense_sensitivities(
&mut soln.ys,
&mut soln.y_sens,
&soln.ts,
&mut soln.tmp_nout,
&mut soln.tmp_nparams,
&mut soln.tmp_nstates,
&mut soln.tmp_nsens,
&mut self,
start_col,
)?;
soln.stop_reason = Some(stop_reason);
soln.mode = SolutionMode::Tevals(col);
Ok(self)
}
#[allow(clippy::type_complexity)]
fn solve_dense_sensitivities(
&mut self,
t_eval: &[Eqn::T],
) -> Result<
(
<Eqn::V as DefaultDenseMatrix>::M,
Vec<<Eqn::V as DefaultDenseMatrix>::M>,
OdeSolverStopReason<Eqn::T>,
),
DiffsolError,
>
where
Eqn: OdeEquationsImplicitSens,
Eqn::V: DefaultDenseMatrix,
Self: Sized,
{
if self.problem().integrate_out {
return Err(ode_solver_error!(
Other,
"Cannot integrate out when solving for sensitivities"
));
}
let nrows = if let Some(out) = self.problem().eqn.out() {
out.nout()
} else {
self.problem().eqn.rhs().nout()
};
let nstates = self.problem().eqn.rhs().nstates();
let nparams = self.problem().eqn.rhs().nparams();
let ctx = self.problem().context().clone();
let mut ret = ctx.dense_mat_zeros::<Eqn::V>(nrows, t_eval.len());
let mut ret_sens = vec![ctx.dense_mat_zeros::<Eqn::V>(nrows, t_eval.len()); nparams];
let mut tmp_nout = Eqn::V::zeros(
self.problem().eqn.out().map(|out| out.nout()).unwrap_or(0),
ctx.clone(),
);
let mut tmp_nparams = Eqn::V::zeros(
self.problem()
.eqn
.out()
.map(|out| out.nparams())
.unwrap_or(0),
ctx.clone(),
);
let mut tmp_nstates = Eqn::V::zeros(nstates, ctx.clone());
let mut tmp_nsens = vec![Eqn::V::zeros(nstates, ctx); nparams];
let t0 = self.state().t;
if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) {
return Err(ode_solver_error!(InvalidTEval));
}
let (stop_reason, col) = solve_dense_sensitivities(
&mut ret,
&mut ret_sens,
t_eval,
&mut tmp_nout,
&mut tmp_nparams,
&mut tmp_nstates,
&mut tmp_nsens,
self,
0,
)?;
if let OdeSolverStopReason::RootFound(_, _) = stop_reason {
if col < t_eval.len() {
let t = self.state().t;
dense_write_out_sensitivities(
self,
&mut ret,
&mut ret_sens,
t,
col,
&mut tmp_nout,
&mut tmp_nparams,
&mut tmp_nstates,
&mut tmp_nsens,
)?;
if col + 1 < ret.ncols() {
ret.resize_cols(col + 1);
for rs in &mut ret_sens {
rs.resize_cols(col + 1);
}
}
}
}
Ok((ret, ret_sens, stop_reason))
}
}
#[allow(clippy::too_many_arguments)]
fn solve_dense_sensitivities<'a, Eqn, S>(
ret: &mut <Eqn::V as DefaultDenseMatrix>::M,
ret_sens: &mut [<Eqn::V as DefaultDenseMatrix>::M],
t_eval: &[Eqn::T],
tmp_nout: &mut Eqn::V,
tmp_nparams: &mut Eqn::V,
tmp_nstates: &mut Eqn::V,
tmp_nsens: &mut [Eqn::V],
s: &mut S,
start_col: usize,
) -> Result<(OdeSolverStopReason<Eqn::T>, usize), DiffsolError>
where
Eqn: OdeEquationsImplicitSens + 'a,
Eqn::V: DefaultDenseMatrix,
S: SensitivitiesOdeSolverMethod<'a, Eqn>,
{
s.set_stop_time(t_eval[t_eval.len() - 1])?;
let mut stop_reason: OdeSolverStopReason<Eqn::T>;
let mut col = start_col;
loop {
stop_reason = s.step()?;
let t_current = if let OdeSolverStopReason::RootFound(t, _) = stop_reason {
t
} else {
s.state().t
};
while col < t_eval.len() && t_eval[col] <= t_current {
dense_write_out_sensitivities(
s,
ret,
ret_sens,
t_eval[col],
col,
tmp_nout,
tmp_nparams,
tmp_nstates,
tmp_nsens,
)?;
col += 1;
}
match stop_reason {
OdeSolverStopReason::InternalTimestep => {}
OdeSolverStopReason::TstopReached => {
assert!(
col == t_eval.len(),
"Solver reached stop time before consuming all t_eval points, this should not happen"
);
break;
}
OdeSolverStopReason::RootFound(t_root, _) => {
s.state_mut_back(t_root)?;
break;
}
}
}
Ok((stop_reason, col))
}
#[allow(clippy::too_many_arguments)]
fn dense_write_out_sensitivities<'a, Eqn, S>(
s: &S,
ret: &mut <Eqn::V as DefaultDenseMatrix>::M,
ret_sens: &mut [<Eqn::V as DefaultDenseMatrix>::M],
t: Eqn::T,
col: usize,
tmp_nout: &mut Eqn::V,
tmp_nparams: &mut Eqn::V,
tmp_nstates: &mut Eqn::V,
tmp_nsens: &mut [Eqn::V],
) -> Result<(), DiffsolError>
where
Eqn: OdeEquationsImplicitSens + 'a,
Eqn::V: DefaultDenseMatrix,
S: SensitivitiesOdeSolverMethod<'a, Eqn>,
{
s.interpolate_inplace(t, tmp_nstates)?;
s.interpolate_sens_inplace(t, tmp_nsens)?;
if let Some(out) = s.problem().eqn.out() {
out.call_inplace(tmp_nstates, t, tmp_nout);
ret.column_mut(col).copy_from(tmp_nout);
for (j, s_j) in tmp_nsens.iter().enumerate() {
let mut col_v = ret_sens[j].column_mut(col);
tmp_nparams.set_index(j, Eqn::T::one());
out.jac_mul_inplace(tmp_nstates, t, s_j, tmp_nout);
col_v.copy_from(&*tmp_nout);
out.sens_mul_inplace(tmp_nstates, t, tmp_nparams, tmp_nout);
col_v.add_assign(&*tmp_nout);
tmp_nparams.set_index(j, Eqn::T::zero());
}
} else {
ret.column_mut(col).copy_from(tmp_nstates);
for (j, s_j) in tmp_nsens.iter().enumerate() {
ret_sens[j].column_mut(col).copy_from(s_j);
}
}
Ok(())
}