Skip to main content

diffsol/ode_equations/
diffsl.rs

1use core::panic;
2#[cfg(feature = "diffsl-external-dynamic")]
3use diffsl::ExternalDynModule;
4use num_traits::{One, Zero};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::any::TypeId;
7use std::cell::RefCell;
8use std::ops::MulAssign;
9#[cfg(feature = "diffsl-external-dynamic")]
10use std::path::PathBuf;
11
12#[cfg(feature = "diffsl-external")]
13use diffsl::execution::external::{ExternSymbols, ExternalModule};
14use diffsl::{
15    discretise::DiscreteModel,
16    execution::{
17        module::{
18            CodegenModule, CodegenModuleCompile, CodegenModuleEmit, CodegenModuleJit,
19            CodegenModuleLink,
20        },
21        scalar::Scalar as DiffSlScalar,
22    },
23    parser::parse_ds_string,
24    Compiler, ObjectModule,
25};
26
27use crate::{
28    error::DiffsolError, jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity,
29    op::nonlinear_op::NonLinearOpJacobian, ConstantOp, ConstantOpSens, ConstantOpSensAdjoint,
30    LinearOp, LinearOpTranspose, Matrix, MatrixHost, NonLinearOp, NonLinearOpAdjoint,
31    NonLinearOpSens, NonLinearOpSensAdjoint, OdeEquations, OdeEquationsRef, Op, Scale, Vector,
32    VectorHost,
33};
34
35/// Context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/).
36///
37/// This contains the compiled code and the data structures needed to evaluate the ODE equations.
38pub struct DiffSlContext<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> {
39    compiler: Compiler<CG, M::T>,
40    data: RefCell<Vec<M::T>>,
41    ddata: RefCell<Vec<M::T>>,
42    sens_data: RefCell<Vec<M::T>>,
43    tmp: RefCell<M::V>,
44    tmp2: RefCell<M::V>,
45    tmp_root: RefCell<M::V>,
46    tmp2_root: RefCell<M::V>,
47    tmp_out: RefCell<M::V>,
48    tmp2_out: RefCell<M::V>,
49    nstates: usize,
50    nroots: usize,
51    nparams: usize,
52    model_index: u32,
53    has_mass: bool,
54    has_root: bool,
55    has_reset: bool,
56    has_out: bool,
57    nout: usize,
58    ctx: M::C,
59    rhs_state_deps: Vec<(usize, usize)>,
60    rhs_input_deps: Vec<(usize, usize)>,
61    mass_state_deps: Vec<(usize, usize)>,
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize)]
65struct DiffSlExternalObject {
66    scalar_type: DiffSlExternalScalarType,
67    object: Vec<u8>,
68    rhs_state_deps: Vec<(usize, usize)>,
69    rhs_input_deps: Vec<(usize, usize)>,
70    mass_state_deps: Vec<(usize, usize)>,
71    include_sensitivities: bool,
72}
73
74#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
75#[serde(rename_all = "snake_case")]
76enum DiffSlExternalScalarType {
77    F32,
78    F64,
79}
80
81fn diffsl_external_scalar_type<T: DiffSlScalar>() -> Result<DiffSlExternalScalarType, DiffsolError>
82{
83    if TypeId::of::<T>() == TypeId::of::<f32>() {
84        Ok(DiffSlExternalScalarType::F32)
85    } else if TypeId::of::<T>() == TypeId::of::<f64>() {
86        Ok(DiffSlExternalScalarType::F64)
87    } else {
88        Err(DiffsolError::Other(format!(
89            "DiffSl external object does not support scalar type {}",
90            std::any::type_name::<T>()
91        )))
92    }
93}
94
95impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> DiffSlContext<M, CG> {
96    fn new_common(
97        compiler: Compiler<CG, M::T>,
98        rhs_state_deps: Vec<(usize, usize)>,
99        rhs_input_deps: Vec<(usize, usize)>,
100        mass_state_deps: Vec<(usize, usize)>,
101        ctx: M::C,
102    ) -> Result<Self, DiffsolError> {
103        let (nstates, nparams, nout, _ndata, nroots, has_mass, has_reset) = compiler.get_dims();
104        let has_root = nroots > 0;
105        let has_out = nout > 0;
106        let data = RefCell::new(compiler.get_new_data());
107        let ddata = RefCell::new(compiler.get_new_data());
108        let sens_data = RefCell::new(compiler.get_new_data());
109        let tmp = RefCell::new(M::V::zeros(nstates, ctx.clone()));
110        let tmp2 = RefCell::new(M::V::zeros(nstates, ctx.clone()));
111        let tmp_root = RefCell::new(M::V::zeros(nroots, ctx.clone()));
112        let tmp2_root = RefCell::new(M::V::zeros(nroots, ctx.clone()));
113        let tmp_out = RefCell::new(M::V::zeros(nout, ctx.clone()));
114        let tmp2_out = RefCell::new(M::V::zeros(nout, ctx.clone()));
115        let model_index = 0;
116
117        Ok(Self {
118            compiler,
119            data,
120            ddata,
121            sens_data,
122            nparams,
123            nstates,
124            tmp,
125            tmp2,
126            tmp_root,
127            tmp2_root,
128            tmp_out,
129            tmp2_out,
130            nroots,
131            nout,
132            has_mass,
133            has_root,
134            has_reset,
135            has_out,
136            ctx,
137            rhs_state_deps,
138            rhs_input_deps,
139            mass_state_deps,
140            model_index,
141        })
142    }
143}
144
145#[cfg(feature = "diffsl-external-dynamic")]
146impl<M: Matrix<T: DiffSlScalar>> DiffSlContext<M, ExternalDynModule<M::T>> {
147    pub fn new_external_dynamic(
148        path: impl Into<PathBuf>,
149        nthreads: usize,
150        rhs_state_deps: Vec<(usize, usize)>,
151        rhs_input_deps: Vec<(usize, usize)>,
152        mass_state_deps: Vec<(usize, usize)>,
153        ctx: M::C,
154    ) -> Result<Self, DiffsolError> {
155        let mode = match nthreads {
156            0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
157            1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
158            _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
159        };
160        let module = ExternalDynModule::new(path)
161            .map_err(|e| DiffsolError::DiffslCompilerError(e.to_string()))?;
162        let compiler = Compiler::from_codegen_module(module, mode)
163            .map_err(|e| DiffsolError::DiffslCompilerError(e.to_string()))?;
164
165        Self::new_common(
166            compiler,
167            rhs_state_deps,
168            rhs_input_deps,
169            mass_state_deps,
170            ctx,
171        )
172    }
173}
174
175#[cfg(feature = "diffsl-external")]
176impl<M: Matrix<T: DiffSlScalar + ExternSymbols>> DiffSlContext<M, ExternalModule<M::T>> {
177    pub fn new_external(
178        nthreads: usize,
179        rhs_state_deps: Vec<(usize, usize)>,
180        rhs_input_deps: Vec<(usize, usize)>,
181        mass_state_deps: Vec<(usize, usize)>,
182        ctx: M::C,
183    ) -> Result<Self, DiffsolError> {
184        let mode = match nthreads {
185            0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
186            1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
187            _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
188        };
189        let module = ExternalModule::default();
190        let compiler = Compiler::from_codegen_module(module, mode)
191            .map_err(|e| DiffsolError::DiffslCompilerError(e.to_string()))?;
192
193        Self::new_common(
194            compiler,
195            rhs_state_deps,
196            rhs_input_deps,
197            mass_state_deps,
198            ctx,
199        )
200    }
201}
202
203impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModuleLink + CodegenModuleJit> DiffSlContext<M, CG> {
204    fn new_from_object(
205        object: Vec<u8>,
206        nthreads: usize,
207        rhs_state_deps: Vec<(usize, usize)>,
208        rhs_input_deps: Vec<(usize, usize)>,
209        mass_state_deps: Vec<(usize, usize)>,
210        ctx: M::C,
211    ) -> Result<Self, DiffsolError> {
212        let mode = match nthreads {
213            0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
214            1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
215            _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
216        };
217        let compiler = Compiler::from_object_file(object.clone(), mode)
218            .map_err(|e| DiffsolError::DiffslCompilerError(e.to_string()))?;
219
220        Self::new_common(
221            compiler,
222            rhs_state_deps,
223            rhs_input_deps,
224            mass_state_deps,
225            ctx,
226        )
227    }
228}
229
230impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContext<M, CG> {
231    /// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/).
232    /// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE.
233    ///
234    /// # Arguments
235    ///
236    /// * `text` - The text of the ODE equations in the DiffSL language.
237    /// * `nthreads` - The number of threads to use for code generation (0 for automatic, 1 for single-threaded).
238    ///
239    pub fn new(text: &str, nthreads: usize, ctx: M::C) -> Result<Self, DiffsolError> {
240        let mode = match nthreads {
241            0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
242            1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
243            _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
244        };
245        let options = diffsl::execution::compiler::CompilerOptions {
246            mode,
247            ..Default::default()
248        };
249        let model =
250            parse_ds_string(text).map_err(|e| DiffsolError::DiffslParserError(e.to_string()))?;
251        let mut model = DiscreteModel::build("diffsol", &model)
252            .map_err(|e| DiffsolError::DiffslCompilerError(e.as_error_message(text)))?;
253        let compiler = Compiler::from_discrete_model(&model, options, Some(text))
254            .map_err(|e| DiffsolError::DiffslCompilerError(e.to_string()))?;
255        let rhs_state_deps = model.take_rhs_state_deps();
256        let rhs_input_deps = model.take_rhs_input_deps();
257        let mass_state_deps = model.take_mass_state_deps();
258
259        Self::new_common(
260            compiler,
261            rhs_state_deps,
262            rhs_input_deps,
263            mass_state_deps,
264            ctx,
265        )
266    }
267}
268
269impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModuleJit + CodegenModuleCompile> Default
270    for DiffSlContext<M, CG>
271{
272    fn default() -> Self {
273        Self::new(
274            "
275            u { y = 1 }
276            F { -y }
277            out { y }
278        ",
279            1,
280            M::C::default(),
281        )
282        .unwrap()
283    }
284}
285
286/// DiffSl implementation of ODE equations. This uses the [DiffSL language](https://martinjrobins.github.io/diffsl/) to specify the ODE equations.
287///
288/// The DiffSL code is compiled into the [DiffSlContext] which is used to evaluate the ODE equations. After compilation,
289/// if the matrix type is sparse, the sparsity patterns of the Jacobians are extracted from the compiled code for use in the ODE solver.
290pub struct DiffSl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> {
291    context: DiffSlContext<M, CG>,
292    include_sensitivities: bool,
293    mass_sparsity: Option<M::Sparsity>,
294    mass_coloring: Option<JacobianColoring<M>>,
295    mass_transpose_sparsity: Option<M::Sparsity>,
296    mass_transpose_coloring: Option<JacobianColoring<M>>,
297    rhs_sparsity: Option<M::Sparsity>,
298    rhs_coloring: Option<JacobianColoring<M>>,
299    rhs_adjoint_sparsity: Option<M::Sparsity>,
300    rhs_adjoint_coloring: Option<JacobianColoring<M>>,
301    rhs_sens_sparsity: Option<M::Sparsity>,
302    rhs_sens_coloring: Option<JacobianColoring<M>>,
303    rhs_sens_adjoint_sparsity: Option<M::Sparsity>,
304    rhs_sens_adjoint_coloring: Option<JacobianColoring<M>>,
305}
306
307impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> DiffSl<M, CG> {
308    /// Create a `DiffSl` instance from a pre-compiled `DiffSlContext`.
309    ///
310    /// This function extracts the sparsity patterns and Jacobian colorings from the compiled
311    /// context if the matrix type is sparse. The sparsity patterns are used by ODE solvers
312    /// to efficiently compute Jacobians using finite differences with coloring.
313    ///
314    /// # Arguments
315    ///
316    /// * `context` - A pre-compiled DiffSL context containing the compiled code
317    /// * `include_sensitivities` - Whether to extract sparsity patterns for sensitivity computations.
318    ///   If `true`, extracts sparsity patterns for forward and adjoint sensitivities. Set to `true`
319    ///   if you plan to compute sensitivities or adjoints.
320    ///
321    /// # Returns
322    ///
323    /// A new `DiffSl` instance with sparsity patterns extracted (if applicable).
324    ///
325    /// # Note
326    ///
327    /// For dense matrices, this function simply wraps the context without extracting sparsity patterns.
328    pub fn from_context(context: DiffSlContext<M, CG>, include_sensitivities: bool) -> Self {
329        let mut ret = Self {
330            context,
331            include_sensitivities,
332            mass_coloring: None,
333            mass_sparsity: None,
334            mass_transpose_coloring: None,
335            mass_transpose_sparsity: None,
336            rhs_coloring: None,
337            rhs_sparsity: None,
338            rhs_adjoint_coloring: None,
339            rhs_adjoint_sparsity: None,
340            rhs_sens_coloring: None,
341            rhs_sens_sparsity: None,
342            rhs_sens_adjoint_coloring: None,
343            rhs_sens_adjoint_sparsity: None,
344        };
345        if M::is_sparse() {
346            let op = ret.rhs();
347            let ctx = op.context().clone();
348            let n = op.nstates();
349            let nparams = op.nparams();
350
351            let non_zeros = ret.context.rhs_state_deps.as_slice();
352
353            let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.to_vec())
354                .expect("invalid sparsity pattern");
355            let coloring = JacobianColoring::new(&sparsity, non_zeros, ctx.clone());
356            ret.rhs_coloring = Some(coloring);
357            ret.rhs_sparsity = Some(sparsity);
358
359            let non_zeros = non_zeros.iter().map(|(i, j)| (*j, *i)).collect::<Vec<_>>();
360            let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.clone())
361                .expect("invalid sparsity pattern");
362            let coloring = JacobianColoring::new(&sparsity, &non_zeros, ctx.clone());
363            ret.rhs_adjoint_sparsity = Some(sparsity);
364            ret.rhs_adjoint_coloring = Some(coloring);
365
366            if nparams > 0 && include_sensitivities {
367                let non_zeros = ret.context.rhs_input_deps.as_slice();
368
369                let sparsity = M::Sparsity::try_from_indices(n, nparams, non_zeros.to_vec())
370                    .expect("invalid sparsity pattern");
371                let coloring = JacobianColoring::new(&sparsity, non_zeros, ctx.clone());
372                ret.rhs_sens_coloring = Some(coloring);
373                ret.rhs_sens_sparsity = Some(sparsity);
374
375                let non_zeros = non_zeros.iter().map(|(i, j)| (*j, *i)).collect::<Vec<_>>();
376                let sparsity = M::Sparsity::try_from_indices(nparams, n, non_zeros.clone())
377                    .expect("invalid sparsity pattern");
378                let coloring = JacobianColoring::new(&sparsity, &non_zeros, ctx.clone());
379                ret.rhs_sens_adjoint_sparsity = Some(sparsity);
380                ret.rhs_sens_adjoint_coloring = Some(coloring);
381            }
382
383            let non_zeros = ret.context.mass_state_deps.as_slice();
384            if let Some(op) = ret.mass() {
385                let ctx = op.context().clone();
386                let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.to_vec())
387                    .expect("invalid sparsity pattern");
388                let coloring = JacobianColoring::new(&sparsity, non_zeros, op.context().clone());
389                ret.mass_coloring = Some(coloring);
390                ret.mass_sparsity = Some(sparsity);
391
392                let non_zeros = non_zeros.iter().map(|(i, j)| (*j, *i)).collect::<Vec<_>>();
393                let sparsity = M::Sparsity::try_from_indices(n, n, non_zeros.clone())
394                    .expect("invalid sparsity pattern");
395                let coloring = JacobianColoring::new(&sparsity, &non_zeros, ctx);
396                ret.mass_transpose_sparsity = Some(sparsity);
397                ret.mass_transpose_coloring = Some(coloring);
398            }
399        }
400        ret
401    }
402
403    /// Set the active DiffSL model index together with parameters.
404    ///
405    /// This updates the compiler input block and then recomputes constants via `set_u0`.
406    pub fn set_params_and_model(&mut self, p: &M::V, model_index: u32) {
407        self.context.model_index = model_index;
408        self.context.compiler.set_inputs(
409            p.as_slice(),
410            self.context.data.borrow_mut().as_mut_slice(),
411            self.context.model_index,
412        );
413        let mut dummy = M::V::zeros(self.context.nstates, self.context.ctx.clone());
414        self.context.compiler.set_u0(
415            dummy.as_mut_slice(),
416            self.context.data.borrow_mut().as_mut_slice(),
417        );
418    }
419}
420
421#[cfg(feature = "diffsl-external-dynamic")]
422impl<M: MatrixHost<T: DiffSlScalar>> DiffSl<M, ExternalDynModule<M::T>> {
423    /// Create a `DiffSl` instance using externally-provided functions & sparsity patterns.
424    ///
425    /// # Arguments
426    ///
427    /// * `path` - The path to the external dynamic library
428    /// * `ctx` - The computational context for vector and matrix operations (e.g., CPU, GPU)
429    /// * `rhs_state_deps` - Sparsity pattern for the RHS Jacobian (∂f/∂y) as pairs (row, col). Can be empty if M is dense.
430    /// * `rhs_input_deps` - Sparsity pattern for the RHS sensitivity matrix (∂f/∂p) as pairs (row, col). Can be empty if M is dense or if there are no parameters.
431    /// * `mass_state_deps` - Sparsity pattern for the mass matrix Jacobian (∂M/∂y) as pairs (row, col). Can be empty if there is no mass matrix or if M is dense.
432    /// * `include_sensitivities` - Whether to set up sparsity patterns for sensitivity computations.
433    ///   If `true`, enables forward and adjoint sensitivity analysis. Set to `false` to skip
434    ///   sensitivity setup for better memory efficiency when sensitivities are not needed.
435    ///
436    /// # Returns
437    ///
438    /// A new `DiffSl` instance with Jacobian colorings configured for efficient matrix computation,
439    /// or an error if the context creation fails.
440    pub fn from_external_dynamic(
441        path: impl Into<PathBuf>,
442        ctx: M::C,
443        rhs_state_deps: Vec<(usize, usize)>,
444        rhs_input_deps: Vec<(usize, usize)>,
445        mass_state_deps: Vec<(usize, usize)>,
446        include_sensitivities: bool,
447    ) -> Result<Self, DiffsolError> {
448        let context = DiffSlContext::<M, ExternalDynModule<M::T>>::new_external_dynamic(
449            path,
450            1,
451            rhs_state_deps,
452            rhs_input_deps,
453            mass_state_deps,
454            ctx,
455        )?;
456        Ok(Self::from_context(context, include_sensitivities))
457    }
458}
459
460#[cfg(feature = "diffsl-external")]
461impl<M: MatrixHost<T: DiffSlScalar + ExternSymbols>> DiffSl<M, ExternalModule<M::T>> {
462    /// Create a `DiffSl` instance using externally-provided functions & sparsity patterns.
463    ///
464    /// # Arguments
465    ///
466    /// * `ctx` - The computational context for vector and matrix operations (e.g., CPU, GPU)
467    /// * `rhs_state_deps` - Sparsity pattern for the RHS Jacobian (∂f/∂y) as pairs (row, col). Can be empty if M is dense.
468    /// * `rhs_input_deps` - Sparsity pattern for the RHS sensitivity matrix (∂f/∂p) as pairs (row, col). Can be empty if M is dense or if there are no parameters.
469    /// * `mass_state_deps` - Sparsity pattern for the mass matrix Jacobian (∂M/∂y) as pairs (row, col). Can be empty if there is no mass matrix or if M is dense.
470    /// * `include_sensitivities` - Whether to set up sparsity patterns for sensitivity computations.
471    ///   If `true`, enables forward and adjoint sensitivity analysis. Set to `false` to skip
472    ///   sensitivity setup for better memory efficiency when sensitivities are not needed.
473    ///
474    /// # Returns
475    ///
476    /// A new `DiffSl` instance with Jacobian colorings configured for efficient matrix computation,
477    /// or an error if the context creation fails.
478    pub fn from_external(
479        ctx: M::C,
480        rhs_state_deps: Vec<(usize, usize)>,
481        rhs_input_deps: Vec<(usize, usize)>,
482        mass_state_deps: Vec<(usize, usize)>,
483        include_sensitivities: bool,
484    ) -> Result<Self, DiffsolError> {
485        let context = DiffSlContext::<M, ExternalModule<M::T>>::new_external(
486            1,
487            rhs_state_deps,
488            rhs_input_deps,
489            mass_state_deps,
490            ctx,
491        )?;
492        Ok(Self::from_context(context, include_sensitivities))
493    }
494}
495
496impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl<M, CG> {
497    /// Compile DiffSL code into ODE equations.
498    ///
499    /// This is a convenience function that creates a new `DiffSlContext` from the provided code
500    /// and then calls `from_context` to create the `DiffSl` instance. For more control over
501    /// the compilation process (e.g., number of threads), create the context directly using
502    /// `DiffSlContext::new` and then call `from_context`.
503    ///
504    /// # Arguments
505    ///
506    /// * `code` - The DiffSL code defining the ODE system
507    /// * `ctx` - The context for creating vectors and matrices (typically `Default::default()`)
508    /// * `include_sensitivities` - Whether to extract sparsity patterns for sensitivity computations.
509    ///   Set to `true` if you plan to compute sensitivities or adjoints.
510    ///
511    /// # Returns
512    ///
513    /// A new `DiffSl` instance that implements `OdeEquations` and can be used with ODE solvers.
514    ///
515    /// # Errors
516    ///
517    /// Returns an error if the DiffSL code cannot be parsed or compiled.
518    pub fn compile(
519        code: &str,
520        ctx: M::C,
521        include_sensitivities: bool,
522    ) -> Result<Self, DiffsolError> {
523        let context = DiffSlContext::<M, CG>::new(code, 1, ctx)?;
524        Ok(Self::from_context(context, include_sensitivities))
525    }
526}
527
528impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule + CodegenModuleEmit> DiffSl<M, CG> {
529    fn to_external_object(&self) -> Result<DiffSlExternalObject, DiffsolError> {
530        let object = self
531            .context
532            .compiler
533            .module()
534            .to_object()
535            .map_err(|e| DiffsolError::DiffslCompilerError(e.to_string()))?;
536        Ok(DiffSlExternalObject {
537            scalar_type: diffsl_external_scalar_type::<M::T>()?,
538            object,
539            rhs_state_deps: self.context.rhs_state_deps.clone(),
540            rhs_input_deps: self.context.rhs_input_deps.clone(),
541            mass_state_deps: self.context.mass_state_deps.clone(),
542            include_sensitivities: self.include_sensitivities,
543        })
544    }
545}
546
547impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> DiffSl<M, CG>
548where
549    CG: CodegenModuleLink + CodegenModuleJit,
550{
551    fn from_external_object(
552        external_object: DiffSlExternalObject,
553        ctx: M::C,
554    ) -> Result<Self, DiffsolError> {
555        let expected_scalar_type = diffsl_external_scalar_type::<M::T>()?;
556        if external_object.scalar_type != expected_scalar_type {
557            return Err(DiffsolError::Other(format!(
558                "DiffSl external object scalar type mismatch: object is {:?}, requested {:?}",
559                external_object.scalar_type, expected_scalar_type
560            )));
561        }
562        let context = DiffSlContext::<M, CG>::new_from_object(
563            external_object.object,
564            1,
565            external_object.rhs_state_deps,
566            external_object.rhs_input_deps,
567            external_object.mass_state_deps,
568            ctx,
569        )?;
570        Ok(Self::from_context(
571            context,
572            external_object.include_sensitivities,
573        ))
574    }
575}
576
577#[cfg(feature = "diffsl-llvm")]
578impl<M: MatrixHost<T: DiffSlScalar>> Serialize for DiffSl<M, crate::LlvmModule> {
579    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
580    where
581        S: Serializer,
582    {
583        self.to_external_object()
584            .map_err(serde::ser::Error::custom)?
585            .serialize(serializer)
586    }
587}
588
589impl<M: MatrixHost<T: DiffSlScalar>> Serialize for DiffSl<M, ObjectModule> {
590    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
591    where
592        S: Serializer,
593    {
594        self.to_external_object()
595            .map_err(serde::ser::Error::custom)?
596            .serialize(serializer)
597    }
598}
599
600impl<'de, M: MatrixHost<T: DiffSlScalar>> Deserialize<'de> for DiffSl<M, ObjectModule> {
601    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
602    where
603        D: Deserializer<'de>,
604    {
605        let payload = DiffSlExternalObject::deserialize(deserializer)?;
606        Self::from_external_object(payload, M::C::default()).map_err(serde::de::Error::custom)
607    }
608}
609
610pub struct DiffSlRoot<'a, M: Matrix<T: DiffSlScalar>, CG: CodegenModule>(&'a DiffSl<M, CG>);
611pub struct DiffSlReset<'a, M: Matrix<T: DiffSlScalar>, CG: CodegenModule>(&'a DiffSl<M, CG>);
612pub struct DiffSlOut<'a, M: Matrix<T: DiffSlScalar>, CG: CodegenModule>(&'a DiffSl<M, CG>);
613pub struct DiffSlRhs<'a, M: Matrix<T: DiffSlScalar>, CG: CodegenModule>(&'a DiffSl<M, CG>);
614pub struct DiffSlMass<'a, M: Matrix<T: DiffSlScalar>, CG: CodegenModule>(&'a DiffSl<M, CG>);
615pub struct DiffSlInit<'a, M: Matrix<T: DiffSlScalar>, CG: CodegenModule>(&'a DiffSl<M, CG>);
616
617macro_rules! impl_op_for_diffsl {
618    ($name:ident) => {
619        impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> Op for $name<'_, M, CG> {
620            type M = M;
621            type T = M::T;
622            type V = M::V;
623            type C = M::C;
624
625            fn nstates(&self) -> usize {
626                self.0.context.nstates
627            }
628            #[allow(clippy::misnamed_getters)]
629            fn nout(&self) -> usize {
630                self.0.context.nstates
631            }
632            fn nparams(&self) -> usize {
633                self.0.context.nparams
634            }
635            fn context(&self) -> &Self::C {
636                &self.0.context.ctx
637            }
638        }
639    };
640}
641
642impl_op_for_diffsl!(DiffSlRhs);
643impl_op_for_diffsl!(DiffSlMass);
644
645impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> Op for DiffSlInit<'_, M, CG> {
646    type M = M;
647    type T = M::T;
648    type V = M::V;
649    type C = M::C;
650
651    fn nstates(&self) -> usize {
652        self.0.context.nstates
653    }
654    #[allow(clippy::misnamed_getters)]
655    fn nout(&self) -> usize {
656        self.0.context.nstates
657    }
658    fn nparams(&self) -> usize {
659        self.0.context.nparams
660    }
661    fn context(&self) -> &Self::C {
662        &self.0.context.ctx
663    }
664}
665
666impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> Op for DiffSlRoot<'_, M, CG> {
667    type M = M;
668    type T = M::T;
669    type V = M::V;
670    type C = M::C;
671
672    fn nstates(&self) -> usize {
673        self.0.context.nstates
674    }
675    #[allow(clippy::misnamed_getters)]
676    fn nout(&self) -> usize {
677        self.0.context.nroots
678    }
679    fn nparams(&self) -> usize {
680        self.0.context.nparams
681    }
682    fn context(&self) -> &Self::C {
683        &self.0.context.ctx
684    }
685}
686
687impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> Op for DiffSlReset<'_, M, CG> {
688    type M = M;
689    type T = M::T;
690    type V = M::V;
691    type C = M::C;
692
693    fn nstates(&self) -> usize {
694        self.0.context.nstates
695    }
696    #[allow(clippy::misnamed_getters)]
697    fn nout(&self) -> usize {
698        self.0.context.nstates
699    }
700    fn nparams(&self) -> usize {
701        self.0.context.nparams
702    }
703    fn context(&self) -> &Self::C {
704        &self.0.context.ctx
705    }
706}
707
708impl<M: Matrix<T: DiffSlScalar>, CG: CodegenModule> Op for DiffSlOut<'_, M, CG> {
709    type M = M;
710    type T = M::T;
711    type V = M::V;
712    type C = M::C;
713
714    fn nstates(&self) -> usize {
715        self.0.context.nstates
716    }
717    fn nout(&self) -> usize {
718        self.0.context.nout
719    }
720    fn nparams(&self) -> usize {
721        self.0.context.nparams
722    }
723    fn context(&self) -> &Self::C {
724        &self.0.context.ctx
725    }
726}
727
728impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> ConstantOp for DiffSlInit<'_, M, CG> {
729    fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
730        self.0.context.compiler.set_u0(
731            y.as_mut_slice(),
732            self.0.context.data.borrow_mut().as_mut_slice(),
733        );
734    }
735}
736
737impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> ConstantOpSens for DiffSlInit<'_, M, CG> {
738    fn sens_mul_inplace(&self, _t: Self::T, v: &Self::V, y: &mut Self::V) {
739        self.0.context.compiler.set_inputs(
740            v.as_slice(),
741            self.0.context.sens_data.borrow_mut().as_mut_slice(),
742            self.0.context.model_index,
743        );
744        self.0.context.compiler.set_u0_sgrad(
745            self.0.context.tmp.borrow().as_slice(),
746            y.as_mut_slice(),
747            self.0.context.data.borrow_mut().as_mut_slice(),
748            self.0.context.sens_data.borrow_mut().as_mut_slice(),
749        );
750    }
751}
752
753impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> ConstantOpSensAdjoint
754    for DiffSlInit<'_, M, CG>
755{
756    fn sens_transpose_mul_inplace(&self, _t: Self::T, v: &Self::V, y: &mut Self::V) {
757        // copy v to tmp2
758        let mut tmp2 = self.0.context.tmp2.borrow_mut();
759        tmp2.copy_from(v);
760        // zero out sens_data
761        self.0.context.sens_data.borrow_mut().fill(M::T::zero());
762        self.0.context.compiler.set_u0_rgrad(
763            self.0.context.tmp.borrow().as_slice(),
764            tmp2.as_mut_slice(),
765            self.0.context.data.borrow().as_slice(),
766            self.0.context.sens_data.borrow_mut().as_mut_slice(),
767        );
768        self.0.context.compiler.get_inputs(
769            y.as_mut_slice(),
770            self.0.context.sens_data.borrow().as_slice(),
771        );
772        // negate y
773        y.mul_assign(Scale(-M::T::one()));
774    }
775}
776
777impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOp for DiffSlRoot<'_, M, CG> {
778    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
779        self.0.context.compiler.calc_stop(
780            t,
781            x.as_slice(),
782            self.0.context.data.borrow_mut().as_mut_slice(),
783            y.as_mut_slice(),
784        );
785    }
786}
787
788impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpJacobian
789    for DiffSlRoot<'_, M, CG>
790{
791    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
792        let stop = self.0.context.tmp_root.borrow();
793        self.0.context.compiler.calc_stop_grad(
794            t,
795            x.as_slice(),
796            v.as_slice(),
797            self.0.context.data.borrow().as_slice(),
798            self.0.context.ddata.borrow_mut().as_mut_slice(),
799            stop.as_slice(),
800            y.as_mut_slice(),
801        );
802    }
803}
804
805impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpAdjoint
806    for DiffSlRoot<'_, M, CG>
807{
808    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
809        let stop = self.0.context.tmp_root.borrow();
810        let mut tmp2_root = self.0.context.tmp2_root.borrow_mut();
811        tmp2_root.copy_from(v);
812        self.0.context.ddata.borrow_mut().fill(M::T::zero());
813        y.fill(M::T::zero());
814        self.0.context.compiler.calc_stop_rgrad(
815            t,
816            x.as_slice(),
817            y.as_mut_slice(),
818            self.0.context.data.borrow().as_slice(),
819            self.0.context.ddata.borrow_mut().as_mut_slice(),
820            stop.as_slice(),
821            tmp2_root.as_mut_slice(),
822        );
823        y.mul_assign(Scale(-M::T::one()));
824    }
825}
826
827impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSens for DiffSlRoot<'_, M, CG> {
828    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
829        let stop = self.0.context.tmp_root.borrow();
830        self.0.context.compiler.set_inputs(
831            v.as_slice(),
832            self.0.context.sens_data.borrow_mut().as_mut_slice(),
833            self.0.context.model_index,
834        );
835        self.0.context.compiler.calc_stop_sgrad(
836            t,
837            x.as_slice(),
838            self.0.context.data.borrow().as_slice(),
839            self.0.context.sens_data.borrow_mut().as_mut_slice(),
840            stop.as_slice(),
841            y.as_mut_slice(),
842        );
843    }
844}
845
846impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSensAdjoint
847    for DiffSlRoot<'_, M, CG>
848{
849    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
850        let stop = self.0.context.tmp_root.borrow();
851        let mut tmp2_root = self.0.context.tmp2_root.borrow_mut();
852        tmp2_root.copy_from(v);
853        self.0.context.sens_data.borrow_mut().fill(M::T::zero());
854        self.0.context.compiler.calc_stop_srgrad(
855            t,
856            x.as_slice(),
857            self.0.context.data.borrow().as_slice(),
858            self.0.context.sens_data.borrow_mut().as_mut_slice(),
859            stop.as_slice(),
860            tmp2_root.as_mut_slice(),
861        );
862        self.0.context.compiler.get_inputs(
863            y.as_mut_slice(),
864            self.0.context.sens_data.borrow().as_slice(),
865        );
866        y.mul_assign(Scale(-M::T::one()));
867    }
868}
869
870impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOp for DiffSlReset<'_, M, CG> {
871    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
872        self.0.context.compiler.reset(
873            t,
874            x.as_slice(),
875            self.0.context.data.borrow_mut().as_mut_slice(),
876            y.as_mut_slice(),
877        );
878    }
879}
880
881impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpJacobian
882    for DiffSlReset<'_, M, CG>
883{
884    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
885        self.0.context.ddata.borrow_mut().fill(M::T::zero());
886        let tmp = self.0.context.tmp.borrow();
887        self.0.context.compiler.reset_grad(
888            t,
889            x.as_slice(),
890            v.as_slice(),
891            self.0.context.data.borrow_mut().as_slice(),
892            self.0.context.ddata.borrow_mut().as_mut_slice(),
893            tmp.as_slice(),
894            y.as_mut_slice(),
895        );
896    }
897}
898
899impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpAdjoint
900    for DiffSlReset<'_, M, CG>
901{
902    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
903        // copy v to tmp2
904        let mut tmp2 = self.0.context.tmp2.borrow_mut();
905        tmp2.copy_from(v);
906        // zero out ddata
907        self.0.context.ddata.borrow_mut().fill(M::T::zero());
908        // zero y
909        y.fill(M::T::zero());
910        self.0.context.compiler.reset_rgrad(
911            t,
912            x.as_slice(),
913            y.as_mut_slice(),
914            self.0.context.data.borrow().as_slice(),
915            self.0.context.ddata.borrow_mut().as_mut_slice(),
916            self.0.context.tmp.borrow().as_slice(),
917            tmp2.as_mut_slice(),
918        );
919        // negate y
920        y.mul_assign(Scale(-M::T::one()));
921    }
922}
923
924impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSens for DiffSlReset<'_, M, CG> {
925    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
926        let tmp = self.0.context.tmp.borrow();
927        self.0.context.compiler.set_inputs(
928            v.as_slice(),
929            self.0.context.sens_data.borrow_mut().as_mut_slice(),
930            self.0.context.model_index,
931        );
932        self.0.context.compiler.reset_sgrad(
933            t,
934            x.as_slice(),
935            self.0.context.data.borrow_mut().as_slice(),
936            self.0.context.sens_data.borrow_mut().as_mut_slice(),
937            tmp.as_slice(),
938            y.as_mut_slice(),
939        );
940    }
941}
942
943impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSensAdjoint
944    for DiffSlReset<'_, M, CG>
945{
946    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
947        let tmp = self.0.context.tmp.borrow();
948        // copy v to tmp2
949        let mut tmp2 = self.0.context.tmp2.borrow_mut();
950        tmp2.copy_from(v);
951        // zero out sens_data
952        self.0.context.sens_data.borrow_mut().fill(M::T::zero());
953        self.0.context.compiler.reset_srgrad(
954            t,
955            x.as_slice(),
956            self.0.context.data.borrow_mut().as_mut_slice(),
957            self.0.context.sens_data.borrow_mut().as_mut_slice(),
958            tmp.as_slice(),
959            tmp2.as_mut_slice(),
960        );
961        // get inputs
962        self.0.context.compiler.get_inputs(
963            y.as_mut_slice(),
964            self.0.context.sens_data.borrow().as_slice(),
965        );
966        // negate y
967        y.mul_assign(Scale(-M::T::one()));
968    }
969}
970
971impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOp for DiffSlOut<'_, M, CG> {
972    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
973        self.0.context.compiler.calc_out(
974            t,
975            x.as_slice(),
976            self.0.context.data.borrow_mut().as_mut_slice(),
977            y.as_mut_slice(),
978        );
979    }
980}
981
982impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpJacobian
983    for DiffSlOut<'_, M, CG>
984{
985    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
986        // init ddata with all zero except for out
987        let mut ddata = self.0.context.ddata.borrow_mut();
988        ddata.fill(M::T::zero());
989        self.0.context.compiler.calc_out_grad(
990            t,
991            x.as_slice(),
992            v.as_slice(),
993            self.0.context.data.borrow_mut().as_mut_slice(),
994            ddata.as_mut_slice(),
995            self.0.context.tmp_out.borrow().as_slice(),
996            y.as_mut_slice(),
997        );
998    }
999}
1000
1001impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpAdjoint
1002    for DiffSlOut<'_, M, CG>
1003{
1004    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1005        // init ddata with all zero except for out
1006        let mut ddata = self.0.context.ddata.borrow_mut();
1007        ddata.fill(M::T::zero());
1008        let mut tmp2_out = self.0.context.tmp2_out.borrow_mut();
1009        tmp2_out.copy_from(v);
1010        // zero y
1011        y.fill(M::T::zero());
1012        self.0.context.compiler.calc_out_rgrad(
1013            t,
1014            x.as_slice(),
1015            y.as_mut_slice(),
1016            self.0.context.data.borrow_mut().as_slice(),
1017            ddata.as_mut_slice(),
1018            self.0.context.tmp_out.borrow().as_slice(),
1019            tmp2_out.as_mut_slice(),
1020        );
1021        // negate y
1022        y.mul_assign(Scale(-M::T::one()));
1023    }
1024}
1025
1026impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSens for DiffSlOut<'_, M, CG> {
1027    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1028        // set inputs for sens_data
1029        self.0.context.compiler.set_inputs(
1030            v.as_slice(),
1031            self.0.context.sens_data.borrow_mut().as_mut_slice(),
1032            self.0.context.model_index,
1033        );
1034        self.0.context.compiler.calc_out_sgrad(
1035            t,
1036            x.as_slice(),
1037            self.0.context.data.borrow_mut().as_mut_slice(),
1038            self.0.context.sens_data.borrow_mut().as_mut_slice(),
1039            self.0.context.tmp_out.borrow().as_slice(),
1040            y.as_mut_slice(),
1041        );
1042    }
1043}
1044
1045impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSensAdjoint
1046    for DiffSlOut<'_, M, CG>
1047{
1048    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1049        let mut sens_data = self.0.context.sens_data.borrow_mut();
1050        // set outputs for sens_data (zero everything except for out)
1051        sens_data.fill(M::T::zero());
1052        let mut tmp2_out = self.0.context.tmp2_out.borrow_mut();
1053        tmp2_out.copy_from(v);
1054        self.0.context.compiler.calc_out_srgrad(
1055            t,
1056            x.as_slice(),
1057            self.0.context.data.borrow_mut().as_mut_slice(),
1058            sens_data.as_mut_slice(),
1059            self.0.context.tmp_out.borrow().as_slice(),
1060            tmp2_out.as_mut_slice(),
1061        );
1062        // set y to the result in inputs
1063        self.0
1064            .context
1065            .compiler
1066            .get_inputs(y.as_mut_slice(), sens_data.as_slice());
1067        // negate y
1068        y.mul_assign(Scale(-M::T::one()));
1069    }
1070}
1071
1072impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOp for DiffSlRhs<'_, M, CG> {
1073    fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
1074        self.0.context.compiler.rhs(
1075            t,
1076            x.as_slice(),
1077            self.0.context.data.borrow_mut().as_mut_slice(),
1078            y.as_mut_slice(),
1079        );
1080    }
1081}
1082
1083impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpJacobian
1084    for DiffSlRhs<'_, M, CG>
1085{
1086    fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1087        self.0.context.ddata.borrow_mut().fill(M::T::zero());
1088        let tmp = self.0.context.tmp.borrow();
1089        self.0.context.compiler.rhs_grad(
1090            t,
1091            x.as_slice(),
1092            v.as_slice(),
1093            self.0.context.data.borrow_mut().as_slice(),
1094            self.0.context.ddata.borrow_mut().as_mut_slice(),
1095            tmp.as_slice(),
1096            y.as_mut_slice(),
1097        );
1098    }
1099
1100    fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
1101        if let Some(coloring) = &self.0.rhs_coloring {
1102            coloring.jacobian_inplace(self, x, t, y);
1103        } else {
1104            self._default_jacobian_inplace(x, t, y);
1105        }
1106    }
1107    fn jacobian_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
1108        self.0.rhs_sparsity.clone()
1109    }
1110}
1111
1112impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpAdjoint
1113    for DiffSlRhs<'_, M, CG>
1114{
1115    fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1116        // copy v to tmp2
1117        let mut tmp2 = self.0.context.tmp2.borrow_mut();
1118        tmp2.copy_from(v);
1119        // zero out ddata
1120        self.0.context.ddata.borrow_mut().fill(M::T::zero());
1121        // zero y
1122        y.fill(M::T::zero());
1123        self.0.context.compiler.rhs_rgrad(
1124            t,
1125            x.as_slice(),
1126            y.as_mut_slice(),
1127            self.0.context.data.borrow().as_slice(),
1128            self.0.context.ddata.borrow_mut().as_mut_slice(),
1129            self.0.context.tmp.borrow().as_slice(),
1130            tmp2.as_mut_slice(),
1131        );
1132        // negate y
1133        y.mul_assign(Scale(-M::T::one()));
1134    }
1135    fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
1136        // if we have a rhs_coloring and no rhs_adjoint_coloring, user has not called prep_adjoint
1137        // fail here
1138        if self.0.rhs_coloring.is_some() && self.0.rhs_adjoint_coloring.is_none() {
1139            panic!("Adjoint not prepared. Call prep_adjoint before calling adjoint_inplace");
1140        }
1141        if let Some(coloring) = &self.0.rhs_adjoint_coloring {
1142            coloring.jacobian_inplace(self, x, t, y);
1143        } else {
1144            self._default_adjoint_inplace(x, t, y);
1145        }
1146    }
1147    fn adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
1148        self.0.rhs_adjoint_sparsity.clone()
1149    }
1150}
1151
1152impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSens for DiffSlRhs<'_, M, CG> {
1153    fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1154        let tmp = self.0.context.tmp.borrow();
1155        self.0.context.compiler.set_inputs(
1156            v.as_slice(),
1157            self.0.context.sens_data.borrow_mut().as_mut_slice(),
1158            self.0.context.model_index,
1159        );
1160        self.0.context.compiler.rhs_sgrad(
1161            t,
1162            x.as_slice(),
1163            self.0.context.data.borrow_mut().as_slice(),
1164            self.0.context.sens_data.borrow_mut().as_mut_slice(),
1165            tmp.as_slice(),
1166            y.as_mut_slice(),
1167        );
1168    }
1169    fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
1170        if let Some(coloring) = &self.0.rhs_sens_coloring {
1171            coloring.sens_inplace(self, x, t, y);
1172        } else {
1173            self._default_sens_inplace(x, t, y);
1174        }
1175    }
1176    fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
1177        self.0.rhs_sens_sparsity.clone()
1178    }
1179}
1180
1181impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> NonLinearOpSensAdjoint
1182    for DiffSlRhs<'_, M, CG>
1183{
1184    fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) {
1185        // todo: would rhs_srgrad ever use rr? I don't think so, but need to check
1186        let tmp = self.0.context.tmp.borrow();
1187        // copy v to tmp2
1188        let mut tmp2 = self.0.context.tmp2.borrow_mut();
1189        tmp2.copy_from(v);
1190        // zero out sens_data
1191        self.0.context.sens_data.borrow_mut().fill(M::T::zero());
1192        self.0.context.compiler.rhs_srgrad(
1193            t,
1194            x.as_slice(),
1195            self.0.context.data.borrow_mut().as_mut_slice(),
1196            self.0.context.sens_data.borrow_mut().as_mut_slice(),
1197            tmp.as_slice(),
1198            tmp2.as_mut_slice(),
1199        );
1200        // get inputs
1201        self.0.context.compiler.get_inputs(
1202            y.as_mut_slice(),
1203            self.0.context.sens_data.borrow().as_slice(),
1204        );
1205        // negate y
1206        y.mul_assign(Scale(-M::T::one()));
1207    }
1208    fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
1209        if let Some(coloring) = &self.0.rhs_sens_adjoint_coloring {
1210            coloring.sens_adjoint_inplace(self, x, t, y);
1211        } else {
1212            self._default_adjoint_inplace(x, t, y);
1213        }
1214    }
1215    fn sens_adjoint_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
1216        self.0.rhs_sens_adjoint_sparsity.clone()
1217    }
1218}
1219
1220impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> LinearOp for DiffSlMass<'_, M, CG> {
1221    fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
1222        let mut tmp = self.0.context.tmp.borrow_mut();
1223        self.0.context.compiler.mass(
1224            t,
1225            x.as_slice(),
1226            self.0.context.data.borrow_mut().as_mut_slice(),
1227            tmp.as_mut_slice(),
1228        );
1229
1230        // y = tmp + beta * y
1231        y.axpy(M::T::one(), &tmp, beta);
1232    }
1233
1234    fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
1235        if let Some(coloring) = &self.0.mass_coloring {
1236            coloring.matrix_inplace(self, t, y);
1237        } else {
1238            self._default_matrix_inplace(t, y);
1239        }
1240    }
1241    fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
1242        self.0.mass_sparsity.clone()
1243    }
1244}
1245
1246impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> LinearOpTranspose
1247    for DiffSlMass<'_, M, CG>
1248{
1249    fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) {
1250        // scale y by beta
1251        y.mul_assign(Scale(beta));
1252
1253        // copy x to tmp
1254        let mut tmp = self.0.context.tmp.borrow_mut();
1255        tmp.copy_from(x);
1256
1257        // zero out ddata
1258        self.0.context.ddata.borrow_mut().fill(M::T::zero());
1259
1260        // y += M^T x + beta * y
1261        self.0.context.compiler.mass_rgrad(
1262            t,
1263            y.as_mut_slice(),
1264            self.0.context.data.borrow_mut().as_slice(),
1265            self.0.context.ddata.borrow_mut().as_mut_slice(),
1266            tmp.as_mut_slice(),
1267        );
1268    }
1269
1270    fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
1271        if let Some(coloring) = &self.0.mass_transpose_coloring {
1272            coloring.matrix_inplace(self, t, y);
1273        } else {
1274            self._default_matrix_inplace(t, y);
1275        }
1276    }
1277    fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
1278        self.0.mass_transpose_sparsity.clone()
1279    }
1280}
1281
1282impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> Op for DiffSl<M, CG> {
1283    type M = M;
1284    type T = M::T;
1285    type V = M::V;
1286    type C = M::C;
1287
1288    fn nstates(&self) -> usize {
1289        self.context.nstates
1290    }
1291    fn nout(&self) -> usize {
1292        if self.context.has_out {
1293            self.context.nout
1294        } else {
1295            self.context.nstates
1296        }
1297    }
1298    fn nparams(&self) -> usize {
1299        self.context.nparams
1300    }
1301    fn context(&self) -> &Self::C {
1302        &self.context.ctx
1303    }
1304}
1305
1306impl<'a, M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> OdeEquationsRef<'a> for DiffSl<M, CG> {
1307    type Mass = DiffSlMass<'a, M, CG>;
1308    type Rhs = DiffSlRhs<'a, M, CG>;
1309    type Root = DiffSlRoot<'a, M, CG>;
1310    type Init = DiffSlInit<'a, M, CG>;
1311    type Out = DiffSlOut<'a, M, CG>;
1312    type Reset = DiffSlReset<'a, M, CG>;
1313}
1314
1315impl<M: MatrixHost<T: DiffSlScalar>, CG: CodegenModule> OdeEquations for DiffSl<M, CG> {
1316    fn rhs(&self) -> DiffSlRhs<'_, M, CG> {
1317        DiffSlRhs(self)
1318    }
1319
1320    fn mass(&self) -> Option<DiffSlMass<'_, M, CG>> {
1321        self.context.has_mass.then_some(DiffSlMass(self))
1322    }
1323
1324    fn root(&self) -> Option<DiffSlRoot<'_, M, CG>> {
1325        self.context.has_root.then_some(DiffSlRoot(self))
1326    }
1327
1328    fn init(&self) -> DiffSlInit<'_, M, CG> {
1329        DiffSlInit(self)
1330    }
1331
1332    fn out(&self) -> Option<DiffSlOut<'_, M, CG>> {
1333        self.context.has_out.then_some(DiffSlOut(self))
1334    }
1335
1336    fn reset(&self) -> Option<DiffSlReset<'_, M, CG>> {
1337        self.context.has_reset.then_some(DiffSlReset(self))
1338    }
1339
1340    fn set_params(&mut self, p: &Self::V) {
1341        // `set_params` preserves the current model index.
1342        self.set_params_and_model(p, self.context.model_index);
1343    }
1344
1345    fn set_model_index(&mut self, m: usize) {
1346        self.context.model_index = m as u32;
1347        let mut p = M::V::zeros(self.nparams(), self.context.ctx.clone());
1348        self.get_params(&mut p);
1349        self.set_params_and_model(&p, self.context.model_index);
1350    }
1351
1352    fn get_params(&self, p: &mut Self::V) {
1353        self.context
1354            .compiler
1355            .get_inputs(p.as_mut_slice(), self.context.data.borrow().as_slice());
1356    }
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361    use diffsl::execution::{
1362        module::{CodegenModuleCompile, CodegenModuleJit},
1363        scalar::Scalar as DiffSlScalar,
1364    };
1365    #[cfg(feature = "diffsl-llvm")]
1366    use diffsl::ObjectModule;
1367
1368    use crate::{
1369        matrix::MatrixRef,
1370        op::{
1371            linear_op::LinearOp,
1372            nonlinear_op::{NonLinearOp, NonLinearOpJacobian},
1373        },
1374        scalar::Scalar,
1375        ConstantOp, Context, DefaultDenseMatrix, DefaultSolver, DenseMatrix, DiffSlContext,
1376        DiffsolError, Matrix, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint,
1377        OdeBuilder, OdeEquations, OdeSolverMethod, Vector, VectorHost, VectorRef, VectorView,
1378    };
1379    use num_traits::ToPrimitive;
1380
1381    use super::DiffSl;
1382    use num_traits::{FromPrimitive, One, Zero};
1383    use paste::paste;
1384    #[cfg(feature = "diffsl-llvm")]
1385    use serde_json;
1386
1387    /// Macro to generate test functions for all combinations of backend (cranelift/llvm) and scalar type (f32/f64)
1388    ///
1389    /// Usage: `generate_tests!(test_name, generic_test_function);`
1390    ///
1391    /// This will generate 4 test functions:
1392    /// - {test_name}_cranelift_f64
1393    /// - {test_name}_cranelift_f32
1394    /// - {test_name}_llvm_f64
1395    /// - {test_name}_llvm_f32
1396    ///
1397    /// Example:
1398    /// ```
1399    /// fn my_test<M: CodegenModuleCompile + CodegenModuleJit, T: Scalar>() { ... }
1400    /// generate_tests!(my_test);
1401    /// ```
1402    macro_rules! generate_tests {
1403        ($test_fn:ident) => {
1404            generate_tests!(@impl $test_fn, cranelift_dense_f64, crate::CraneliftJitModule, crate::NalgebraMat<f64>, "diffsl-cranelift");
1405            generate_tests!(@impl $test_fn, cranelift_sparse_f64, crate::CraneliftJitModule, crate::FaerSparseMat<f64>, "diffsl-cranelift");
1406            generate_tests!(@impl $test_fn, cranelift_dense_f32, crate::CraneliftJitModule, crate::NalgebraMat<f32>, "diffsl-cranelift");
1407            generate_tests!(@impl $test_fn, cranelift_sparse_f32, crate::CraneliftJitModule, crate::FaerSparseMat<f32>, "diffsl-cranelift");
1408            generate_tests!(@impl $test_fn, llvm_dense_f64, crate::LlvmModule, crate::NalgebraMat<f64>, "diffsl-llvm");
1409            generate_tests!(@impl $test_fn, llvm_sparse_f64, crate::LlvmModule, crate::FaerSparseMat<f64>, "diffsl-llvm");
1410            generate_tests!(@impl $test_fn, llvm_dense_f32, crate::LlvmModule, crate::NalgebraMat<f32>, "diffsl-llvm");
1411            generate_tests!(@impl $test_fn, llvm_sparse_f32, crate::LlvmModule, crate::FaerSparseMat<f32>, "diffsl-llvm");
1412        };
1413        (@impl $test_fn:ident, $variant:ident, $module:ty, $matrix:ty, $feature:literal) => {
1414            paste! {
1415                #[cfg(feature = $feature)]
1416                #[test]
1417                fn [<$test_fn _ $variant>]() {
1418                    $test_fn::<$module, $matrix>();
1419                }
1420            }
1421        };
1422    }
1423
1424    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1425    generate_tests!(diffsl_logistic_growth);
1426    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1427    generate_tests!(diffsl_logistic_growth_with_model_index);
1428    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1429    generate_tests!(diffsl_reset_call_and_jac_mul);
1430    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1431    generate_tests!(diffsl_context_handles_thread_modes);
1432    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1433    generate_tests!(diffsl_context_reports_parser_and_compiler_errors);
1434    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1435    generate_tests!(diffsl_root_and_output_operators_work);
1436
1437    // Sensitivity and reverse-mode (adjoint) require LLVM — Cranelift supports neither.
1438    macro_rules! generate_tests_llvm_only {
1439        ($test_fn:ident) => {
1440            generate_tests!(@impl $test_fn, llvm_dense_f64, crate::LlvmModule, crate::NalgebraMat<f64>, "diffsl-llvm");
1441            generate_tests!(@impl $test_fn, llvm_sparse_f64, crate::LlvmModule, crate::FaerSparseMat<f64>, "diffsl-llvm");
1442            generate_tests!(@impl $test_fn, llvm_dense_f32, crate::LlvmModule, crate::NalgebraMat<f32>, "diffsl-llvm");
1443            generate_tests!(@impl $test_fn, llvm_sparse_f32, crate::LlvmModule, crate::FaerSparseMat<f32>, "diffsl-llvm");
1444        };
1445    }
1446
1447    generate_tests_llvm_only!(diffsl_reset_sens_and_adjoint_gradients);
1448    generate_tests_llvm_only!(diffsl_root_sens_gradients);
1449    /// Tests forward evaluation and Jacobian-vector product for DiffSlReset.
1450    /// Runs on all backends (Cranelift + LLVM).
1451    ///
1452    /// Model: reset_i { 2 * y + a, z + a }  with a=3, (y,z)=(3,2), t=0.
1453    ///   J = d(reset)/d(x) = [[2, 0], [0, 1]]
1454    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1455    fn diffsl_reset_call_and_jac_mul<
1456        CG: CodegenModuleJit + CodegenModuleCompile,
1457        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1458    >()
1459    where
1460        for<'b> &'b M::V: VectorRef<M::V>,
1461        for<'b> &'b M: MatrixRef<M>,
1462    {
1463        let text = "
1464            in { a = 1 }
1465            u_i {
1466                y = a,
1467                z = 2,
1468            }
1469            F_i {
1470                y,
1471                z,
1472            }
1473            reset_i {
1474                2 * y + a,
1475                z + a,
1476            }
1477            stop_i {
1478                y - 0.5,
1479            }
1480            out_i {
1481                y,
1482                z,
1483            }
1484        ";
1485
1486        let ctx = M::C::default();
1487        let a = M::T::from_f64(3.0).unwrap();
1488        let p = ctx.vector_from_vec(vec![a]);
1489        let mut eqn = DiffSl::<M, CG>::compile(text, ctx.clone(), false).unwrap();
1490        eqn.set_params(&p);
1491
1492        // x = (y, z) = (a, 2) = (3, 2) after set_params
1493        let x = eqn.init().call(M::T::zero());
1494        let t = M::T::zero();
1495        let reset_op = eqn.reset().expect("model must have a reset operator");
1496
1497        // reset(x, t) = [2*3+3, 2+3] = [9, 5]
1498        let reset_val = reset_op.call(&x, t);
1499        let reset_expected = ctx.vector_from_vec(vec![
1500            M::T::from_f64(9.0).unwrap(),
1501            M::T::from_f64(5.0).unwrap(),
1502        ]);
1503        reset_val.assert_eq_st(&reset_expected, M::T::from_f64(1e-10).unwrap());
1504
1505        // jac_mul: J*v, J=[[2,0],[0,1]], v=[3,-1] => [6,-1]
1506        let v = ctx.vector_from_vec(vec![M::T::from_f64(3.0).unwrap(), -M::T::one()]);
1507        let mut y = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
1508        reset_op.jac_mul_inplace(&x, t, &v, &mut y);
1509        let jac_mul_expected =
1510            ctx.vector_from_vec(vec![M::T::from_f64(6.0).unwrap(), -M::T::one()]);
1511        y.assert_eq_st(&jac_mul_expected, M::T::from_f64(1e-10).unwrap());
1512    }
1513
1514    /// Tests sensitivity and adjoint gradient products for DiffSlReset.
1515    /// Requires LLVM — Cranelift does not compile sensitivity or reverse-mode autograd.
1516    ///
1517    /// Model: reset_i { 2 * y + a, z + a }  with a=3, (y,z)=(3,2), t=0.
1518    ///   d(reset)/d(a) = [1, 1]
1519    ///   J^T = [[2, 0], [0, 1]] (diagonal, same as J)
1520    ///
1521    /// Note: jac_transpose_mul and sens_transpose_mul return negated values
1522    ///       (same convention as rhs adjoint).
1523    #[allow(dead_code)]
1524    fn diffsl_reset_sens_and_adjoint_gradients<
1525        CG: CodegenModuleJit + CodegenModuleCompile,
1526        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1527    >()
1528    where
1529        for<'b> &'b M::V: VectorRef<M::V>,
1530        for<'b> &'b M: MatrixRef<M>,
1531    {
1532        let text = "
1533            in { a = 1 }
1534            u_i {
1535                y = a,
1536                z = 2,
1537            }
1538            F_i {
1539                y,
1540                z,
1541            }
1542            reset_i {
1543                2 * y + a,
1544                z + a,
1545            }
1546            stop_i {
1547                y - 0.5,
1548            }
1549            out_i {
1550                y,
1551                z,
1552            }
1553        ";
1554
1555        let ctx = M::C::default();
1556        let a = M::T::from_f64(3.0).unwrap();
1557        let p = ctx.vector_from_vec(vec![a]);
1558        let mut eqn = DiffSl::<M, CG>::compile(text, ctx.clone(), false).unwrap();
1559        eqn.set_params(&p);
1560
1561        let x = eqn.init().call(M::T::zero());
1562        let t = M::T::zero();
1563        let reset_op = eqn.reset().expect("model must have a reset operator");
1564
1565        let v = ctx.vector_from_vec(vec![M::T::from_f64(3.0).unwrap(), -M::T::one()]);
1566
1567        // sens_mul: (d_reset/d_a)*vp, d/da=[1,1], vp=[2] => [2,2]
1568        let vp = ctx.vector_from_vec(vec![M::T::from_f64(2.0).unwrap()]);
1569        let mut y = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
1570        reset_op.sens_mul_inplace(&x, t, &vp, &mut y);
1571        let sens_expected = ctx.vector_from_vec(vec![
1572            M::T::from_f64(2.0).unwrap(),
1573            M::T::from_f64(2.0).unwrap(),
1574        ]);
1575        y.assert_eq_st(&sens_expected, M::T::from_f64(1e-10).unwrap());
1576
1577        // jac_transpose_mul: -J^T*v, J=[[2,0],[0,1]], v=[3,-1] => -[6,-1] = [-6,1]
1578        let mut y = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
1579        reset_op.jac_transpose_mul_inplace(&x, t, &v, &mut y);
1580        let jac_adj_expected =
1581            ctx.vector_from_vec(vec![M::T::from_f64(-6.0).unwrap(), M::T::one()]);
1582        y.assert_eq_st(&jac_adj_expected, M::T::from_f64(1e-10).unwrap());
1583
1584        // sens_transpose_mul: -(d_reset/d_a)^T*v = -(1*3 + 1*(-1)) = -2
1585        let mut y_p = ctx.vector_from_vec(vec![M::T::zero()]);
1586        reset_op.sens_transpose_mul_inplace(&x, t, &v, &mut y_p);
1587        let sens_adj_expected = ctx.vector_from_vec(vec![M::T::from_f64(-2.0).unwrap()]);
1588        y_p.assert_eq_st(&sens_adj_expected, M::T::from_f64(1e-10).unwrap());
1589    }
1590
1591    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1592    fn diffsl_logistic_growth<
1593        CG: CodegenModuleJit + CodegenModuleCompile,
1594        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1595    >()
1596    where
1597        for<'b> &'b M::V: VectorRef<M::V>,
1598        for<'b> &'b M: MatrixRef<M>,
1599    {
1600        let text = "
1601            in_i { r = 1, k = 1 }
1602            u_i {
1603                y = 0.1,
1604                z = 0,
1605            }
1606            dudt_i {
1607                dydt = 0,
1608                dzdt = 0,
1609            }
1610            M_i {
1611                dydt,
1612                0,
1613            }
1614            F_i {
1615                (r * y) * (1 - (y / k)),
1616                (2 * y) - z,
1617            }
1618            out_i {
1619                3 * y,
1620                4 * z,
1621            }
1622        ";
1623
1624        let k = M::T::one();
1625        let r = M::T::one();
1626        let ctx = M::C::default();
1627        let context = DiffSlContext::<M, CG>::new(text, 1, ctx.clone()).unwrap();
1628        let p = ctx.vector_from_vec(vec![r, k]);
1629        let mut eqn = DiffSl::from_context(context, false);
1630        eqn.set_params(&p);
1631
1632        // test that the initial values look ok
1633        let y0 = M::T::from_f64(0.1).unwrap();
1634        let init = eqn.init().call(M::T::zero());
1635        let init_expect = ctx.vector_from_vec(vec![y0, M::T::zero()]);
1636        init.assert_eq_st(&init_expect, M::T::from_f64(1e-10).unwrap());
1637        let rhs = eqn.rhs().call(&init, M::T::zero());
1638        let rhs_expect = ctx.vector_from_vec(vec![
1639            r * y0 * (M::T::one() - y0 / k),
1640            M::T::from_f64(2.0).unwrap() * y0,
1641        ]);
1642        rhs.assert_eq_st(&rhs_expect, M::T::from_f64(1e-10).unwrap());
1643        let v = ctx.vector_from_vec(vec![M::T::one(), M::T::one()]);
1644        let rhs_jac = eqn.rhs().jac_mul(&init, M::T::zero(), &v);
1645        let rhs_jac_expect =
1646            ctx.vector_from_vec(vec![r * (M::T::one() - y0 / k) - r * y0 / k, M::T::one()]);
1647        rhs_jac.assert_eq_st(&rhs_jac_expect, M::T::from_f64(1e-10).unwrap());
1648        let mut mass_y = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
1649        let v = ctx.vector_from_vec(vec![M::T::one(), M::T::one()]);
1650        eqn.mass()
1651            .unwrap()
1652            .call_inplace(&v, M::T::zero(), &mut mass_y);
1653        let mass_y_expect = ctx.vector_from_vec(vec![M::T::one(), M::T::zero()]);
1654        mass_y.assert_eq_st(&mass_y_expect, M::T::from_f64(1e-10).unwrap());
1655
1656        // solver a bit and check the state and output
1657        let atol = 1e-4;
1658        let rtol = 1e-4;
1659        let problem = OdeBuilder::<M>::new()
1660            .p([r.to_f64().unwrap(), k.to_f64().unwrap()])
1661            .atol([atol])
1662            .rtol(rtol)
1663            .build_from_eqn(eqn)
1664            .unwrap();
1665        let mut solver = problem.bdf::<<M as DefaultSolver>::LS>().unwrap();
1666        let t = M::T::one();
1667        let (ys, ts, _stop_reason) = solver.solve(t).unwrap();
1668        for (i, t) in ts.iter().enumerate() {
1669            let y_expect = k / (M::T::one() + (k - y0) * (-r * *t).exp() / y0);
1670            let z_expect = M::T::from_f64(2.0).unwrap() * y_expect;
1671            let expected_out = ctx.vector_from_vec(vec![
1672                M::T::from_f64(3.0).unwrap() * y_expect,
1673                M::T::from_f64(4.0).unwrap() * z_expect,
1674            ]);
1675            ys.column(i).into_owned().assert_eq_norm(
1676                &expected_out,
1677                &problem.atol,
1678                problem.rtol,
1679                M::T::from_f64(10.0).unwrap(),
1680            );
1681        }
1682
1683        // do it again with some explicit t_evals
1684        let t_evals = vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0];
1685        let t_evals = t_evals
1686            .into_iter()
1687            .map(|t| M::T::from_f64(t).unwrap())
1688            .collect::<Vec<_>>();
1689        let mut solver = problem.bdf::<<M as DefaultSolver>::LS>().unwrap();
1690        let (ys, _stop_reason) = solver.solve_dense(&t_evals).unwrap();
1691        for (i, t) in t_evals.iter().enumerate() {
1692            let y_expect = k / (M::T::one() + (k - y0) * (-r * *t).exp() / y0);
1693            let z_expect = M::T::from_f64(2.0).unwrap() * y_expect;
1694            let expected_out = ctx.vector_from_vec(vec![
1695                M::T::from_f64(3.0).unwrap() * y_expect,
1696                M::T::from_f64(4.0).unwrap() * z_expect,
1697            ]);
1698            ys.column(i).into_owned().assert_eq_norm(
1699                &expected_out,
1700                &problem.atol,
1701                problem.rtol,
1702                M::T::from_f64(10.0).unwrap(),
1703            );
1704        }
1705    }
1706
1707    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1708    fn diffsl_context_handles_thread_modes<
1709        CG: CodegenModuleJit + CodegenModuleCompile,
1710        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1711    >() {
1712        let text = "
1713            in_i { r = 1 }
1714            u_i { y = 0.1 }
1715            F_i { r * y }
1716            out_i { y }
1717        ";
1718
1719        for nthreads in [0, 1, 4] {
1720            let context = DiffSlContext::<M, CG>::new(text, nthreads, M::C::default()).unwrap();
1721            assert_eq!(context.nstates, 1);
1722            assert_eq!(context.nparams, 1);
1723            assert_eq!(context.nout, 1);
1724            assert!(!context.has_mass);
1725            assert!(!context.has_root);
1726            assert!(!context.has_reset);
1727            assert!(context.has_out);
1728        }
1729    }
1730
1731    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1732    fn diffsl_context_reports_parser_and_compiler_errors<
1733        CG: CodegenModuleJit + CodegenModuleCompile,
1734        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1735    >() {
1736        let parser_err = match DiffSlContext::<M, CG>::new("this is not diffsl", 1, M::C::default())
1737        {
1738            Ok(_) => panic!("expected parser error"),
1739            Err(err) => err,
1740        };
1741        assert!(matches!(parser_err, DiffsolError::DiffslParserError(_)));
1742
1743        let compiler_err = match DiffSlContext::<M, CG>::new(
1744            "
1745                u_i { y = 1 }
1746                F_i { missing_symbol }
1747            ",
1748            1,
1749            M::C::default(),
1750        ) {
1751            Ok(_) => panic!("expected compiler error"),
1752            Err(err) => err,
1753        };
1754        assert!(matches!(compiler_err, DiffsolError::DiffslCompilerError(_)));
1755    }
1756
1757    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1758    fn diffsl_root_and_output_operators_work<
1759        CG: CodegenModuleJit + CodegenModuleCompile,
1760        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1761    >()
1762    where
1763        for<'b> &'b M::V: VectorRef<M::V>,
1764        for<'b> &'b M: MatrixRef<M>,
1765    {
1766        let text = "
1767            in_i { a = 1 }
1768            u_i {
1769                y = a,
1770                z = 2,
1771            }
1772            F_i {
1773                y,
1774                z,
1775            }
1776            stop_i {
1777                y - 0.5,
1778            }
1779            reset_i {
1780                2 * y + a,
1781                z + a,
1782            }
1783            out_i {
1784                3 * y,
1785                4 * z,
1786            }
1787        ";
1788
1789        let ctx = M::C::default();
1790        let mut eqn = DiffSl::<M, CG>::compile(text, ctx.clone(), false).unwrap();
1791        let p = ctx.vector_from_vec(vec![M::T::from_f64(3.0).unwrap()]);
1792        eqn.set_params(&p);
1793
1794        let x = eqn.init().call(M::T::zero());
1795        let root = eqn.root().unwrap().call(&x, M::T::zero());
1796        root.assert_eq_st(
1797            &ctx.vector_from_vec(vec![M::T::from_f64(2.5).unwrap()]),
1798            M::T::from_f64(1e-10).unwrap(),
1799        );
1800
1801        let root_op = eqn.root().unwrap();
1802        let v = ctx.vector_from_vec(vec![M::T::from_f64(2.0).unwrap(), -M::T::one()]);
1803        let mut root_jvp = ctx.vector_from_vec(vec![M::T::zero()]);
1804        root_op.jac_mul_inplace(&x, M::T::zero(), &v, &mut root_jvp);
1805        root_jvp.assert_eq_st(
1806            &ctx.vector_from_vec(vec![M::T::from_f64(2.0).unwrap()]),
1807            M::T::from_f64(1e-10).unwrap(),
1808        );
1809
1810        let out = eqn.out().unwrap().call(&x, M::T::zero());
1811        out.assert_eq_st(
1812            &ctx.vector_from_vec(vec![
1813                M::T::from_f64(9.0).unwrap(),
1814                M::T::from_f64(8.0).unwrap(),
1815            ]),
1816            M::T::from_f64(1e-10).unwrap(),
1817        );
1818    }
1819
1820    #[allow(dead_code)]
1821    fn diffsl_root_sens_gradients<
1822        CG: CodegenModuleJit + CodegenModuleCompile,
1823        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1824    >()
1825    where
1826        for<'b> &'b M::V: VectorRef<M::V>,
1827        for<'b> &'b M: MatrixRef<M>,
1828    {
1829        let text = "
1830            in_i { a = 1 }
1831            u_i {
1832                y = a,
1833                z = 2,
1834            }
1835            F_i {
1836                y,
1837                z,
1838            }
1839            stop_i {
1840                y + a - 0.5,
1841            }
1842        ";
1843
1844        let ctx = M::C::default();
1845        let mut eqn = DiffSl::<M, CG>::compile(text, ctx.clone(), false).unwrap();
1846        let p = ctx.vector_from_vec(vec![M::T::from_f64(3.0).unwrap()]);
1847        eqn.set_params(&p);
1848
1849        let x = eqn.init().call(M::T::zero());
1850        let root_op = eqn.root().expect("model must have a root operator");
1851        let vp = ctx.vector_from_vec(vec![M::T::from_f64(2.0).unwrap()]);
1852        let mut y = ctx.vector_from_vec(vec![M::T::zero()]);
1853
1854        root_op.sens_mul_inplace(&x, M::T::zero(), &vp, &mut y);
1855        y.assert_eq_st(
1856            &ctx.vector_from_vec(vec![M::T::from_f64(2.0).unwrap()]),
1857            M::T::from_f64(1e-10).unwrap(),
1858        );
1859
1860        let v = ctx.vector_from_vec(vec![M::T::from_f64(3.0).unwrap()]);
1861        let mut y_x = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
1862        root_op.jac_transpose_mul_inplace(&x, M::T::zero(), &v, &mut y_x);
1863        y_x.assert_eq_st(
1864            &ctx.vector_from_vec(vec![-M::T::from_f64(3.0).unwrap(), M::T::zero()]),
1865            M::T::from_f64(1e-10).unwrap(),
1866        );
1867
1868        let mut y_p = ctx.vector_from_vec(vec![M::T::zero()]);
1869        root_op.sens_transpose_mul_inplace(&x, M::T::zero(), &v, &mut y_p);
1870        y_p.assert_eq_st(
1871            &ctx.vector_from_vec(vec![-M::T::from_f64(3.0).unwrap()]),
1872            M::T::from_f64(1e-10).unwrap(),
1873        );
1874    }
1875
1876    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
1877    fn diffsl_logistic_growth_with_model_index<
1878        CG: CodegenModuleJit + CodegenModuleCompile,
1879        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1880    >()
1881    where
1882        for<'b> &'b M::V: VectorRef<M::V>,
1883        for<'b> &'b M: MatrixRef<M>,
1884    {
1885        let text = "
1886            r_i {
1887                1,
1888                2,
1889                4,
1890            }
1891            u_i {
1892                y = 0.1,
1893            }
1894            reset_i {
1895                y,
1896            }
1897            stop_i {
1898                y - 0.5,
1899            }
1900            F_i {
1901                r_i[N] * y,
1902            }
1903        ";
1904
1905        let ctx = M::C::default();
1906        let mut eqn = DiffSl::<M, CG>::compile(text, ctx.clone(), false).unwrap();
1907        let t = M::T::zero();
1908        let y = eqn.init().call(t);
1909        let tol = M::T::from_f64(1e-10).unwrap();
1910        let one_tenth = M::T::from_f64(0.1).unwrap();
1911        let p = ctx.vector_from_vec(Vec::<M::T>::new());
1912
1913        let rhs_model_0 = eqn.rhs().call(&y, t);
1914        let rhs_model_0_expected =
1915            ctx.vector_from_vec(vec![M::T::from_f64(1.0).unwrap() * one_tenth]);
1916        rhs_model_0.assert_eq_st(&rhs_model_0_expected, tol);
1917
1918        eqn.set_model_index(1);
1919        let rhs_model_1 = eqn.rhs().call(&y, t);
1920        let rhs_model_1_expected =
1921            ctx.vector_from_vec(vec![M::T::from_f64(2.0).unwrap() * one_tenth]);
1922        rhs_model_1.assert_eq_st(&rhs_model_1_expected, tol);
1923
1924        eqn.set_model_index(2);
1925        let rhs_model_2 = eqn.rhs().call(&y, t);
1926        let rhs_model_2_expected =
1927            ctx.vector_from_vec(vec![M::T::from_f64(4.0).unwrap() * one_tenth]);
1928        rhs_model_2.assert_eq_st(&rhs_model_2_expected, tol);
1929
1930        // set_params preserves the current model index.
1931        eqn.set_params(&p);
1932        let rhs_after_set_params = eqn.rhs().call(&y, t);
1933        rhs_after_set_params.assert_eq_st(&rhs_model_2_expected, tol);
1934    }
1935
1936    #[cfg(feature = "diffsl-llvm")]
1937    fn serialization_test_model() -> &'static str {
1938        "
1939            in_i { a = 1, b = 2 }
1940            u_i {
1941                y = a,
1942                z = 2,
1943            }
1944            dudt_i {
1945                dydt = 0,
1946                dzdt = 0,
1947            }
1948            M_i {
1949                dydt,
1950                0,
1951            }
1952            F_i {
1953                a * y + b,
1954                z + a,
1955            }
1956            stop_i {
1957                y + a - 0.5,
1958            }
1959            reset_i {
1960                2 * y + a,
1961                z + a,
1962            }
1963            out_i {
1964                3 * y,
1965                4 * z,
1966            }
1967        "
1968    }
1969
1970    #[cfg(feature = "diffsl-llvm")]
1971    fn assert_object_roundtrip<M>(include_sensitivities: bool)
1972    where
1973        M: Matrix<V: VectorHost + DefaultDenseMatrix, T: DiffSlScalar> + DefaultSolver,
1974        for<'b> &'b M::V: VectorRef<M::V>,
1975        for<'b> &'b M: MatrixRef<M>,
1976    {
1977        let ctx = M::C::default();
1978        let p = ctx.vector_from_vec(vec![
1979            M::T::from_f64(3.0).unwrap(),
1980            M::T::from_f64(5.0).unwrap(),
1981        ]);
1982        let mut compiled = DiffSl::<M, crate::LlvmModule>::compile(
1983            serialization_test_model(),
1984            ctx.clone(),
1985            include_sensitivities,
1986        )
1987        .unwrap();
1988        compiled.set_params(&p);
1989        let rhs_state_deps = compiled.context.rhs_state_deps.clone();
1990        let rhs_input_deps = compiled.context.rhs_input_deps.clone();
1991        let mass_state_deps = compiled.context.mass_state_deps.clone();
1992
1993        let t = M::T::zero();
1994        let x_compiled = compiled.init().call(t);
1995        let rhs_compiled = compiled.rhs().call(&x_compiled, t);
1996        let v = ctx.vector_from_vec(vec![M::T::one(), M::T::one()]);
1997        let mut mass_compiled = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
1998        compiled
1999            .mass()
2000            .unwrap()
2001            .call_inplace(&v, t, &mut mass_compiled);
2002        let root_compiled = compiled.root().unwrap().call(&x_compiled, t);
2003        let out_compiled = compiled.out().unwrap().call(&x_compiled, t);
2004        let reset_compiled = compiled.reset().unwrap().call(&x_compiled, t);
2005        let external_object = compiled.to_external_object().unwrap();
2006        let mut imported =
2007            DiffSl::<M, ObjectModule>::from_external_object(external_object, ctx.clone()).unwrap();
2008        imported.set_params(&p);
2009
2010        let x_imported = imported.init().call(t);
2011        x_imported.assert_eq_st(&x_compiled, M::T::from_f64(1e-10).unwrap());
2012        let rhs_imported = imported.rhs().call(&x_imported, t);
2013        rhs_imported.assert_eq_st(&rhs_compiled, M::T::from_f64(1e-10).unwrap());
2014        let mut mass_imported = ctx.vector_from_vec(vec![M::T::zero(), M::T::zero()]);
2015        imported
2016            .mass()
2017            .unwrap()
2018            .call_inplace(&v, t, &mut mass_imported);
2019        mass_imported.assert_eq_st(&mass_compiled, M::T::from_f64(1e-10).unwrap());
2020        let root_imported = imported.root().unwrap().call(&x_imported, t);
2021        root_imported.assert_eq_st(&root_compiled, M::T::from_f64(1e-10).unwrap());
2022        let out_imported = imported.out().unwrap().call(&x_imported, t);
2023        out_imported.assert_eq_st(&out_compiled, M::T::from_f64(1e-10).unwrap());
2024        let reset_imported = imported.reset().unwrap().call(&x_imported, t);
2025        reset_imported.assert_eq_st(&reset_compiled, M::T::from_f64(1e-10).unwrap());
2026
2027        assert_eq!(imported.context.rhs_state_deps, rhs_state_deps);
2028        assert_eq!(imported.context.rhs_input_deps, rhs_input_deps);
2029        assert_eq!(imported.context.mass_state_deps, mass_state_deps);
2030        assert_eq!(imported.include_sensitivities, include_sensitivities);
2031    }
2032
2033    #[cfg(feature = "diffsl-llvm")]
2034    #[cfg_attr(
2035        all(target_os = "macos", target_arch = "x86_64"),
2036        ignore = "from_external_object is unsupported on Intel macOS due to unsupported relocations"
2037    )]
2038    #[test]
2039    fn diffsl_external_object_roundtrip_sparse_f64() {
2040        type M = crate::FaerSparseMat<f64>;
2041
2042        let ctx = <M as crate::matrix::MatrixCommon>::C::default();
2043        let compiled =
2044            DiffSl::<M, crate::LlvmModule>::compile(serialization_test_model(), ctx, true).unwrap();
2045        let external_object = compiled.to_external_object().unwrap();
2046        let rhs_state_deps = external_object.rhs_state_deps.clone();
2047        let mass_state_deps = external_object.mass_state_deps.clone();
2048        let include_sensitivities = external_object.include_sensitivities;
2049
2050        assert!(!rhs_state_deps.is_empty());
2051        assert!(!mass_state_deps.is_empty());
2052        assert!(include_sensitivities);
2053
2054        assert_object_roundtrip::<M>(include_sensitivities);
2055
2056        let mut imported = DiffSl::<M, ObjectModule>::from_external_object(
2057            external_object,
2058            <M as crate::matrix::MatrixCommon>::C::default(),
2059        )
2060        .unwrap();
2061        let p = <M as crate::matrix::MatrixCommon>::C::default().vector_from_vec(vec![3.0, 5.0]);
2062        imported.set_params(&p);
2063        assert!(imported.rhs().jacobian_sparsity().is_some());
2064        assert!(imported.mass().unwrap().sparsity().is_some());
2065        assert!(imported.rhs_sens_sparsity.is_some());
2066    }
2067
2068    #[cfg(feature = "diffsl-llvm")]
2069    #[cfg_attr(
2070        all(target_os = "macos", target_arch = "x86_64"),
2071        ignore = "from_external_object is unsupported on Intel macOS due to unsupported relocations"
2072    )]
2073    #[test]
2074    fn diffsl_external_object_roundtrip_dense_f64() {
2075        type M = crate::NalgebraMat<f64>;
2076
2077        assert_object_roundtrip::<M>(false);
2078    }
2079
2080    #[cfg(feature = "diffsl-llvm")]
2081    #[cfg_attr(
2082        all(target_os = "macos", target_arch = "x86_64"),
2083        ignore = "from_external_object is unsupported on Intel macOS due to unsupported relocations"
2084    )]
2085    #[test]
2086    fn diffsl_serde_roundtrip_object_module_f64() {
2087        type M = crate::FaerSparseMat<f64>;
2088
2089        let ctx = <M as crate::matrix::MatrixCommon>::C::default();
2090        let p = ctx.vector_from_vec(vec![3.0, 5.0]);
2091        let compiled =
2092            DiffSl::<M, crate::LlvmModule>::compile(serialization_test_model(), ctx, true).unwrap();
2093        let external_object = compiled.to_external_object().unwrap();
2094        let rhs_state_deps = external_object.rhs_state_deps.clone();
2095        let rhs_input_deps = external_object.rhs_input_deps.clone();
2096        let mass_state_deps = external_object.mass_state_deps.clone();
2097
2098        let mut imported =
2099            DiffSl::<M, ObjectModule>::from_external_object(external_object, ctx).unwrap();
2100        imported.set_params(&p);
2101
2102        let encoded = serde_json::to_string(&imported).unwrap();
2103        let mut decoded: DiffSl<M, ObjectModule> = serde_json::from_str(&encoded).unwrap();
2104        decoded.set_params(&p);
2105
2106        let t = 0.0;
2107        let x_imported = imported.init().call(t);
2108        let x_decoded = decoded.init().call(t);
2109        x_decoded.assert_eq_st(&x_imported, 1e-10);
2110
2111        let rhs_imported = imported.rhs().call(&x_imported, t);
2112        let rhs_decoded = decoded.rhs().call(&x_decoded, t);
2113        rhs_decoded.assert_eq_st(&rhs_imported, 1e-10);
2114
2115        assert_eq!(decoded.context.rhs_state_deps, rhs_state_deps);
2116        assert_eq!(decoded.context.rhs_input_deps, rhs_input_deps);
2117        assert_eq!(decoded.context.mass_state_deps, mass_state_deps);
2118        assert!(decoded.rhs().jacobian_sparsity().is_some());
2119    }
2120
2121    #[cfg(feature = "diffsl-llvm")]
2122    #[test]
2123    fn diffsl_to_external_object_preserves_deps_after_from_context_f64() {
2124        type M = crate::FaerSparseMat<f64>;
2125
2126        let context = DiffSlContext::<M, crate::LlvmModule>::new(
2127            serialization_test_model(),
2128            1,
2129            <M as crate::matrix::MatrixCommon>::C::default(),
2130        )
2131        .unwrap();
2132        let expected_rhs_state_deps = context.rhs_state_deps.clone();
2133        let expected_rhs_input_deps = context.rhs_input_deps.clone();
2134        let expected_mass_state_deps = context.mass_state_deps.clone();
2135        let eqn = DiffSl::from_context(context, true);
2136
2137        let external_object = eqn.to_external_object().unwrap();
2138        let rhs_state_deps = external_object.rhs_state_deps;
2139        let rhs_input_deps = external_object.rhs_input_deps;
2140        let mass_state_deps = external_object.mass_state_deps;
2141        let include_sensitivities = external_object.include_sensitivities;
2142
2143        assert_eq!(rhs_state_deps, expected_rhs_state_deps);
2144        assert_eq!(rhs_input_deps, expected_rhs_input_deps);
2145        assert_eq!(mass_state_deps, expected_mass_state_deps);
2146        assert!(include_sensitivities);
2147    }
2148
2149    #[cfg(feature = "diffsl-llvm")]
2150    #[test]
2151    fn diffsl_from_external_object_rejects_scalar_type_mismatch() {
2152        type Mf64 = crate::FaerSparseMat<f64>;
2153        type Mf32 = crate::FaerSparseMat<f32>;
2154
2155        let external_object = DiffSl::<Mf64, crate::LlvmModule>::compile(
2156            serialization_test_model(),
2157            <Mf64 as crate::matrix::MatrixCommon>::C::default(),
2158            true,
2159        )
2160        .unwrap()
2161        .to_external_object()
2162        .unwrap();
2163
2164        let err = match DiffSl::<Mf32, ObjectModule>::from_external_object(
2165            external_object,
2166            <Mf32 as crate::matrix::MatrixCommon>::C::default(),
2167        ) {
2168            Ok(_) => panic!("expected scalar type mismatch"),
2169            Err(err) => err,
2170        };
2171
2172        assert!(matches!(err, DiffsolError::Other(_)));
2173        assert!(err.to_string().contains("scalar type mismatch"));
2174    }
2175}