diffsol/ode_solver/
diffsl.rs

1use core::panic;
2use std::cell::RefCell;
3use std::ops::MulAssign;
4
5use diffsl::{
6    execution::module::{CodegenModule, CodegenModuleCompile, CodegenModuleJit},
7    Compiler,
8};
9
10use crate::{
11    error::DiffsolError, find_jacobian_non_zeros, find_matrix_non_zeros,
12    jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity,
13    op::nonlinear_op::NonLinearOpJacobian, ConstantOp, ConstantOpSens, ConstantOpSensAdjoint,
14    LinearOp, LinearOpTranspose, Matrix, MatrixHost, NonLinearOp, NonLinearOpAdjoint,
15    NonLinearOpSens, NonLinearOpSensAdjoint, OdeEquations, OdeEquationsRef, Op, Scale, Vector,
16    VectorHost,
17};
18
19pub type T = f64;
20
21/// Context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/).
22///
23/// This contains the compiled code and the data structures needed to evaluate the ODE equations.
24pub struct DiffSlContext<M: Matrix<T = T>, CG: CodegenModule> {
25    compiler: Compiler<CG>,
26    data: RefCell<Vec<M::T>>,
27    ddata: RefCell<Vec<M::T>>,
28    sens_data: RefCell<Vec<M::T>>,
29    tmp: RefCell<M::V>,
30    tmp2: RefCell<M::V>,
31    tmp_out: RefCell<M::V>,
32    tmp2_out: RefCell<M::V>,
33    nstates: usize,
34    nroots: usize,
35    nparams: usize,
36    has_mass: bool,
37    has_root: bool,
38    has_out: bool,
39    nout: usize,
40    nthreads: usize,
41    ctx: M::C,
42}
43
44impl<M: Matrix<T = T>, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContext<M, CG> {
45    /// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/).
46    /// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE.
47    ///
48    /// # Arguments
49    ///
50    /// * `text` - The text of the ODE equations in the DiffSL language.
51    /// * `nthreads` - The number of threads to use for code generation (0 for automatic, 1 for single-threaded).
52    ///
53    pub fn new(text: &str, nthreads: usize, ctx: M::C) -> Result<Self, DiffsolError> {
54        let mode = match nthreads {
55            0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
56            1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
57            _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
58        };
59        let compiler = Compiler::from_discrete_str(text, mode)
60            .map_err(|e| DiffsolError::Other(e.to_string()))?;
61        let (nstates, _nparams, _nout, _ndata, _nroots, _has_mass) = compiler.get_dims();
62
63        let compiler = if nthreads == 0 {
64            let num_cpus = std::thread::available_parallelism().unwrap().get();
65            let nthreads = num_cpus.min(nstates / 1000).max(1);
66            Compiler::from_discrete_str(
67                text,
68                diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
69            )
70            .map_err(|e| DiffsolError::Other(e.to_string()))?
71        } else {
72            compiler
73        };
74
75        let (nstates, nparams, nout, _ndata, nroots, has_mass) = compiler.get_dims();
76
77        let has_root = nroots > 0;
78        let has_out = nout > 0;
79        let data = RefCell::new(compiler.get_new_data());
80        let ddata = RefCell::new(compiler.get_new_data());
81        let sens_data = RefCell::new(compiler.get_new_data());
82        let tmp = RefCell::new(M::V::zeros(nstates, ctx.clone()));
83        let tmp2 = RefCell::new(M::V::zeros(nstates, ctx.clone()));
84        let tmp_out = RefCell::new(M::V::zeros(nout, ctx.clone()));
85        let tmp2_out = RefCell::new(M::V::zeros(nout, ctx.clone()));
86
87        Ok(Self {
88            compiler,
89            data,
90            ddata,
91            sens_data,
92            nparams,
93            nstates,
94            tmp,
95            tmp2,
96            tmp_out,
97            tmp2_out,
98            nroots,
99            nout,
100            has_mass,
101            has_root,
102            has_out,
103            nthreads,
104            ctx,
105        })
106    }
107
108    pub fn recompile(&mut self, text: &str) -> Result<(), DiffsolError> {
109        let mode = match self.nthreads {
110            0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
111            1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
112            _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(self.nthreads)),
113        };
114        self.compiler = Compiler::from_discrete_str(text, mode)
115            .map_err(|e| DiffsolError::Other(e.to_string()))?;
116        let (nstates, nparams, nout, _ndata, nroots, has_mass) = self.compiler.get_dims();
117        self.data = RefCell::new(self.compiler.get_new_data());
118        self.ddata = RefCell::new(self.compiler.get_new_data());
119        self.tmp = RefCell::new(M::V::zeros(nstates, self.ctx.clone()));
120        self.nparams = nparams;
121        self.nstates = nstates;
122        self.nout = nout;
123        self.nroots = nroots;
124        self.has_mass = has_mass;
125        Ok(())
126    }
127}
128
129impl<M: Matrix<T = T>, CG: CodegenModuleJit + CodegenModuleCompile> Default
130    for DiffSlContext<M, CG>
131{
132    fn default() -> Self {
133        Self::new(
134            "
135            u { y = 1 }
136            F { -y }
137            out { y }
138        ",
139            1,
140            M::C::default(),
141        )
142        .unwrap()
143    }
144}
145
146pub struct DiffSl<M: Matrix<T = T>, CG: CodegenModule> {
147    context: DiffSlContext<M, CG>,
148    mass_sparsity: Option<M::Sparsity>,
149    mass_coloring: Option<JacobianColoring<M>>,
150    mass_transpose_sparsity: Option<M::Sparsity>,
151    mass_transpose_coloring: Option<JacobianColoring<M>>,
152    rhs_sparsity: Option<M::Sparsity>,
153    rhs_coloring: Option<JacobianColoring<M>>,
154    rhs_adjoint_sparsity: Option<M::Sparsity>,
155    rhs_adjoint_coloring: Option<JacobianColoring<M>>,
156}
157
158impl<M: MatrixHost<T = T>, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl<M, CG> {
159    pub fn compile(code: &str, ctx: M::C) -> Result<Self, DiffsolError> {
160        let context = DiffSlContext::<M, CG>::new(code, 1, ctx)?;
161        Ok(Self::from_context(context))
162    }
163    pub fn from_context(context: DiffSlContext<M, CG>) -> Self {
164        let mut ret = Self {
165            context,
166            mass_coloring: None,
167            mass_sparsity: None,
168            mass_transpose_coloring: None,
169            mass_transpose_sparsity: None,
170            rhs_coloring: None,
171            rhs_sparsity: None,
172            rhs_adjoint_coloring: None,
173            rhs_adjoint_sparsity: None,
174        };
175        if M::is_sparse() {
176            let op = ret.rhs();
177            let ctx = op.context().clone();
178            let t0 = 0.0;
179            let x0 = M::V::zeros(op.nstates(), op.context().clone());
180            let non_zeros = find_jacobian_non_zeros(&op, &x0, t0);
181            let n = op.nstates();
182
183            let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.clone())
184                .expect("invalid sparsity pattern");
185            let coloring = JacobianColoring::new(&sparsity, &non_zeros, op.context().clone());
186            ret.rhs_coloring = Some(coloring);
187            ret.rhs_sparsity = Some(sparsity);
188
189            let non_zeros = non_zeros
190                .into_iter()
191                .map(|(i, j)| (j, i))
192                .collect::<Vec<_>>();
193            let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.clone())
194                .expect("invalid sparsity pattern");
195            let coloring = JacobianColoring::new(&sparsity, &non_zeros, ctx);
196            ret.rhs_adjoint_sparsity = Some(sparsity);
197            ret.rhs_adjoint_coloring = Some(coloring);
198
199            if let Some(op) = ret.mass() {
200                let ctx = op.context().clone();
201                let non_zeros = find_matrix_non_zeros(&op, t0);
202                let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.clone())
203                    .expect("invalid sparsity pattern");
204                let coloring = JacobianColoring::new(&sparsity, &non_zeros, op.context().clone());
205                ret.mass_coloring = Some(coloring);
206                ret.mass_sparsity = Some(sparsity);
207
208                let non_zeros = non_zeros
209                    .into_iter()
210                    .map(|(i, j)| (j, i))
211                    .collect::<Vec<_>>();
212                let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.clone())
213                    .expect("invalid sparsity pattern");
214                let coloring = JacobianColoring::new(&sparsity, &non_zeros, ctx);
215                ret.mass_transpose_sparsity = Some(sparsity);
216                ret.mass_transpose_coloring = Some(coloring);
217            }
218        }
219        ret
220    }
221}
222
223pub struct DiffSlRoot<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
224pub struct DiffSlOut<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
225pub struct DiffSlRhs<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
226pub struct DiffSlMass<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
227pub struct DiffSlInit<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
228
229macro_rules! impl_op_for_diffsl {
230    ($name:ident) => {
231        impl<M: Matrix<T = T>, CG: CodegenModule> Op for $name<'_, M, CG> {
232            type M = M;
233            type T = T;
234            type V = M::V;
235            type C = M::C;
236
237            fn nstates(&self) -> usize {
238                self.0.context.nstates
239            }
240            #[allow(clippy::misnamed_getters)]
241            fn nout(&self) -> usize {
242                self.0.context.nstates
243            }
244            fn nparams(&self) -> usize {
245                self.0.context.nparams
246            }
247            fn context(&self) -> &Self::C {
248                &self.0.context.ctx
249            }
250        }
251    };
252}
253
254impl_op_for_diffsl!(DiffSlRhs);
255impl_op_for_diffsl!(DiffSlMass);
256
257impl<M: Matrix<T = T>, CG: CodegenModule> Op for DiffSlInit<'_, M, CG> {
258    type M = M;
259    type T = T;
260    type V = M::V;
261    type C = M::C;
262
263    fn nstates(&self) -> usize {
264        self.0.context.nstates
265    }
266    #[allow(clippy::misnamed_getters)]
267    fn nout(&self) -> usize {
268        self.0.context.nstates
269    }
270    fn nparams(&self) -> usize {
271        self.0.context.nparams
272    }
273    fn context(&self) -> &Self::C {
274        &self.0.context.ctx
275    }
276}
277
278impl<M: Matrix<T = T>, CG: CodegenModule> Op for DiffSlRoot<'_, M, CG> {
279    type M = M;
280    type T = T;
281    type V = M::V;
282    type C = M::C;
283
284    fn nstates(&self) -> usize {
285        self.0.context.nstates
286    }
287    #[allow(clippy::misnamed_getters)]
288    fn nout(&self) -> usize {
289        self.0.context.nroots
290    }
291    fn nparams(&self) -> usize {
292        self.0.context.nparams
293    }
294    fn context(&self) -> &Self::C {
295        &self.0.context.ctx
296    }
297}
298
299impl<M: Matrix<T = T>, CG: CodegenModule> Op for DiffSlOut<'_, M, CG> {
300    type M = M;
301    type T = T;
302    type V = M::V;
303    type C = M::C;
304
305    fn nstates(&self) -> usize {
306        self.0.context.nstates
307    }
308    fn nout(&self) -> usize {
309        self.0.context.nout
310    }
311    fn nparams(&self) -> usize {
312        self.0.context.nparams
313    }
314    fn context(&self) -> &Self::C {
315        &self.0.context.ctx
316    }
317}
318
319impl<M: MatrixHost<T = T>, CG: CodegenModule> ConstantOp for DiffSlInit<'_, M, CG> {
320    fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
321        self.0.context.compiler.set_u0(
322            y.as_mut_slice(),
323            self.0.context.data.borrow_mut().as_mut_slice(),
324        );
325    }
326}
327
328impl<M: MatrixHost<T = T>, CG: CodegenModule> ConstantOpSens for DiffSlInit<'_, M, CG> {
329    fn sens_mul_inplace(&self, _t: Self::T, v: &Self::V, y: &mut Self::V) {
330        self.0.context.compiler.set_inputs(
331            v.as_slice(),
332            self.0.context.sens_data.borrow_mut().as_mut_slice(),
333        );
334        self.0.context.compiler.set_u0_grad(
335            self.0.context.tmp.borrow().as_slice(),
336            y.as_mut_slice(),
337            self.0.context.data.borrow_mut().as_mut_slice(),
338            self.0.context.sens_data.borrow_mut().as_mut_slice(),
339        );
340    }
341}
342
343impl<M: MatrixHost<T = T>, CG: CodegenModule> ConstantOpSensAdjoint for DiffSlInit<'_, M, CG> {
344    fn sens_transpose_mul_inplace(&self, _t: Self::T, v: &Self::V, y: &mut Self::V) {
345        // copy v to tmp2
346        let mut tmp2 = self.0.context.tmp2.borrow_mut();
347        tmp2.copy_from(v);
348        // zero out sens_data
349        self.0.context.sens_data.borrow_mut().fill(0.0);
350        self.0.context.compiler.set_u0_rgrad(
351            self.0.context.tmp.borrow().as_slice(),
352            tmp2.as_mut_slice(),
353            self.0.context.data.borrow().as_slice(),
354            self.0.context.sens_data.borrow_mut().as_mut_slice(),
355        );
356        self.0.context.compiler.get_inputs(
357            y.as_mut_slice(),
358            self.0.context.sens_data.borrow().as_slice(),
359        );
360        // negate y
361        y.mul_assign(Scale(-1.0));
362    }
363}
364
365impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOp for DiffSlRoot<'_, M, CG> {
366    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
367        self.0.context.compiler.calc_stop(
368            t,
369            x.as_slice(),
370            self.0.context.data.borrow_mut().as_mut_slice(),
371            y.as_mut_slice(),
372        );
373    }
374}
375
376impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpJacobian for DiffSlRoot<'_, M, CG> {
377    fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
378        y.fill(0.0);
379    }
380}
381
382impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOp for DiffSlOut<'_, M, CG> {
383    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
384        self.0.context.compiler.calc_out(
385            t,
386            x.as_slice(),
387            self.0.context.data.borrow_mut().as_mut_slice(),
388            y.as_mut_slice(),
389        );
390    }
391}
392
393impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpJacobian for DiffSlOut<'_, M, CG> {
394    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
395        // init ddata with all zero except for out
396        let mut ddata = self.0.context.ddata.borrow_mut();
397        ddata.fill(0.0);
398        self.0.context.compiler.calc_out_grad(
399            t,
400            x.as_slice(),
401            v.as_slice(),
402            self.0.context.data.borrow_mut().as_mut_slice(),
403            ddata.as_mut_slice(),
404            self.0.context.tmp_out.borrow().as_slice(),
405            y.as_mut_slice(),
406        );
407    }
408}
409
410impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpAdjoint for DiffSlOut<'_, M, CG> {
411    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
412        // init ddata with all zero except for out
413        let mut ddata = self.0.context.ddata.borrow_mut();
414        ddata.fill(0.0);
415        let mut tmp2_out = self.0.context.tmp2_out.borrow_mut();
416        tmp2_out.copy_from(v);
417        // zero y
418        y.fill(0.0);
419        self.0.context.compiler.calc_out_rgrad(
420            t,
421            x.as_slice(),
422            y.as_mut_slice(),
423            self.0.context.data.borrow_mut().as_slice(),
424            ddata.as_mut_slice(),
425            self.0.context.tmp_out.borrow().as_slice(),
426            tmp2_out.as_mut_slice(),
427        );
428        // negate y
429        y.mul_assign(Scale(-1.0));
430    }
431}
432
433impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpSens for DiffSlOut<'_, M, CG> {
434    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
435        // set inputs for sens_data
436        self.0.context.compiler.set_inputs(
437            v.as_slice(),
438            self.0.context.sens_data.borrow_mut().as_mut_slice(),
439        );
440        self.0.context.compiler.calc_out_sgrad(
441            t,
442            x.as_slice(),
443            self.0.context.data.borrow_mut().as_mut_slice(),
444            self.0.context.sens_data.borrow_mut().as_mut_slice(),
445            self.0.context.tmp_out.borrow().as_slice(),
446            y.as_mut_slice(),
447        );
448    }
449}
450
451impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpSensAdjoint for DiffSlOut<'_, M, CG> {
452    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
453        let mut sens_data = self.0.context.sens_data.borrow_mut();
454        // set outputs for sens_data (zero everything except for out)
455        sens_data.fill(0.0);
456        let mut tmp2_out = self.0.context.tmp2_out.borrow_mut();
457        tmp2_out.copy_from(v);
458        self.0.context.compiler.calc_out_srgrad(
459            t,
460            x.as_slice(),
461            self.0.context.data.borrow_mut().as_mut_slice(),
462            sens_data.as_mut_slice(),
463            self.0.context.tmp_out.borrow().as_slice(),
464            tmp2_out.as_mut_slice(),
465        );
466        // set y to the result in inputs
467        self.0
468            .context
469            .compiler
470            .get_inputs(y.as_mut_slice(), sens_data.as_slice());
471        // negate y
472        y.mul_assign(Scale(-1.0));
473    }
474}
475
476impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOp for DiffSlRhs<'_, M, CG> {
477    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
478        self.0.context.compiler.rhs(
479            t,
480            x.as_slice(),
481            self.0.context.data.borrow_mut().as_mut_slice(),
482            y.as_mut_slice(),
483        );
484    }
485}
486
487impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpJacobian for DiffSlRhs<'_, M, CG> {
488    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
489        let tmp = self.0.context.tmp.borrow();
490        self.0.context.compiler.rhs_grad(
491            t,
492            x.as_slice(),
493            v.as_slice(),
494            self.0.context.data.borrow_mut().as_slice(),
495            self.0.context.ddata.borrow_mut().as_mut_slice(),
496            tmp.as_slice(),
497            y.as_mut_slice(),
498        );
499    }
500
501    fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
502        if let Some(coloring) = &self.0.rhs_coloring {
503            coloring.jacobian_inplace(self, x, t, y);
504        } else {
505            self._default_jacobian_inplace(x, t, y);
506        }
507    }
508    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
509        self.0.rhs_sparsity.clone()
510    }
511}
512
513impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpAdjoint for DiffSlRhs<'_, M, CG> {
514    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
515        // copy v to tmp2
516        let mut tmp2 = self.0.context.tmp2.borrow_mut();
517        tmp2.copy_from(v);
518        // zero out ddata
519        self.0.context.ddata.borrow_mut().fill(0.0);
520        // zero y
521        y.fill(0.0);
522        self.0.context.compiler.rhs_rgrad(
523            t,
524            x.as_slice(),
525            y.as_mut_slice(),
526            self.0.context.data.borrow().as_slice(),
527            self.0.context.ddata.borrow_mut().as_mut_slice(),
528            self.0.context.tmp.borrow().as_slice(),
529            tmp2.as_mut_slice(),
530        );
531        // negate y
532        y.mul_assign(Scale(-1.0));
533    }
534    fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
535        // if we have a rhs_coloring and no rhs_adjoint_coloring, user has not called prep_adjoint
536        // fail here
537        if self.0.rhs_coloring.is_some() && self.0.rhs_adjoint_coloring.is_none() {
538            panic!("Adjoint not prepared. Call prep_adjoint before calling adjoint_inplace");
539        }
540        if let Some(coloring) = &self.0.rhs_adjoint_coloring {
541            coloring.jacobian_inplace(self, x, t, y);
542        } else {
543            self._default_adjoint_inplace(x, t, y);
544        }
545    }
546    fn adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
547        self.0.rhs_adjoint_sparsity.clone()
548    }
549}
550
551impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpSens for DiffSlRhs<'_, M, CG> {
552    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
553        let tmp = self.0.context.tmp.borrow();
554        self.0.context.compiler.set_inputs(
555            v.as_slice(),
556            self.0.context.sens_data.borrow_mut().as_mut_slice(),
557        );
558        self.0.context.compiler.rhs_sgrad(
559            t,
560            x.as_slice(),
561            self.0.context.data.borrow_mut().as_slice(),
562            self.0.context.sens_data.borrow_mut().as_mut_slice(),
563            tmp.as_slice(),
564            y.as_mut_slice(),
565        );
566    }
567}
568
569impl<M: MatrixHost<T = T>, CG: CodegenModule> NonLinearOpSensAdjoint for DiffSlRhs<'_, M, CG> {
570    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
571        // todo: would rhs_srgrad ever use rr? I don't think so, but need to check
572        let tmp = self.0.context.tmp.borrow();
573        // copy v to tmp2
574        let mut tmp2 = self.0.context.tmp2.borrow_mut();
575        tmp2.copy_from(v);
576        // zero out sens_data
577        self.0.context.sens_data.borrow_mut().fill(0.0);
578        self.0.context.compiler.rhs_srgrad(
579            t,
580            x.as_slice(),
581            self.0.context.data.borrow_mut().as_mut_slice(),
582            self.0.context.sens_data.borrow_mut().as_mut_slice(),
583            tmp.as_slice(),
584            tmp2.as_mut_slice(),
585        );
586        // get inputs
587        self.0.context.compiler.get_inputs(
588            y.as_mut_slice(),
589            self.0.context.sens_data.borrow().as_slice(),
590        );
591        // negate y
592        y.mul_assign(Scale(-1.0));
593    }
594}
595
596impl<M: MatrixHost<T = T>, CG: CodegenModule> LinearOp for DiffSlMass<'_, M, CG> {
597    fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
598        let mut tmp = self.0.context.tmp.borrow_mut();
599        self.0.context.compiler.mass(
600            t,
601            x.as_slice(),
602            self.0.context.data.borrow_mut().as_mut_slice(),
603            tmp.as_mut_slice(),
604        );
605
606        // y = tmp + beta * y
607        y.axpy(1.0, &tmp, beta);
608    }
609
610    fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
611        if let Some(coloring) = &self.0.mass_coloring {
612            coloring.matrix_inplace(self, t, y);
613        } else {
614            self._default_matrix_inplace(t, y);
615        }
616    }
617    fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
618        self.0.mass_sparsity.clone()
619    }
620}
621
622impl<M: MatrixHost<T = T>, CG: CodegenModule> LinearOpTranspose for DiffSlMass<'_, M, CG> {
623    fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
624        // scale y by beta
625        y.mul_assign(Scale(beta));
626
627        // copy x to tmp
628        let mut tmp = self.0.context.tmp.borrow_mut();
629        tmp.copy_from(x);
630
631        // zero out ddata
632        self.0.context.ddata.borrow_mut().fill(0.0);
633
634        // y += M^T x + beta * y
635        self.0.context.compiler.mass_rgrad(
636            t,
637            y.as_mut_slice(),
638            self.0.context.data.borrow_mut().as_slice(),
639            self.0.context.ddata.borrow_mut().as_mut_slice(),
640            tmp.as_mut_slice(),
641        );
642    }
643
644    fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
645        if let Some(coloring) = &self.0.mass_transpose_coloring {
646            coloring.matrix_inplace(self, t, y);
647        } else {
648            self._default_matrix_inplace(t, y);
649        }
650    }
651    fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
652        self.0.mass_transpose_sparsity.clone()
653    }
654}
655
656impl<M: MatrixHost<T = T>, CG: CodegenModule> Op for DiffSl<M, CG> {
657    type M = M;
658    type T = T;
659    type V = M::V;
660    type C = M::C;
661
662    fn nstates(&self) -> usize {
663        self.context.nstates
664    }
665    fn nout(&self) -> usize {
666        if self.context.has_out {
667            self.context.nout
668        } else {
669            self.context.nstates
670        }
671    }
672    fn nparams(&self) -> usize {
673        self.context.nparams
674    }
675    fn context(&self) -> &Self::C {
676        &self.context.ctx
677    }
678}
679
680impl<'a, M: MatrixHost<T = T>, CG: CodegenModule> OdeEquationsRef<'a> for DiffSl<M, CG> {
681    type Mass = DiffSlMass<'a, M, CG>;
682    type Rhs = DiffSlRhs<'a, M, CG>;
683    type Root = DiffSlRoot<'a, M, CG>;
684    type Init = DiffSlInit<'a, M, CG>;
685    type Out = DiffSlOut<'a, M, CG>;
686}
687
688impl<M: MatrixHost<T = T>, CG: CodegenModule> OdeEquations for DiffSl<M, CG> {
689    fn rhs(&self) -> DiffSlRhs<'_, M, CG> {
690        DiffSlRhs(self)
691    }
692
693    fn mass(&self) -> Option<DiffSlMass<'_, M, CG>> {
694        self.context.has_mass.then_some(DiffSlMass(self))
695    }
696
697    fn root(&self) -> Option<DiffSlRoot<'_, M, CG>> {
698        self.context.has_root.then_some(DiffSlRoot(self))
699    }
700
701    fn init(&self) -> DiffSlInit<'_, M, CG> {
702        DiffSlInit(self)
703    }
704
705    fn out(&self) -> Option<DiffSlOut<'_, M, CG>> {
706        self.context.has_out.then_some(DiffSlOut(self))
707    }
708
709    fn set_params(&mut self, p: &Self::V) {
710        // set the parameters in data
711        self.context
712            .compiler
713            .set_inputs(p.as_slice(), self.context.data.borrow_mut().as_mut_slice());
714
715        // set_u0 will calculate all the constants in the equations based on the params
716        let mut dummy = M::V::zeros(self.context.nstates, self.context().clone());
717        self.context.compiler.set_u0(
718            dummy.as_mut_slice(),
719            self.context.data.borrow_mut().as_mut_slice(),
720        );
721    }
722
723    fn get_params(&self, p: &mut Self::V) {
724        self.context
725            .compiler
726            .get_inputs(p.as_mut_slice(), self.context.data.borrow().as_slice());
727    }
728}
729
730#[cfg(test)]
731mod tests {
732    use diffsl::execution::module::{CodegenModuleCompile, CodegenModuleJit};
733
734    use crate::{
735        matrix::dense_nalgebra_serial::NalgebraMat, ConstantOp, Context, DenseMatrix, LinearOp,
736        NalgebraContext, NalgebraLU, NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations,
737        OdeSolverMethod, Vector, VectorView,
738    };
739
740    use super::{DiffSl, DiffSlContext};
741
742    #[cfg(feature = "diffsl-cranelift")]
743    #[test]
744    fn diffsl_logistic_growth_cranelift() {
745        diffsl_logistic_growth::<diffsl::CraneliftJitModule>();
746    }
747
748    #[cfg(feature = "diffsl-llvm")]
749    #[test]
750    fn diffsl_logistic_growth_llvm() {
751        diffsl_logistic_growth::<diffsl::LlvmModule>();
752    }
753
754    fn diffsl_logistic_growth<CG: CodegenModuleJit + CodegenModuleCompile>() {
755        let text = "
756            in = [r, k]
757            r { 1 }
758            k { 1 }
759            u_i {
760                y = 0.1,
761                z = 0,
762            }
763            dudt_i {
764                dydt = 0,
765                dzdt = 0,
766            }
767            M_i {
768                dydt,
769                0,
770            }
771            F_i {
772                (r * y) * (1 - (y / k)),
773                (2 * y) - z,
774            }
775            out_i {
776                3 * y,
777                4 * z,
778            }
779        ";
780
781        let k = 1.0;
782        let r = 1.0;
783        let ctx = NalgebraContext;
784        let context = DiffSlContext::<NalgebraMat<f64>, CG>::new(text, 1, ctx.clone()).unwrap();
785        let p = ctx.vector_from_vec(vec![r, k]);
786        let mut eqn = DiffSl::from_context(context);
787        eqn.set_params(&p);
788
789        // test that the initial values look ok
790        let y0 = 0.1;
791        let init = eqn.init().call(0.0);
792        let init_expect = ctx.vector_from_vec(vec![y0, 0.0]);
793        init.assert_eq_st(&init_expect, 1e-10);
794        let rhs = eqn.rhs().call(&init, 0.0);
795        let rhs_expect = ctx.vector_from_vec(vec![r * y0 * (1.0 - y0 / k), 2.0 * y0]);
796        rhs.assert_eq_st(&rhs_expect, 1e-10);
797        let v = ctx.vector_from_vec(vec![1.0, 1.0]);
798        let rhs_jac = eqn.rhs().jac_mul(&init, 0.0, &v);
799        let rhs_jac_expect = ctx.vector_from_vec(vec![r * (1.0 - y0 / k) - r * y0 / k, 1.0]);
800        rhs_jac.assert_eq_st(&rhs_jac_expect, 1e-10);
801        let mut mass_y = ctx.vector_from_vec(vec![0.0, 0.0]);
802        let v = ctx.vector_from_vec(vec![1.0, 1.0]);
803        eqn.mass().unwrap().call_inplace(&v, 0.0, &mut mass_y);
804        let mass_y_expect = ctx.vector_from_vec(vec![1.0, 0.0]);
805        mass_y.assert_eq_st(&mass_y_expect, 1e-10);
806
807        // solver a bit and check the state and output
808        let problem = OdeBuilder::<NalgebraMat<f64>>::new()
809            .p([r, k])
810            .build_from_eqn(eqn)
811            .unwrap();
812        let mut solver = problem.bdf::<NalgebraLU<f64>>().unwrap();
813        let t = 1.0;
814        let (ys, ts) = solver.solve(t).unwrap();
815        for (i, t) in ts.iter().enumerate() {
816            let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0);
817            let z_expect = 2.0 * y_expect;
818            let expected_out = ctx.vector_from_vec(vec![3.0 * y_expect, 4.0 * z_expect]);
819            ys.column(i).into_owned().assert_eq_st(&expected_out, 1e-4);
820        }
821
822        // do it again with some explicit t_evals
823        let t_evals = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0];
824        let mut solver = problem.bdf::<NalgebraLU<f64>>().unwrap();
825        let ys = solver.solve_dense(&t_evals).unwrap();
826        for (i, t) in t_evals.iter().enumerate() {
827            let y_expect = k / (1.0 + (k - y0) * (-r * t).exp() / y0);
828            let z_expect = 2.0 * y_expect;
829            let expected_out = ctx.vector_from_vec(vec![3.0 * y_expect, 4.0 * z_expect]);
830            ys.column(i).into_owned().assert_eq_st(&expected_out, 1e-4);
831        }
832    }
833}