Skip to main content

diffsol_c/
ode.rs

1use std::sync::{Arc, Mutex};
2
3use crate::jit::JitBackendType;
4use crate::{
5    error::DiffsolRtError, host_array::HostArray,
6    initial_condition_options::InitialConditionSolverOptions, linear_solver_type::LinearSolverType,
7    matrix_type::MatrixType, ode_options::OdeSolverOptions, ode_solver_type::OdeSolverType,
8    scalar_type::ScalarType, solution_wrapper::SolutionWrapper, solve::Solve,
9};
10
11pub struct Ode {
12    pub(crate) solve: Box<dyn Solve>,
13    code: String,
14    scalar_type: ScalarType,
15    jit_backend: Option<JitBackendType>,
16    linear_solver: LinearSolverType,
17    ode_solver: OdeSolverType,
18}
19
20unsafe impl Send for Ode {}
21unsafe impl Sync for Ode {}
22
23#[derive(Clone)]
24pub struct OdeWrapper(Arc<Mutex<Ode>>);
25
26impl OdeWrapper {
27    fn guard(&self) -> Result<std::sync::MutexGuard<'_, Ode>, DiffsolRtError> {
28        self.0.lock().map_err(|_| {
29            DiffsolRtError::from(diffsol::error::DiffsolError::Other(
30                "Failed to acquire lock on ODE solver".to_string(),
31            ))
32        })
33    }
34}
35
36impl OdeWrapper {
37    fn build(
38        code: String,
39        scalar_type: ScalarType,
40        solve: Box<dyn Solve>,
41        jit_backend: Option<JitBackendType>,
42        linear_solver: LinearSolverType,
43        ode_solver: OdeSolverType,
44    ) -> Result<Self, DiffsolRtError> {
45        solve.check(linear_solver)?;
46        Ok(OdeWrapper(Arc::new(Mutex::new(Ode {
47            code,
48            scalar_type,
49            solve,
50            jit_backend,
51            linear_solver,
52            ode_solver,
53        }))))
54    }
55
56    /// Construct an ODE solver backed by externally-provided DiffSL symbols.
57    #[cfg(feature = "external")]
58    pub fn new_external(
59        rhs_state_deps: Vec<(usize, usize)>,
60        rhs_input_deps: Vec<(usize, usize)>,
61        mass_state_deps: Vec<(usize, usize)>,
62        scalar_type: ScalarType,
63        matrix_type: MatrixType,
64        linear_solver: LinearSolverType,
65        ode_solver: OdeSolverType,
66    ) -> Result<Self, DiffsolRtError> {
67        let solve = crate::solve::solve_factory_external(
68            rhs_state_deps,
69            rhs_input_deps,
70            mass_state_deps,
71            matrix_type,
72            scalar_type,
73        )?;
74        Self::build(
75            String::new(),
76            scalar_type,
77            solve,
78            None,
79            linear_solver,
80            ode_solver,
81        )
82    }
83
84    /// Construct an ODE solver by JIT-compiling DiffSL code immediately.
85    #[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
86    pub fn new_jit(
87        code: &str,
88        jit_backend: JitBackendType,
89        scalar_type: ScalarType,
90        matrix_type: MatrixType,
91        linear_solver: LinearSolverType,
92        ode_solver: OdeSolverType,
93    ) -> Result<Self, DiffsolRtError> {
94        let solve = crate::solve::solve_factory_jit(code, jit_backend, matrix_type, scalar_type)?;
95        Self::build(
96            code.to_owned(),
97            scalar_type,
98            solve,
99            Some(jit_backend),
100            linear_solver,
101            ode_solver,
102        )
103    }
104
105    /// Matrix type used in the ODE solver. This is fixed after construction.
106    pub fn get_matrix_type(&self) -> Result<MatrixType, DiffsolRtError> {
107        Ok(self.guard()?.solve.matrix_type())
108    }
109
110    pub fn get_nstates(&self) -> Result<usize, DiffsolRtError> {
111        Ok(self.guard()?.solve.nstates())
112    }
113
114    pub fn get_nparams(&self) -> Result<usize, DiffsolRtError> {
115        Ok(self.guard()?.solve.nparams())
116    }
117
118    pub fn get_nout(&self) -> Result<usize, DiffsolRtError> {
119        Ok(self.guard()?.solve.nout())
120    }
121
122    pub fn has_stop(&self) -> Result<bool, DiffsolRtError> {
123        Ok(self.guard()?.solve.has_stop())
124    }
125
126    /// Ode solver method, default Bdf (backward differentiation formula).
127    pub fn get_ode_solver(&self) -> Result<OdeSolverType, DiffsolRtError> {
128        Ok(self.guard()?.ode_solver)
129    }
130
131    pub fn set_ode_solver(&self, value: OdeSolverType) -> Result<(), DiffsolRtError> {
132        self.guard()?.ode_solver = value;
133        Ok(())
134    }
135
136    /// Linear solver type used in the ODE solver. Set to default to use the
137    /// solver's default choice, which is typically an LU solver.
138    pub fn get_linear_solver(&self) -> Result<LinearSolverType, DiffsolRtError> {
139        Ok(self.guard()?.linear_solver)
140    }
141
142    pub fn set_linear_solver(&self, value: LinearSolverType) -> Result<(), DiffsolRtError> {
143        self.guard()?.solve.check(value)?;
144        self.guard()?.linear_solver = value;
145        Ok(())
146    }
147
148    /// Relative tolerance for the solver, default 1e-6. Governs the error relative to the solution size.
149    pub fn get_rtol(&self) -> Result<f64, DiffsolRtError> {
150        Ok(self.guard()?.solve.rtol())
151    }
152
153    pub fn set_rtol(&self, value: f64) -> Result<(), DiffsolRtError> {
154        self.guard()?.solve.set_rtol(value);
155        Ok(())
156    }
157
158    /// Absolute tolerance for the solver, default 1e-6. Governs the error as the solution goes to zero.
159    pub fn get_atol(&self) -> Result<f64, DiffsolRtError> {
160        Ok(self.guard()?.solve.atol())
161    }
162
163    pub fn set_atol(&self, value: f64) -> Result<(), DiffsolRtError> {
164        self.guard()?.solve.set_atol(value);
165        Ok(())
166    }
167
168    pub fn get_code(&self) -> Result<String, DiffsolRtError> {
169        Ok(self.guard()?.code.clone())
170    }
171
172    pub fn get_scalar_type(&self) -> Result<ScalarType, DiffsolRtError> {
173        Ok(self.guard()?.scalar_type)
174    }
175
176    pub fn get_jit_backend(&self) -> Result<Option<JitBackendType>, DiffsolRtError> {
177        Ok(self.guard()?.jit_backend)
178    }
179
180    pub fn get_ic_options(&self) -> InitialConditionSolverOptions {
181        InitialConditionSolverOptions::new(self.0.clone())
182    }
183
184    pub fn get_options(&self) -> OdeSolverOptions {
185        OdeSolverOptions::new(self.0.clone())
186    }
187
188    /// Get the initial condition vector y0 as a 1D numpy array.
189    pub fn y0(&self, params: HostArray) -> Result<HostArray, DiffsolRtError> {
190        let mut self_guard = self.guard()?;
191        self_guard.solve.y0(params.as_slice()?)
192    }
193
194    /// evaluate the right-hand side function at time `t` and state `y`.
195    pub fn rhs(
196        &self,
197        params: HostArray,
198        t: f64,
199        y: HostArray,
200    ) -> Result<HostArray, DiffsolRtError> {
201        let mut self_guard = self.guard()?;
202        self_guard.solve.rhs(params.as_slice()?, t, y.as_slice()?)
203    }
204
205    /// evaluate the right-hand side Jacobian-vector product `Jv`` at time `t` and state `y`.
206    pub fn rhs_jac_mul(
207        &self,
208        params: HostArray,
209        t: f64,
210        y: HostArray,
211        v: HostArray,
212    ) -> Result<HostArray, DiffsolRtError> {
213        let mut self_guard = self.guard()?;
214        self_guard
215            .solve
216            .rhs_jac_mul(params.as_slice()?, t, y.as_slice()?, v.as_slice()?)
217    }
218
219    /// Using the provided state, solve the problem up to time `final_time`.
220    ///
221    /// The number of params must match the expected params in the diffsl code.
222    /// If specified, the config can be used to override the solver method
223    /// (Bdf by default) and SolverType (Lu by default) along with other solver
224    /// params like `rtol`.
225    ///
226    /// :param params: 1D array of solver parameters
227    /// :type params: numpy.ndarray
228    /// :param final_time: end time of solver
229    /// :type final_time: float
230    /// :return: `(ys, ts)` tuple where `ys` is a 2D array of values at times
231    ///     `ts` chosen by the solver
232    /// :rtype: Tuple[numpy.ndarray, numpy.ndarray]
233    ///
234    /// Example:
235    ///     >>> print(ode.solve(np.array([]), 0.5))
236    #[allow(clippy::type_complexity)]
237    pub fn solve(
238        &self,
239        params: HostArray,
240        final_time: f64,
241    ) -> Result<SolutionWrapper, DiffsolRtError> {
242        let mut self_guard = self.guard()?;
243        let params = params.as_slice()?;
244        let linear_solver = self_guard.linear_solver;
245        let method = self_guard.ode_solver;
246        let solution = self_guard
247            .solve
248            .solve(method, linear_solver, params, final_time)?;
249        Ok(SolutionWrapper::new(solution))
250    }
251
252    /// Solve a hybrid ODE up to `final_time`, automatically applying reset
253    /// functions and continuing after root events until the solution completes.
254    pub fn solve_hybrid(
255        &self,
256        params: HostArray,
257        final_time: f64,
258    ) -> Result<SolutionWrapper, DiffsolRtError> {
259        let mut self_guard = self.guard()?;
260        let params = params.as_slice()?;
261        let linear_solver = self_guard.linear_solver;
262        let method = self_guard.ode_solver;
263        let solution = self_guard
264            .solve
265            .solve_hybrid(method, linear_solver, params, final_time)?;
266        Ok(SolutionWrapper::new(solution))
267    }
268
269    /// Using the provided state, solve the problem up to time
270    /// `t_eval[t_eval.len()-1]`. Returns 2D array of solution values at
271    /// timepoints given by `t_eval`.
272    ///
273    /// The number of params must match the expected params in the diffsl code.
274    /// The config may be optionally specified to override solver settings.
275    ///
276    /// :param params: 1D array of solver parameters
277    /// :type params: numpy.ndarray
278    /// :param t_eval: 1D array of solver times
279    /// :type params: numpy.ndarray
280    /// :return: 2D array of values at times `t_eval`
281    /// :rtype: numpy.ndarray
282    pub fn solve_dense(
283        &self,
284        params: HostArray,
285        t_eval: HostArray,
286    ) -> Result<SolutionWrapper, DiffsolRtError> {
287        let mut self_guard = self.guard()?;
288        let params = params.as_slice()?;
289        let t_eval = t_eval.as_slice()?;
290        let linear_solver = self_guard.linear_solver;
291        let method = self_guard.ode_solver;
292        let solution = self_guard
293            .solve
294            .solve_dense(method, linear_solver, params, t_eval)?;
295        Ok(SolutionWrapper::new(solution))
296    }
297
298    /// Solve a hybrid ODE at dense evaluation times, automatically applying
299    /// reset functions and continuing after root events until all requested
300    /// output points are filled.
301    pub fn solve_hybrid_dense(
302        &self,
303        params: HostArray,
304        t_eval: HostArray,
305    ) -> Result<SolutionWrapper, DiffsolRtError> {
306        let mut self_guard = self.guard()?;
307        let params = params.as_slice()?;
308        let t_eval = t_eval.as_slice()?;
309        let linear_solver = self_guard.linear_solver;
310        let method = self_guard.ode_solver;
311        let solution =
312            self_guard
313                .solve
314                .solve_hybrid_dense(method, linear_solver, params, t_eval)?;
315        Ok(SolutionWrapper::new(solution))
316    }
317
318    /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`.
319    /// Returns 2D array of solution values at timepoints given by `t_eval`.
320    /// Also returns a list of 2D arrays of sensitivities at the same timepoints
321    /// as the solution.
322    /// The number of params must match the expected params in the diffsl code.
323    /// The config may be optionally specified to override solver settings.
324    /// :param params: 1D array of solver parameters
325    /// :type params: numpy.ndarray
326    /// :param t_eval: 1D array of solver times
327    /// :type params: numpy.ndarray
328    /// :return: 2D array of values at times `t_eval` and a list of 2D arrays of sensitivities at the same timepoints
329    /// :rtype: (numpy.ndarray, List[numpy.ndarray])
330    #[allow(clippy::type_complexity)]
331    pub fn solve_fwd_sens(
332        &self,
333        params: HostArray,
334        t_eval: HostArray,
335    ) -> Result<SolutionWrapper, DiffsolRtError> {
336        let mut self_guard = self.guard()?;
337        let params = params.as_slice()?;
338        let t_eval = t_eval.as_slice()?;
339        let linear_solver = self_guard.linear_solver;
340        let method = self_guard.ode_solver;
341        let solution = self_guard
342            .solve
343            .solve_fwd_sens(method, linear_solver, params, t_eval)?;
344        Ok(SolutionWrapper::new(solution))
345    }
346
347    /// Solve a hybrid ODE with forward sensitivities at dense evaluation times,
348    /// automatically applying sensitivity-aware reset functions and continuing
349    /// after root events until all requested output points are filled.
350    #[allow(clippy::type_complexity)]
351    pub fn solve_hybrid_fwd_sens(
352        &self,
353        params: HostArray,
354        t_eval: HostArray,
355    ) -> Result<SolutionWrapper, DiffsolRtError> {
356        let mut self_guard = self.guard()?;
357        let params = params.as_slice()?;
358        let t_eval = t_eval.as_slice()?;
359        let linear_solver = self_guard.linear_solver;
360        let method = self_guard.ode_solver;
361        let solution =
362            self_guard
363                .solve
364                .solve_hybrid_fwd_sens(method, linear_solver, params, t_eval)?;
365        Ok(SolutionWrapper::new(solution))
366    }
367
368    /// Using the provided state, solve the adjoint problem for the sum of squares
369    /// objective given data at timepoints `t_eval`.
370    /// Returns the objective value and a list of 1D arrays of adjoint sensitivities
371    /// for each parameter.
372    #[allow(clippy::type_complexity)]
373    pub fn solve_sum_squares_adj(
374        &self,
375        params: HostArray,
376        data: HostArray,
377        t_eval: HostArray,
378    ) -> Result<(f64, HostArray), DiffsolRtError> {
379        let mut self_guard = self.guard()?;
380        let linear_solver = self_guard.linear_solver;
381        let ode_solver = self_guard.ode_solver;
382
383        self_guard.solve.solve_sum_squares_adj(
384            ode_solver,
385            linear_solver,
386            ode_solver,
387            linear_solver,
388            params.as_slice()?,
389            data,
390            t_eval.as_slice()?,
391        )
392    }
393}
394
395#[cfg(all(test, feature = "diffsl-external-f64"))]
396mod tests {
397    use crate::host_array::FromHostArray;
398    use crate::linear_solver_type::LinearSolverType;
399    use crate::scalar_type::ScalarType;
400    use crate::test_support::{
401        assert_close, assert_solution_tail, logistic_integral, logistic_state, logistic_state_dr,
402        mass_state_deps, rhs_input_deps, rhs_state_deps, vector_host, ASSERT_TOL, LOGISTIC_X0,
403    };
404
405    use super::*;
406
407    fn all_ode_solvers() -> [OdeSolverType; 4] {
408        [
409            OdeSolverType::Bdf,
410            OdeSolverType::Esdirk34,
411            OdeSolverType::TrBdf2,
412            OdeSolverType::Tsit45,
413        ]
414    }
415
416    fn make_ode(matrix_type: MatrixType, ode_solver: OdeSolverType) -> OdeWrapper {
417        OdeWrapper::new_external(
418            rhs_state_deps(),
419            rhs_input_deps(),
420            mass_state_deps(),
421            ScalarType::F64,
422            matrix_type,
423            LinearSolverType::Default,
424            ode_solver,
425        )
426        .unwrap()
427    }
428
429    fn assert_runtime_dispatch(matrix_type: MatrixType) {
430        let ode = make_ode(matrix_type, OdeSolverType::Bdf);
431        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
432        assert_eq!(ode.get_nstates().unwrap(), 1);
433        assert_eq!(ode.get_nparams().unwrap(), 1);
434        assert_eq!(ode.get_nout().unwrap(), 1);
435        assert!(ode.has_stop().unwrap());
436
437        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
438        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
439
440        let rhs = ode
441            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
442            .unwrap();
443        assert_close(
444            Vec::<f64>::from_host_array(rhs).unwrap()[0],
445            0.375,
446            ASSERT_TOL,
447            "rhs(0.25)",
448        );
449
450        let rhs_jac_mul = ode
451            .rhs_jac_mul(
452                vector_host(&[2.0]),
453                0.0,
454                vector_host(&[0.25]),
455                vector_host(&[3.0]),
456            )
457            .unwrap();
458        assert_close(
459            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
460            3.0,
461            ASSERT_TOL,
462            "rhs_jac_mul(0.25, 3.0)",
463        );
464    }
465
466    fn assert_solver_dense_solution(matrix_type: MatrixType, ode_solver: OdeSolverType) {
467        let ode = make_ode(matrix_type, ode_solver);
468        ode.set_rtol(1e-8).unwrap();
469        ode.set_atol(1e-8).unwrap();
470
471        let t_eval = [0.25, 0.5, 1.0];
472        let solution = ode
473            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
474            .unwrap();
475
476        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
477    }
478
479    fn hybrid_root_time() -> f64 {
480        0.5 * 9.0_f64.ln()
481    }
482
483    fn assert_hybrid_solution_applies_reset_after_root(ode_solver: OdeSolverType) {
484        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
485        ode.set_rtol(1e-8).unwrap();
486        ode.set_atol(1e-8).unwrap();
487
488        let final_time = 2.0;
489        let solution = ode.solve_hybrid(vector_host(&[2.0]), final_time).unwrap();
490        let ys = solution.get_ys().unwrap();
491        let ys = ys.as_array::<f64>().unwrap();
492        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
493        let root_time = hybrid_root_time();
494
495        assert_eq!(ys.nrows(), 1);
496        assert_eq!(ys.ncols(), ts.len());
497        assert!(!ts.is_empty(), "expected hybrid solve to produce output");
498        assert_close(
499            *ts.last().unwrap(),
500            final_time,
501            ASSERT_TOL,
502            "hybrid final time",
503        );
504        assert_close(ys[(0, ys.ncols() - 1)], 1.0, 5e-4, "hybrid final value");
505        assert!(
506            ts.iter().any(|&t| t < root_time),
507            "expected pre-root samples"
508        );
509        assert!(
510            ts.iter().any(|&t| t > root_time),
511            "expected post-root samples after reset"
512        );
513    }
514
515    fn assert_hybrid_dense_solution_continues_after_reset(ode_solver: OdeSolverType) {
516        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
517        ode.set_rtol(1e-8).unwrap();
518        ode.set_atol(1e-8).unwrap();
519
520        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
521        let solution = ode
522            .solve_hybrid_dense(vector_host(&[2.0]), vector_host(&t_eval))
523            .unwrap();
524        let ys = solution.get_ys().unwrap();
525        let ys = ys.as_array::<f64>().unwrap();
526
527        assert_eq!(ys.nrows(), 1);
528        assert_eq!(ys.ncols(), t_eval.len());
529        assert_close(
530            ys[(0, 0)],
531            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
532            5e-4,
533            "hybrid dense pre-root value",
534        );
535        assert_close(
536            ys[(0, 1)],
537            logistic_state(LOGISTIC_X0, 2.0, t_eval[1]),
538            5e-4,
539            "hybrid dense near-root value",
540        );
541        for col in 2..t_eval.len() {
542            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid dense post-root value");
543        }
544    }
545
546    fn assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver: OdeSolverType) {
547        let ode = make_ode(MatrixType::NalgebraDense, ode_solver);
548        ode.set_rtol(1e-8).unwrap();
549        ode.set_atol(1e-8).unwrap();
550
551        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
552        let solution = ode
553            .solve_hybrid_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
554            .unwrap();
555        let ys = solution.get_ys().unwrap();
556        let ys = ys.as_array::<f64>().unwrap();
557        let sens = solution.get_sens().unwrap();
558
559        assert_eq!(ys.nrows(), 1);
560        assert_eq!(ys.ncols(), t_eval.len());
561        assert_eq!(sens.len(), 1);
562        let sens_values = sens[0].as_array::<f64>().unwrap();
563        assert_eq!(sens_values.nrows(), 1);
564        assert_eq!(sens_values.ncols(), t_eval.len());
565        assert_close(
566            ys[(0, 0)],
567            logistic_state(LOGISTIC_X0, 2.0, t_eval[0]),
568            5e-4,
569            "hybrid sens pre-root value",
570        );
571        for col in 2..t_eval.len() {
572            assert_close(ys[(0, col)], 1.0, 5e-4, "hybrid sens post-root value");
573            assert!(
574                sens_values[(0, col)].is_finite(),
575                "expected finite post-root sensitivity at column {col}"
576            );
577        }
578    }
579
580    #[test]
581    fn runtime_dispatch_matches_requested_matrix_type() {
582        for matrix_type in [
583            MatrixType::NalgebraDense,
584            MatrixType::FaerDense,
585            MatrixType::FaerSparse,
586        ] {
587            assert_runtime_dispatch(matrix_type);
588        }
589    }
590
591    #[test]
592    fn bdf_dense_solution_matches_logistic_solution() {
593        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
594        ode.set_rtol(1e-8).unwrap();
595        ode.set_atol(1e-8).unwrap();
596
597        let t_eval = [0.25, 0.5, 1.0];
598        let solution = ode
599            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
600            .unwrap();
601
602        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
603    }
604
605    #[test]
606    fn esdirk34_dense_solution_matches_logistic_solution() {
607        assert_solver_dense_solution(MatrixType::FaerDense, OdeSolverType::Esdirk34);
608    }
609
610    #[test]
611    fn tr_bdf2_sparse_solution_matches_logistic_solution() {
612        assert_solver_dense_solution(MatrixType::FaerSparse, OdeSolverType::TrBdf2);
613    }
614
615    #[test]
616    fn tsit45_dense_solution_matches_logistic_solution() {
617        assert_solver_dense_solution(MatrixType::NalgebraDense, OdeSolverType::Tsit45);
618    }
619
620    #[test]
621    fn bdf_forward_sensitivities_match_logistic_derivative() {
622        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
623        ode.set_rtol(1e-8).unwrap();
624        ode.set_atol(1e-8).unwrap();
625
626        let t_eval = [0.25, 0.5, 1.0];
627        let solution = ode
628            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
629            .unwrap();
630
631        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
632        let sens = solution.get_sens().unwrap();
633        assert_eq!(sens.len(), 1);
634        let sens_values = sens[0].as_array::<f64>().unwrap();
635        assert_eq!(sens_values.nrows(), 1);
636        assert_eq!(sens_values.ncols(), t_eval.len());
637        for (i, &t) in t_eval.iter().enumerate() {
638            assert_close(
639                sens_values[(0, i)],
640                logistic_state_dr(LOGISTIC_X0, 2.0, t),
641                ASSERT_TOL,
642                &format!("sensitivity[{i}]"),
643            );
644        }
645    }
646
647    #[test]
648    fn bdf_sum_squares_adjoint_matches_external_logistic_model() {
649        let ode = make_ode(MatrixType::NalgebraDense, OdeSolverType::Bdf);
650        ode.set_rtol(1e-8).unwrap();
651        ode.set_atol(1e-8).unwrap();
652
653        let t_eval = [0.0, 0.25, 0.5, 1.0];
654        let data_values: Vec<f64> = t_eval
655            .iter()
656            .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
657            .collect();
658        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
659        let (value, sens) = ode
660            .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
661            .unwrap();
662        let grad = Vec::<f64>::from_host_array(sens).unwrap();
663
664        assert_close(value, 0.0, ASSERT_TOL, "sum_squares objective");
665        assert_eq!(grad.len(), 1);
666        assert_close(grad[0], 0.0, ASSERT_TOL, "sum_squares gradient");
667    }
668
669    #[test]
670    fn hybrid_solution_applies_reset_after_root_for_all_solvers() {
671        for ode_solver in all_ode_solvers() {
672            assert_hybrid_solution_applies_reset_after_root(ode_solver);
673        }
674    }
675
676    #[test]
677    fn hybrid_dense_solution_continues_after_reset_for_all_solvers() {
678        for ode_solver in all_ode_solvers() {
679            assert_hybrid_dense_solution_continues_after_reset(ode_solver);
680        }
681    }
682
683    #[test]
684    fn hybrid_forward_sensitivities_complete_across_reset_for_all_solvers() {
685        for ode_solver in all_ode_solvers() {
686            assert_hybrid_forward_sensitivities_complete_across_reset(ode_solver);
687        }
688    }
689}
690
691#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
692mod jit_tests {
693    use crate::host_array::FromHostArray;
694    use crate::jit::JitBackendType;
695    use crate::linear_solver_type::LinearSolverType;
696    use crate::scalar_type::ScalarType;
697    use crate::test_support::{
698        assert_close, assert_solution_tail, available_jit_backends, hybrid_logistic_diffsl_code,
699        hybrid_logistic_period, hybrid_logistic_state, logistic_diffsl_code, logistic_state,
700        vector_host, ASSERT_TOL, LOGISTIC_X0,
701    };
702    #[cfg(feature = "diffsl-llvm")]
703    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
704
705    use super::*;
706
707    fn all_ode_solvers() -> [OdeSolverType; 4] {
708        [
709            OdeSolverType::Bdf,
710            OdeSolverType::Esdirk34,
711            OdeSolverType::TrBdf2,
712            OdeSolverType::Tsit45,
713        ]
714    }
715
716    fn make_ode(
717        jit_backend: JitBackendType,
718        matrix_type: MatrixType,
719        ode_solver: OdeSolverType,
720    ) -> OdeWrapper {
721        OdeWrapper::new_jit(
722            logistic_diffsl_code(),
723            jit_backend,
724            ScalarType::F64,
725            matrix_type,
726            LinearSolverType::Default,
727            ode_solver,
728        )
729        .unwrap()
730    }
731
732    fn make_hybrid_ode(
733        jit_backend: JitBackendType,
734        matrix_type: MatrixType,
735        ode_solver: OdeSolverType,
736    ) -> OdeWrapper {
737        OdeWrapper::new_jit(
738            hybrid_logistic_diffsl_code(),
739            jit_backend,
740            ScalarType::F64,
741            matrix_type,
742            LinearSolverType::Default,
743            ode_solver,
744        )
745        .unwrap()
746    }
747
748    fn assert_runtime_dispatch(jit_backend: JitBackendType, matrix_type: MatrixType) {
749        let ode = make_ode(jit_backend, matrix_type, OdeSolverType::Bdf);
750        assert_eq!(ode.get_matrix_type().unwrap(), matrix_type);
751        assert_eq!(ode.get_code().unwrap(), logistic_diffsl_code());
752        assert_eq!(ode.get_nstates().unwrap(), 1);
753        assert_eq!(ode.get_nparams().unwrap(), 1);
754        assert_eq!(ode.get_nout().unwrap(), 1);
755        assert!(!ode.has_stop().unwrap());
756
757        let y0 = ode.y0(vector_host(&[2.0])).unwrap();
758        assert_eq!(Vec::<f64>::from_host_array(y0).unwrap(), vec![LOGISTIC_X0]);
759
760        let rhs = ode
761            .rhs(vector_host(&[2.0]), 0.0, vector_host(&[0.25]))
762            .unwrap();
763        assert_close(
764            Vec::<f64>::from_host_array(rhs).unwrap()[0],
765            0.375,
766            ASSERT_TOL,
767            "jit rhs(0.25)",
768        );
769
770        let rhs_jac_mul = ode
771            .rhs_jac_mul(
772                vector_host(&[2.0]),
773                0.0,
774                vector_host(&[0.25]),
775                vector_host(&[3.0]),
776            )
777            .unwrap();
778        assert_close(
779            Vec::<f64>::from_host_array(rhs_jac_mul).unwrap()[0],
780            3.0,
781            ASSERT_TOL,
782            "jit rhs_jac_mul(0.25, 3.0)",
783        );
784    }
785
786    fn assert_solver_dense_solution(
787        jit_backend: JitBackendType,
788        matrix_type: MatrixType,
789        ode_solver: OdeSolverType,
790    ) {
791        let ode = make_ode(jit_backend, matrix_type, ode_solver);
792        ode.set_rtol(1e-8).unwrap();
793        ode.set_atol(1e-8).unwrap();
794
795        let t_eval = [0.25, 0.5, 1.0];
796        let solution = ode
797            .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
798            .unwrap();
799
800        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
801    }
802
803    fn hybrid_t_eval() -> [f64; 7] {
804        [0.5, 1.0, 2.0, 2.5, 3.0, 4.0, 4.8]
805    }
806
807    fn assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
808        jit_backend: JitBackendType,
809        ode_solver: OdeSolverType,
810    ) {
811        let r = 2.0;
812        let final_time = 5.0;
813        let tau = hybrid_logistic_period(r);
814        let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
815        ode.set_rtol(1e-8).unwrap();
816        ode.set_atol(1e-8).unwrap();
817        assert_eq!(ode.get_nstates().unwrap(), 1);
818        assert_eq!(ode.get_nparams().unwrap(), 1);
819        assert_eq!(ode.get_nout().unwrap(), 1);
820        assert!(ode.has_stop().unwrap());
821
822        let solution = ode.solve_hybrid(vector_host(&[r]), final_time).unwrap();
823        let ys = solution.get_ys().unwrap();
824        let ys = ys.as_array::<f64>().unwrap();
825        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
826
827        assert_eq!(ys.nrows(), 1);
828        assert_eq!(ys.ncols(), ts.len());
829        assert!(!ts.is_empty(), "expected hybrid solve to produce output");
830        assert_close(
831            *ts.last().unwrap(),
832            final_time,
833            ASSERT_TOL,
834            "jit hybrid final time",
835        );
836        assert_close(
837            ys[(0, ys.ncols() - 1)],
838            hybrid_logistic_state(r, final_time),
839            5e-4,
840            "jit hybrid final value",
841        );
842        assert!(ts.iter().any(|&t| (t - tau).abs() < 1e-3));
843        assert!(ts.iter().any(|&t| (t - 2.0 * tau).abs() < 1e-3));
844        for (col, &t) in ts.iter().enumerate() {
845            if ((t / tau).round() * tau - t).abs() < 1e-3 {
846                continue;
847            }
848            assert_close(
849                ys[(0, col)],
850                hybrid_logistic_state(r, t),
851                5e-4,
852                &format!("jit hybrid value[{col}]"),
853            );
854        }
855    }
856
857    fn assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
858        jit_backend: JitBackendType,
859        ode_solver: OdeSolverType,
860    ) {
861        let r = 2.0;
862        let t_eval = hybrid_t_eval();
863        let ode = make_hybrid_ode(jit_backend, MatrixType::NalgebraDense, ode_solver);
864        ode.set_rtol(1e-8).unwrap();
865        ode.set_atol(1e-8).unwrap();
866
867        let solution = ode
868            .solve_hybrid_dense(vector_host(&[r]), vector_host(&t_eval))
869            .unwrap();
870        let ys = solution.get_ys().unwrap();
871        let ys = ys.as_array::<f64>().unwrap();
872        let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
873
874        assert_eq!(ys.nrows(), 1);
875        assert_eq!(ys.ncols(), t_eval.len());
876        assert_eq!(ts, t_eval);
877        for (col, &t) in t_eval.iter().enumerate() {
878            assert_close(
879                ys[(0, col)],
880                hybrid_logistic_state(r, t),
881                5e-4,
882                &format!("jit hybrid dense value[{col}]"),
883            );
884        }
885    }
886
887    #[cfg(feature = "diffsl-llvm")]
888    fn assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(
889        ode_solver: OdeSolverType,
890    ) {
891        let r = 2.0;
892        let t_eval = hybrid_t_eval();
893        let ode = make_hybrid_ode(JitBackendType::Llvm, MatrixType::NalgebraDense, ode_solver);
894        ode.set_rtol(1e-8).unwrap();
895        ode.set_atol(1e-8).unwrap();
896
897        let solution = ode
898            .solve_hybrid_fwd_sens(vector_host(&[r]), vector_host(&t_eval))
899            .unwrap();
900        let ys = solution.get_ys().unwrap();
901        let ys = ys.as_array::<f64>().unwrap();
902        let sens = solution.get_sens().unwrap();
903
904        assert_eq!(ys.nrows(), 1);
905        assert_eq!(ys.ncols(), t_eval.len());
906        assert_eq!(sens.len(), 1);
907        let sens_values = sens[0].as_array::<f64>().unwrap();
908        assert_eq!(sens_values.nrows(), 1);
909        assert_eq!(sens_values.ncols(), t_eval.len());
910        for (col, &t) in t_eval.iter().enumerate() {
911            assert_close(
912                ys[(0, col)],
913                hybrid_logistic_state(r, t),
914                5e-4,
915                &format!("jit hybrid sens value[{col}]"),
916            );
917            assert_close(
918                sens_values[(0, col)],
919                hybrid_logistic_state_dr(r, t),
920                5e-4,
921                &format!("jit hybrid sensitivity[{col}]"),
922            );
923        }
924    }
925
926    #[test]
927    fn runtime_dispatch_matches_requested_matrix_type_from_diffsl() {
928        for jit_backend in available_jit_backends() {
929            for matrix_type in [
930                MatrixType::NalgebraDense,
931                MatrixType::FaerDense,
932                MatrixType::FaerSparse,
933            ] {
934                assert_runtime_dispatch(jit_backend, matrix_type);
935            }
936        }
937    }
938
939    #[test]
940    fn dense_solution_matches_logistic_solution_from_diffsl() {
941        for jit_backend in available_jit_backends() {
942            for (matrix_type, solver) in [
943                (MatrixType::FaerDense, OdeSolverType::Esdirk34),
944                (MatrixType::FaerSparse, OdeSolverType::TrBdf2),
945                (MatrixType::NalgebraDense, OdeSolverType::Tsit45),
946            ] {
947                assert_solver_dense_solution(jit_backend, matrix_type, solver);
948            }
949        }
950    }
951
952    #[test]
953    fn bdf_dense_solution_matches_logistic_diffsl_model() {
954        for jit_backend in available_jit_backends() {
955            let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
956            ode.set_rtol(1e-8).unwrap();
957            ode.set_atol(1e-8).unwrap();
958
959            let t_eval = [0.25, 0.5, 1.0];
960            let solution = ode
961                .solve_dense(vector_host(&[2.0]), vector_host(&t_eval))
962                .unwrap();
963
964            assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
965        }
966    }
967
968    #[test]
969    fn bdf_solution_matches_logistic_diffsl_model() {
970        for jit_backend in available_jit_backends() {
971            let x0 = LOGISTIC_X0;
972            let r = 2.0;
973            let ode = make_ode(jit_backend, MatrixType::NalgebraDense, OdeSolverType::Bdf);
974            ode.set_rtol(1e-8).unwrap();
975            ode.set_atol(1e-8).unwrap();
976
977            let final_time = 1.0;
978            let solution = ode.solve(vector_host(&[r]), final_time).unwrap();
979
980            let ys = solution.get_ys().unwrap();
981            let ys = ys.as_array::<f64>().unwrap();
982            let ts = Vec::<f64>::from_host_array(solution.get_ts().unwrap()).unwrap();
983
984            assert_eq!(ys.nrows(), 1);
985            assert_eq!(ys.ncols(), ts.len());
986            assert!(
987                !ts.is_empty(),
988                "expected solve() to record at least one time point"
989            );
990            assert_close(
991                *ts.last().unwrap(),
992                final_time,
993                ASSERT_TOL,
994                "solve final time",
995            );
996            for (i, &t) in ts.iter().enumerate() {
997                assert_close(
998                    ys[(0, i)],
999                    logistic_state(x0, r, t),
1000                    5e-4,
1001                    &format!("solve value[{i}]"),
1002                );
1003            }
1004        }
1005    }
1006
1007    #[test]
1008    fn hybrid_solution_matches_piecewise_logistic_diffsl_model() {
1009        for jit_backend in available_jit_backends() {
1010            for ode_solver in all_ode_solvers() {
1011                assert_hybrid_solution_matches_piecewise_logistic_diffsl_model(
1012                    jit_backend,
1013                    ode_solver,
1014                );
1015            }
1016        }
1017    }
1018
1019    #[test]
1020    fn hybrid_dense_solution_matches_piecewise_logistic_diffsl_model() {
1021        for jit_backend in available_jit_backends() {
1022            for ode_solver in all_ode_solvers() {
1023                assert_hybrid_dense_solution_matches_piecewise_logistic_diffsl_model(
1024                    jit_backend,
1025                    ode_solver,
1026                );
1027            }
1028        }
1029    }
1030
1031    #[cfg(feature = "diffsl-llvm")]
1032    #[test]
1033    fn bdf_forward_sensitivities_match_logistic_derivative_from_diffsl() {
1034        let ode = make_ode(
1035            JitBackendType::Llvm,
1036            MatrixType::NalgebraDense,
1037            OdeSolverType::Bdf,
1038        );
1039        ode.set_rtol(1e-8).unwrap();
1040        ode.set_atol(1e-8).unwrap();
1041
1042        let t_eval = [0.25, 0.5, 1.0];
1043        let solution = ode
1044            .solve_fwd_sens(vector_host(&[2.0]), vector_host(&t_eval))
1045            .unwrap();
1046
1047        assert_solution_tail(&solution, &t_eval, LOGISTIC_X0, 2.0, 5e-4);
1048        let sens = solution.get_sens().unwrap();
1049        assert_eq!(sens.len(), 1);
1050        let sens_values = sens[0].as_array::<f64>().unwrap();
1051        assert_eq!(sens_values.nrows(), 1);
1052        assert_eq!(sens_values.ncols(), t_eval.len());
1053        for (i, &t) in t_eval.iter().enumerate() {
1054            assert_close(
1055                sens_values[(0, i)],
1056                logistic_state_dr(LOGISTIC_X0, 2.0, t),
1057                ASSERT_TOL,
1058                &format!("jit sensitivity[{i}]"),
1059            );
1060        }
1061    }
1062
1063    #[cfg(feature = "diffsl-llvm")]
1064    #[test]
1065    fn hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model() {
1066        for ode_solver in all_ode_solvers() {
1067            assert_hybrid_forward_sensitivities_match_piecewise_logistic_diffsl_model(ode_solver);
1068        }
1069    }
1070
1071    #[cfg(feature = "diffsl-llvm")]
1072    #[test]
1073    fn bdf_sum_squares_adjoint_matches_logistic_diffsl_model() {
1074        let ode = make_ode(
1075            JitBackendType::Llvm,
1076            MatrixType::NalgebraDense,
1077            OdeSolverType::Bdf,
1078        );
1079        ode.set_rtol(1e-8).unwrap();
1080        ode.set_atol(1e-8).unwrap();
1081
1082        let t_eval = [0.0, 0.25, 0.5, 1.0];
1083        let data_values: Vec<f64> = t_eval
1084            .iter()
1085            .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1086            .collect();
1087        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1088        let (value, sens) = ode
1089            .solve_sum_squares_adj(vector_host(&[2.0]), data, vector_host(&t_eval))
1090            .unwrap();
1091        let grad = Vec::<f64>::from_host_array(sens).unwrap();
1092
1093        assert_close(value, 0.0, ASSERT_TOL, "jit sum_squares objective");
1094        assert_eq!(grad.len(), 1);
1095        assert!(
1096            grad[0].is_finite(),
1097            "jit sum_squares gradient should be finite"
1098        );
1099    }
1100
1101    #[cfg(feature = "diffsl-llvm")]
1102    #[test]
1103    fn bdf_sum_squares_adjoint_matches_finite_difference_gradient_for_logistic_model() {
1104        let logistic_model = r#"
1105            in_i { r = 1, k = 1, y0 = 0.1 }
1106            u { y0 }
1107            F { r * u * (1.0 - u / k) }
1108        "#;
1109        let ode = OdeWrapper::new_jit(
1110            logistic_model,
1111            JitBackendType::Llvm,
1112            ScalarType::F64,
1113            MatrixType::NalgebraDense,
1114            LinearSolverType::Default,
1115            OdeSolverType::Bdf,
1116        )
1117        .unwrap();
1118        ode.set_rtol(1e-8).unwrap();
1119        ode.set_atol(1e-8).unwrap();
1120
1121        let t_eval = [0.0, 0.1, 0.3, 0.7, 1.0];
1122        let data_params = [1.2, 0.9, 0.2];
1123        let fit_params = [0.8, 1.3, 0.12];
1124        let fd_step = 1e-6;
1125
1126        let data_solution = ode
1127            .solve_dense(vector_host(&data_params), vector_host(&t_eval))
1128            .unwrap();
1129        let data_ys = data_solution.get_ys().unwrap();
1130        let data_ys = data_ys.as_array::<f64>().unwrap();
1131        let data_values: Vec<f64> = (0..t_eval.len()).map(|col| data_ys[(0, col)]).collect();
1132
1133        let objective_from_dense = |params: [f64; 3]| -> f64 {
1134            let solution = ode
1135                .solve_dense(vector_host(&params), vector_host(&t_eval))
1136                .unwrap();
1137            let ys = solution.get_ys().unwrap();
1138            let ys = ys.as_array::<f64>().unwrap();
1139            (0..t_eval.len())
1140                .map(|col| {
1141                    let residual = ys[(0, col)] - data_values[col];
1142                    residual * residual
1143                })
1144                .sum()
1145        };
1146
1147        let objective_fd = objective_from_dense(fit_params);
1148        let mut finite_difference_gradient = [0.0; 3];
1149        for i in 0..fit_params.len() {
1150            let mut plus = fit_params;
1151            let mut minus = fit_params;
1152            let step = fd_step * fit_params[i].abs().max(1.0);
1153            plus[i] += step;
1154            minus[i] -= step;
1155            finite_difference_gradient[i] =
1156                (objective_from_dense(plus) - objective_from_dense(minus)) / (2.0 * step);
1157        }
1158
1159        let data = crate::test_support::matrix_host(1, t_eval.len(), &data_values);
1160        let ode_adj = OdeWrapper::new_jit(
1161            logistic_model,
1162            JitBackendType::Llvm,
1163            ScalarType::F64,
1164            MatrixType::NalgebraDense,
1165            LinearSolverType::Default,
1166            OdeSolverType::Bdf,
1167        )
1168        .unwrap();
1169        ode_adj.set_rtol(1e-8).unwrap();
1170        ode_adj.set_atol(1e-8).unwrap();
1171
1172        let (objective_adj, sens) = ode_adj
1173            .solve_sum_squares_adj(vector_host(&fit_params), data, vector_host(&t_eval))
1174            .unwrap();
1175        let adjoint_gradient = Vec::<f64>::from_host_array(sens).unwrap();
1176
1177        assert_eq!(adjoint_gradient.len(), 3);
1178        assert_close(
1179            objective_adj,
1180            objective_fd,
1181            1e-5,
1182            "sum_squares objective from dense finite differences",
1183        );
1184        for i in 0..adjoint_gradient.len() {
1185            assert_close(
1186                adjoint_gradient[i],
1187                finite_difference_gradient[i],
1188                5e-4,
1189                &format!("sum_squares gradient component {i}"),
1190            );
1191        }
1192    }
1193}