use crate::{
scale, LinearOp, Matrix, MatrixSparsityRef, NonLinearOpJacobian, OdeEquationsImplicit, Vector,
VectorIndex,
};
use num_traits::One;
use std::cell::RefCell;
use super::{NonLinearOp, Op};
pub struct InitOp<'a, Eqn: OdeEquationsImplicit> {
eqn: &'a Eqn,
jac: Eqn::M,
pub y0: RefCell<Eqn::V>,
pub algebraic_indices: <Eqn::V as Vector>::Index,
neg_mass: Eqn::M,
}
impl<'a, Eqn: OdeEquationsImplicit> InitOp<'a, Eqn> {
pub fn new(
eqn: &'a Eqn,
t0: Eqn::T,
y0: &Eqn::V,
algebraic_indices: <Eqn::V as Vector>::Index,
) -> Self {
let n = eqn.rhs().nstates();
let rhs_jac = eqn.rhs().jacobian(y0, t0);
let mass = eqn.mass().unwrap().matrix(t0);
let [(m_u, _), _, _, _] = mass.split(&algebraic_indices);
let m_u = m_u * scale(-Eqn::T::one());
let [_, (dfdv, _), _, (dgdv, _)] = rhs_jac.split(&algebraic_indices);
let zero_ll = <Eqn::M as Matrix>::zeros(
algebraic_indices.len(),
n - algebraic_indices.len(),
eqn.context().clone(),
);
let zero_ur = <Eqn::M as Matrix>::zeros(
n - algebraic_indices.len(),
algebraic_indices.len(),
eqn.context().clone(),
);
let zero_lr = <Eqn::M as Matrix>::zeros(
algebraic_indices.len(),
algebraic_indices.len(),
eqn.context().clone(),
);
let jac = Eqn::M::combine(&m_u, &dfdv, &zero_ll, &dgdv, &algebraic_indices);
let neg_mass = Eqn::M::combine(&m_u, &zero_ur, &zero_ll, &zero_lr, &algebraic_indices);
let y0 = y0.clone();
let y0 = RefCell::new(y0);
Self {
eqn,
jac,
y0,
neg_mass,
algebraic_indices,
}
}
pub fn scatter_soln(&self, soln: &Eqn::V, y: &mut Eqn::V, dy: &mut Eqn::V) {
let tmp = dy.clone();
dy.copy_from(soln);
dy.copy_from_indices(&tmp, &self.algebraic_indices);
y.copy_from_indices(soln, &self.algebraic_indices);
}
}
impl<Eqn: OdeEquationsImplicit> Op for InitOp<'_, Eqn> {
type V = Eqn::V;
type T = Eqn::T;
type M = Eqn::M;
type C = Eqn::C;
fn nstates(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nout(&self) -> usize {
self.eqn.rhs().nstates()
}
fn nparams(&self) -> usize {
self.eqn.rhs().nparams()
}
fn context(&self) -> &Self::C {
self.eqn.context()
}
}
impl<Eqn: OdeEquationsImplicit> NonLinearOp for InitOp<'_, Eqn> {
fn call_inplace(&self, x: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) {
let mut y0 = self.y0.borrow_mut();
y0.copy_from_indices(x, &self.algebraic_indices);
self.eqn.rhs().call_inplace(&y0, t, y);
self.neg_mass.gemv(Eqn::T::one(), x, Eqn::T::one(), y);
}
}
impl<Eqn: OdeEquationsImplicit> NonLinearOpJacobian for InitOp<'_, Eqn> {
fn jac_mul_inplace(&self, _x: &Eqn::V, _t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) {
self.jac.gemv(Eqn::T::one(), v, Eqn::T::one(), y);
}
fn jacobian_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::M) {
y.copy_from(&self.jac);
}
fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
self.jac.sparsity().map(|x| x.to_owned())
}
}
#[cfg(test)]
mod tests {
use crate::ode_equations::test_models::exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem;
use crate::op::init::InitOp;
use crate::vector::Vector;
use crate::{
DenseMatrix, LinearOp, Matrix, NalgebraMat, NalgebraVec, NonLinearOp, NonLinearOpJacobian,
OdeEquations,
};
type Mcpu = NalgebraMat<f64>;
type Vcpu = NalgebraVec<f64>;
#[test]
fn test_initop() {
let (problem, _soln) = exponential_decay_with_algebraic_problem::<Mcpu>(false);
let y0 = Vcpu::from_vec(vec![1.0, 2.0, 3.0], *problem.context());
let dy0 = Vcpu::from_vec(vec![4.0, 5.0, 6.0], *problem.context());
let t = 0.0;
let (algebraic_indices, _) = problem
.eqn()
.mass()
.unwrap()
.matrix(t)
.partition_indices_by_zero_diagonal();
let initop = InitOp::new(&problem.eqn, t, &y0, algebraic_indices);
let mut y_out = Vcpu::from_vec(vec![0.0, 0.0, 0.0], *problem.context());
let du_v = Vcpu::from_vec(vec![dy0[0], dy0[1], y0[2]], *problem.context());
initop.call_inplace(&du_v, t, &mut y_out);
let y_out_expect = Vcpu::from_vec(vec![-4.1, -5.2, 1.0], *problem.context());
y_out.assert_eq_st(&y_out_expect, 1e-10);
let jac = initop.jacobian(&du_v, t);
assert_eq!(jac.get_index(0, 0), -1.0);
assert_eq!(jac.get_index(0, 1), 0.0);
assert_eq!(jac.get_index(0, 2), 0.0);
assert_eq!(jac.get_index(1, 0), 0.0);
assert_eq!(jac.get_index(1, 1), -1.0);
assert_eq!(jac.get_index(1, 2), 0.0);
assert_eq!(jac.get_index(2, 0), 0.0);
assert_eq!(jac.get_index(2, 1), 0.0);
assert_eq!(jac.get_index(2, 2), 1.0);
}
}