Skip to main content

diffsol_c/
ode.rs

1use std::sync::{Arc, Mutex};
2
3use crate::jit::JitBackendType;
4use crate::{
5    error::DiffsolJsError, host_array::HostArray,
6    initial_condition_options::InitialConditionSolverOptions, linear_solver_type::LinearSolverType,
7    matrix_type::MatrixType, ode_options::OdeSolverOptions, ode_solver_type::OdeSolverType,
8    scalar_type::ScalarType, solution_wrapper::SolutionWrapper, solve::Solve,
9};
10
11pub struct Ode {
12    pub(crate) solve: Box<dyn Solve>,
13    code: String,
14    scalar_type: ScalarType,
15    jit_backend: Option<JitBackendType>,
16    linear_solver: LinearSolverType,
17    ode_solver: OdeSolverType,
18}
19
20unsafe impl Send for Ode {}
21unsafe impl Sync for Ode {}
22
23#[derive(Clone)]
24pub struct OdeWrapper(Arc<Mutex<Ode>>);
25
26impl OdeWrapper {
27    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolJsError> {
28        self.0.lock().map_err(|_| {
29            DiffsolJsError::from(diffsol::error::DiffsolError::Other(
30                "Failed to acquire lock on ODE solver".to_string(),
31            ))
32        })
33    }
34}
35
36impl OdeWrapper {
37    fn build(
38        code: String,
39        scalar_type: ScalarType,
40        solve: Box<dyn Solve>,
41        jit_backend: Option<JitBackendType>,
42        linear_solver: LinearSolverType,
43        ode_solver: OdeSolverType,
44    ) -> Result<Self, DiffsolJsError> {
45        solve.check(linear_solver)?;
46        Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
47            code,
48            scalar_type,
49            solve,
50            jit_backend,
51            linear_solver,
52            ode_solver,
53        }))))
54    }
55
56    /// Construct an ODE solver backed by externally-provided DiffSL symbols.
57    #[cfg(feature = "external")]
58    pub fn new_external(
59        rhs_state_deps: Vec<(usize, usize)>,
60        rhs_input_deps: Vec<(usize, usize)>,
61        mass_state_deps: Vec<(usize, usize)>,
62        scalar_type: ScalarType,
63        matrix_type: MatrixType,
64        linear_solver: LinearSolverType,
65        ode_solver: OdeSolverType,
66    ) -> Result<Self, DiffsolJsError> {
67        let solve = crate::solve::solve_factory_external(
68            rhs_state_deps,
69            rhs_input_deps,
70            mass_state_deps,
71            matrix_type,
72            scalar_type,
73        )?;
74        Self::build(
75            String::new(),
76            scalar_type,
77            solve,
78            None,
79            linear_solver,
80            ode_solver,
81        )
82    }
83
84    /// Construct an ODE solver by JIT-compiling DiffSL code immediately.
85    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
86    pub fn new_jit(
87        code: &str,
88        jit_backend: JitBackendType,
89        scalar_type: ScalarType,
90        matrix_type: MatrixType,
91        linear_solver: LinearSolverType,
92        ode_solver: OdeSolverType,
93    ) -> Result<Self, DiffsolJsError> {
94        let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
95        Self::build(
96            code.to_owned(),
97            scalar_type,
98            solve,
99            Some(jit_backend),
100            linear_solver,
101            ode_solver,
102        )
103    }
104
105    /// Matrix type used in the ODE solver. This is fixed after construction.
106    pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolJsError> {
107        Ok(self.guard()?.solve.matrix_type())
108    }
109
110    /// Ode solver method, default Bdf (backward differentiation formula).
111    pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolJsError> {
112        Ok(self.guard()?.ode_solver)
113    }
114
115    pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolJsError> {
116        self.guard()?.ode_solver = value;
117        Ok(())
118    }
119
120    /// Linear solver type used in the ODE solver. Set to default to use the
121    /// solver's default choice, which is typically an LU solver.
122    pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolJsError> {
123        Ok(self.guard()?.linear_solver)
124    }
125
126    pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolJsError> {
127        self.guard()?.solve.check(value)?;
128        self.guard()?.linear_solver = value;
129        Ok(())
130    }
131
132    /// Relative tolerance for the solver, default 1e-6. Governs the error relative to the solution size.
133    pub fn get_rtol(&self) -> Result<f64, DiffsolJsError> {
134        Ok(self.guard()?.solve.rtol())
135    }
136
137    pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolJsError> {
138        self.guard()?.solve.set_rtol(value);
139        Ok(())
140    }
141
142    /// Absolute tolerance for the solver, default 1e-6. Governs the error as the solution goes to zero.
143    pub fn get_atol(&self) -> Result<f64, DiffsolJsError> {
144        Ok(self.guard()?.solve.atol())
145    }
146
147    pub fn set_atol(&self, value: f64) -> Result<(), DiffsolJsError> {
148        self.guard()?.solve.set_atol(value);
149        Ok(())
150    }
151
152    pub fn get_code(&self) -> Result<String, DiffsolJsError> {
153        Ok(self.guard()?.code.clone())
154    }
155
156    pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolJsError> {
157        Ok(self.guard()?.scalar_type)
158    }
159
160    pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolJsError> {
161        Ok(self.guard()?.jit_backend)
162    }
163
164    pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
165        InitialConditionSolverOptions::new(self.0.clone())
166    }
167
168    pub fn get_options(&self) -> OdeSolverOptions {
169        OdeSolverOptions::new(self.0.clone())
170    }
171
172    /// Get the initial condition vector y0 as a 1D numpy array.
173    pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolJsError> {
174        let mut self_guard = self.guard()?;
175        self_guard.solve.y0(params.as_slice()?)
176    }
177
178    /// evaluate the right-hand side function at time `t` and state `y`.
179    pub fn rhs(
180        &self,
181        params: HostArray,
182        t: f64,
183        y: HostArray,
184    ) -> Result<HostArray, DiffsolJsError> {
185        let mut self_guard = self.guard()?;
186        self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
187    }
188
189    /// evaluate the right-hand side Jacobian-vector product `Jv`` at time `t` and state `y`.
190    pub fn rhs_jac_mul(
191        &self,
192        params: HostArray,
193        t: f64,
194        y: HostArray,
195        v: HostArray,
196    ) -> Result<HostArray, DiffsolJsError> {
197        let mut self_guard = self.guard()?;
198        self_guard
199            .solve
200            .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
201    }
202
203    /// Using the provided state, solve the problem up to time `final_time`.
204    ///
205    /// The number of params must match the expected params in the diffsl code.
206    /// If specified, the config can be used to override the solver method
207    /// (Bdf by default) and SolverType (Lu by default) along with other solver
208    /// params like `rtol`.
209    ///
210    /// :param params: 1D array of solver parameters
211    /// :type params: numpy.ndarray
212    /// :param final_time: end time of solver
213    /// :type final_time: float
214    /// :return: `(ys, ts)` tuple where `ys` is a 2D array of values at times
215    ///     `ts` chosen by the solver
216    /// :rtype: Tuple[numpy.ndarray, numpy.ndarray]
217    ///
218    /// Example:
219    ///     >>> print(ode.solve(np.array([]), 0.5))
220    #[allow(clippy::type_complexity)]
221    pub fn solve(
222        &self,
223        params: HostArray,
224        final_time: f64,
225    ) -> Result<SolutionWrapper, DiffsolJsError> {
226        let mut self_guard = self.guard()?;
227        let params = params.as_slice()?;
228        let linear_solver = self_guard.linear_solver;
229        let method = self_guard.ode_solver;
230        let solution = self_guard
231            .solve
232            .solve(method, linear_solver, params, final_time)?;
233        Ok(SolutionWrapper::new(solution))
234    }
235
236    /// Solve a hybrid ODE up to `final_time`, automatically applying reset
237    /// functions and continuing after root events until the solution completes.
238    pub fn solve_hybrid(
239        &self,
240        params: HostArray,
241        final_time: f64,
242    ) -> Result<SolutionWrapper, DiffsolJsError> {
243        let mut self_guard = self.guard()?;
244        let params = params.as_slice()?;
245        let linear_solver = self_guard.linear_solver;
246        let method = self_guard.ode_solver;
247        let solution = self_guard
248            .solve
249            .solve_hybrid(method, linear_solver, params, final_time)?;
250        Ok(SolutionWrapper::new(solution))
251    }
252
253    /// Using the provided state, solve the problem up to time
254    /// `t_eval[t_eval.len()-1]`. Returns 2D array of solution values at
255    /// timepoints given by `t_eval`.
256    ///
257    /// The number of params must match the expected params in the diffsl code.
258    /// The config may be optionally specified to override solver settings.
259    ///
260    /// :param params: 1D array of solver parameters
261    /// :type params: numpy.ndarray
262    /// :param t_eval: 1D array of solver times
263    /// :type params: numpy.ndarray
264    /// :return: 2D array of values at times `t_eval`
265    /// :rtype: numpy.ndarray
266    pub fn solve_dense(
267        &self,
268        params: HostArray,
269        t_eval: HostArray,
270    ) -> Result<SolutionWrapper, DiffsolJsError> {
271        let mut self_guard = self.guard()?;
272        let params = params.as_slice()?;
273        let t_eval = t_eval.as_slice()?;
274        let linear_solver = self_guard.linear_solver;
275        let method = self_guard.ode_solver;
276        let solution = self_guard
277            .solve
278            .solve_dense(method, linear_solver, params, t_eval)?;
279        Ok(SolutionWrapper::new(solution))
280    }
281
282    /// Solve a hybrid ODE at dense evaluation times, automatically applying
283    /// reset functions and continuing after root events until all requested
284    /// output points are filled.
285    pub fn solve_hybrid_dense(
286        &self,
287        params: HostArray,
288        t_eval: HostArray,
289    ) -> Result<SolutionWrapper, DiffsolJsError> {
290        let mut self_guard = self.guard()?;
291        let params = params.as_slice()?;
292        let t_eval = t_eval.as_slice()?;
293        let linear_solver = self_guard.linear_solver;
294        let method = self_guard.ode_solver;
295        let solution =
296            self_guard
297                .solve
298                .solve_hybrid_dense(method, linear_solver, params, t_eval)?;
299        Ok(SolutionWrapper::new(solution))
300    }
301
302    /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`.
303    /// Returns 2D array of solution values at timepoints given by `t_eval`.
304    /// Also returns a list of 2D arrays of sensitivities at the same timepoints
305    /// as the solution.
306    /// The number of params must match the expected params in the diffsl code.
307    /// The config may be optionally specified to override solver settings.
308    /// :param params: 1D array of solver parameters
309    /// :type params: numpy.ndarray
310    /// :param t_eval: 1D array of solver times
311    /// :type params: numpy.ndarray
312    /// :return: 2D array of values at times `t_eval` and a list of 2D arrays of sensitivities at the same timepoints
313    /// :rtype: (numpy.ndarray, List[numpy.ndarray])
314    #[allow(clippy::type_complexity)]
315    pub fn solve_fwd_sens(
316        &self,
317        params: HostArray,
318        t_eval: HostArray,
319    ) -> Result<SolutionWrapper, DiffsolJsError> {
320        let mut self_guard = self.guard()?;
321        let params = params.as_slice()?;
322        let t_eval = t_eval.as_slice()?;
323        let linear_solver = self_guard.linear_solver;
324        let method = self_guard.ode_solver;
325        let solution = self_guard
326            .solve
327            .solve_fwd_sens(method, linear_solver, params, t_eval)?;
328        Ok(SolutionWrapper::new(solution))
329    }
330
331    /// Solve a hybrid ODE with forward sensitivities at dense evaluation times,
332    /// automatically applying sensitivity-aware reset functions and continuing
333    /// after root events until all requested output points are filled.
334    #[allow(clippy::type_complexity)]
335    pub fn solve_hybrid_fwd_sens(
336        &self,
337        params: HostArray,
338        t_eval: HostArray,
339    ) -> Result<SolutionWrapper, DiffsolJsError> {
340        let mut self_guard = self.guard()?;
341        let params = params.as_slice()?;
342        let t_eval = t_eval.as_slice()?;
343        let linear_solver = self_guard.linear_solver;
344        let method = self_guard.ode_solver;
345        let solution =
346            self_guard
347                .solve
348                .solve_hybrid_fwd_sens(method, linear_solver, params, t_eval)?;
349        Ok(SolutionWrapper::new(solution))
350    }
351
352    /// Using the provided state, solve the adjoint problem for the sum of squares
353    /// objective given data at timepoints `t_eval`.
354    /// Returns the objective value and a list of 1D arrays of adjoint sensitivities
355    /// for each parameter.
356    #[allow(clippy::type_complexity)]
357    pub(crate) fn solve_sum_squares_adj(
358        &self,
359        params: HostArray,
360        data: HostArray,
361        t_eval: HostArray,
362    ) -> Result<(f64, HostArray), DiffsolJsError> {
363        let mut self_guard = self.guard()?;
364        let linear_solver = self_guard.linear_solver;
365        let ode_solver = self_guard.ode_solver;
366
367        self_guard.solve.solve_sum_squares_adj(
368            ode_solver,
369            linear_solver,
370            ode_solver,
371            linear_solver,
372            params.as_slice()?,
373            data,
374            t_eval.as_slice()?,
375        )
376    }
377}
378
379#[cfg(all(test, feature = "diffsl-external-f64"))]
380mod tests {
381    use crate::host_array::FromHostArray;
382    use crate::linear_solver_type::LinearSolverType;
383    use crate::scalar_type::ScalarType;
384    use crate::test_support::{
385        ASSERT_TOL, LOGISTIC_X0, assert_close, assert_solution_tail, logistic_integral,
386        logistic_state, logistic_state_dr, mass_state_deps, rhs_input_deps, rhs_state_deps,
387        vector_host,
388    };
389
390    use super::*;
391
392    fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
393        OdeWrapper::new_external(
394            rhs_state_deps(),
395            rhs_input_deps(),
396            mass_state_deps(),
397            ScalarType::F64,
398            matrix_type,
399            LinearSolverType::Default,
400            ode_solver,
401        )
402        .unwrap()
403    }
404
405    fn assert_runtime_dispatch(matrix_type: MatrixType) {
406        let ode = make_ode(matrix_type, OdeSolverType::Bdf);
407        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
408
409        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
410        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
411
412        let rhs = ode
413            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
414            .unwrap();
415        assert_close(
416            Vec::<f64>::from_host_array(rhs).unwrap()[0],
417            0.375,
418            ASSERT_TOL,
419            "rhs(0.25)",
420        );
421
422        let rhs_jac_mul = ode
423            .rhs_jac_mul(
424                vector_host(&[2.0]),
425                0.0,
426                vector_host(&[0.25]),
427                vector_host(&[3.0]),
428            )
429            .unwrap();
430        assert_close(
431            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
432            3.0,
433            ASSERT_TOL,
434            "rhs_jac_mul(0.25, 3.0)",
435        );
436    }
437
438    fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
439        let ode = make_ode(matrix_type, ode_solver);
440        ode.set_rtol(1e-8).unwrap();
441        ode.set_atol(1e-8).unwrap();
442
443        let t_eval = [0.25, 0.5, 1.0];
444        let solution = ode
445            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
446            .unwrap();
447
448        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
449    }
450
451    fn hybrid_root_time() -> f64 {
452        0.5 * 9.0_f64.ln()
453    }
454
455    #[test]
456    fn runtime_dispatch_matches_requested_matrix_type() {
457        for matrix_type in [
458            MatrixType::NalgebraDense,
459            MatrixType::FaerDense,
460            MatrixType::FaerSparse,
461        ] {
462            assert_runtime_dispatch(matrix_type);
463        }
464    }
465
466    #[test]
467    fn bdf_dense_solution_matches_logistic_solution() {
468        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
469        ode.set_rtol(1e-8).unwrap();
470        ode.set_atol(1e-8).unwrap();
471
472        let t_eval = [0.25, 0.5, 1.0];
473        let solution = ode
474            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
475            .unwrap();
476
477        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
478    }
479
480    #[test]
481    fn esdirk34_dense_solution_matches_logistic_solution() {
482        assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
483    }
484
485    #[test]
486    fn tr_bdf2_sparse_solution_matches_logistic_solution() {
487        assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
488    }
489
490    #[test]
491    fn tsit45_dense_solution_matches_logistic_solution() {
492        assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
493    }
494
495    #[test]
496    fn bdf_forward_sensitivities_match_logistic_derivative() {
497        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
498        ode.set_rtol(1e-8).unwrap();
499        ode.set_atol(1e-8).unwrap();
500
501        let t_eval = [0.25, 0.5, 1.0];
502        let solution = ode
503            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
504            .unwrap();
505
506        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
507        let sens = solution.get_sens().unwrap();
508        assert_eq!(sens.len(), 1);
509        let sens_values = sens[0].as_array::<f64>().unwrap();
510        assert_eq!(sens_values.nrows(), 1);
511        assert_eq!(sens_values.ncols(), t_eval.len());
512        for (i, &t) in t_eval.iter().enumerate() {
513            assert_close(
514                sens_values[(0, i)],
515                logistic_state_dr(LOGISTIC_X0, 2.0, t),
516                ASSERT_TOL,
517                &format!("sensitivity[{i}]"),
518            );
519        }
520    }
521
522    #[test]
523    fn bdf_sum_squares_adjoint_matches_external_logistic_model() {
524        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
525        ode.set_rtol(1e-8).unwrap();
526        ode.set_atol(1e-8).unwrap();
527
528        let t_eval = [0.0, 0.25, 0.5, 1.0];
529        let data_values: Vec<f64> = t_eval
530            .iter()
531            .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
532            .collect();
533        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
534        let (value, sens) = ode
535            .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
536            .unwrap();
537        let grad = Vec::<f64>::from_host_array(sens).unwrap();
538
539        assert_close(value, 0.0, ASSERT_TOL, "sum_squares objective");
540        assert_eq!(grad.len(), 1);
541        assert_close(grad[0], 0.0, ASSERT_TOL, "sum_squares gradient");
542    }
543
544    #[test]
545    fn bdf_hybrid_solution_applies_reset_after_root() {
546        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
547        ode.set_rtol(1e-8).unwrap();
548        ode.set_atol(1e-8).unwrap();
549
550        let final_time = 2.0;
551        let solution = ode.solve_hybrid(vector_host(&[2.0]), final_time).unwrap();
552        let ys = solution.get_ys().unwrap();
553        let ys = ys.as_array::<f64>().unwrap();
554        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
555        let root_time = hybrid_root_time();
556
557        assert_eq!(ys.nrows(), 1);
558        assert_eq!(ys.ncols(), ts.len());
559        assert!(!ts.is_empty(), "expected hybrid solve to produce output");
560        assert_close(
561            *ts.last().unwrap(),
562            final_time,
563            ASSERT_TOL,
564            "hybrid final time",
565        );
566        assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
567        assert!(
568            ts.iter().any(|&t| t < root_time),
569            "expected pre-root samples"
570        );
571        assert!(
572            ts.iter().any(|&t| t > root_time),
573            "expected post-root samples after reset"
574        );
575    }
576
577    #[test]
578    fn bdf_hybrid_dense_solution_continues_after_reset() {
579        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
580        ode.set_rtol(1e-8).unwrap();
581        ode.set_atol(1e-8).unwrap();
582
583        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
584        let solution = ode
585            .solve_hybrid_dense(vector_host(&[2.0]), vector_host(&t_eval))
586            .unwrap();
587        let ys = solution.get_ys().unwrap();
588        let ys = ys.as_array::<f64>().unwrap();
589
590        assert_eq!(ys.nrows(), 1);
591        assert_eq!(ys.ncols(), t_eval.len());
592        assert_close(
593            ys[(0, 0)],
594            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
595            5e-4,
596            "hybrid dense pre-root value",
597        );
598        assert_close(
599            ys[(0, 1)],
600            logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
601            5e-4,
602            "hybrid dense near-root value",
603        );
604        for col in 2..t_eval.len() {
605            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
606        }
607    }
608
609    #[test]
610    fn bdf_hybrid_forward_sensitivities_complete_across_reset() {
611        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
612        ode.set_rtol(1e-8).unwrap();
613        ode.set_atol(1e-8).unwrap();
614
615        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
616        let solution = ode
617            .solve_hybrid_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
618            .unwrap();
619        let ys = solution.get_ys().unwrap();
620        let ys = ys.as_array::<f64>().unwrap();
621        let sens = solution.get_sens().unwrap();
622
623        assert_eq!(ys.nrows(), 1);
624        assert_eq!(ys.ncols(), t_eval.len());
625        assert_eq!(sens.len(), 1);
626        let sens_values = sens[0].as_array::<f64>().unwrap();
627        assert_eq!(sens_values.nrows(), 1);
628        assert_eq!(sens_values.ncols(), t_eval.len());
629        assert_close(
630            ys[(0, 0)],
631            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
632            5e-4,
633            "hybrid sens pre-root value",
634        );
635        for col in 2..t_eval.len() {
636            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
637            assert!(
638                sens_values[(0, col)].is_finite(),
639                "expected finite post-root sensitivity at column {col}"
640            );
641        }
642    }
643}
644
645#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
646mod jit_tests {
647    use crate::host_array::FromHostArray;
648    use crate::jit::JitBackendType;
649    use crate::linear_solver_type::LinearSolverType;
650    use crate::scalar_type::ScalarType;
651    use crate::test_support::{
652        ASSERT_TOL, LOGISTIC_X0, assert_close, assert_solution_tail, available_jit_backends,
653        hybrid_logistic_diffsl_code, hybrid_logistic_period, hybrid_logistic_state,
654        hybrid_logistic_state_dr, logistic_diffsl_code, logistic_state, vector_host,
655    };
656    #[cfg(feature = "diffsl-llvm")]
657    use crate::test_support::{logistic_integral, logistic_state_dr};
658
659    use super::*;
660
661    fn make_ode(
662        jit_backend: JitBackendType,
663        matrix_type: MatrixType,
664        ode_solver: OdeSolverType,
665    ) -> OdeWrapper {
666        OdeWrapper::new_jit(
667            logistic_diffsl_code(),
668            jit_backend,
669            ScalarType::F64,
670            matrix_type,
671            LinearSolverType::Default,
672            ode_solver,
673        )
674        .unwrap()
675    }
676
677    fn make_hybrid_ode(
678        jit_backend: JitBackendType,
679        matrix_type: MatrixType,
680        ode_solver: OdeSolverType,
681    ) -> OdeWrapper {
682        OdeWrapper::new_jit(
683            hybrid_logistic_diffsl_code(),
684            jit_backend,
685            ScalarType::F64,
686            matrix_type,
687            LinearSolverType::Default,
688            ode_solver,
689        )
690        .unwrap()
691    }
692
693    fn assert_runtime_dispatch(jit_backend: JitBackendType, matrix_type: MatrixType) {
694        let ode = make_ode(jit_backend, matrix_type, OdeSolverType::Bdf);
695        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
696        assert_eq!(ode.get_code().unwrap(), logistic_diffsl_code());
697
698        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
699        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
700
701        let rhs = ode
702            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
703            .unwrap();
704        assert_close(
705            Vec::<f64>::from_host_array(rhs).unwrap()[0],
706            0.375,
707            ASSERT_TOL,
708            "jit rhs(0.25)",
709        );
710
711        let rhs_jac_mul = ode
712            .rhs_jac_mul(
713                vector_host(&[2.0]),
714                0.0,
715                vector_host(&[0.25]),
716                vector_host(&[3.0]),
717            )
718            .unwrap();
719        assert_close(
720            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
721            3.0,
722            ASSERT_TOL,
723            "jit rhs_jac_mul(0.25, 3.0)",
724        );
725    }
726
727    fn assert_solver_dense_solution(
728        jit_backend: JitBackendType,
729        matrix_type: MatrixType,
730        ode_solver: OdeSolverType,
731    ) {
732        let ode = make_ode(jit_backend, matrix_type, ode_solver);
733        ode.set_rtol(1e-8).unwrap();
734        ode.set_atol(1e-8).unwrap();
735
736        let t_eval = [0.25, 0.5, 1.0];
737        let solution = ode
738            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
739            .unwrap();
740
741        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
742    }
743
744    fn hybrid_t_eval() -> [f64; 7] {
745        [0.5, 1.0, 2.0, 2.5, 3.0, 4.0, 4.8]
746    }
747
748    #[test]
749    fn runtime_dispatch_matches_requested_matrix_type_from_diffsl() {
750        for jit_backend in available_jit_backends() {
751            for matrix_type in [
752                MatrixType::NalgebraDense,
753                MatrixType::FaerDense,
754                MatrixType::FaerSparse,
755            ] {
756                assert_runtime_dispatch(jit_backend, matrix_type);
757            }
758        }
759    }
760
761    #[test]
762    fn dense_solution_matches_logistic_solution_from_diffsl() {
763        for jit_backend in available_jit_backends() {
764            for (matrix_type, solver) in [
765                (MatrixType::FaerDense, OdeSolverType::Esdirk34),
766                (MatrixType::FaerSparse, OdeSolverType::TrBdf2),
767                (MatrixType::NalgebraDense, OdeSolverType::Tsit45),
768            ] {
769                assert_solver_dense_solution(jit_backend, matrix_type, solver);
770            }
771        }
772    }
773
774    #[test]
775    fn bdf_dense_solution_matches_logistic_diffsl_model() {
776        for jit_backend in available_jit_backends() {
777            let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
778            ode.set_rtol(1e-8).unwrap();
779            ode.set_atol(1e-8).unwrap();
780
781            let t_eval = [0.25, 0.5, 1.0];
782            let solution = ode
783                .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
784                .unwrap();
785
786            assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
787        }
788    }
789
790    #[test]
791    fn bdf_solution_matches_logistic_diffsl_model() {
792        for jit_backend in available_jit_backends() {
793            let x0 = LOGISTIC_X0;
794            let r = 2.0;
795            let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
796            ode.set_rtol(1e-8).unwrap();
797            ode.set_atol(1e-8).unwrap();
798
799            let final_time = 1.0;
800            let solution = ode.solve(vector_host(&[r]), final_time).unwrap();
801
802            let ys = solution.get_ys().unwrap();
803            let ys = ys.as_array::<f64>().unwrap();
804            let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
805
806            assert_eq!(ys.nrows(), 1);
807            assert_eq!(ys.ncols(), ts.len());
808            assert!(
809                !ts.is_empty(),
810                "expected solve() to record at least one time point"
811            );
812            assert_close(
813                *ts.last().unwrap(),
814                final_time,
815                ASSERT_TOL,
816                "solve final time",
817            );
818            for (i, &t) in ts.iter().enumerate() {
819                assert_close(
820                    ys[(0, i)],
821                    logistic_state(x0, r, t),
822                    5e-4,
823                    &format!("solve value[{i}]"),
824                );
825            }
826        }
827    }
828
829    #[test]
830    fn hybrid_solution_matches_piecewise_logistic_diffsl_model() {
831        let r = 2.0;
832        let final_time = 5.0;
833        let tau = hybrid_logistic_period(r);
834        for jit_backend in available_jit_backends() {
835            let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
836            ode.set_rtol(1e-8).unwrap();
837            ode.set_atol(1e-8).unwrap();
838
839            let solution = ode.solve_hybrid(vector_host(&[r]), final_time).unwrap();
840            let ys = solution.get_ys().unwrap();
841            let ys = ys.as_array::<f64>().unwrap();
842            let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
843
844            assert_eq!(ys.nrows(), 1);
845            assert_eq!(ys.ncols(), ts.len());
846            assert!(!ts.is_empty(), "expected hybrid solve to produce output");
847            assert_close(
848                *ts.last().unwrap(),
849                final_time,
850                ASSERT_TOL,
851                "jit hybrid final time",
852            );
853            assert_close(
854                ys[(0, ys.ncols() - 1)],
855                hybrid_logistic_state(r, final_time),
856                5e-4,
857                "jit hybrid final value",
858            );
859            assert!(ts.iter().any(|&t| (t - tau).abs() < 1e-3));
860            assert!(ts.iter().any(|&t| (t - 2.0 * tau).abs() < 1e-3));
861            for (col, &t) in ts.iter().enumerate() {
862                if ((t / tau).round() * tau - t).abs() < 1e-3 {
863                    continue;
864                }
865                assert_close(
866                    ys[(0, col)],
867                    hybrid_logistic_state(r, t),
868                    5e-4,
869                    &format!("jit hybrid value[{col}]"),
870                );
871            }
872        }
873    }
874
875    #[test]
876    fn hybrid_dense_solution_matches_piecewise_logistic_diffsl_model() {
877        let r = 2.0;
878        let t_eval = hybrid_t_eval();
879        for jit_backend in available_jit_backends() {
880            let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
881            ode.set_rtol(1e-8).unwrap();
882            ode.set_atol(1e-8).unwrap();
883
884            let solution = ode
885                .solve_hybrid_dense(vector_host(&[r]), vector_host(&t_eval))
886                .unwrap();
887            let ys = solution.get_ys().unwrap();
888            let ys = ys.as_array::<f64>().unwrap();
889            let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
890
891            assert_eq!(ys.nrows(), 1);
892            assert_eq!(ys.ncols(), t_eval.len());
893            assert_eq!(ts, t_eval);
894            for (col, &t) in t_eval.iter().enumerate() {
895                assert_close(
896                    ys[(0, col)],
897                    hybrid_logistic_state(r, t),
898                    5e-4,
899                    &format!("jit hybrid dense value[{col}]"),
900                );
901            }
902        }
903    }
904
905    #[cfg(feature = "diffsl-llvm")]
906    #[test]
907    fn bdf_forward_sensitivities_match_logistic_derivative_from_diffsl() {
908        let ode = make_ode(
909            JitBackendType::Llvm,
910            MatrixType::NalgebraDense,
911            OdeSolverType::Bdf,
912        );
913        ode.set_rtol(1e-8).unwrap();
914        ode.set_atol(1e-8).unwrap();
915
916        let t_eval = [0.25, 0.5, 1.0];
917        let solution = ode
918            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
919            .unwrap();
920
921        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
922        let sens = solution.get_sens().unwrap();
923        assert_eq!(sens.len(), 1);
924        let sens_values = sens[0].as_array::<f64>().unwrap();
925        assert_eq!(sens_values.nrows(), 1);
926        assert_eq!(sens_values.ncols(), t_eval.len());
927        for (i, &t) in t_eval.iter().enumerate() {
928            assert_close(
929                sens_values[(0, i)],
930                logistic_state_dr(LOGISTIC_X0, 2.0, t),
931                ASSERT_TOL,
932                &format!("jit sensitivity[{i}]"),
933            );
934        }
935    }
936
937    #[cfg(feature = "diffsl-llvm")]
938    #[test]
939    fn hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model() {
940        let r = 2.0;
941        let t_eval = hybrid_t_eval();
942        let ode = make_hybrid_ode(
943            JitBackendType::Llvm,
944            MatrixType::NalgebraDense,
945            OdeSolverType::Bdf,
946        );
947        ode.set_rtol(1e-8).unwrap();
948        ode.set_atol(1e-8).unwrap();
949
950        let solution = ode
951            .solve_hybrid_fwd_sens(vector_host(&[r]), vector_host(&t_eval))
952            .unwrap();
953        let ys = solution.get_ys().unwrap();
954        let ys = ys.as_array::<f64>().unwrap();
955        let sens = solution.get_sens().unwrap();
956
957        assert_eq!(ys.nrows(), 1);
958        assert_eq!(ys.ncols(), t_eval.len());
959        assert_eq!(sens.len(), 1);
960        let sens_values = sens[0].as_array::<f64>().unwrap();
961        assert_eq!(sens_values.nrows(), 1);
962        assert_eq!(sens_values.ncols(), t_eval.len());
963        for (col, &t) in t_eval.iter().enumerate() {
964            assert_close(
965                ys[(0, col)],
966                hybrid_logistic_state(r, t),
967                5e-4,
968                &format!("jit hybrid sens value[{col}]"),
969            );
970            assert_close(
971                sens_values[(0, col)],
972                hybrid_logistic_state_dr(r, t),
973                5e-4,
974                &format!("jit hybrid sensitivity[{col}]"),
975            );
976        }
977    }
978
979    #[cfg(feature = "diffsl-llvm")]
980    #[test]
981    fn bdf_sum_squares_adjoint_matches_logistic_diffsl_model() {
982        let ode = make_ode(
983            JitBackendType::Llvm,
984            MatrixType::NalgebraDense,
985            OdeSolverType::Bdf,
986        );
987        ode.set_rtol(1e-8).unwrap();
988        ode.set_atol(1e-8).unwrap();
989
990        let t_eval = [0.0, 0.25, 0.5, 1.0];
991        let data_values: Vec<f64> = t_eval
992            .iter()
993            .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
994            .collect();
995        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
996        let (value, sens) = ode
997            .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
998            .unwrap();
999        let grad = Vec::<f64>::from_host_array(sens).unwrap();
1000
1001        assert_close(value, 0.0, ASSERT_TOL, "jit sum_squares objective");
1002        assert_eq!(grad.len(), 1);
1003        assert!(
1004            grad[0].is_finite(),
1005            "jit sum_squares gradient should be finite"
1006        );
1007    }
1008}