use num_traits::{One, Zero};
use std::cell::RefCell;
use crate::{
op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations, ConstantOp, ConstantOpSens,
Matrix, NonLinearOp, NonLinearOpSens, OdeEquations, OdeEquationsImplicitSens, OdeEquationsRef,
OdeSolverProblem, Op, Vector,
};
pub struct SensInit<'a, Eqn>
where
Eqn: OdeEquations,
{
eqn: &'a Eqn,
index: usize,
tmp: Eqn::V,
t0: Eqn::T,
}
impl<'a, Eqn> SensInit<'a, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
pub fn new(eqn: &'a Eqn, t0: Eqn::T) -> Self {
let index = 0;
let nparams = eqn.rhs().nparams();
let tmp = Eqn::V::zeros(nparams, eqn.context().clone());
Self {
tmp,
eqn,
index,
t0,
}
}
pub fn set_param_index(&mut self, index: usize) {
self.tmp.set_index(self.index, Eqn::T::zero());
self.index = index;
self.tmp.set_index(self.index, Eqn::T::one());
}
}
impl<Eqn> Op for SensInit<'_, Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
type C = Eqn::C;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
fn context(&self) -> &Self::C {
self.eqn.context()
}
}
impl<Eqn> ConstantOp for SensInit<'_, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
self.eqn.init().sens_mul_inplace(self.t0, &self.tmp, y);
}
}
pub struct SensRhs<'a, Eqn>
where
Eqn: OdeEquations,
{
eqn: &'a Eqn,
sens: RefCell<Eqn::M>,
y: RefCell<Eqn::V>,
index: RefCell<usize>,
}
impl<'a, Eqn> SensRhs<'a, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
pub fn new(eqn: &'a Eqn, allocate: bool) -> Self {
if !allocate {
return Self {
eqn,
sens: RefCell::new(<Eqn::M as Matrix>::zeros(0, 0, eqn.context().clone())),
y: RefCell::new(<Eqn::V as Vector>::zeros(0, eqn.context().clone())),
index: RefCell::new(0),
};
}
let nstates = eqn.rhs().nstates();
let nparams = eqn.rhs().nparams();
let rhs_sens = Eqn::M::new_from_sparsity(
nstates,
nparams,
eqn.rhs().sens_sparsity().map(|s| s.to_owned()),
eqn.context().clone(),
);
let y = RefCell::new(<Eqn::V as Vector>::zeros(nstates, eqn.context().clone()));
let index = RefCell::new(0);
Self {
eqn,
sens: RefCell::new(rhs_sens),
y,
index,
}
}
pub fn update_state(&mut self, y: &Eqn::V, _dy: &Eqn::V, t: Eqn::T) {
let mut sens = self.sens.borrow_mut();
self.eqn.rhs().sens_inplace(y, t, &mut sens);
let mut state_y = self.y.borrow_mut();
state_y.copy_from(y);
}
pub fn set_param_index(&self, index: usize) {
self.index.replace(index);
}
}
impl<Eqn> Op for SensRhs<'_, Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
type C = Eqn::C;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
fn context(&self) -> &Self::C {
self.eqn.context()
}
}
impl<Eqn> NonLinearOp for SensRhs<'_, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
let state_y = self.y.borrow();
let sens = self.sens.borrow();
let index = *self.index.borrow();
self.eqn.rhs().jac_mul_inplace(&state_y, t, x, y);
sens.add_column_to_vector(index, y);
}
}
impl<Eqn> NonLinearOpJacobian for SensRhs<'_, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
let state_y = self.y.borrow();
self.eqn.rhs().jac_mul_inplace(&state_y, t, v, y);
}
fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) {
let state_y = self.y.borrow();
self.eqn.rhs().jacobian_inplace(&state_y, t, y);
}
fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
self.eqn.rhs().jacobian_sparsity()
}
}
pub struct SensEquations<'a, Eqn>
where
Eqn: OdeEquations,
{
eqn: &'a Eqn,
rhs: SensRhs<'a, Eqn>,
init: SensInit<'a, Eqn>,
atol: Option<&'a Eqn::V>,
rtol: Option<Eqn::T>,
}
impl<Eqn> Clone for SensEquations<'_, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
fn clone(&self) -> Self {
Self {
eqn: self.eqn,
rhs: SensRhs::new(self.eqn, false),
init: SensInit::new(self.eqn, self.init.t0),
rtol: self.rtol,
atol: self.atol,
}
}
}
impl<Eqn> std::fmt::Debug for SensEquations<'_, Eqn>
where
Eqn: OdeEquations,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SensEquations")
}
}
impl<'a, Eqn> SensEquations<'a, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
pub(crate) fn new(problem: &'a OdeSolverProblem<Eqn>) -> Self {
let eqn = &problem.eqn;
let rtol = problem.sens_rtol;
let atol = problem.sens_atol.as_ref();
let rhs = SensRhs::new(eqn, true);
let init = SensInit::new(eqn, problem.t0);
Self {
rhs,
init,
eqn,
rtol,
atol,
}
}
}
impl<Eqn> Op for SensEquations<'_, Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
type C = Eqn::C;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nout()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
fn context(&self) -> &Self::C {
self.eqn.context()
}
}
impl<'a, 'b, Eqn> OdeEquationsRef<'a> for SensEquations<'b, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
type Rhs = &'a SensRhs<'b, Eqn>;
type Mass = <Eqn as OdeEquationsRef<'a>>::Mass;
type Root = <Eqn as OdeEquationsRef<'a>>::Root;
type Init = &'a SensInit<'b, Eqn>;
type Out = <Eqn as OdeEquationsRef<'a>>::Out;
type Reset = <Eqn as OdeEquationsRef<'a>>::Reset;
}
impl<'a, Eqn> OdeEquations for SensEquations<'a, Eqn>
where
Eqn: OdeEquationsImplicitSens,
{
fn rhs(&self) -> &SensRhs<'a, Eqn> {
&self.rhs
}
fn mass(&self) -> Option<<Eqn as OdeEquationsRef<'_>>::Mass> {
self.eqn.mass()
}
fn root(&self) -> Option<<Eqn as OdeEquationsRef<'_>>::Root> {
None
}
fn init(&self) -> &SensInit<'a, Eqn> {
&self.init
}
fn out(&self) -> Option<<Eqn as OdeEquationsRef<'_>>::Out> {
None
}
fn set_params(&mut self, p: &Self::V) {
self.eqn.set_params(p);
}
fn set_model_index(&mut self, m: usize) {
self.eqn.set_model_index(m);
}
fn get_params(&self, p: &mut Self::V) {
self.eqn.get_params(p);
}
}
impl<Eqn: OdeEquationsImplicitSens> AugmentedOdeEquations<Eqn> for SensEquations<'_, Eqn> {
fn include_in_error_control(&self) -> bool {
self.rtol.is_some() && self.atol.is_some()
}
fn include_out_in_error_control(&self) -> bool {
false
}
fn rtol(&self) -> Option<Eqn::T> {
self.rtol
}
fn atol(&self) -> Option<&Eqn::V> {
self.atol
}
fn out_atol(&self) -> Option<&Eqn::V> {
None
}
fn out_rtol(&self) -> Option<Eqn::T> {
None
}
fn max_index(&self) -> usize {
self.nparams()
}
fn update_rhs_out_state(&mut self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T) {
self.rhs.update_state(y, dy, t);
}
fn set_index(&mut self, index: usize) {
self.rhs.set_param_index(index);
self.init.set_param_index(index);
}
fn integrate_main_eqn(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use crate::{
matrix::dense_nalgebra_serial::NalgebraMat,
ode_equations::test_models::{
exponential_decay::exponential_decay_problem_sens,
exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem_sens,
robertson::robertson_sens,
},
AugmentedOdeEquations, DenseMatrix, MatrixCommon, NalgebraVec, NonLinearOp, RkState,
SensEquations, Vector,
};
type Mcpu = NalgebraMat<f64>;
type Vcpu = NalgebraVec<f64>;
#[test]
fn test_rhs_exponential() {
let (problem, _soln) = exponential_decay_problem_sens::<Mcpu>(false);
let mut sens_eqn = SensEquations::new(&problem);
let state = RkState {
t: 0.0,
y: Vcpu::from_vec(vec![1.0, 1.0], *problem.context()),
dy: Vcpu::from_vec(vec![1.0, 1.0], *problem.context()),
g: Vcpu::zeros(0, *problem.context()),
dg: Vcpu::zeros(0, *problem.context()),
sg: Vec::new(),
dsg: Vec::new(),
s: Vec::new(),
ds: Vec::new(),
h: 0.0,
};
sens_eqn.update_rhs_out_state(&state.y, &state.dy, state.t);
let sens = sens_eqn.rhs.sens.borrow();
assert_eq!(sens.nrows(), 2);
assert_eq!(sens.ncols(), 2);
assert_eq!(sens.get_index(0, 0), -1.0);
assert_eq!(sens.get_index(1, 0), -1.0);
sens_eqn.rhs.set_param_index(0);
let s = Vcpu::from_vec(vec![1.0, 2.0], *problem.context());
let f = sens_eqn.rhs.call(&s, state.t);
let f_expect = Vcpu::from_vec(vec![-1.1, -1.2], *problem.context());
f.assert_eq_st(&f_expect, 1e-10);
}
#[test]
fn test_rhs_exponential_algebraic() {
let (problem, _soln) = exponential_decay_with_algebraic_problem_sens::<Mcpu>();
let mut sens_eqn = SensEquations::new(&problem);
let state = RkState {
t: 0.0,
y: Vcpu::from_vec(vec![1.0, 1.0, 1.0], *problem.context()),
dy: Vcpu::from_vec(vec![1.0, 1.0, 1.0], *problem.context()),
g: Vcpu::zeros(0, *problem.context()),
dg: Vcpu::zeros(0, *problem.context()),
sg: Vec::new(),
dsg: Vec::new(),
s: Vec::new(),
ds: Vec::new(),
h: 0.0,
};
sens_eqn.update_rhs_out_state(&state.y, &state.dy, state.t);
let sens = sens_eqn.rhs.sens.borrow();
assert_eq!(sens.nrows(), 3);
assert_eq!(sens.ncols(), 1);
assert_eq!(sens.get_index(0, 0), -1.0);
assert_eq!(sens.get_index(1, 0), -1.0);
assert_eq!(sens.get_index(2, 0), 0.0);
sens_eqn.rhs.y.borrow().assert_eq_st(&state.y, 1e-10);
sens_eqn.rhs.set_param_index(0);
assert_eq!(sens_eqn.rhs.index.borrow().clone(), 0);
let s = Vcpu::from_vec(vec![1.0, 1.0, 1.0], *problem.context());
let f = sens_eqn.rhs.call(&s, state.t);
let f_expect = Vcpu::from_vec(vec![-1.1, -1.1, 0.0], *problem.context());
f.assert_eq_st(&f_expect, 1e-10);
}
#[test]
fn test_rhs_robertson() {
let (problem, _soln) = robertson_sens::<Mcpu>();
let mut sens_eqn = SensEquations::new(&problem);
let state = RkState {
t: 0.0,
y: Vcpu::from_vec(vec![1.0, 2.0, 3.0], *problem.context()),
dy: Vcpu::from_vec(vec![1.0, 1.0, 1.0], *problem.context()),
g: Vcpu::zeros(0, *problem.context()),
dg: Vcpu::zeros(0, *problem.context()),
sg: Vec::new(),
dsg: Vec::new(),
s: Vec::new(),
ds: Vec::new(),
h: 0.0,
};
sens_eqn.update_rhs_out_state(&state.y, &state.dy, state.t);
let sens = sens_eqn.rhs.sens.borrow();
assert_eq!(sens.nrows(), 3);
assert_eq!(sens.ncols(), 3);
assert_eq!(sens.get_index(0, 0), -state.y[0]);
assert_eq!(sens.get_index(0, 1), state.y[1] * state.y[2]);
assert_eq!(sens.get_index(0, 2), 0.0);
assert_eq!(sens.get_index(1, 0), state.y[0]);
assert_eq!(sens.get_index(1, 1), -state.y[1] * state.y[2]);
assert_eq!(sens.get_index(1, 2), -state.y[1] * state.y[1]);
assert_eq!(sens.get_index(2, 0), 0.0);
assert_eq!(sens.get_index(2, 1), 0.0);
assert_eq!(sens.get_index(2, 2), 0.0);
}
}