Skip to main content

diffsol_c/
ode.rs

1use std::sync::{Arc, Mutex};
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer};
4
5use crate::jit::JitBackendType;
6use crate::{
7    adjoint_checkpoint::AdjointCheckpointWrapper,
8    error::DiffsolRtError,
9    host_array::{FromHostArray, HostArray},
10    initial_condition_options::{
11        InitialConditionSolverOptions, InitialConditionSolverOptionsSnapshot,
12    },
13    linear_solver_type::LinearSolverType,
14    matrix_type::MatrixType,
15    ode_options::{OdeSolverOptions, OdeSolverOptionsSnapshot},
16    ode_solver_type::OdeSolverType,
17    scalar_type::ScalarType,
18    solution_wrapper::SolutionWrapper,
19    solve::Solve,
20};
21
22pub struct Ode {
23    pub(crate) solve: Box<dyn Solve>,
24    code: String,
25    scalar_type: ScalarType,
26    jit_backend: Option<JitBackendType>,
27    linear_solver: LinearSolverType,
28    ode_solver: OdeSolverType,
29}
30
31unsafe impl Send for Ode {}
32unsafe impl Sync for Ode {}
33
34#[derive(Clone)]
35pub struct OdeWrapper(Arc<Mutex<Ode>>);
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
38struct OdeWrapperSnapshot {
39    code: String,
40    equation: Vec<u8>,
41    jit_backend: JitBackendType,
42    scalar_type: ScalarType,
43    matrix_type: MatrixType,
44    linear_solver: LinearSolverType,
45    ode_solver: OdeSolverType,
46    rtol: f64,
47    atol: f64,
48    t0: f64,
49    h0: f64,
50    integrate_out: bool,
51    sens_rtol: Option<f64>,
52    sens_atol: Option<f64>,
53    out_rtol: Option<f64>,
54    out_atol: Option<f64>,
55    param_rtol: Option<f64>,
56    param_atol: Option<f64>,
57    ic_options: InitialConditionSolverOptionsSnapshot,
58    ode_options: OdeSolverOptionsSnapshot,
59}
60
61impl OdeWrapper {
62    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
63        self.0.lock().map_err(|_| {
64            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
65                "Failed to acquire lock on ODE solver".to_string(),
66            ))
67        })
68    }
69}
70
71impl OdeWrapper {
72    fn snapshot(&self) -> Result<OdeWrapperSnapshot, DiffsolRtError> {
73        let ode = self.guard()?;
74        let jit_backend = ode.jit_backend.ok_or_else(|| {
75            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
76                "OdeWrapper serialization is only supported for JIT-backed solvers".to_string(),
77            ))
78        })?;
79        Ok(OdeWrapperSnapshot {
80            code: ode.code.clone(),
81            equation: ode.solve.serialized_diffsl()?,
82            jit_backend,
83            scalar_type: ode.scalar_type,
84            matrix_type: ode.solve.matrix_type(),
85            linear_solver: ode.linear_solver,
86            ode_solver: ode.ode_solver,
87            rtol: ode.solve.rtol(),
88            atol: ode.solve.atol(),
89            t0: ode.solve.t0(),
90            h0: ode.solve.h0(),
91            integrate_out: ode.solve.integrate_out(),
92            sens_rtol: ode.solve.sens_rtol(),
93            sens_atol: ode.solve.sens_atol(),
94            out_rtol: ode.solve.out_rtol(),
95            out_atol: ode.solve.out_atol(),
96            param_rtol: ode.solve.param_rtol(),
97            param_atol: ode.solve.param_atol(),
98            ic_options: InitialConditionSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
99            ode_options: OdeSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
100        })
101    }
102
103    fn build(
104        code: String,
105        scalar_type: ScalarType,
106        solve: Box<dyn Solve>,
107        jit_backend: Option<JitBackendType>,
108        linear_solver: LinearSolverType,
109        ode_solver: OdeSolverType,
110    ) -> Result<Self, DiffsolRtError> {
111        solve.check(linear_solver)?;
112        Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
113            code,
114            scalar_type,
115            solve,
116            jit_backend,
117            linear_solver,
118            ode_solver,
119        }))))
120    }
121
122    fn from_snapshot(snapshot: OdeWrapperSnapshot) -> Result<Self, DiffsolRtError> {
123        let solve = crate::solve::solve_factory_from_serialized_diffsl(
124            snapshot.equation.as_slice(),
125            snapshot.matrix_type,
126            snapshot.scalar_type,
127        )?;
128        let wrapper = Self::build(
129            snapshot.code,
130            snapshot.scalar_type,
131            solve,
132            Some(snapshot.jit_backend),
133            snapshot.linear_solver,
134            snapshot.ode_solver,
135        )?;
136        {
137            let mut ode = wrapper.guard()?;
138            ode.solve.set_rtol(snapshot.rtol);
139            ode.solve.set_atol(snapshot.atol);
140            ode.solve.set_t0(snapshot.t0);
141            ode.solve.set_h0(snapshot.h0);
142            ode.solve.set_integrate_out(snapshot.integrate_out);
143            ode.solve.set_sens_rtol(snapshot.sens_rtol);
144            ode.solve.set_sens_atol(snapshot.sens_atol);
145            ode.solve.set_out_rtol(snapshot.out_rtol);
146            ode.solve.set_out_atol(snapshot.out_atol);
147            ode.solve.set_param_rtol(snapshot.param_rtol);
148            ode.solve.set_param_atol(snapshot.param_atol);
149            snapshot.ic_options.apply_to_solve(ode.solve.as_mut());
150            snapshot.ode_options.apply_to_solve(ode.solve.as_mut());
151        }
152        Ok(wrapper)
153    }
154
155    /// Construct an ODE solver backed by externally-provided DiffSL symbols.
156    #[cfg(feature = "external")]
157    pub fn new_external(
158        rhs_state_deps: Vec<(usize, usize)>,
159        rhs_input_deps: Vec<(usize, usize)>,
160        mass_state_deps: Vec<(usize, usize)>,
161        scalar_type: ScalarType,
162        matrix_type: MatrixType,
163        linear_solver: LinearSolverType,
164        ode_solver: OdeSolverType,
165    ) -> Result<Self, DiffsolRtError> {
166        let solve = crate::solve::solve_factory_external(
167            rhs_state_deps,
168            rhs_input_deps,
169            mass_state_deps,
170            matrix_type,
171            scalar_type,
172        )?;
173        Self::build(
174            String::new(),
175            scalar_type,
176            solve,
177            None,
178            linear_solver,
179            ode_solver,
180        )
181    }
182
183    /// Construct an ODE solver backed by DiffSL symbols loaded from a dynamic library.
184    #[cfg(feature = "diffsl-external-dynamic")]
185    #[allow(clippy::too_many_arguments)]
186    pub fn new_external_dynamic(
187        path: impl Into<std::path::PathBuf>,
188        rhs_state_deps: Vec<(usize, usize)>,
189        rhs_input_deps: Vec<(usize, usize)>,
190        mass_state_deps: Vec<(usize, usize)>,
191        scalar_type: ScalarType,
192        matrix_type: MatrixType,
193        linear_solver: LinearSolverType,
194        ode_solver: OdeSolverType,
195    ) -> Result<Self, DiffsolRtError> {
196        let solve = crate::solve::solve_factory_external_dynamic(
197            path.into(),
198            rhs_state_deps,
199            rhs_input_deps,
200            mass_state_deps,
201            matrix_type,
202            scalar_type,
203        )?;
204        Self::build(
205            String::new(),
206            scalar_type,
207            solve,
208            None,
209            linear_solver,
210            ode_solver,
211        )
212    }
213
214    /// Construct an ODE solver by JIT-compiling DiffSL code immediately.
215    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
216    pub fn new_jit(
217        code: &str,
218        jit_backend: JitBackendType,
219        scalar_type: ScalarType,
220        matrix_type: MatrixType,
221        linear_solver: LinearSolverType,
222        ode_solver: OdeSolverType,
223    ) -> Result<Self, DiffsolRtError> {
224        let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
225        Self::build(
226            code.to_owned(),
227            scalar_type,
228            solve,
229            Some(jit_backend),
230            linear_solver,
231            ode_solver,
232        )
233    }
234
235    /// Matrix type used in the ODE solver. This is fixed after construction.
236    pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolRtError> {
237        Ok(self.guard()?.solve.matrix_type())
238    }
239
240    pub fn get_nstates(&self) -> Result<usize, DiffsolRtError> {
241        Ok(self.guard()?.solve.nstates())
242    }
243
244    pub fn get_nparams(&self) -> Result<usize, DiffsolRtError> {
245        Ok(self.guard()?.solve.nparams())
246    }
247
248    pub fn get_nout(&self) -> Result<usize, DiffsolRtError> {
249        Ok(self.guard()?.solve.nout())
250    }
251
252    pub fn has_stop(&self) -> Result<bool, DiffsolRtError> {
253        Ok(self.guard()?.solve.has_stop())
254    }
255
256    /// Ode solver method, default Bdf (backward differentiation formula).
257    pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolRtError> {
258        Ok(self.guard()?.ode_solver)
259    }
260
261    pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolRtError> {
262        self.guard()?.ode_solver = value;
263        Ok(())
264    }
265
266    /// Linear solver type used in the ODE solver. Set to default to use the
267    /// solver's default choice, which is typically an LU solver.
268    pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolRtError> {
269        Ok(self.guard()?.linear_solver)
270    }
271
272    pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolRtError> {
273        self.guard()?.solve.check(value)?;
274        self.guard()?.linear_solver = value;
275        Ok(())
276    }
277
278    /// Relative tolerance for the solver, default 1e-6. Governs the error relative to the solution size.
279    pub fn get_rtol(&self) -> Result<f64, DiffsolRtError> {
280        Ok(self.guard()?.solve.rtol())
281    }
282
283    pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolRtError> {
284        self.guard()?.solve.set_rtol(value);
285        Ok(())
286    }
287
288    /// Absolute tolerance for the solver, default 1e-6. Governs the error as the solution goes to zero.
289    pub fn get_atol(&self) -> Result<f64, DiffsolRtError> {
290        Ok(self.guard()?.solve.atol())
291    }
292
293    pub fn set_atol(&self, value: f64) -> Result<(), DiffsolRtError> {
294        self.guard()?.solve.set_atol(value);
295        Ok(())
296    }
297
298    /// Initial time for the ODE solve, default 0.0.
299    pub fn get_t0(&self) -> Result<f64, DiffsolRtError> {
300        Ok(self.guard()?.solve.t0())
301    }
302
303    pub fn set_t0(&self, value: f64) -> Result<(), DiffsolRtError> {
304        self.guard()?.solve.set_t0(value);
305        Ok(())
306    }
307
308    /// Initial step size for the ODE solver, default 1.0.
309    pub fn get_h0(&self) -> Result<f64, DiffsolRtError> {
310        Ok(self.guard()?.solve.h0())
311    }
312
313    pub fn set_h0(&self, value: f64) -> Result<(), DiffsolRtError> {
314        self.guard()?.solve.set_h0(value);
315        Ok(())
316    }
317
318    /// Whether to integrate output equations alongside state equations.
319    pub fn get_integrate_out(&self) -> Result<bool, DiffsolRtError> {
320        Ok(self.guard()?.solve.integrate_out())
321    }
322
323    pub fn set_integrate_out(&self, value: bool) -> Result<(), DiffsolRtError> {
324        self.guard()?.solve.set_integrate_out(value);
325        Ok(())
326    }
327
328    /// Relative tolerance for forward sensitivity or adjoint equations.
329    pub fn get_sens_rtol(&self) -> Result<Option<f64>, DiffsolRtError> {
330        Ok(self.guard()?.solve.sens_rtol())
331    }
332
333    pub fn set_sens_rtol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
334        self.guard()?.solve.set_sens_rtol(value);
335        Ok(())
336    }
337
338    /// Absolute tolerance for forward sensitivity or adjoint equations.
339    pub fn get_sens_atol(&self) -> Result<Option<f64>, DiffsolRtError> {
340        Ok(self.guard()?.solve.sens_atol())
341    }
342
343    pub fn set_sens_atol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
344        self.guard()?.solve.set_sens_atol(value);
345        Ok(())
346    }
347
348    /// Relative tolerance for integrated output equations.
349    pub fn get_out_rtol(&self) -> Result<Option<f64>, DiffsolRtError> {
350        Ok(self.guard()?.solve.out_rtol())
351    }
352
353    pub fn set_out_rtol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
354        self.guard()?.solve.set_out_rtol(value);
355        Ok(())
356    }
357
358    /// Absolute tolerance for integrated output equations.
359    pub fn get_out_atol(&self) -> Result<Option<f64>, DiffsolRtError> {
360        Ok(self.guard()?.solve.out_atol())
361    }
362
363    pub fn set_out_atol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
364        self.guard()?.solve.set_out_atol(value);
365        Ok(())
366    }
367
368    /// Relative tolerance for adjoint parameter gradient equations.
369    pub fn get_param_rtol(&self) -> Result<Option<f64>, DiffsolRtError> {
370        Ok(self.guard()?.solve.param_rtol())
371    }
372
373    pub fn set_param_rtol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
374        self.guard()?.solve.set_param_rtol(value);
375        Ok(())
376    }
377
378    /// Absolute tolerance for adjoint parameter gradient equations.
379    pub fn get_param_atol(&self) -> Result<Option<f64>, DiffsolRtError> {
380        Ok(self.guard()?.solve.param_atol())
381    }
382
383    pub fn set_param_atol(&self, value: Option<f64>) -> Result<(), DiffsolRtError> {
384        self.guard()?.solve.set_param_atol(value);
385        Ok(())
386    }
387
388    pub fn get_code(&self) -> Result<String, DiffsolRtError> {
389        Ok(self.guard()?.code.clone())
390    }
391
392    pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolRtError> {
393        Ok(self.guard()?.scalar_type)
394    }
395
396    pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolRtError> {
397        Ok(self.guard()?.jit_backend)
398    }
399
400    pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
401        InitialConditionSolverOptions::new(self.0.clone())
402    }
403
404    pub fn get_options(&self) -> OdeSolverOptions {
405        OdeSolverOptions::new(self.0.clone())
406    }
407
408    /// Get the initial condition vector y0 as a 1D numpy array.
409    pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolRtError> {
410        let mut self_guard = self.guard()?;
411        self_guard.solve.y0(params.as_slice()?)
412    }
413
414    /// evaluate the right-hand side function at time `t` and state `y`.
415    pub fn rhs(
416        &self,
417        params: HostArray,
418        t: f64,
419        y: HostArray,
420    ) -> Result<HostArray, DiffsolRtError> {
421        let mut self_guard = self.guard()?;
422        self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
423    }
424
425    /// evaluate the right-hand side Jacobian-vector product `Jv`` at time `t` and state `y`.
426    pub fn rhs_jac_mul(
427        &self,
428        params: HostArray,
429        t: f64,
430        y: HostArray,
431        v: HostArray,
432    ) -> Result<HostArray, DiffsolRtError> {
433        let mut self_guard = self.guard()?;
434        self_guard
435            .solve
436            .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
437    }
438
439    /// Using the provided state, solve the problem up to time `final_time`.
440    ///
441    /// The number of params must match the expected params in the diffsl code.
442    /// If specified, the config can be used to override the solver method
443    /// (Bdf by default) and SolverType (Lu by default) along with other solver
444    /// params like `rtol`.
445    ///
446    /// :param params: 1D array of solver parameters
447    /// :type params: numpy.ndarray
448    /// :param final_time: end time of solver
449    /// :type final_time: float
450    /// :return: `(ys, ts)` tuple where `ys` is a 2D array of values at times
451    ///     `ts` chosen by the solver
452    /// :rtype: Tuple[numpy.ndarray, numpy.ndarray]
453    ///
454    /// Example:
455    ///     >>> print(ode.solve(np.array([]), 0.5))
456    #[allow(clippy::type_complexity)]
457    pub fn solve(
458        &self,
459        params: HostArray,
460        final_time: f64,
461    ) -> Result<SolutionWrapper, DiffsolRtError> {
462        let mut self_guard = self.guard()?;
463        let params = params.as_slice()?;
464        let linear_solver = self_guard.linear_solver;
465        let method = self_guard.ode_solver;
466        let solution = self_guard
467            .solve
468            .solve(method, linear_solver, params, final_time)?;
469        Ok(SolutionWrapper::new(solution))
470    }
471
472    /// Using the provided state, solve the problem up to time
473    /// `t_eval[t_eval.len()-1]`. Returns 2D array of solution values at
474    /// timepoints given by `t_eval`.
475    ///
476    /// The number of params must match the expected params in the diffsl code.
477    /// The config may be optionally specified to override solver settings.
478    ///
479    /// :param params: 1D array of solver parameters
480    /// :type params: numpy.ndarray
481    /// :param t_eval: 1D array of solver times
482    /// :type params: numpy.ndarray
483    /// :return: 2D array of values at times `t_eval`
484    /// :rtype: numpy.ndarray
485    pub fn solve_dense(
486        &self,
487        params: HostArray,
488        t_eval: HostArray,
489    ) -> Result<SolutionWrapper, DiffsolRtError> {
490        let mut self_guard = self.guard()?;
491        let params = params.as_slice()?;
492        let t_eval = t_eval.as_slice()?;
493        let linear_solver = self_guard.linear_solver;
494        let method = self_guard.ode_solver;
495        let solution = self_guard
496            .solve
497            .solve_dense(method, linear_solver, params, t_eval)?;
498        Ok(SolutionWrapper::new(solution))
499    }
500
501    /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`.
502    /// Returns 2D array of solution values at timepoints given by `t_eval`.
503    /// Also returns a list of 2D arrays of sensitivities at the same timepoints
504    /// as the solution.
505    /// The number of params must match the expected params in the diffsl code.
506    /// The config may be optionally specified to override solver settings.
507    /// :param params: 1D array of solver parameters
508    /// :type params: numpy.ndarray
509    /// :param t_eval: 1D array of solver times
510    /// :type params: numpy.ndarray
511    /// :return: 2D array of values at times `t_eval` and a list of 2D arrays of sensitivities at the same timepoints
512    /// :rtype: (numpy.ndarray, List[numpy.ndarray])
513    #[allow(clippy::type_complexity)]
514    pub fn solve_fwd_sens(
515        &self,
516        params: HostArray,
517        t_eval: HostArray,
518    ) -> Result<SolutionWrapper, DiffsolRtError> {
519        let mut self_guard = self.guard()?;
520        let params = params.as_slice()?;
521        let t_eval = t_eval.as_slice()?;
522        let linear_solver = self_guard.linear_solver;
523        let method = self_guard.ode_solver;
524        let solution = self_guard
525            .solve
526            .solve_fwd_sens(method, linear_solver, params, t_eval)?;
527        Ok(SolutionWrapper::new(solution))
528    }
529
530    /// Solve the continuous adjoint problem for the integral of the model output
531    /// from the initial time to `final_time`.
532    ///
533    /// Returns `(integral, gradient)`, where `integral` is a vector of length
534    /// `nout` and `gradient` is an `(nparams, nout)` matrix.
535    pub fn solve_continuous_adjoint(
536        &self,
537        params: HostArray,
538        final_time: f64,
539    ) -> Result<(HostArray, HostArray), DiffsolRtError> {
540        let mut self_guard = self.guard()?;
541        let linear_solver = self_guard.linear_solver;
542        let ode_solver = self_guard.ode_solver;
543        self_guard.solve.solve_continuous_adjoint(
544            ode_solver,
545            linear_solver,
546            params.as_slice()?,
547            final_time,
548        )
549    }
550
551    /// Solve the forward problem at `t_eval` and retain checkpoint data for a
552    /// later discrete adjoint backward pass.
553    pub fn solve_adjoint_fwd(
554        &self,
555        params: HostArray,
556        t_eval: HostArray,
557    ) -> Result<(SolutionWrapper, AdjointCheckpointWrapper), DiffsolRtError> {
558        let mut self_guard = self.guard()?;
559        let params = params.as_slice()?;
560        let t_eval = t_eval.as_slice()?;
561        let linear_solver = self_guard.linear_solver;
562        let method = self_guard.ode_solver;
563        let (solution, checkpoint) =
564            self_guard
565                .solve
566                .solve_adjoint_fwd(method, linear_solver, params, t_eval)?;
567        Ok((SolutionWrapper::new(solution), checkpoint))
568    }
569
570    /// Solve the discrete adjoint backward pass using a prior forward adjoint
571    /// checkpoint and the gradient of a scalar objective with respect to model
572    /// outputs at each saved evaluation time.
573    ///
574    /// Returns an `(nparams, 1)` gradient matrix.
575    pub fn solve_adjoint_bkwd(
576        &self,
577        solution: &SolutionWrapper,
578        checkpoint: &AdjointCheckpointWrapper,
579        dgdu_eval: HostArray,
580    ) -> Result<HostArray, DiffsolRtError> {
581        let t_eval_host = solution.get_ts()?;
582        let t_eval = Vec::<f64>::from_host_array(t_eval_host)?;
583        let mut self_guard = self.guard()?;
584        let linear_solver = self_guard.linear_solver;
585        let method = self_guard.ode_solver;
586        self_guard
587            .solve
588            .solve_adjoint_bkwd(method, linear_solver, checkpoint, &t_eval, dgdu_eval)
589    }
590}
591
592impl Serialize for OdeWrapper {
593    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
594    where
595        S: Serializer,
596    {
597        self.snapshot()
598            .map_err(serde::ser::Error::custom)?
599            .serialize(serializer)
600    }
601}
602
603impl<'de> Deserialize<'de> for OdeWrapper {
604    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
605    where
606        D: Deserializer<'de>,
607    {
608        let snapshot = OdeWrapperSnapshot::deserialize(deserializer)?;
609        Self::from_snapshot(snapshot).map_err(DeError::custom)
610    }
611}
612
613#[cfg(all(test, feature = "diffsl-external-f64"))]
614mod tests {
615    use super::*;
616    use crate::host_array::FromHostArray;
617    use crate::linear_solver_type::LinearSolverType;
618    use crate::scalar_type::ScalarType;
619    use crate::test_support::{
620        assert_close, assert_solution_tail, logistic_state, logistic_state_dr, mass_state_deps,
621        rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
622    };
623
624    fn all_ode_solvers() -> [OdeSolverType; 4] {
625        [
626            OdeSolverType::Bdf,
627            OdeSolverType::Esdirk34,
628            OdeSolverType::TrBdf2,
629            OdeSolverType::Tsit45,
630        ]
631    }
632
633    fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
634        OdeWrapper::new_external(
635            rhs_state_deps(),
636            rhs_input_deps(),
637            mass_state_deps(),
638            ScalarType::F64,
639            matrix_type,
640            LinearSolverType::Default,
641            ode_solver,
642        )
643        .unwrap()
644    }
645
646    fn assert_runtime_dispatch(matrix_type: MatrixType) {
647        let ode = make_ode(matrix_type, OdeSolverType::Bdf);
648        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
649        assert_eq!(ode.get_nstates().unwrap(), 1);
650        assert_eq!(ode.get_nparams().unwrap(), 1);
651        assert_eq!(ode.get_nout().unwrap(), 1);
652        assert!(ode.has_stop().unwrap());
653
654        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
655        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
656
657        let rhs = ode
658            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
659            .unwrap();
660        assert_close(
661            Vec::<f64>::from_host_array(rhs).unwrap()[0],
662            0.375,
663            ASSERT_TOL,
664            "rhs(0.25)",
665        );
666
667        let rhs_jac_mul = ode
668            .rhs_jac_mul(
669                vector_host(&[2.0]),
670                0.0,
671                vector_host(&[0.25]),
672                vector_host(&[3.0]),
673            )
674            .unwrap();
675        assert_close(
676            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
677            3.0,
678            ASSERT_TOL,
679            "rhs_jac_mul(0.25, 3.0)",
680        );
681    }
682
683    fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
684        let ode = make_ode(matrix_type, ode_solver);
685        ode.set_rtol(1e-8).unwrap();
686        ode.set_atol(1e-8).unwrap();
687
688        let t_eval = [0.25, 0.5, 1.0];
689        let solution = ode
690            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
691            .unwrap();
692
693        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
694    }
695
696    fn hybrid_root_time() -> f64 {
697        0.5 * 9.0_f64.ln()
698    }
699
700    fn assert_hybrid_solution_applies_reset_after_root(ode_solver: OdeSolverType) {
701        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
702        ode.set_rtol(1e-8).unwrap();
703        ode.set_atol(1e-8).unwrap();
704
705        let final_time = 2.0;
706        let solution = ode.solve(vector_host(&[2.0]), final_time).unwrap();
707        let ys = solution.get_ys().unwrap();
708        let ys = ys.as_array::<f64>().unwrap();
709        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
710        let root_time = hybrid_root_time();
711
712        assert_eq!(ys.nrows(), 1);
713        assert_eq!(ys.ncols(), ts.len());
714        assert!(!ts.is_empty(), "expected hybrid solve to produce output");
715        assert_close(
716            *ts.last().unwrap(),
717            final_time,
718            ASSERT_TOL,
719            "hybrid final time",
720        );
721        assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
722        assert!(
723            ts.iter().any(|&t| t < root_time),
724            "expected pre-root samples"
725        );
726        assert!(
727            ts.iter().any(|&t| t > root_time),
728            "expected post-root samples after reset"
729        );
730    }
731
732    fn assert_hybrid_dense_solution_continues_after_reset(ode_solver: OdeSolverType) {
733        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
734        ode.set_rtol(1e-8).unwrap();
735        ode.set_atol(1e-8).unwrap();
736
737        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
738        let solution = ode
739            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
740            .unwrap();
741        let ys = solution.get_ys().unwrap();
742        let ys = ys.as_array::<f64>().unwrap();
743
744        assert_eq!(ys.nrows(), 1);
745        assert_eq!(ys.ncols(), t_eval.len());
746        assert_close(
747            ys[(0, 0)],
748            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
749            5e-4,
750            "hybrid dense pre-root value",
751        );
752        assert_close(
753            ys[(0, 1)],
754            logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
755            5e-4,
756            "hybrid dense near-root value",
757        );
758        for col in 2..t_eval.len() {
759            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
760        }
761    }
762
763    fn assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver: OdeSolverType) {
764        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
765        ode.set_rtol(1e-8).unwrap();
766        ode.set_atol(1e-8).unwrap();
767
768        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
769        let solution = ode
770            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
771            .unwrap();
772        let ys = solution.get_ys().unwrap();
773        let ys = ys.as_array::<f64>().unwrap();
774        let sens = solution.get_sens().unwrap();
775
776        assert_eq!(ys.nrows(), 1);
777        assert_eq!(ys.ncols(), t_eval.len());
778        assert_eq!(sens.len(), 1);
779        let sens_values = sens[0].as_array::<f64>().unwrap();
780        assert_eq!(sens_values.nrows(), 1);
781        assert_eq!(sens_values.ncols(), t_eval.len());
782        assert_close(
783            ys[(0, 0)],
784            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
785            5e-4,
786            "hybrid sens pre-root value",
787        );
788        for col in 2..t_eval.len() {
789            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
790            assert!(
791                sens_values[(0, col)].is_finite(),
792                "expected finite post-root sensitivity at column {col}"
793            );
794        }
795    }
796
797    #[test]
798    fn runtime_dispatch_matches_requested_matrix_type() {
799        for matrix_type in [
800            MatrixType::NalgebraDense,
801            MatrixType::FaerDense,
802            MatrixType::FaerSparse,
803        ] {
804            assert_runtime_dispatch(matrix_type);
805        }
806    }
807
808    #[test]
809    fn bdf_dense_solution_matches_logistic_solution() {
810        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
811        ode.set_rtol(1e-8).unwrap();
812        ode.set_atol(1e-8).unwrap();
813
814        let t_eval = [0.25, 0.5, 1.0];
815        let solution = ode
816            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
817            .unwrap();
818
819        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
820    }
821
822    #[test]
823    fn esdirk34_dense_solution_matches_logistic_solution() {
824        assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
825    }
826
827    #[test]
828    fn tr_bdf2_sparse_solution_matches_logistic_solution() {
829        assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
830    }
831
832    #[test]
833    fn tsit45_dense_solution_matches_logistic_solution() {
834        assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
835    }
836
837    #[test]
838    fn bdf_forward_sensitivities_match_logistic_derivative() {
839        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
840        ode.set_rtol(1e-8).unwrap();
841        ode.set_atol(1e-8).unwrap();
842
843        let t_eval = [0.25, 0.5, 1.0];
844        let solution = ode
845            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
846            .unwrap();
847
848        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
849        let sens = solution.get_sens().unwrap();
850        assert_eq!(sens.len(), 1);
851        let sens_values = sens[0].as_array::<f64>().unwrap();
852        assert_eq!(sens_values.nrows(), 1);
853        assert_eq!(sens_values.ncols(), t_eval.len());
854        for (i, &t) in t_eval.iter().enumerate() {
855            assert_close(
856                sens_values[(0, i)],
857                logistic_state_dr(LOGISTIC_X0, 2.0, t),
858                ASSERT_TOL,
859                &format!("sensitivity[{i}]"),
860            );
861        }
862    }
863
864    #[test]
865    fn hybrid_solution_applies_reset_after_root_for_all_solvers() {
866        for ode_solver in all_ode_solvers() {
867            assert_hybrid_solution_applies_reset_after_root(ode_solver);
868        }
869    }
870
871    #[test]
872    fn hybrid_dense_solution_continues_after_reset_for_all_solvers() {
873        for ode_solver in all_ode_solvers() {
874            assert_hybrid_dense_solution_continues_after_reset(ode_solver);
875        }
876    }
877
878    #[test]
879    fn hybrid_forward_sensitivities_complete_across_reset_for_all_solvers() {
880        for ode_solver in all_ode_solvers() {
881            assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver);
882        }
883    }
884}
885
886#[cfg(all(test, feature = "diffsl-external-dynamic"))]
887mod dynamic_tests {
888    use crate::host_array::FromHostArray;
889    use crate::linear_solver_type::LinearSolverType;
890    use crate::scalar_type::ScalarType;
891    use crate::test_support::{
892        assert_close, assert_solution_tail, external_dynamic_fixture_path, mass_state_deps,
893        rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
894    };
895
896    use super::*;
897
898    fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
899        OdeWrapper::new_external_dynamic(
900            external_dynamic_fixture_path(),
901            rhs_state_deps(),
902            rhs_input_deps(),
903            mass_state_deps(),
904            ScalarType::F64,
905            matrix_type,
906            LinearSolverType::Default,
907            ode_solver,
908        )
909        .unwrap()
910    }
911
912    #[test]
913    fn runtime_dispatch_matches_requested_matrix_type() {
914        for matrix_type in [
915            MatrixType::NalgebraDense,
916            MatrixType::FaerDense,
917            MatrixType::FaerSparse,
918        ] {
919            let ode = make_ode(matrix_type, OdeSolverType::Bdf);
920            assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
921            assert_eq!(ode.get_code().unwrap(), "");
922            assert_eq!(ode.get_jit_backend().unwrap(), None);
923            assert_eq!(ode.get_nstates().unwrap(), 1);
924            assert_eq!(ode.get_nparams().unwrap(), 1);
925            assert_eq!(ode.get_nout().unwrap(), 1);
926            assert!(ode.has_stop().unwrap());
927
928            let y0 = ode.y0(vector_host(&[2.0])).unwrap();
929            assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
930
931            let rhs = ode
932                .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
933                .unwrap();
934            assert_close(
935                Vec::<f64>::from_host_array(rhs).unwrap()[0],
936                0.375,
937                ASSERT_TOL,
938                "rhs(0.25)",
939            );
940
941            let rhs_jac_mul = ode
942                .rhs_jac_mul(
943                    vector_host(&[2.0]),
944                    0.0,
945                    vector_host(&[0.25]),
946                    vector_host(&[3.0]),
947                )
948                .unwrap();
949            assert_close(
950                Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
951                3.0,
952                ASSERT_TOL,
953                "rhs_jac_mul(0.25, 3.0)",
954            );
955        }
956    }
957
958    #[test]
959    fn dense_solution_matches_logistic_solution() {
960        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
961        ode.set_rtol(1e-8).unwrap();
962        ode.set_atol(1e-8).unwrap();
963
964        let t_eval = [0.25, 0.5, 1.0];
965        let solution = ode
966            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
967            .unwrap();
968
969        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
970    }
971
972    #[test]
973    fn non_jit_serialization_is_rejected() {
974        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
975        let err = serde_json::to_string(&ode).unwrap_err().to_string();
976        assert!(err.contains("JIT-backed"));
977    }
978}