use num_traits::{One, Zero};
use std::{cell::RefCell, rc::Rc};
use crate::{ConstantOp, LinearOp, Matrix, MatrixSparsity, NonLinearOp, OdeEquations, Op, Vector};
pub struct SensInit<Eqn>
where
Eqn: OdeEquations,
{
eqn: Rc<Eqn>,
init_sens: RefCell<Eqn::M>,
index: RefCell<usize>,
}
impl<Eqn> SensInit<Eqn>
where
Eqn: OdeEquations,
{
pub fn new(eqn: &Rc<Eqn>) -> Self {
let nstates = eqn.rhs().nstates();
let nparams = eqn.rhs().nparams();
let init_sens = Eqn::M::new_from_sparsity(nstates, nparams, eqn.init().sparsity_sens());
let init_sens = RefCell::new(init_sens);
let index = RefCell::new(0);
Self {
eqn: eqn.clone(),
init_sens,
index,
}
}
pub fn update_state(&self, t: Eqn::T) {
let mut init_sens = self.init_sens.borrow_mut();
self.eqn.init().sens_inplace(t, &mut init_sens);
}
pub fn set_param_index(&self, index: usize) {
self.index.replace(index);
}
}
impl<Eqn> Op for SensInit<Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
}
impl<Eqn> ConstantOp for SensInit<Eqn>
where
Eqn: OdeEquations,
{
fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
let init_sens = self.init_sens.borrow();
let index = *self.index.borrow();
y.fill(Eqn::T::zero());
init_sens.add_column_to_vector(index, y);
}
}
pub struct SensRhs<Eqn>
where
Eqn: OdeEquations,
{
eqn: Rc<Eqn>,
sens: RefCell<Eqn::M>,
rhs_sens: Option<RefCell<Eqn::M>>,
mass_sens: Option<RefCell<Eqn::M>>,
y: RefCell<Eqn::V>,
index: RefCell<usize>,
}
impl<Eqn> SensRhs<Eqn>
where
Eqn: OdeEquations,
{
pub fn new(eqn: &Rc<Eqn>) -> Self {
let nstates = eqn.rhs().nstates();
let nparams = eqn.rhs().nparams();
let rhs_sens = Eqn::M::new_from_sparsity(nstates, nparams, eqn.rhs().sparsity_sens());
let y = RefCell::new(<Eqn::V as Vector>::zeros(nstates));
let index = RefCell::new(0);
if let Some(mass) = eqn.mass() {
let mass_sens = Eqn::M::new_from_sparsity(nstates, nparams, mass.sparsity_sens());
let sens = if rhs_sens.sparsity().is_some() && mass_sens.sparsity().is_some() {
let sparsity = rhs_sens
.sparsity()
.unwrap()
.union(mass_sens.sparsity().unwrap())
.unwrap();
Eqn::M::new_from_sparsity(nstates, nparams, Some(&sparsity))
} else {
Eqn::M::new_from_sparsity(nstates, nparams, None)
};
Self {
eqn: eqn.clone(),
sens: RefCell::new(sens),
rhs_sens: Some(RefCell::new(rhs_sens)),
mass_sens: Some(RefCell::new(mass_sens)),
y,
index,
}
} else {
Self {
eqn: eqn.clone(),
sens: RefCell::new(rhs_sens),
rhs_sens: None,
mass_sens: None,
y,
index,
}
}
}
pub fn update_state(&self, y: &Eqn::V, dy: &Eqn::V, t: Eqn::T) {
if self.rhs_sens.is_some() {
let mut rhs_sens = self.rhs_sens.as_ref().unwrap().borrow_mut();
let mut mass_sens = self.mass_sens.as_ref().unwrap().borrow_mut();
let mut sens = self.sens.borrow_mut();
self.eqn.rhs().sens_inplace(y, t, &mut rhs_sens);
self.eqn.mass().unwrap().sens_inplace(dy, t, &mut mass_sens);
sens.scale_add_and_assign(&rhs_sens, -Eqn::T::one(), &mass_sens);
} else {
let mut sens = self.sens.borrow_mut();
self.eqn.rhs().sens_inplace(y, t, &mut sens);
}
let mut state_y = self.y.borrow_mut();
state_y.copy_from(y);
}
pub fn set_param_index(&self, index: usize) {
self.index.replace(index);
}
}
impl<Eqn> Op for SensRhs<Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
}
impl<Eqn> NonLinearOp for SensRhs<Eqn>
where
Eqn: OdeEquations,
{
fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
let state_y = self.y.borrow();
let sens = self.sens.borrow();
let index = *self.index.borrow();
self.eqn.rhs().jac_mul_inplace(&state_y, t, x, y);
sens.add_column_to_vector(index, y);
}
fn jac_mul_inplace(&self, _x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
let state_y = self.y.borrow();
self.eqn.rhs().jac_mul_inplace(&state_y, t, v, y);
}
fn jacobian_inplace(&self, _x: &Self::V, t: Self::T, y: &mut Self::M) {
let state_y = self.y.borrow();
self.eqn.rhs().jacobian_inplace(&state_y, t, y);
}
}
pub struct SensEquations<Eqn>
where
Eqn: OdeEquations,
{
eqn: Rc<Eqn>,
rhs: Rc<SensRhs<Eqn>>,
init: Rc<SensInit<Eqn>>,
}
impl<Eqn> SensEquations<Eqn>
where
Eqn: OdeEquations,
{
pub fn new(eqn: &Rc<Eqn>) -> Self {
let rhs = Rc::new(SensRhs::new(eqn));
let init = Rc::new(SensInit::new(eqn));
Self {
rhs,
init,
eqn: eqn.clone(),
}
}
}
impl<Eqn> Op for SensEquations<Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nout()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
}
impl<Eqn> OdeEquations for SensEquations<Eqn>
where
Eqn: OdeEquations,
{
type T = Eqn::T;
type V = Eqn::V;
type M = Eqn::M;
type Rhs = SensRhs<Eqn>;
type Mass = Eqn::Mass;
type Root = Eqn::Root;
type Init = SensInit<Eqn>;
fn rhs(&self) -> &Rc<Self::Rhs> {
&self.rhs
}
fn mass(&self) -> Option<&Rc<Self::Mass>> {
self.eqn.mass()
}
fn root(&self) -> Option<&Rc<Self::Root>> {
self.eqn.root()
}
fn init(&self) -> &Rc<Self::Init> {
&self.init
}
fn set_params(&mut self, _p: Self::V) {
panic!("Not implemented for SensEquations");
}
}
#[cfg(test)]
mod tests {
use crate::{
ode_solver::test_models::{
exponential_decay::exponential_decay_problem_sens,
exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem_sens,
robertson_sens::robertson_sens,
},
NonLinearOp, OdeSolverState, SensEquations, Vector,
};
type Mcpu = nalgebra::DMatrix<f64>;
type Vcpu = nalgebra::DVector<f64>;
#[test]
fn test_rhs_exponential() {
let (problem, _soln) = exponential_decay_problem_sens::<Mcpu>(false);
let sens_eqn = SensEquations::new(&problem.eqn);
let state = OdeSolverState {
t: 0.0,
y: Vcpu::from_vec(vec![1.0, 1.0]),
dy: Vcpu::from_vec(vec![1.0, 1.0]),
s: Vec::new(),
ds: Vec::new(),
h: 0.0,
};
sens_eqn.rhs.update_state(&state.y, &state.dy, state.t);
let sens = sens_eqn.rhs.sens.borrow();
assert_eq!(sens.nrows(), 2);
assert_eq!(sens.ncols(), 1);
assert_eq!(sens[(0, 0)], -1.0);
assert_eq!(sens[(1, 0)], -1.0);
sens_eqn.rhs.set_param_index(0);
let s = Vcpu::from_vec(vec![1.0, 2.0]);
let f = sens_eqn.rhs.call(&s, state.t);
let f_expect = Vcpu::from_vec(vec![-1.1, -1.2]);
f.assert_eq_st(&f_expect, 1e-10);
}
#[test]
fn test_rhs_exponential_algebraic() {
let (problem, _soln) = exponential_decay_with_algebraic_problem_sens::<Mcpu>(false);
let sens_eqn = SensEquations::new(&problem.eqn);
let state = OdeSolverState {
t: 0.0,
y: Vcpu::from_vec(vec![1.0, 1.0, 1.0]),
dy: Vcpu::from_vec(vec![1.0, 1.0, 1.0]),
s: Vec::new(),
ds: Vec::new(),
h: 0.0,
};
sens_eqn.rhs.update_state(&state.y, &state.dy, state.t);
let sens = sens_eqn.rhs.sens.borrow();
assert_eq!(sens.nrows(), 3);
assert_eq!(sens.ncols(), 1);
assert_eq!(sens[(0, 0)], -1.0);
assert_eq!(sens[(1, 0)], -1.0);
assert_eq!(sens[(2, 0)], 0.0);
sens_eqn.rhs.y.borrow().assert_eq_st(&state.y, 1e-10);
sens_eqn.rhs.set_param_index(0);
assert_eq!(sens_eqn.rhs.index.borrow().clone(), 0);
let s = Vcpu::from_vec(vec![1.0, 1.0, 1.0]);
let f = sens_eqn.rhs.call(&s, state.t);
let f_expect = Vcpu::from_vec(vec![-1.1, -1.1, 0.0]);
f.assert_eq_st(&f_expect, 1e-10);
}
#[test]
fn test_rhs_robertson() {
let (problem, _soln) = robertson_sens::<Mcpu>(false);
let sens_eqn = SensEquations::new(&problem.eqn);
let state = OdeSolverState {
t: 0.0,
y: Vcpu::from_vec(vec![1.0, 2.0, 3.0]),
dy: Vcpu::from_vec(vec![1.0, 1.0, 1.0]),
s: Vec::new(),
ds: Vec::new(),
h: 0.0,
};
sens_eqn.rhs.update_state(&state.y, &state.dy, state.t);
let sens = sens_eqn.rhs.sens.borrow();
assert_eq!(sens.nrows(), 3);
assert_eq!(sens.ncols(), 3);
assert_eq!(sens[(0, 0)], -state.y[0]);
assert_eq!(sens[(0, 1)], state.y[1] * state.y[2]);
assert_eq!(sens[(0, 2)], 0.0);
assert_eq!(sens[(1, 0)], state.y[0]);
assert_eq!(sens[(1, 1)], -state.y[1] * state.y[2]);
assert_eq!(sens[(1, 2)], -state.y[1] * state.y[1]);
assert_eq!(sens[(2, 0)], 0.0);
assert_eq!(sens[(2, 1)], 0.0);
assert_eq!(sens[(2, 2)], 0.0);
}
}