1use crate::{
2 scale, LinearOp, Matrix, MatrixSparsityRef, NonLinearOpJacobian, OdeEquationsImplicit, Vector,
3 VectorIndex,
4};
5use num_traits::One;
6use std::cell::RefCell;
7
8use super::{NonLinearOp, Op};
9
10pub struct InitOp<'a, Eqn: OdeEquationsImplicit> {
15 eqn: &'a Eqn,
16 jac: Eqn::M,
17 pub y0: RefCell<Eqn::V>,
18 pub algebraic_indices: <Eqn::V as Vector>::Index,
19 neg_mass: Eqn::M,
20}
21
22impl<'a, Eqn: OdeEquationsImplicit> InitOp<'a, Eqn> {
23 pub fn new(eqn: &'a Eqn, t0: Eqn::T, y0: &Eqn::V) -> Self {
24 let n = eqn.rhs().nstates();
25 let (algebraic_indices, _) = eqn
26 .mass()
27 .unwrap()
28 .matrix(t0)
29 .partition_indices_by_zero_diagonal();
30
31 let rhs_jac = eqn.rhs().jacobian(y0, t0);
32 let mass = eqn.mass().unwrap().matrix(t0);
33
34 let [(m_u, _), _, _, _] = mass.split(&algebraic_indices);
46 let m_u = m_u * scale(-Eqn::T::one());
47 let [_, (dfdv, _), _, (dgdv, _)] = rhs_jac.split(&algebraic_indices);
48 let zero_ll = <Eqn::M as Matrix>::zeros(
49 algebraic_indices.len(),
50 n - algebraic_indices.len(),
51 eqn.context().clone(),
52 );
53 let zero_ur = <Eqn::M as Matrix>::zeros(
54 n - algebraic_indices.len(),
55 algebraic_indices.len(),
56 eqn.context().clone(),
57 );
58 let zero_lr = <Eqn::M as Matrix>::zeros(
59 algebraic_indices.len(),
60 algebraic_indices.len(),
61 eqn.context().clone(),
62 );
63 let jac = Eqn::M::combine(&m_u, &dfdv, &zero_ll, &dgdv, &algebraic_indices);
64 let neg_mass = Eqn::M::combine(&m_u, &zero_ur, &zero_ll, &zero_lr, &algebraic_indices);
65
66 let y0 = y0.clone();
67 let y0 = RefCell::new(y0);
68 Self {
69 eqn,
70 jac,
71 y0,
72 neg_mass,
73 algebraic_indices,
74 }
75 }
76
77 pub fn scatter_soln(&self, soln: &Eqn::V, y: &mut Eqn::V, dy: &mut Eqn::V) {
78 let tmp = dy.clone();
79 dy.copy_from(soln);
80 dy.copy_from_indices(&tmp, &self.algebraic_indices);
81 y.copy_from_indices(soln, &self.algebraic_indices);
82 }
83}
84
85impl<Eqn: OdeEquationsImplicit> Op for InitOp<'_, Eqn> {
86 type V = Eqn::V;
87 type T = Eqn::T;
88 type M = Eqn::M;
89 type C = Eqn::C;
90 fn nstates(&self) -> usize {
91 self.eqn.rhs().nstates()
92 }
93 fn nout(&self) -> usize {
94 self.eqn.rhs().nstates()
95 }
96 fn nparams(&self) -> usize {
97 self.eqn.rhs().nparams()
98 }
99 fn context(&self) -> &Self::C {
100 self.eqn.context()
101 }
102}
103
104impl<Eqn: OdeEquationsImplicit> NonLinearOp for InitOp<'_, Eqn> {
105 fn call_inplace(&self, x: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) {
108 let mut y0 = self.y0.borrow_mut();
111 y0.copy_from_indices(x, &self.algebraic_indices);
112
113 self.eqn.rhs().call_inplace(&y0, t, y);
115
116 self.neg_mass.gemv(Eqn::T::one(), x, Eqn::T::one(), y);
118 }
119}
120
121impl<Eqn: OdeEquationsImplicit> NonLinearOpJacobian for InitOp<'_, Eqn> {
122 fn jac_mul_inplace(&self, _x: &Eqn::V, _t: Eqn::T, v: &Eqn::V, y: &mut Eqn::V) {
124 self.jac.gemv(Eqn::T::one(), v, Eqn::T::one(), y);
125 }
126
127 fn jacobian_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::M) {
129 y.copy_from(&self.jac);
130 }
131
132 fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
133 self.jac.sparsity().map(|x| x.to_owned())
134 }
135}
136
137#[cfg(test)]
138mod tests {
139
140 use crate::ode_solver::test_models::exponential_decay_with_algebraic::exponential_decay_with_algebraic_problem;
141 use crate::op::init::InitOp;
142 use crate::vector::Vector;
143 use crate::{DenseMatrix, NalgebraMat, NalgebraVec, NonLinearOp, NonLinearOpJacobian};
144
145 type Mcpu = NalgebraMat<f64>;
146 type Vcpu = NalgebraVec<f64>;
147
148 #[test]
149 fn test_initop() {
150 let (problem, _soln) = exponential_decay_with_algebraic_problem::<Mcpu>(false);
151 let y0 = Vcpu::from_vec(vec![1.0, 2.0, 3.0], problem.context().clone());
152 let dy0 = Vcpu::from_vec(vec![4.0, 5.0, 6.0], problem.context().clone());
153 let t = 0.0;
154 let initop = InitOp::new(&problem.eqn, t, &y0);
155 let mut y_out = Vcpu::from_vec(vec![0.0, 0.0, 0.0], problem.context().clone());
157
158 let du_v = Vcpu::from_vec(vec![dy0[0], dy0[1], y0[2]], problem.context().clone());
179 initop.call_inplace(&du_v, t, &mut y_out);
180 let y_out_expect = Vcpu::from_vec(vec![-4.1, -5.2, 1.0], problem.context().clone());
181 y_out.assert_eq_st(&y_out_expect, 1e-10);
182
183 let jac = initop.jacobian(&du_v, t);
190 assert_eq!(jac.get_index(0, 0), -1.0);
191 assert_eq!(jac.get_index(0, 1), 0.0);
192 assert_eq!(jac.get_index(0, 2), 0.0);
193 assert_eq!(jac.get_index(1, 0), 0.0);
194 assert_eq!(jac.get_index(1, 1), -1.0);
195 assert_eq!(jac.get_index(1, 2), 0.0);
196 assert_eq!(jac.get_index(2, 0), 0.0);
197 assert_eq!(jac.get_index(2, 1), 0.0);
198 assert_eq!(jac.get_index(2, 2), 1.0);
199 }
200}