Skip to main content

diffsol_c/
ode.rs

1use std::sync::{Arc, Mutex};
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer};
4
5use crate::jit::JitBackendType;
6use crate::{
7    error::DiffsolRtError,
8    host_array::HostArray,
9    initial_condition_options::{
10        InitialConditionSolverOptions, InitialConditionSolverOptionsSnapshot,
11    },
12    linear_solver_type::LinearSolverType,
13    matrix_type::MatrixType,
14    ode_options::{OdeSolverOptions, OdeSolverOptionsSnapshot},
15    ode_solver_type::OdeSolverType,
16    scalar_type::ScalarType,
17    solution_wrapper::SolutionWrapper,
18    solve::Solve,
19};
20
21pub struct Ode {
22    pub(crate) solve: Box<dyn Solve>,
23    code: String,
24    scalar_type: ScalarType,
25    jit_backend: Option<JitBackendType>,
26    linear_solver: LinearSolverType,
27    ode_solver: OdeSolverType,
28}
29
30unsafe impl Send for Ode {}
31unsafe impl Sync for Ode {}
32
33#[derive(Clone)]
34pub struct OdeWrapper(Arc<Mutex<Ode>>);
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
37struct OdeWrapperSnapshot {
38    code: String,
39    equation: Vec<u8>,
40    jit_backend: JitBackendType,
41    scalar_type: ScalarType,
42    matrix_type: MatrixType,
43    linear_solver: LinearSolverType,
44    ode_solver: OdeSolverType,
45    rtol: f64,
46    atol: f64,
47    ic_options: InitialConditionSolverOptionsSnapshot,
48    ode_options: OdeSolverOptionsSnapshot,
49}
50
51impl OdeWrapper {
52    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
53        self.0.lock().map_err(|_| {
54            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
55                "Failed to acquire lock on ODE solver".to_string(),
56            ))
57        })
58    }
59}
60
61impl OdeWrapper {
62    fn snapshot(&self) -> Result<OdeWrapperSnapshot, DiffsolRtError> {
63        let ode = self.guard()?;
64        let jit_backend = ode.jit_backend.ok_or_else(|| {
65            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
66                "OdeWrapper serialization is only supported for JIT-backed solvers".to_string(),
67            ))
68        })?;
69        Ok(OdeWrapperSnapshot {
70            code: ode.code.clone(),
71            equation: ode.solve.serialized_diffsl()?,
72            jit_backend,
73            scalar_type: ode.scalar_type,
74            matrix_type: ode.solve.matrix_type(),
75            linear_solver: ode.linear_solver,
76            ode_solver: ode.ode_solver,
77            rtol: ode.solve.rtol(),
78            atol: ode.solve.atol(),
79            ic_options: InitialConditionSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
80            ode_options: OdeSolverOptionsSnapshot::from_solve(ode.solve.as_ref()),
81        })
82    }
83
84    fn build(
85        code: String,
86        scalar_type: ScalarType,
87        solve: Box<dyn Solve>,
88        jit_backend: Option<JitBackendType>,
89        linear_solver: LinearSolverType,
90        ode_solver: OdeSolverType,
91    ) -> Result<Self, DiffsolRtError> {
92        solve.check(linear_solver)?;
93        Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
94            code,
95            scalar_type,
96            solve,
97            jit_backend,
98            linear_solver,
99            ode_solver,
100        }))))
101    }
102
103    fn from_snapshot(snapshot: OdeWrapperSnapshot) -> Result<Self, DiffsolRtError> {
104        let solve = crate::solve::solve_factory_from_serialized_diffsl(
105            snapshot.equation.as_slice(),
106            snapshot.matrix_type,
107            snapshot.scalar_type,
108        )?;
109        let wrapper = Self::build(
110            snapshot.code,
111            snapshot.scalar_type,
112            solve,
113            Some(snapshot.jit_backend),
114            snapshot.linear_solver,
115            snapshot.ode_solver,
116        )?;
117        {
118            let mut ode = wrapper.guard()?;
119            ode.solve.set_rtol(snapshot.rtol);
120            ode.solve.set_atol(snapshot.atol);
121            snapshot.ic_options.apply_to_solve(ode.solve.as_mut());
122            snapshot.ode_options.apply_to_solve(ode.solve.as_mut());
123        }
124        Ok(wrapper)
125    }
126
127    /// Construct an ODE solver backed by externally-provided DiffSL symbols.
128    #[cfg(feature = "external")]
129    pub fn new_external(
130        rhs_state_deps: Vec<(usize, usize)>,
131        rhs_input_deps: Vec<(usize, usize)>,
132        mass_state_deps: Vec<(usize, usize)>,
133        scalar_type: ScalarType,
134        matrix_type: MatrixType,
135        linear_solver: LinearSolverType,
136        ode_solver: OdeSolverType,
137    ) -> Result<Self, DiffsolRtError> {
138        let solve = crate::solve::solve_factory_external(
139            rhs_state_deps,
140            rhs_input_deps,
141            mass_state_deps,
142            matrix_type,
143            scalar_type,
144        )?;
145        Self::build(
146            String::new(),
147            scalar_type,
148            solve,
149            None,
150            linear_solver,
151            ode_solver,
152        )
153    }
154
155    /// Construct an ODE solver backed by DiffSL symbols loaded from a dynamic library.
156    #[cfg(feature = "diffsl-external-dynamic")]
157    #[allow(clippy::too_many_arguments)]
158    pub fn new_external_dynamic(
159        path: impl Into<std::path::PathBuf>,
160        rhs_state_deps: Vec<(usize, usize)>,
161        rhs_input_deps: Vec<(usize, usize)>,
162        mass_state_deps: Vec<(usize, usize)>,
163        scalar_type: ScalarType,
164        matrix_type: MatrixType,
165        linear_solver: LinearSolverType,
166        ode_solver: OdeSolverType,
167    ) -> Result<Self, DiffsolRtError> {
168        let solve = crate::solve::solve_factory_external_dynamic(
169            path.into(),
170            rhs_state_deps,
171            rhs_input_deps,
172            mass_state_deps,
173            matrix_type,
174            scalar_type,
175        )?;
176        Self::build(
177            String::new(),
178            scalar_type,
179            solve,
180            None,
181            linear_solver,
182            ode_solver,
183        )
184    }
185
186    /// Construct an ODE solver by JIT-compiling DiffSL code immediately.
187    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
188    pub fn new_jit(
189        code: &str,
190        jit_backend: JitBackendType,
191        scalar_type: ScalarType,
192        matrix_type: MatrixType,
193        linear_solver: LinearSolverType,
194        ode_solver: OdeSolverType,
195    ) -> Result<Self, DiffsolRtError> {
196        let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
197        Self::build(
198            code.to_owned(),
199            scalar_type,
200            solve,
201            Some(jit_backend),
202            linear_solver,
203            ode_solver,
204        )
205    }
206
207    /// Matrix type used in the ODE solver. This is fixed after construction.
208    pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolRtError> {
209        Ok(self.guard()?.solve.matrix_type())
210    }
211
212    pub fn get_nstates(&self) -> Result<usize, DiffsolRtError> {
213        Ok(self.guard()?.solve.nstates())
214    }
215
216    pub fn get_nparams(&self) -> Result<usize, DiffsolRtError> {
217        Ok(self.guard()?.solve.nparams())
218    }
219
220    pub fn get_nout(&self) -> Result<usize, DiffsolRtError> {
221        Ok(self.guard()?.solve.nout())
222    }
223
224    pub fn has_stop(&self) -> Result<bool, DiffsolRtError> {
225        Ok(self.guard()?.solve.has_stop())
226    }
227
228    /// Ode solver method, default Bdf (backward differentiation formula).
229    pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolRtError> {
230        Ok(self.guard()?.ode_solver)
231    }
232
233    pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolRtError> {
234        self.guard()?.ode_solver = value;
235        Ok(())
236    }
237
238    /// Linear solver type used in the ODE solver. Set to default to use the
239    /// solver's default choice, which is typically an LU solver.
240    pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolRtError> {
241        Ok(self.guard()?.linear_solver)
242    }
243
244    pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolRtError> {
245        self.guard()?.solve.check(value)?;
246        self.guard()?.linear_solver = value;
247        Ok(())
248    }
249
250    /// Relative tolerance for the solver, default 1e-6. Governs the error relative to the solution size.
251    pub fn get_rtol(&self) -> Result<f64, DiffsolRtError> {
252        Ok(self.guard()?.solve.rtol())
253    }
254
255    pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolRtError> {
256        self.guard()?.solve.set_rtol(value);
257        Ok(())
258    }
259
260    /// Absolute tolerance for the solver, default 1e-6. Governs the error as the solution goes to zero.
261    pub fn get_atol(&self) -> Result<f64, DiffsolRtError> {
262        Ok(self.guard()?.solve.atol())
263    }
264
265    pub fn set_atol(&self, value: f64) -> Result<(), DiffsolRtError> {
266        self.guard()?.solve.set_atol(value);
267        Ok(())
268    }
269
270    pub fn get_code(&self) -> Result<String, DiffsolRtError> {
271        Ok(self.guard()?.code.clone())
272    }
273
274    pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolRtError> {
275        Ok(self.guard()?.scalar_type)
276    }
277
278    pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolRtError> {
279        Ok(self.guard()?.jit_backend)
280    }
281
282    pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
283        InitialConditionSolverOptions::new(self.0.clone())
284    }
285
286    pub fn get_options(&self) -> OdeSolverOptions {
287        OdeSolverOptions::new(self.0.clone())
288    }
289
290    /// Get the initial condition vector y0 as a 1D numpy array.
291    pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolRtError> {
292        let mut self_guard = self.guard()?;
293        self_guard.solve.y0(params.as_slice()?)
294    }
295
296    /// evaluate the right-hand side function at time `t` and state `y`.
297    pub fn rhs(
298        &self,
299        params: HostArray,
300        t: f64,
301        y: HostArray,
302    ) -> Result<HostArray, DiffsolRtError> {
303        let mut self_guard = self.guard()?;
304        self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
305    }
306
307    /// evaluate the right-hand side Jacobian-vector product `Jv`` at time `t` and state `y`.
308    pub fn rhs_jac_mul(
309        &self,
310        params: HostArray,
311        t: f64,
312        y: HostArray,
313        v: HostArray,
314    ) -> Result<HostArray, DiffsolRtError> {
315        let mut self_guard = self.guard()?;
316        self_guard
317            .solve
318            .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
319    }
320
321    /// Using the provided state, solve the problem up to time `final_time`.
322    ///
323    /// The number of params must match the expected params in the diffsl code.
324    /// If specified, the config can be used to override the solver method
325    /// (Bdf by default) and SolverType (Lu by default) along with other solver
326    /// params like `rtol`.
327    ///
328    /// :param params: 1D array of solver parameters
329    /// :type params: numpy.ndarray
330    /// :param final_time: end time of solver
331    /// :type final_time: float
332    /// :return: `(ys, ts)` tuple where `ys` is a 2D array of values at times
333    ///     `ts` chosen by the solver
334    /// :rtype: Tuple[numpy.ndarray, numpy.ndarray]
335    ///
336    /// Example:
337    ///     >>> print(ode.solve(np.array([]), 0.5))
338    #[allow(clippy::type_complexity)]
339    pub fn solve(
340        &self,
341        params: HostArray,
342        final_time: f64,
343    ) -> Result<SolutionWrapper, DiffsolRtError> {
344        let mut self_guard = self.guard()?;
345        let params = params.as_slice()?;
346        let linear_solver = self_guard.linear_solver;
347        let method = self_guard.ode_solver;
348        let solution = self_guard
349            .solve
350            .solve(method, linear_solver, params, final_time)?;
351        Ok(SolutionWrapper::new(solution))
352    }
353
354    /// Solve a hybrid ODE up to `final_time`, automatically applying reset
355    /// functions and continuing after root events until the solution completes.
356    pub fn solve_hybrid(
357        &self,
358        params: HostArray,
359        final_time: f64,
360    ) -> Result<SolutionWrapper, DiffsolRtError> {
361        let mut self_guard = self.guard()?;
362        let params = params.as_slice()?;
363        let linear_solver = self_guard.linear_solver;
364        let method = self_guard.ode_solver;
365        let solution = self_guard
366            .solve
367            .solve_hybrid(method, linear_solver, params, final_time)?;
368        Ok(SolutionWrapper::new(solution))
369    }
370
371    /// Using the provided state, solve the problem up to time
372    /// `t_eval[t_eval.len()-1]`. Returns 2D array of solution values at
373    /// timepoints given by `t_eval`.
374    ///
375    /// The number of params must match the expected params in the diffsl code.
376    /// The config may be optionally specified to override solver settings.
377    ///
378    /// :param params: 1D array of solver parameters
379    /// :type params: numpy.ndarray
380    /// :param t_eval: 1D array of solver times
381    /// :type params: numpy.ndarray
382    /// :return: 2D array of values at times `t_eval`
383    /// :rtype: numpy.ndarray
384    pub fn solve_dense(
385        &self,
386        params: HostArray,
387        t_eval: HostArray,
388    ) -> Result<SolutionWrapper, DiffsolRtError> {
389        let mut self_guard = self.guard()?;
390        let params = params.as_slice()?;
391        let t_eval = t_eval.as_slice()?;
392        let linear_solver = self_guard.linear_solver;
393        let method = self_guard.ode_solver;
394        let solution = self_guard
395            .solve
396            .solve_dense(method, linear_solver, params, t_eval)?;
397        Ok(SolutionWrapper::new(solution))
398    }
399
400    /// Solve a hybrid ODE at dense evaluation times, automatically applying
401    /// reset functions and continuing after root events until all requested
402    /// output points are filled.
403    pub fn solve_hybrid_dense(
404        &self,
405        params: HostArray,
406        t_eval: HostArray,
407    ) -> Result<SolutionWrapper, DiffsolRtError> {
408        let mut self_guard = self.guard()?;
409        let params = params.as_slice()?;
410        let t_eval = t_eval.as_slice()?;
411        let linear_solver = self_guard.linear_solver;
412        let method = self_guard.ode_solver;
413        let solution =
414            self_guard
415                .solve
416                .solve_hybrid_dense(method, linear_solver, params, t_eval)?;
417        Ok(SolutionWrapper::new(solution))
418    }
419
420    /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`.
421    /// Returns 2D array of solution values at timepoints given by `t_eval`.
422    /// Also returns a list of 2D arrays of sensitivities at the same timepoints
423    /// as the solution.
424    /// The number of params must match the expected params in the diffsl code.
425    /// The config may be optionally specified to override solver settings.
426    /// :param params: 1D array of solver parameters
427    /// :type params: numpy.ndarray
428    /// :param t_eval: 1D array of solver times
429    /// :type params: numpy.ndarray
430    /// :return: 2D array of values at times `t_eval` and a list of 2D arrays of sensitivities at the same timepoints
431    /// :rtype: (numpy.ndarray, List[numpy.ndarray])
432    #[allow(clippy::type_complexity)]
433    pub fn solve_fwd_sens(
434        &self,
435        params: HostArray,
436        t_eval: HostArray,
437    ) -> Result<SolutionWrapper, DiffsolRtError> {
438        let mut self_guard = self.guard()?;
439        let params = params.as_slice()?;
440        let t_eval = t_eval.as_slice()?;
441        let linear_solver = self_guard.linear_solver;
442        let method = self_guard.ode_solver;
443        let solution = self_guard
444            .solve
445            .solve_fwd_sens(method, linear_solver, params, t_eval)?;
446        Ok(SolutionWrapper::new(solution))
447    }
448
449    /// Solve a hybrid ODE with forward sensitivities at dense evaluation times,
450    /// automatically applying sensitivity-aware reset functions and continuing
451    /// after root events until all requested output points are filled.
452    #[allow(clippy::type_complexity)]
453    pub fn solve_hybrid_fwd_sens(
454        &self,
455        params: HostArray,
456        t_eval: HostArray,
457    ) -> Result<SolutionWrapper, DiffsolRtError> {
458        let mut self_guard = self.guard()?;
459        let params = params.as_slice()?;
460        let t_eval = t_eval.as_slice()?;
461        let linear_solver = self_guard.linear_solver;
462        let method = self_guard.ode_solver;
463        let solution =
464            self_guard
465                .solve
466                .solve_hybrid_fwd_sens(method, linear_solver, params, t_eval)?;
467        Ok(SolutionWrapper::new(solution))
468    }
469
470    /// Using the provided state, solve the adjoint problem for the sum of squares
471    /// objective given data at timepoints `t_eval`.
472    /// Returns the objective value and a list of 1D arrays of adjoint sensitivities
473    /// for each parameter.
474    #[allow(clippy::type_complexity)]
475    pub fn solve_sum_squares_adj(
476        &self,
477        params: HostArray,
478        data: HostArray,
479        t_eval: HostArray,
480    ) -> Result<(f64, HostArray), DiffsolRtError> {
481        let mut self_guard = self.guard()?;
482        let linear_solver = self_guard.linear_solver;
483        let ode_solver = self_guard.ode_solver;
484
485        self_guard.solve.solve_sum_squares_adj(
486            ode_solver,
487            linear_solver,
488            ode_solver,
489            linear_solver,
490            params.as_slice()?,
491            data,
492            t_eval.as_slice()?,
493        )
494    }
495}
496
497impl Serialize for OdeWrapper {
498    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
499    where
500        S: Serializer,
501    {
502        self.snapshot()
503            .map_err(serde::ser::Error::custom)?
504            .serialize(serializer)
505    }
506}
507
508impl<'de> Deserialize<'de> for OdeWrapper {
509    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
510    where
511        D: Deserializer<'de>,
512    {
513        let snapshot = OdeWrapperSnapshot::deserialize(deserializer)?;
514        Self::from_snapshot(snapshot).map_err(DeError::custom)
515    }
516}
517
518#[cfg(all(test, feature = "diffsl-external-f64"))]
519mod tests {
520    use super::*;
521    use crate::host_array::FromHostArray;
522    use crate::linear_solver_type::LinearSolverType;
523    use crate::scalar_type::ScalarType;
524    use crate::test_support::{
525        assert_close, assert_solution_tail, logistic_state, logistic_state_dr, mass_state_deps,
526        rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
527    };
528
529    fn all_ode_solvers() -> [OdeSolverType; 4] {
530        [
531            OdeSolverType::Bdf,
532            OdeSolverType::Esdirk34,
533            OdeSolverType::TrBdf2,
534            OdeSolverType::Tsit45,
535        ]
536    }
537
538    fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
539        OdeWrapper::new_external(
540            rhs_state_deps(),
541            rhs_input_deps(),
542            mass_state_deps(),
543            ScalarType::F64,
544            matrix_type,
545            LinearSolverType::Default,
546            ode_solver,
547        )
548        .unwrap()
549    }
550
551    fn assert_runtime_dispatch(matrix_type: MatrixType) {
552        let ode = make_ode(matrix_type, OdeSolverType::Bdf);
553        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
554        assert_eq!(ode.get_nstates().unwrap(), 1);
555        assert_eq!(ode.get_nparams().unwrap(), 1);
556        assert_eq!(ode.get_nout().unwrap(), 1);
557        assert!(ode.has_stop().unwrap());
558
559        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
560        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
561
562        let rhs = ode
563            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
564            .unwrap();
565        assert_close(
566            Vec::<f64>::from_host_array(rhs).unwrap()[0],
567            0.375,
568            ASSERT_TOL,
569            "rhs(0.25)",
570        );
571
572        let rhs_jac_mul = ode
573            .rhs_jac_mul(
574                vector_host(&[2.0]),
575                0.0,
576                vector_host(&[0.25]),
577                vector_host(&[3.0]),
578            )
579            .unwrap();
580        assert_close(
581            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
582            3.0,
583            ASSERT_TOL,
584            "rhs_jac_mul(0.25, 3.0)",
585        );
586    }
587
588    fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
589        let ode = make_ode(matrix_type, ode_solver);
590        ode.set_rtol(1e-8).unwrap();
591        ode.set_atol(1e-8).unwrap();
592
593        let t_eval = [0.25, 0.5, 1.0];
594        let solution = ode
595            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
596            .unwrap();
597
598        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
599    }
600
601    fn hybrid_root_time() -> f64 {
602        0.5 * 9.0_f64.ln()
603    }
604
605    fn assert_hybrid_solution_applies_reset_after_root(ode_solver: OdeSolverType) {
606        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
607        ode.set_rtol(1e-8).unwrap();
608        ode.set_atol(1e-8).unwrap();
609
610        let final_time = 2.0;
611        let solution = ode.solve_hybrid(vector_host(&[2.0]), final_time).unwrap();
612        let ys = solution.get_ys().unwrap();
613        let ys = ys.as_array::<f64>().unwrap();
614        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
615        let root_time = hybrid_root_time();
616
617        assert_eq!(ys.nrows(), 1);
618        assert_eq!(ys.ncols(), ts.len());
619        assert!(!ts.is_empty(), "expected hybrid solve to produce output");
620        assert_close(
621            *ts.last().unwrap(),
622            final_time,
623            ASSERT_TOL,
624            "hybrid final time",
625        );
626        assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
627        assert!(
628            ts.iter().any(|&t| t < root_time),
629            "expected pre-root samples"
630        );
631        assert!(
632            ts.iter().any(|&t| t > root_time),
633            "expected post-root samples after reset"
634        );
635    }
636
637    fn assert_hybrid_dense_solution_continues_after_reset(ode_solver: OdeSolverType) {
638        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
639        ode.set_rtol(1e-8).unwrap();
640        ode.set_atol(1e-8).unwrap();
641
642        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
643        let solution = ode
644            .solve_hybrid_dense(vector_host(&[2.0]), vector_host(&t_eval))
645            .unwrap();
646        let ys = solution.get_ys().unwrap();
647        let ys = ys.as_array::<f64>().unwrap();
648
649        assert_eq!(ys.nrows(), 1);
650        assert_eq!(ys.ncols(), t_eval.len());
651        assert_close(
652            ys[(0, 0)],
653            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
654            5e-4,
655            "hybrid dense pre-root value",
656        );
657        assert_close(
658            ys[(0, 1)],
659            logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
660            5e-4,
661            "hybrid dense near-root value",
662        );
663        for col in 2..t_eval.len() {
664            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
665        }
666    }
667
668    fn assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver: OdeSolverType) {
669        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
670        ode.set_rtol(1e-8).unwrap();
671        ode.set_atol(1e-8).unwrap();
672
673        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
674        let solution = ode
675            .solve_hybrid_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
676            .unwrap();
677        let ys = solution.get_ys().unwrap();
678        let ys = ys.as_array::<f64>().unwrap();
679        let sens = solution.get_sens().unwrap();
680
681        assert_eq!(ys.nrows(), 1);
682        assert_eq!(ys.ncols(), t_eval.len());
683        assert_eq!(sens.len(), 1);
684        let sens_values = sens[0].as_array::<f64>().unwrap();
685        assert_eq!(sens_values.nrows(), 1);
686        assert_eq!(sens_values.ncols(), t_eval.len());
687        assert_close(
688            ys[(0, 0)],
689            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
690            5e-4,
691            "hybrid sens pre-root value",
692        );
693        for col in 2..t_eval.len() {
694            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
695            assert!(
696                sens_values[(0, col)].is_finite(),
697                "expected finite post-root sensitivity at column {col}"
698            );
699        }
700    }
701
702    #[test]
703    fn runtime_dispatch_matches_requested_matrix_type() {
704        for matrix_type in [
705            MatrixType::NalgebraDense,
706            MatrixType::FaerDense,
707            MatrixType::FaerSparse,
708        ] {
709            assert_runtime_dispatch(matrix_type);
710        }
711    }
712
713    #[test]
714    fn bdf_dense_solution_matches_logistic_solution() {
715        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
716        ode.set_rtol(1e-8).unwrap();
717        ode.set_atol(1e-8).unwrap();
718
719        let t_eval = [0.25, 0.5, 1.0];
720        let solution = ode
721            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
722            .unwrap();
723
724        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
725    }
726
727    #[test]
728    fn esdirk34_dense_solution_matches_logistic_solution() {
729        assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
730    }
731
732    #[test]
733    fn tr_bdf2_sparse_solution_matches_logistic_solution() {
734        assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
735    }
736
737    #[test]
738    fn tsit45_dense_solution_matches_logistic_solution() {
739        assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
740    }
741
742    #[test]
743    fn bdf_forward_sensitivities_match_logistic_derivative() {
744        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
745        ode.set_rtol(1e-8).unwrap();
746        ode.set_atol(1e-8).unwrap();
747
748        let t_eval = [0.25, 0.5, 1.0];
749        let solution = ode
750            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
751            .unwrap();
752
753        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
754        let sens = solution.get_sens().unwrap();
755        assert_eq!(sens.len(), 1);
756        let sens_values = sens[0].as_array::<f64>().unwrap();
757        assert_eq!(sens_values.nrows(), 1);
758        assert_eq!(sens_values.ncols(), t_eval.len());
759        for (i, &t) in t_eval.iter().enumerate() {
760            assert_close(
761                sens_values[(0, i)],
762                logistic_state_dr(LOGISTIC_X0, 2.0, t),
763                ASSERT_TOL,
764                &format!("sensitivity[{i}]"),
765            );
766        }
767    }
768
769    #[test]
770    fn bdf_sum_squares_adjoint_matches_external_logistic_model() {
771        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
772        ode.set_rtol(1e-8).unwrap();
773        ode.set_atol(1e-8).unwrap();
774
775        let t_eval = [0.0, 0.25, 0.5, 1.0];
776        let data_values: Vec<f64> = t_eval
777            .iter()
778            .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
779            .collect();
780        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
781        let (value, sens) = ode
782            .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
783            .unwrap();
784        let grad = Vec::<f64>::from_host_array(sens).unwrap();
785
786        assert_close(value, 0.0, ASSERT_TOL, "sum_squares objective");
787        assert_eq!(grad.len(), 1);
788        assert_close(grad[0], 0.0, ASSERT_TOL, "sum_squares gradient");
789    }
790
791    #[test]
792    fn hybrid_solution_applies_reset_after_root_for_all_solvers() {
793        for ode_solver in all_ode_solvers() {
794            assert_hybrid_solution_applies_reset_after_root(ode_solver);
795        }
796    }
797
798    #[test]
799    fn hybrid_dense_solution_continues_after_reset_for_all_solvers() {
800        for ode_solver in all_ode_solvers() {
801            assert_hybrid_dense_solution_continues_after_reset(ode_solver);
802        }
803    }
804
805    #[test]
806    fn hybrid_forward_sensitivities_complete_across_reset_for_all_solvers() {
807        for ode_solver in all_ode_solvers() {
808            assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver);
809        }
810    }
811}
812
813#[cfg(all(test, feature = "diffsl-external-dynamic"))]
814mod dynamic_tests {
815    use crate::host_array::FromHostArray;
816    use crate::linear_solver_type::LinearSolverType;
817    use crate::scalar_type::ScalarType;
818    use crate::test_support::{
819        assert_close, assert_solution_tail, external_dynamic_fixture_path, mass_state_deps,
820        rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
821    };
822
823    use super::*;
824
825    fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
826        OdeWrapper::new_external_dynamic(
827            external_dynamic_fixture_path(),
828            rhs_state_deps(),
829            rhs_input_deps(),
830            mass_state_deps(),
831            ScalarType::F64,
832            matrix_type,
833            LinearSolverType::Default,
834            ode_solver,
835        )
836        .unwrap()
837    }
838
839    #[test]
840    fn runtime_dispatch_matches_requested_matrix_type() {
841        for matrix_type in [
842            MatrixType::NalgebraDense,
843            MatrixType::FaerDense,
844            MatrixType::FaerSparse,
845        ] {
846            let ode = make_ode(matrix_type, OdeSolverType::Bdf);
847            assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
848            assert_eq!(ode.get_code().unwrap(), "");
849            assert_eq!(ode.get_jit_backend().unwrap(), None);
850            assert_eq!(ode.get_nstates().unwrap(), 1);
851            assert_eq!(ode.get_nparams().unwrap(), 1);
852            assert_eq!(ode.get_nout().unwrap(), 1);
853            assert!(ode.has_stop().unwrap());
854
855            let y0 = ode.y0(vector_host(&[2.0])).unwrap();
856            assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
857
858            let rhs = ode
859                .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
860                .unwrap();
861            assert_close(
862                Vec::<f64>::from_host_array(rhs).unwrap()[0],
863                0.375,
864                ASSERT_TOL,
865                "rhs(0.25)",
866            );
867
868            let rhs_jac_mul = ode
869                .rhs_jac_mul(
870                    vector_host(&[2.0]),
871                    0.0,
872                    vector_host(&[0.25]),
873                    vector_host(&[3.0]),
874                )
875                .unwrap();
876            assert_close(
877                Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
878                3.0,
879                ASSERT_TOL,
880                "rhs_jac_mul(0.25, 3.0)",
881            );
882        }
883    }
884
885    #[test]
886    fn dense_solution_matches_logistic_solution() {
887        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
888        ode.set_rtol(1e-8).unwrap();
889        ode.set_atol(1e-8).unwrap();
890
891        let t_eval = [0.25, 0.5, 1.0];
892        let solution = ode
893            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
894            .unwrap();
895
896        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
897    }
898
899    #[test]
900    fn non_jit_serialization_is_rejected() {
901        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
902        let err = serde_json::to_string(&ode).unwrap_err().to_string();
903        assert!(err.contains("JIT-backed"));
904    }
905}
906
907#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
908mod jit_tests {
909    use crate::host_array::FromHostArray;
910    use crate::jit::JitBackendType;
911    use crate::linear_solver_type::LinearSolverType;
912    use crate::scalar_type::ScalarType;
913    use crate::test_support::{
914        assert_close, assert_solution_tail, available_jit_backends, hybrid_logistic_diffsl_code,
915        hybrid_logistic_period, hybrid_logistic_state, logistic_diffsl_code, logistic_state,
916        vector_host, ASSERT_TOL, LOGISTIC_X0,
917    };
918    #[cfg(feature = "diffsl-llvm")]
919    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
920    #[cfg(any(
921        all(feature = "diffsl-llvm", not(feature = "diffsl-cranelift")),
922        all(feature = "diffsl-cranelift", not(feature = "diffsl-llvm"))
923    ))]
924    use serde_json::Value;
925    use serde_json::{self};
926
927    use super::*;
928
929    fn all_ode_solvers() -> [OdeSolverType; 4] {
930        [
931            OdeSolverType::Bdf,
932            OdeSolverType::Esdirk34,
933            OdeSolverType::TrBdf2,
934            OdeSolverType::Tsit45,
935        ]
936    }
937
938    fn make_ode(
939        jit_backend: JitBackendType,
940        scalar_type: ScalarType,
941        matrix_type: MatrixType,
942        ode_solver: OdeSolverType,
943    ) -> OdeWrapper {
944        OdeWrapper::new_jit(
945            logistic_diffsl_code(),
946            jit_backend,
947            scalar_type,
948            matrix_type,
949            LinearSolverType::Default,
950            ode_solver,
951        )
952        .unwrap()
953    }
954
955    fn make_hybrid_ode(
956        jit_backend: JitBackendType,
957        matrix_type: MatrixType,
958        ode_solver: OdeSolverType,
959    ) -> OdeWrapper {
960        OdeWrapper::new_jit(
961            hybrid_logistic_diffsl_code(),
962            jit_backend,
963            ScalarType::F64,
964            matrix_type,
965            LinearSolverType::Default,
966            ode_solver,
967        )
968        .unwrap()
969    }
970
971    fn serialized_linear_solver(matrix_type: MatrixType) -> LinearSolverType {
972        match matrix_type {
973            MatrixType::NalgebraDense | MatrixType::FaerDense => LinearSolverType::Lu,
974            MatrixType::FaerSparse => LinearSolverType::Default,
975        }
976    }
977
978    fn configure_serialized_ode(ode: &OdeWrapper, matrix_type: MatrixType) {
979        ode.set_linear_solver(serialized_linear_solver(matrix_type))
980            .unwrap();
981        ode.set_ode_solver(OdeSolverType::TrBdf2).unwrap();
982        ode.set_rtol(1e-7).unwrap();
983        ode.set_atol(1e-9).unwrap();
984
985        let ic_options = ode.get_ic_options();
986        ic_options.set_use_linesearch(true).unwrap();
987        ic_options.set_max_linesearch_iterations(13).unwrap();
988        ic_options.set_max_newton_iterations(17).unwrap();
989        ic_options.set_max_linear_solver_setups(19).unwrap();
990        ic_options.set_step_reduction_factor(0.5).unwrap();
991        ic_options.set_armijo_constant(1e-4).unwrap();
992
993        let options = ode.get_options();
994        options.set_max_nonlinear_solver_iterations(23).unwrap();
995        options.set_max_error_test_failures(29).unwrap();
996        options.set_update_jacobian_after_steps(31).unwrap();
997        options.set_update_rhs_jacobian_after_steps(37).unwrap();
998        options.set_threshold_to_update_jacobian(1e-3).unwrap();
999        options.set_threshold_to_update_rhs_jacobian(2e-3).unwrap();
1000        options.set_min_timestep(1e-4).unwrap();
1001    }
1002
1003    fn scalar_value(value: f64, scalar_type: ScalarType) -> f64 {
1004        match scalar_type {
1005            ScalarType::F32 => (value as f32) as f64,
1006            ScalarType::F64 => value,
1007        }
1008    }
1009
1010    fn assert_serialization_roundtrip(
1011        jit_backend: JitBackendType,
1012        scalar_type: ScalarType,
1013        matrix_type: MatrixType,
1014    ) {
1015        let ode = make_ode(jit_backend, scalar_type, matrix_type, OdeSolverType::Bdf);
1016        configure_serialized_ode(&ode, matrix_type);
1017
1018        #[cfg(feature = "diffsl-cranelift")]
1019        if jit_backend == JitBackendType::Cranelift {
1020            let err = serde_json::to_string(&ode).unwrap_err().to_string();
1021            assert!(err.contains("not supported for Cranelift"));
1022            return;
1023        }
1024
1025        let y0_before = Vec::<f64>::from_host_array(ode.y0(vector_host(&[2.0])).unwrap()).unwrap();
1026        let rhs_before = Vec::<f64>::from_host_array(
1027            ode.rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
1028                .unwrap(),
1029        )
1030        .unwrap();
1031
1032        let encoded = serde_json::to_string(&ode).unwrap();
1033        let decoded: OdeWrapper = serde_json::from_str(&encoded).unwrap();
1034
1035        assert_eq!(decoded.get_jit_backend().unwrap(), Some(jit_backend));
1036        assert_eq!(decoded.get_code().unwrap(), logistic_diffsl_code());
1037        assert_eq!(decoded.get_scalar_type().unwrap(), scalar_type);
1038        assert_eq!(decoded.get_matrix_type().unwrap(), matrix_type);
1039        assert_eq!(
1040            decoded.get_linear_solver().unwrap(),
1041            serialized_linear_solver(matrix_type)
1042        );
1043        assert_eq!(decoded.get_ode_solver().unwrap(), OdeSolverType::TrBdf2);
1044        assert_close(
1045            decoded.get_rtol().unwrap(),
1046            scalar_value(1e-7, scalar_type),
1047            1e-12,
1048            "serialized rtol",
1049        );
1050        assert_close(
1051            decoded.get_atol().unwrap(),
1052            scalar_value(1e-9, scalar_type),
1053            1e-12,
1054            "serialized atol",
1055        );
1056
1057        let ic_options = decoded.get_ic_options();
1058        assert!(ic_options.get_use_linesearch().unwrap());
1059        assert_eq!(ic_options.get_max_linesearch_iterations().unwrap(), 13);
1060        assert_eq!(ic_options.get_max_newton_iterations().unwrap(), 17);
1061        assert_eq!(ic_options.get_max_linear_solver_setups().unwrap(), 19);
1062        assert_close(
1063            ic_options.get_step_reduction_factor().unwrap(),
1064            scalar_value(0.5, scalar_type),
1065            1e-12,
1066            "serialized step_reduction_factor",
1067        );
1068        assert_close(
1069            ic_options.get_armijo_constant().unwrap(),
1070            scalar_value(1e-4, scalar_type),
1071            1e-12,
1072            "serialized armijo_constant",
1073        );
1074
1075        let options = decoded.get_options();
1076        assert_eq!(options.get_max_nonlinear_solver_iterations().unwrap(), 23);
1077        assert_eq!(options.get_max_error_test_failures().unwrap(), 29);
1078        assert_eq!(options.get_update_jacobian_after_steps().unwrap(), 31);
1079        assert_eq!(options.get_update_rhs_jacobian_after_steps().unwrap(), 37);
1080        assert_close(
1081            options.get_threshold_to_update_jacobian().unwrap(),
1082            scalar_value(1e-3, scalar_type),
1083            1e-12,
1084            "serialized threshold_to_update_jacobian",
1085        );
1086        assert_close(
1087            options.get_threshold_to_update_rhs_jacobian().unwrap(),
1088            scalar_value(2e-3, scalar_type),
1089            1e-12,
1090            "serialized threshold_to_update_rhs_jacobian",
1091        );
1092        assert_close(
1093            options.get_min_timestep().unwrap(),
1094            scalar_value(1e-4, scalar_type),
1095            1e-12,
1096            "serialized min_timestep",
1097        );
1098
1099        let y0_after =
1100            Vec::<f64>::from_host_array(decoded.y0(vector_host(&[2.0])).unwrap()).unwrap();
1101        let rhs_after = Vec::<f64>::from_host_array(
1102            decoded
1103                .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
1104                .unwrap(),
1105        )
1106        .unwrap();
1107        assert_eq!(y0_after, y0_before);
1108        assert_close(
1109            rhs_after[0],
1110            rhs_before[0],
1111            ASSERT_TOL,
1112            "serialized rhs matches",
1113        );
1114
1115        decoded
1116            .set_linear_solver(serialized_linear_solver(matrix_type))
1117            .unwrap();
1118        let t_eval = [0.25, 0.5, 1.0];
1119        let solution = decoded
1120            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
1121            .unwrap();
1122        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1123    }
1124
1125    fn assert_runtime_dispatch(jit_backend: JitBackendType, matrix_type: MatrixType) {
1126        let ode = make_ode(
1127            jit_backend,
1128            ScalarType::F64,
1129            matrix_type,
1130            OdeSolverType::Bdf,
1131        );
1132        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
1133        assert_eq!(ode.get_code().unwrap(), logistic_diffsl_code());
1134        assert_eq!(ode.get_nstates().unwrap(), 1);
1135        assert_eq!(ode.get_nparams().unwrap(), 1);
1136        assert_eq!(ode.get_nout().unwrap(), 1);
1137        assert!(!ode.has_stop().unwrap());
1138
1139        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
1140        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
1141
1142        let rhs = ode
1143            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
1144            .unwrap();
1145        assert_close(
1146            Vec::<f64>::from_host_array(rhs).unwrap()[0],
1147            0.375,
1148            ASSERT_TOL,
1149            "jit rhs(0.25)",
1150        );
1151
1152        let rhs_jac_mul = ode
1153            .rhs_jac_mul(
1154                vector_host(&[2.0]),
1155                0.0,
1156                vector_host(&[0.25]),
1157                vector_host(&[3.0]),
1158            )
1159            .unwrap();
1160        assert_close(
1161            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
1162            3.0,
1163            ASSERT_TOL,
1164            "jit rhs_jac_mul(0.25, 3.0)",
1165        );
1166    }
1167
1168    fn assert_solver_dense_solution(
1169        jit_backend: JitBackendType,
1170        scalar_type: ScalarType,
1171        matrix_type: MatrixType,
1172        ode_solver: OdeSolverType,
1173    ) {
1174        let ode = make_ode(jit_backend, scalar_type, matrix_type, ode_solver);
1175        ode.set_rtol(1e-8).unwrap();
1176        ode.set_atol(1e-8).unwrap();
1177
1178        let t_eval = [0.25, 0.5, 1.0];
1179        let solution = ode
1180            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
1181            .unwrap();
1182
1183        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1184    }
1185
1186    fn hybrid_t_eval() -> [f64; 7] {
1187        [0.5, 1.0, 2.0, 2.5, 3.0, 4.0, 4.8]
1188    }
1189
1190    fn assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
1191        jit_backend: JitBackendType,
1192        ode_solver: OdeSolverType,
1193    ) {
1194        let r = 2.0;
1195        let final_time = 5.0;
1196        let tau = hybrid_logistic_period(r);
1197        let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
1198        ode.set_rtol(1e-8).unwrap();
1199        ode.set_atol(1e-8).unwrap();
1200        assert_eq!(ode.get_nstates().unwrap(), 1);
1201        assert_eq!(ode.get_nparams().unwrap(), 1);
1202        assert_eq!(ode.get_nout().unwrap(), 1);
1203        assert!(ode.has_stop().unwrap());
1204
1205        let solution = ode.solve_hybrid(vector_host(&[r]), final_time).unwrap();
1206        let ys = solution.get_ys().unwrap();
1207        let ys = ys.as_array::<f64>().unwrap();
1208        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
1209
1210        assert_eq!(ys.nrows(), 1);
1211        assert_eq!(ys.ncols(), ts.len());
1212        assert!(!ts.is_empty(), "expected hybrid solve to produce output");
1213        assert_close(
1214            *ts.last().unwrap(),
1215            final_time,
1216            ASSERT_TOL,
1217            "jit hybrid final time",
1218        );
1219        assert_close(
1220            ys[(0, ys.ncols() - 1)],
1221            hybrid_logistic_state(r, final_time),
1222            5e-4,
1223            "jit hybrid final value",
1224        );
1225        assert!(ts.iter().any(|&t| (t - tau).abs() < 1e-3));
1226        assert!(ts.iter().any(|&t| (t - 2.0 * tau).abs() < 1e-3));
1227        for (col, &t) in ts.iter().enumerate() {
1228            if ((t / tau).round() * tau - t).abs() < 1e-3 {
1229                continue;
1230            }
1231            assert_close(
1232                ys[(0, col)],
1233                hybrid_logistic_state(r, t),
1234                5e-4,
1235                &format!("jit hybrid value[{col}]"),
1236            );
1237        }
1238    }
1239
1240    fn assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
1241        jit_backend: JitBackendType,
1242        ode_solver: OdeSolverType,
1243    ) {
1244        let r = 2.0;
1245        let t_eval = hybrid_t_eval();
1246        let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
1247        ode.set_rtol(1e-8).unwrap();
1248        ode.set_atol(1e-8).unwrap();
1249
1250        let solution = ode
1251            .solve_hybrid_dense(vector_host(&[r]), vector_host(&t_eval))
1252            .unwrap();
1253        let ys = solution.get_ys().unwrap();
1254        let ys = ys.as_array::<f64>().unwrap();
1255        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
1256
1257        assert_eq!(ys.nrows(), 1);
1258        assert_eq!(ys.ncols(), t_eval.len());
1259        assert_eq!(ts, t_eval);
1260        for (col, &t) in t_eval.iter().enumerate() {
1261            assert_close(
1262                ys[(0, col)],
1263                hybrid_logistic_state(r, t),
1264                5e-4,
1265                &format!("jit hybrid dense value[{col}]"),
1266            );
1267        }
1268    }
1269
1270    #[cfg(feature = "diffsl-llvm")]
1271    fn assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(
1272        ode_solver: OdeSolverType,
1273    ) {
1274        let r = 2.0;
1275        let t_eval = hybrid_t_eval();
1276        let ode = make_hybrid_ode(JitBackendType::Llvm, MatrixType::NalgebraDense, ode_solver);
1277        ode.set_rtol(1e-8).unwrap();
1278        ode.set_atol(1e-8).unwrap();
1279
1280        let solution = ode
1281            .solve_hybrid_fwd_sens(vector_host(&[r]), vector_host(&t_eval))
1282            .unwrap();
1283        let ys = solution.get_ys().unwrap();
1284        let ys = ys.as_array::<f64>().unwrap();
1285        let sens = solution.get_sens().unwrap();
1286
1287        assert_eq!(ys.nrows(), 1);
1288        assert_eq!(ys.ncols(), t_eval.len());
1289        assert_eq!(sens.len(), 1);
1290        let sens_values = sens[0].as_array::<f64>().unwrap();
1291        assert_eq!(sens_values.nrows(), 1);
1292        assert_eq!(sens_values.ncols(), t_eval.len());
1293        for (col, &t) in t_eval.iter().enumerate() {
1294            assert_close(
1295                ys[(0, col)],
1296                hybrid_logistic_state(r, t),
1297                5e-4,
1298                &format!("jit hybrid sens value[{col}]"),
1299            );
1300            assert_close(
1301                sens_values[(0, col)],
1302                hybrid_logistic_state_dr(r, t),
1303                5e-4,
1304                &format!("jit hybrid sensitivity[{col}]"),
1305            );
1306        }
1307    }
1308
1309    #[test]
1310    fn runtime_dispatch_matches_requested_matrix_type_from_diffsl() {
1311        for jit_backend in available_jit_backends() {
1312            for matrix_type in [
1313                MatrixType::NalgebraDense,
1314                MatrixType::FaerDense,
1315                MatrixType::FaerSparse,
1316            ] {
1317                assert_runtime_dispatch(jit_backend, matrix_type);
1318            }
1319        }
1320    }
1321
1322    #[test]
1323    fn dense_solution_matches_logistic_solution_from_diffsl() {
1324        for jit_backend in available_jit_backends() {
1325            for scalar_type in [ScalarType::F64, ScalarType::F32] {
1326                for (matrix_type, solver) in [
1327                    (MatrixType::FaerDense, OdeSolverType::Esdirk34),
1328                    (MatrixType::FaerSparse, OdeSolverType::TrBdf2),
1329                    (MatrixType::NalgebraDense, OdeSolverType::Tsit45),
1330                ] {
1331                    assert_solver_dense_solution(jit_backend, scalar_type, matrix_type, solver);
1332                }
1333            }
1334        }
1335    }
1336
1337    #[test]
1338    fn bdf_dense_solution_matches_logistic_diffsl_model() {
1339        for jit_backend in available_jit_backends() {
1340            let ode = make_ode(
1341                jit_backend,
1342                ScalarType::F64,
1343                MatrixType::NalgebraDense,
1344                OdeSolverType::Bdf,
1345            );
1346            ode.set_rtol(1e-8).unwrap();
1347            ode.set_atol(1e-8).unwrap();
1348
1349            let t_eval = [0.25, 0.5, 1.0];
1350            let solution = ode
1351                .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
1352                .unwrap();
1353
1354            assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1355        }
1356    }
1357
1358    #[test]
1359    fn bdf_solution_matches_logistic_diffsl_model() {
1360        for jit_backend in available_jit_backends() {
1361            let x0 = LOGISTIC_X0;
1362            let r = 2.0;
1363            let ode = make_ode(
1364                jit_backend,
1365                ScalarType::F64,
1366                MatrixType::NalgebraDense,
1367                OdeSolverType::Bdf,
1368            );
1369            ode.set_rtol(1e-8).unwrap();
1370            ode.set_atol(1e-8).unwrap();
1371
1372            let final_time = 1.0;
1373            let solution = ode.solve(vector_host(&[r]), final_time).unwrap();
1374
1375            let ys = solution.get_ys().unwrap();
1376            let ys = ys.as_array::<f64>().unwrap();
1377            let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
1378
1379            assert_eq!(ys.nrows(), 1);
1380            assert_eq!(ys.ncols(), ts.len());
1381            assert!(
1382                !ts.is_empty(),
1383                "expected solve() to record at least one time point"
1384            );
1385            assert_close(
1386                *ts.last().unwrap(),
1387                final_time,
1388                ASSERT_TOL,
1389                "solve final time",
1390            );
1391            for (i, &t) in ts.iter().enumerate() {
1392                assert_close(
1393                    ys[(0, i)],
1394                    logistic_state(x0, r, t),
1395                    5e-4,
1396                    &format!("solve value[{i}]"),
1397                );
1398            }
1399        }
1400    }
1401
1402    #[cfg_attr(
1403        all(target_os = "macos", target_arch = "x86_64"),
1404        ignore = "from_external_object is unsupported on Intel macOS due to unsupported relocations"
1405    )]
1406    #[test]
1407    fn serialization_roundtrip_restores_full_solver_state() {
1408        for jit_backend in available_jit_backends() {
1409            for scalar_type in [ScalarType::F64, ScalarType::F32] {
1410                for matrix_type in [MatrixType::NalgebraDense, MatrixType::FaerSparse] {
1411                    assert_serialization_roundtrip(jit_backend, scalar_type, matrix_type);
1412                }
1413            }
1414        }
1415    }
1416
1417    #[cfg(all(feature = "diffsl-llvm", not(feature = "diffsl-cranelift")))]
1418    #[test]
1419    fn deserialization_rejects_unavailable_jit_backend() {
1420        let ode = make_ode(
1421            JitBackendType::Llvm,
1422            ScalarType::F64,
1423            MatrixType::NalgebraDense,
1424            OdeSolverType::Bdf,
1425        );
1426        let mut value = serde_json::to_value(&ode).unwrap();
1427        value["jit_backend"] = Value::String("cranelift".to_string());
1428        let err = serde_json::from_value::<OdeWrapper>(value)
1429            .err()
1430            .unwrap()
1431            .to_string();
1432        assert!(err.contains("unknown variant"));
1433    }
1434
1435    #[cfg(all(feature = "diffsl-cranelift", not(feature = "diffsl-llvm")))]
1436    #[test]
1437    fn deserialization_rejects_unavailable_jit_backend() {
1438        let ode = make_ode(
1439            JitBackendType::Cranelift,
1440            ScalarType::F64,
1441            MatrixType::NalgebraDense,
1442            OdeSolverType::Bdf,
1443        );
1444        let mut value = serde_json::to_value(&ode).unwrap();
1445        value["jit_backend"] = Value::String("llvm".to_string());
1446        let err = serde_json::from_value::<OdeWrapper>(value)
1447            .err()
1448            .unwrap()
1449            .to_string();
1450        assert!(err.contains("unknown variant"));
1451    }
1452
1453    #[test]
1454    fn hybrid_solution_matches_piecewise_logistic_diffsl_model() {
1455        for jit_backend in available_jit_backends() {
1456            for ode_solver in all_ode_solvers() {
1457                assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
1458                    jit_backend,
1459                    ode_solver,
1460                );
1461            }
1462        }
1463    }
1464
1465    #[test]
1466    fn hybrid_dense_solution_matches_piecewise_logistic_diffsl_model() {
1467        for jit_backend in available_jit_backends() {
1468            for ode_solver in all_ode_solvers() {
1469                assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
1470                    jit_backend,
1471                    ode_solver,
1472                );
1473            }
1474        }
1475    }
1476
1477    #[cfg(feature = "diffsl-llvm")]
1478    #[test]
1479    fn bdf_forward_sensitivities_match_logistic_derivative_from_diffsl() {
1480        let ode = make_ode(
1481            JitBackendType::Llvm,
1482            ScalarType::F64,
1483            MatrixType::NalgebraDense,
1484            OdeSolverType::Bdf,
1485        );
1486        ode.set_rtol(1e-8).unwrap();
1487        ode.set_atol(1e-8).unwrap();
1488
1489        let t_eval = [0.25, 0.5, 1.0];
1490        let solution = ode
1491            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
1492            .unwrap();
1493
1494        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1495        let sens = solution.get_sens().unwrap();
1496        assert_eq!(sens.len(), 1);
1497        let sens_values = sens[0].as_array::<f64>().unwrap();
1498        assert_eq!(sens_values.nrows(), 1);
1499        assert_eq!(sens_values.ncols(), t_eval.len());
1500        for (i, &t) in t_eval.iter().enumerate() {
1501            assert_close(
1502                sens_values[(0, i)],
1503                logistic_state_dr(LOGISTIC_X0, 2.0, t),
1504                ASSERT_TOL,
1505                &format!("jit sensitivity[{i}]"),
1506            );
1507        }
1508    }
1509
1510    #[cfg(feature = "diffsl-llvm")]
1511    #[test]
1512    fn hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model() {
1513        for ode_solver in all_ode_solvers() {
1514            assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(ode_solver);
1515        }
1516    }
1517
1518    #[cfg(feature = "diffsl-llvm")]
1519    #[test]
1520    fn bdf_sum_squares_adjoint_matches_logistic_diffsl_model() {
1521        let ode = make_ode(
1522            JitBackendType::Llvm,
1523            ScalarType::F64,
1524            MatrixType::NalgebraDense,
1525            OdeSolverType::Bdf,
1526        );
1527        ode.set_rtol(1e-8).unwrap();
1528        ode.set_atol(1e-8).unwrap();
1529
1530        let t_eval = [0.0, 0.25, 0.5, 1.0];
1531        let data_values: Vec<f64> = t_eval
1532            .iter()
1533            .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1534            .collect();
1535        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1536        let (value, sens) = ode
1537            .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
1538            .unwrap();
1539        let grad = Vec::<f64>::from_host_array(sens).unwrap();
1540
1541        assert_close(value, 0.0, ASSERT_TOL, "jit sum_squares objective");
1542        assert_eq!(grad.len(), 1);
1543        assert!(
1544            grad[0].is_finite(),
1545            "jit sum_squares gradient should be finite"
1546        );
1547    }
1548
1549    #[cfg(feature = "diffsl-llvm")]
1550    #[test]
1551    fn bdf_sum_squares_adjoint_matches_finite_difference_gradient_for_logistic_model() {
1552        let logistic_model = r#"
1553            in_i { r = 1, k = 1, y0 = 0.1 }
1554            u { y0 }
1555            F { r * u * (1.0 - u / k) }
1556        "#;
1557        let ode = OdeWrapper::new_jit(
1558            logistic_model,
1559            JitBackendType::Llvm,
1560            ScalarType::F64,
1561            MatrixType::NalgebraDense,
1562            LinearSolverType::Default,
1563            OdeSolverType::Bdf,
1564        )
1565        .unwrap();
1566        ode.set_rtol(1e-8).unwrap();
1567        ode.set_atol(1e-8).unwrap();
1568
1569        let t_eval = [0.0, 0.1, 0.3, 0.7, 1.0];
1570        let data_params = [1.2, 0.9, 0.2];
1571        let fit_params = [0.8, 1.3, 0.12];
1572        let fd_step = 1e-6;
1573
1574        let data_solution = ode
1575            .solve_dense(vector_host(&data_params), vector_host(&t_eval))
1576            .unwrap();
1577        let data_ys = data_solution.get_ys().unwrap();
1578        let data_ys = data_ys.as_array::<f64>().unwrap();
1579        let data_values: Vec<f64> = (0..t_eval.len()).map(|col| data_ys[(0, col)]).collect();
1580
1581        let objective_from_dense = |params: [f64; 3]| -> f64 {
1582            let solution = ode
1583                .solve_dense(vector_host(&params), vector_host(&t_eval))
1584                .unwrap();
1585            let ys = solution.get_ys().unwrap();
1586            let ys = ys.as_array::<f64>().unwrap();
1587            (0..t_eval.len())
1588                .map(|col| {
1589                    let residual = ys[(0, col)] - data_values[col];
1590                    residual * residual
1591                })
1592                .sum()
1593        };
1594
1595        let objective_fd = objective_from_dense(fit_params);
1596        let mut finite_difference_gradient = [0.0; 3];
1597        for i in 0..fit_params.len() {
1598            let mut plus = fit_params;
1599            let mut minus = fit_params;
1600            let step = fd_step * fit_params[i].abs().max(1.0);
1601            plus[i] += step;
1602            minus[i] -= step;
1603            finite_difference_gradient[i] =
1604                (objective_from_dense(plus) - objective_from_dense(minus)) / (2.0 * step);
1605        }
1606
1607        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1608        let ode_adj = OdeWrapper::new_jit(
1609            logistic_model,
1610            JitBackendType::Llvm,
1611            ScalarType::F64,
1612            MatrixType::NalgebraDense,
1613            LinearSolverType::Default,
1614            OdeSolverType::Bdf,
1615        )
1616        .unwrap();
1617        ode_adj.set_rtol(1e-8).unwrap();
1618        ode_adj.set_atol(1e-8).unwrap();
1619
1620        let (objective_adj, sens) = ode_adj
1621            .solve_sum_squares_adj(vector_host(&fit_params), data, vector_host(&t_eval))
1622            .unwrap();
1623        let adjoint_gradient = Vec::<f64>::from_host_array(sens).unwrap();
1624
1625        assert_eq!(adjoint_gradient.len(), 3);
1626        assert_close(
1627            objective_adj,
1628            objective_fd,
1629            1e-5,
1630            "sum_squares objective from dense finite differences",
1631        );
1632        for i in 0..adjoint_gradient.len() {
1633            assert_close(
1634                adjoint_gradient[i],
1635                finite_difference_gradient[i],
1636                5e-4,
1637                &format!("sum_squares gradient component {i}"),
1638            );
1639        }
1640    }
1641}