Skip to main content

diffsol_c/
ode_solver_type.rs

1// Solver method Python enum. This is used to select the overarching solver
2// stragegy like bdf or esdirk34 in diffsol.
3
4use diffsol::error::{DiffsolError, OdeSolverError};
5use diffsol::ode_equations::OdeEquationsImplicitSens;
6use diffsol::{
7    matrix::MatrixRef, DefaultDenseMatrix, DenseMatrix, DiffSl, LinearSolver, Matrix,
8    OdeSolverProblem, OdeSolverState, VectorHost, VectorRef, VectorView,
9};
10use diffsol::{
11    ode_solver_error, AdjointOdeSolverMethod, CheckpointingPath, CodegenModule, DefaultSolver,
12    OdeEquations, OdeSolverMethod, OdeSolverStopReason, SensitivitiesOdeSolverMethod, Solution,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17use crate::adjoint_checkpoint::{AdjointCheckpoint, AdjointCheckpointData};
18use crate::ode_solver_tag::{BdfTag, Esdirk34Tag, OdeSolverMethodTag, TrBdf2Tag, Tsit45Tag};
19use crate::scalar_type::Scalar;
20use crate::utils::is_sens_available;
21use crate::{
22    linear_solver_type::LinearSolverType,
23    valid_linear_solver::{KluValidator, LuValidator},
24};
25
26/// Enumerates the possible ODE solver methods for diffsol. See the solver descriptions in the diffsol documentation (https://github.com/martinjrobins/diffsol) for more details.
27///
28/// :attr bdf: Backward Differentiation Formula (BDF) method for stiff ODEs and singular mass matrices
29/// :attr esdirk34: Explicit Singly Diagonally Implicit Runge-Kutta (ESDIRK) method for moderately stiff ODEs and singular mass matrices.
30/// :attr tr_bdf2: Trapezoidal Backward Differentiation Formula of order 2 (TR-BDF2) method for moderately stiff ODEs and singular mass matrices.
31/// :attr tsit45: Tsitouras 4/5th order Explicit Runge-Kutta (TSIT45) method for non-stiff ODEs. This is an explicit method, it cannot handle singular mass matrices and does not require a linear solver.
32#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
33#[serde(rename_all = "snake_case")]
34pub enum OdeSolverType {
35    Bdf,
36    Esdirk34,
37    TrBdf2,
38    Tsit45,
39}
40
41fn solve_with_tag<M, CG, LS, Tag>(
42    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
43    mut soln: Solution<M::V>,
44) -> Result<Solution<M::V>, DiffsolError>
45where
46    M: Matrix<T: Scalar>,
47    CG: CodegenModule,
48    M::V: VectorHost + DefaultDenseMatrix,
49    LS: LinearSolver<M>,
50    Tag: OdeSolverMethodTag<M, CG>,
51    for<'b> &'b M::V: VectorRef<M::V>,
52    for<'b> &'b M: MatrixRef<M>,
53{
54    let mut solver = Tag::solver::<LS>(problem)?;
55    while !soln.is_complete() {
56        solver = solver.solve_soln(&mut soln)?;
57        let root_idx = match soln.stop_reason {
58            Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
59            _ => continue,
60        };
61        if problem.eqn.reset().is_none() {
62            soln.truncate(problem, solver.state())?;
63            return Ok(soln);
64        }
65        let mut state = solver.into_state();
66        problem.eqn.set_model_index(root_idx);
67        state.as_mut().apply_reset(&problem.eqn)?;
68        solver = Tag::solver_with_state::<LS>(problem, state)?;
69    }
70    Ok(soln)
71}
72
73fn solve_fwd_sens_with_tag<M, CG, LS, Tag>(
74    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
75    t_eval: &[M::T],
76) -> Result<Solution<M::V>, DiffsolError>
77where
78    M: Matrix<T: Scalar>,
79    CG: CodegenModule,
80    M::V: VectorHost + DefaultDenseMatrix,
81    LS: LinearSolver<M>,
82    Tag: OdeSolverMethodTag<M, CG>,
83    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
84    for<'b> &'b M::V: VectorRef<M::V>,
85    for<'b> &'b M: MatrixRef<M>,
86{
87    let mut soln = Solution::new_dense(t_eval.to_vec())?;
88    let mut solver = Tag::solver_sens::<LS>(problem)?;
89    while !soln.is_complete() {
90        solver = solver.solve_soln_sensitivities(&mut soln)?;
91        let root_idx = match soln.stop_reason {
92            Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
93            _ => continue,
94        };
95        if problem.eqn.reset().is_none() {
96            soln.truncate_sens(problem, solver.state())?;
97            return Ok(soln);
98        }
99        let mut state = solver.into_state();
100        problem.eqn.set_model_index(root_idx);
101        state
102            .as_mut()
103            .apply_reset_with_sens(&problem.eqn, root_idx)?;
104        solver = Tag::solver_sens_with_state::<LS>(problem, state)?;
105    }
106    Ok(soln)
107}
108
109#[allow(clippy::type_complexity)]
110fn solve_with_checkpointing_with_tag<M, CG, LS, Tag>(
111    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
112    mut soln: Solution<M::V>,
113) -> Result<(Solution<M::V>, CheckpointingPath<DiffSl<M, CG>, Tag::State>), DiffsolError>
114where
115    M: Matrix<T: Scalar>,
116    CG: CodegenModule,
117    M::V: VectorHost + DefaultDenseMatrix,
118    LS: LinearSolver<M>,
119    Tag: OdeSolverMethodTag<M, CG>,
120    for<'b> &'b M::V: VectorRef<M::V>,
121    for<'b> &'b M: MatrixRef<M>,
122{
123    let mut solver = Tag::solver::<LS>(problem)?;
124    let mut checkpointing = Vec::new();
125    while !soln.is_complete() {
126        solver = solver.solve_soln_with_checkpointing(&mut soln, &mut checkpointing, None)?;
127        let root_idx = match soln.stop_reason {
128            Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
129            _ => continue,
130        };
131        if problem.eqn.reset().is_none() {
132            soln.truncate(problem, solver.state())?;
133            return Ok((soln, checkpointing));
134        }
135        let mut state = solver.into_state();
136        problem.eqn.set_model_index(root_idx);
137        state.as_mut().apply_reset(&problem.eqn)?;
138        solver = Tag::solver_with_state::<LS>(problem, state)?;
139    }
140    Ok((soln, checkpointing))
141}
142
143fn integral_from_soln<V>(soln: &Solution<V>) -> Result<V, DiffsolError>
144where
145    V: DefaultDenseMatrix,
146{
147    if soln.ts.is_empty() {
148        return Err(ode_solver_error!(
149            Other,
150            "Continuous adjoint solve returned no integral samples"
151        ));
152    }
153    Ok(soln.ys.column(soln.ts.len() - 1).into_owned())
154}
155
156#[allow(clippy::type_complexity)]
157fn solve_adjoint_fwd_with_tag<M, CG, LS, Tag>(
158    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
159    t_eval: &[M::T],
160    params: &[f64],
161    method: OdeSolverType,
162    linear_solver: LinearSolverType,
163) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
164where
165    M: Matrix<T: Scalar> + 'static,
166    CG: CodegenModule + 'static,
167    M::V: VectorHost + DefaultDenseMatrix,
168    LS: LinearSolver<M>,
169    DiffSl<M, CG>: OdeEquations<M = M, T = M::T, V = M::V, C = M::C>,
170    Tag: OdeSolverMethodTag<M, CG> + 'static,
171    for<'b> &'b M::V: VectorRef<M::V>,
172    for<'b> &'b M: MatrixRef<M>,
173{
174    let soln = Solution::new_dense(t_eval.to_vec())?;
175    let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
176    Ok((
177        soln,
178        Box::new(AdjointCheckpointData::<M, CG, Tag>::new(
179            checkpointing,
180            params.to_vec(),
181            method,
182            linear_solver,
183        )),
184    ))
185}
186
187fn solve_continuous_adjoint_with_tag<M, CG, LS, Tag>(
188    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
189    final_time: M::T,
190    method: OdeSolverType,
191) -> Result<(M::V, Vec<M::V>), DiffsolError>
192where
193    M: Matrix<T: Scalar> + DefaultSolver + 'static,
194    CG: CodegenModule + 'static,
195    M::V: VectorHost + DefaultDenseMatrix,
196    LS: LinearSolver<M>,
197    Tag: OdeSolverMethodTag<M, CG> + 'static,
198    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
199        + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
200    for<'b> &'b M::V: VectorRef<M::V>,
201    for<'b> &'b M: MatrixRef<M>,
202{
203    let soln = Solution::new(final_time);
204    let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
205    let integral = integral_from_soln(&soln)?;
206    let sg = match method {
207        OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<M, CG, LS, LS, Tag, BdfTag>(
208            problem,
209            &soln,
210            checkpointing,
211            &[],
212            None,
213        ),
214        OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
215            M,
216            CG,
217            LS,
218            LS,
219            Tag,
220            Esdirk34Tag,
221        >(problem, &soln, checkpointing, &[], None),
222        OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
223            M,
224            CG,
225            LS,
226            LS,
227            Tag,
228            TrBdf2Tag,
229        >(problem, &soln, checkpointing, &[], None),
230        OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
231            M,
232            CG,
233            LS,
234            LS,
235            Tag,
236            Tsit45Tag,
237        >(problem, &soln, checkpointing, &[], None),
238    }?;
239    Ok((integral, sg))
240}
241
242fn solve_adjoint_bkwds_with_fwd_tag<M, CG, FwdLS, BwdLS, Tag>(
243    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
244    checkpoint: &AdjointCheckpointData<M, CG, Tag>,
245    backwards_method: OdeSolverType,
246    dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
247    t_eval: &[M::T],
248) -> Result<Vec<M::V>, DiffsolError>
249where
250    M: Matrix<T: Scalar> + DefaultSolver + 'static,
251    CG: CodegenModule + 'static,
252    M::V: VectorHost + DefaultDenseMatrix,
253    FwdLS: LinearSolver<M>,
254    BwdLS: LinearSolver<M>,
255    Tag: OdeSolverMethodTag<M, CG> + 'static,
256    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
257        + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
258    for<'b> &'b M::V: VectorRef<M::V>,
259    for<'b> &'b M: MatrixRef<M>,
260{
261    // TODO: can we avoid cloning here? Adjoint equations require ownership of the checkpointing segments so maybe not
262    // unless we can change the adjoint equations to take references to the checkpointing segments instead
263    let checkpointing = checkpoint.checkpointing.clone();
264    let soln = Solution::new_dense(t_eval.to_vec())?;
265
266    // we will only consider a single output g for now, so nout_override is 1
267    let dgdu_eval = [dgdu_eval];
268    match backwards_method {
269        OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
270            M,
271            CG,
272            FwdLS,
273            BwdLS,
274            Tag,
275            BdfTag,
276        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
277        OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
278            M,
279            CG,
280            FwdLS,
281            BwdLS,
282            Tag,
283            Esdirk34Tag,
284        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
285        OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
286            M,
287            CG,
288            FwdLS,
289            BwdLS,
290            Tag,
291            TrBdf2Tag,
292        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
293        OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
294            M,
295            CG,
296            FwdLS,
297            BwdLS,
298            Tag,
299            Tsit45Tag,
300        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
301    }
302}
303
304fn solve_adjoint_bkwds_with_fwd_bkwd_tag<'solver, M, CG, FwdLS, BwdLS, FwdTag, BwdTag>(
305    problem: &'solver mut OdeSolverProblem<DiffSl<M, CG>>,
306    soln: &Solution<M::V>,
307    mut checkpointing: CheckpointingPath<DiffSl<M, CG>, FwdTag::State>,
308    dgdu_eval: &[&<M::V as DefaultDenseMatrix>::M],
309    nout_override: Option<usize>,
310) -> Result<Vec<M::V>, DiffsolError>
311where
312    M: Matrix<T: Scalar> + DefaultSolver + 'solver,
313    CG: CodegenModule + 'solver,
314    M::V: VectorHost + DefaultDenseMatrix,
315    FwdLS: LinearSolver<M>,
316    BwdLS: LinearSolver<M>,
317    FwdTag: OdeSolverMethodTag<M, CG>,
318    BwdTag: OdeSolverMethodTag<M, CG>,
319    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
320        + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
321    for<'b> &'b M::V: VectorRef<M::V>,
322    for<'b> &'b M: MatrixRef<M>,
323{
324    let checkpointing_len = checkpointing.len();
325    if checkpointing_len == 0 {
326        return Err(ode_solver_error!(
327            Other,
328            "Adjoint backward pass requires at least one checkpointing segment"
329        ));
330    }
331
332    let t_eval = if dgdu_eval.is_empty() {
333        &[]
334    } else {
335        soln.ts.as_slice()
336    };
337
338    let current_checkpointing = checkpointing
339        .pop()
340        .ok_or_else(|| ode_solver_error!(Other, "Adjoint backward pass returned no state"))?;
341    let model_index = checkpointing
342        .last()
343        .map(|segment| {
344            segment
345                .terminal_reset_root_idx()
346                .expect("Missing reset root index")
347        })
348        .unwrap_or(0);
349    problem.eqn_mut().set_model_index(model_index);
350    let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
351    let mut adjoint = BwdTag::solver_adjoint::<BwdLS, _>(
352        &*problem,
353        vec![current_checkpointing],
354        Some(fwd_solver),
355        nout_override,
356    )?;
357    loop {
358        let (mut state, adjoint_checkpointing) =
359            adjoint.solve_adjoint_backwards_pass(t_eval, dgdu_eval)?;
360        let Some(previous_checkpointing) = checkpointing.pop() else {
361            return Ok(state.into_common().sg);
362        };
363        let model_index = checkpointing
364            .last()
365            .map(|segment| {
366                segment
367                    .terminal_reset_root_idx()
368                    .expect("Missing reset root index")
369            })
370            .unwrap_or(0);
371        let fwd_state_minus = previous_checkpointing.last_checkpoint();
372        let fwd_state_plus = adjoint_checkpointing
373            .first()
374            .ok_or_else(|| {
375                ode_solver_error!(Other, "Adjoint backward pass returned no checkpointing")
376            })?
377            .first_checkpoint();
378        state.as_mut().apply_reset_with_adjoint(
379            problem.eqn(),
380            previous_checkpointing.terminal_reset_root_idx().unwrap(),
381            fwd_state_minus.as_ref(),
382            fwd_state_plus.as_ref(),
383            problem.integrate_out,
384        )?;
385        problem.eqn_mut().set_model_index(model_index);
386        let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
387        // TODO: remove clone here
388        let adjoint_eqn = problem.adjoint_equations(
389            vec![previous_checkpointing],
390            Some(fwd_solver),
391            nout_override,
392        );
393
394        adjoint = BwdTag::solver_adjoint_from_state::<BwdLS, _>(&*problem, state, adjoint_eqn)?;
395    }
396}
397
398impl OdeSolverType {
399    pub(crate) fn solve<M, CG, LS>(
400        &self,
401        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
402        final_time: M::T,
403    ) -> Result<Solution<M::V>, DiffsolError>
404    where
405        M: Matrix<T: Scalar>,
406        CG: CodegenModule,
407        M::V: VectorHost + DefaultDenseMatrix,
408        LS: LinearSolver<M>,
409        for<'b> &'b M::V: VectorRef<M::V>,
410        for<'b> &'b M: MatrixRef<M>,
411    {
412        match self {
413            OdeSolverType::Bdf => {
414                solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new(final_time))
415            }
416            OdeSolverType::Esdirk34 => {
417                solve_with_tag::<M, CG, LS, Esdirk34Tag>(problem, Solution::new(final_time))
418            }
419            OdeSolverType::TrBdf2 => {
420                solve_with_tag::<M, CG, LS, TrBdf2Tag>(problem, Solution::new(final_time))
421            }
422            OdeSolverType::Tsit45 => {
423                solve_with_tag::<M, CG, LS, Tsit45Tag>(problem, Solution::new(final_time))
424            }
425        }
426    }
427
428    pub(crate) fn solve_dense<M, CG, LS>(
429        &self,
430        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
431        t_eval: &[M::T],
432    ) -> Result<Solution<M::V>, DiffsolError>
433    where
434        M: Matrix<T: Scalar>,
435        CG: CodegenModule,
436        M::V: VectorHost + DefaultDenseMatrix,
437        LS: LinearSolver<M>,
438        for<'b> &'b M::V: VectorRef<M::V>,
439        for<'b> &'b M: MatrixRef<M>,
440    {
441        match self {
442            OdeSolverType::Bdf => {
443                solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new_dense(t_eval.to_vec())?)
444            }
445            OdeSolverType::Esdirk34 => solve_with_tag::<M, CG, LS, Esdirk34Tag>(
446                problem,
447                Solution::new_dense(t_eval.to_vec())?,
448            ),
449            OdeSolverType::TrBdf2 => solve_with_tag::<M, CG, LS, TrBdf2Tag>(
450                problem,
451                Solution::new_dense(t_eval.to_vec())?,
452            ),
453            OdeSolverType::Tsit45 => solve_with_tag::<M, CG, LS, Tsit45Tag>(
454                problem,
455                Solution::new_dense(t_eval.to_vec())?,
456            ),
457        }
458    }
459
460    fn check_sens_available() -> Result<(), DiffsolError> {
461        if !is_sens_available() {
462            return Err(DiffsolError::Other(
463                "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
464            ));
465        }
466        Ok(())
467    }
468
469    #[allow(clippy::type_complexity)]
470    pub(crate) fn solve_fwd_sens<M, CG, LS>(
471        &self,
472        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
473        t_eval: &[M::T],
474    ) -> Result<Solution<M::V>, DiffsolError>
475    where
476        M: Matrix<T: Scalar> + DefaultSolver,
477        CG: CodegenModule,
478        M::V: VectorHost + DefaultDenseMatrix,
479        LS: LinearSolver<M>,
480        for<'b> &'b M::V: VectorRef<M::V>,
481        for<'b> &'b M: MatrixRef<M>,
482    {
483        Self::check_sens_available()?;
484        match self {
485            OdeSolverType::Bdf => solve_fwd_sens_with_tag::<M, CG, LS, BdfTag>(problem, t_eval),
486            OdeSolverType::Esdirk34 => {
487                solve_fwd_sens_with_tag::<M, CG, LS, Esdirk34Tag>(problem, t_eval)
488            }
489            OdeSolverType::TrBdf2 => {
490                solve_fwd_sens_with_tag::<M, CG, LS, TrBdf2Tag>(problem, t_eval)
491            }
492            OdeSolverType::Tsit45 => {
493                solve_fwd_sens_with_tag::<M, CG, LS, Tsit45Tag>(problem, t_eval)
494            }
495        }
496    }
497
498    #[allow(clippy::type_complexity)]
499    pub(crate) fn solve_adjoint_fwd<M, CG, LS>(
500        &self,
501        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
502        t_eval: &[M::T],
503        params: &[f64],
504        linear_solver: LinearSolverType,
505    ) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
506    where
507        M: Matrix<T: Scalar> + DefaultSolver + 'static,
508        CG: CodegenModule + 'static,
509        M::V: VectorHost + DefaultDenseMatrix,
510        LS: LinearSolver<M>,
511        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
512        for<'b> &'b M::V: VectorRef<M::V>,
513        for<'b> &'b M: MatrixRef<M>,
514    {
515        Self::check_sens_available()?;
516        match self {
517            OdeSolverType::Bdf => solve_adjoint_fwd_with_tag::<M, CG, LS, BdfTag>(
518                problem,
519                t_eval,
520                params,
521                *self,
522                linear_solver,
523            ),
524            OdeSolverType::Esdirk34 => solve_adjoint_fwd_with_tag::<M, CG, LS, Esdirk34Tag>(
525                problem,
526                t_eval,
527                params,
528                *self,
529                linear_solver,
530            ),
531            OdeSolverType::TrBdf2 => solve_adjoint_fwd_with_tag::<M, CG, LS, TrBdf2Tag>(
532                problem,
533                t_eval,
534                params,
535                *self,
536                linear_solver,
537            ),
538            OdeSolverType::Tsit45 => solve_adjoint_fwd_with_tag::<M, CG, LS, Tsit45Tag>(
539                problem,
540                t_eval,
541                params,
542                *self,
543                linear_solver,
544            ),
545        }
546    }
547
548    pub(crate) fn solve_continuous_adjoint<M, CG, LS>(
549        &self,
550        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
551        final_time: M::T,
552    ) -> Result<(M::V, Vec<M::V>), DiffsolError>
553    where
554        M: Matrix<T: Scalar> + DefaultSolver + 'static,
555        CG: CodegenModule + 'static,
556        M::V: VectorHost + DefaultDenseMatrix,
557        LS: LinearSolver<M>,
558        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
559            + diffsol::OdeEquationsImplicitAdjoint,
560        for<'b> &'b M::V: VectorRef<M::V>,
561        for<'b> &'b M: MatrixRef<M>,
562    {
563        Self::check_sens_available()?;
564        match self {
565            OdeSolverType::Bdf => {
566                solve_continuous_adjoint_with_tag::<M, CG, LS, BdfTag>(problem, final_time, *self)
567            }
568            OdeSolverType::Esdirk34 => solve_continuous_adjoint_with_tag::<M, CG, LS, Esdirk34Tag>(
569                problem, final_time, *self,
570            ),
571            OdeSolverType::TrBdf2 => solve_continuous_adjoint_with_tag::<M, CG, LS, TrBdf2Tag>(
572                problem, final_time, *self,
573            ),
574            OdeSolverType::Tsit45 => solve_continuous_adjoint_with_tag::<M, CG, LS, Tsit45Tag>(
575                problem, final_time, *self,
576            ),
577        }
578    }
579
580    pub(crate) fn solve_adjoint_bkwd<M, CG, BwdLS>(
581        &self,
582        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
583        checkpoint: &dyn AdjointCheckpoint,
584        dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
585        t_eval: &[M::T],
586    ) -> Result<Vec<M::V>, DiffsolError>
587    where
588        M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M> + 'static,
589        CG: CodegenModule + 'static,
590        M::V: VectorHost + DefaultDenseMatrix,
591        BwdLS: LinearSolver<M>,
592        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
593            + diffsol::OdeEquationsImplicitAdjoint,
594        for<'b> &'b M::V: VectorRef<M::V>,
595        for<'b> &'b M: MatrixRef<M>,
596    {
597        Self::check_sens_available()?;
598        match checkpoint.method() {
599            OdeSolverType::Bdf => {
600                let data = checkpoint.data::<M, CG, BdfTag>()?;
601                match data.linear_solver() {
602                    LinearSolverType::Default => {
603                        solve_adjoint_bkwds_with_fwd_tag::<
604                            M,
605                            CG,
606                            <M as DefaultSolver>::LS,
607                            BwdLS,
608                            BdfTag,
609                        >(problem, data, *self, dgdu_eval, t_eval)
610                    }
611                    LinearSolverType::Lu => {
612                        solve_adjoint_bkwds_with_fwd_tag::<
613                            M,
614                            CG,
615                            <M as LuValidator<M>>::LS,
616                            BwdLS,
617                            BdfTag,
618                        >(problem, data, *self, dgdu_eval, t_eval)
619                    }
620                    LinearSolverType::Klu => {
621                        solve_adjoint_bkwds_with_fwd_tag::<
622                            M,
623                            CG,
624                            <M as KluValidator<M>>::LS,
625                            BwdLS,
626                            BdfTag,
627                        >(problem, data, *self, dgdu_eval, t_eval)
628                    }
629                }
630            }
631            OdeSolverType::Esdirk34 => {
632                let data = checkpoint.data::<M, CG, Esdirk34Tag>()?;
633                match data.linear_solver() {
634                    LinearSolverType::Default => {
635                        solve_adjoint_bkwds_with_fwd_tag::<
636                            M,
637                            CG,
638                            <M as DefaultSolver>::LS,
639                            BwdLS,
640                            Esdirk34Tag,
641                        >(problem, data, *self, dgdu_eval, t_eval)
642                    }
643                    LinearSolverType::Lu => {
644                        solve_adjoint_bkwds_with_fwd_tag::<
645                            M,
646                            CG,
647                            <M as LuValidator<M>>::LS,
648                            BwdLS,
649                            Esdirk34Tag,
650                        >(problem, data, *self, dgdu_eval, t_eval)
651                    }
652                    LinearSolverType::Klu => {
653                        solve_adjoint_bkwds_with_fwd_tag::<
654                            M,
655                            CG,
656                            <M as KluValidator<M>>::LS,
657                            BwdLS,
658                            Esdirk34Tag,
659                        >(problem, data, *self, dgdu_eval, t_eval)
660                    }
661                }
662            }
663            OdeSolverType::TrBdf2 => {
664                let data = checkpoint.data::<M, CG, TrBdf2Tag>()?;
665                match data.linear_solver() {
666                    LinearSolverType::Default => {
667                        solve_adjoint_bkwds_with_fwd_tag::<
668                            M,
669                            CG,
670                            <M as DefaultSolver>::LS,
671                            BwdLS,
672                            TrBdf2Tag,
673                        >(problem, data, *self, dgdu_eval, t_eval)
674                    }
675                    LinearSolverType::Lu => {
676                        solve_adjoint_bkwds_with_fwd_tag::<
677                            M,
678                            CG,
679                            <M as LuValidator<M>>::LS,
680                            BwdLS,
681                            TrBdf2Tag,
682                        >(problem, data, *self, dgdu_eval, t_eval)
683                    }
684                    LinearSolverType::Klu => {
685                        solve_adjoint_bkwds_with_fwd_tag::<
686                            M,
687                            CG,
688                            <M as KluValidator<M>>::LS,
689                            BwdLS,
690                            TrBdf2Tag,
691                        >(problem, data, *self, dgdu_eval, t_eval)
692                    }
693                }
694            }
695            OdeSolverType::Tsit45 => {
696                let data = checkpoint.data::<M, CG, Tsit45Tag>()?;
697                match data.linear_solver() {
698                    LinearSolverType::Default => {
699                        solve_adjoint_bkwds_with_fwd_tag::<
700                            M,
701                            CG,
702                            <M as DefaultSolver>::LS,
703                            BwdLS,
704                            Tsit45Tag,
705                        >(problem, data, *self, dgdu_eval, t_eval)
706                    }
707                    LinearSolverType::Lu => {
708                        solve_adjoint_bkwds_with_fwd_tag::<
709                            M,
710                            CG,
711                            <M as LuValidator<M>>::LS,
712                            BwdLS,
713                            Tsit45Tag,
714                        >(problem, data, *self, dgdu_eval, t_eval)
715                    }
716                    LinearSolverType::Klu => {
717                        solve_adjoint_bkwds_with_fwd_tag::<
718                            M,
719                            CG,
720                            <M as KluValidator<M>>::LS,
721                            BwdLS,
722                            Tsit45Tag,
723                        >(problem, data, *self, dgdu_eval, t_eval)
724                    }
725                }
726            }
727        }
728    }
729}
730
731#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
732mod tests {
733    use diffsol::{
734        CodegenModuleCompile, CodegenModuleJit, DefaultDenseMatrix, DefaultSolver, DenseMatrix,
735        Matrix, MatrixCommon, OdeBuilder, OdeSolverProblem, Op, Vector,
736    };
737
738    #[cfg(feature = "diffsl-llvm")]
739    use crate::linear_solver_type::LinearSolverType;
740    use crate::test_support::{
741        assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
742        logistic_state, LOGISTIC_X0,
743    };
744    #[cfg(feature = "diffsl-llvm")]
745    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
746    use crate::valid_linear_solver::LuValidator;
747
748    use super::OdeSolverType;
749
750    type M = diffsol::NalgebraMat<f64>;
751
752    fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
753    where
754        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
755    {
756        OdeBuilder::<M>::new()
757            .p([2.0])
758            .rtol(1e-6)
759            .atol([1e-6])
760            .build_from_diffsl::<CG>(code)
761            .unwrap()
762    }
763
764    fn assert_dense_solution_matches_expected(
765        soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
766        t_eval: &[f64],
767        expected: impl Fn(f64) -> f64,
768    ) {
769        assert_eq!(soln.ts, t_eval);
770        for (i, &t) in t_eval.iter().enumerate() {
771            assert_close(
772                soln.ys.get_index(0, i),
773                expected(t),
774                5e-4,
775                &format!("solution[{i}]"),
776            );
777        }
778    }
779
780    fn test_all_solver_variants<CG>()
781    where
782        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
783    {
784        let t_eval = [0.25, 0.5, 1.0];
785        for method in [
786            OdeSolverType::Bdf,
787            OdeSolverType::Esdirk34,
788            OdeSolverType::TrBdf2,
789            OdeSolverType::Tsit45,
790        ] {
791            let mut problem = build_problem::<CG>(logistic_diffsl_code());
792            let soln = method
793                .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
794                .unwrap();
795            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
796            assert_close(
797                soln.ys.get_index(0, soln.ts.len() - 1),
798                logistic_state(LOGISTIC_X0, 2.0, 1.0),
799                5e-4,
800                "solve final value",
801            );
802
803            let mut problem = build_problem::<CG>(logistic_diffsl_code());
804            let soln = method
805                .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
806                .unwrap();
807            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
808                logistic_state(LOGISTIC_X0, 2.0, t)
809            });
810        }
811    }
812
813    fn test_all_hybrid_solver_variants<CG>()
814    where
815        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
816    {
817        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
818        for method in [
819            OdeSolverType::Bdf,
820            OdeSolverType::Esdirk34,
821            OdeSolverType::TrBdf2,
822            OdeSolverType::Tsit45,
823        ] {
824            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
825            let soln = method
826                .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
827                .unwrap();
828            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
829            assert_close(
830                soln.ys.get_index(0, soln.ts.len() - 1),
831                hybrid_logistic_state(2.0, 2.0),
832                5e-4,
833                "hybrid final value",
834            );
835
836            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
837            let soln = method
838                .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
839                .unwrap();
840            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
841                hybrid_logistic_state(2.0, t)
842            });
843        }
844    }
845
846    fn test_all_solver_variants_with_lu<CG>()
847    where
848        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
849    {
850        let t_eval = [0.25, 0.5, 1.0];
851        for method in [
852            OdeSolverType::Bdf,
853            OdeSolverType::Esdirk34,
854            OdeSolverType::TrBdf2,
855            OdeSolverType::Tsit45,
856        ] {
857            let mut problem = build_problem::<CG>(logistic_diffsl_code());
858            let soln = method
859                .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
860                .unwrap();
861            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
862
863            let mut problem = build_problem::<CG>(logistic_diffsl_code());
864            let soln = method
865                .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
866                .unwrap();
867            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
868                logistic_state(LOGISTIC_X0, 2.0, t)
869            });
870        }
871    }
872
873    fn test_all_hybrid_solver_variants_with_lu<CG>()
874    where
875        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
876    {
877        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
878        for method in [
879            OdeSolverType::Bdf,
880            OdeSolverType::Esdirk34,
881            OdeSolverType::TrBdf2,
882            OdeSolverType::Tsit45,
883        ] {
884            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
885            let soln = method
886                .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
887                .unwrap();
888            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
889
890            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
891            let soln = method
892                .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
893                .unwrap();
894            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
895                hybrid_logistic_state(2.0, t)
896            });
897        }
898    }
899
900    fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
901    where
902        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
903    {
904        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
905
906        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
907        let soln = method
908            .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
909            .unwrap();
910        assert_close(
911            *soln.ts.last().unwrap(),
912            2.0,
913            5e-4,
914            "direct hybrid restart final time",
915        );
916        assert_close(
917            soln.ys.get_index(0, soln.ts.len() - 1),
918            hybrid_logistic_state(2.0, 2.0),
919            5e-4,
920            "direct hybrid restart final value",
921        );
922
923        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
924        let soln = method
925            .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
926            .unwrap();
927        assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
928    }
929
930    #[cfg(feature = "diffsl-llvm")]
931    fn test_all_sensitivity_solver_variants() {
932        let t_eval = [0.25, 0.5, 1.0];
933        for method in [
934            OdeSolverType::Bdf,
935            OdeSolverType::Esdirk34,
936            OdeSolverType::TrBdf2,
937            OdeSolverType::Tsit45,
938        ] {
939            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
940            let soln = method
941                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
942                    &mut problem,
943                    &t_eval,
944                )
945                .unwrap();
946            for (i, &t) in t_eval.iter().enumerate() {
947                assert_close(
948                    soln.y_sens[0].get_index(0, i),
949                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
950                    5e-4,
951                    &format!("fwd_sens[{i}]"),
952                );
953            }
954
955            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
956            let soln = method
957                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
958                    &mut problem,
959                    &t_eval,
960                )
961                .unwrap();
962            for (i, &t) in t_eval.iter().enumerate() {
963                assert_close(
964                    soln.y_sens[0].get_index(0, i),
965                    hybrid_logistic_state_dr(2.0, t),
966                    5e-4,
967                    &format!("hybrid_fwd_sens[{i}]"),
968                );
969            }
970        }
971    }
972
973    #[cfg(feature = "diffsl-llvm")]
974    fn test_lu_sensitivity_and_adjoint_solver_variants() {
975        let t_eval = [0.25, 0.5, 1.0];
976        for method in [
977            OdeSolverType::Bdf,
978            OdeSolverType::Esdirk34,
979            OdeSolverType::TrBdf2,
980            OdeSolverType::Tsit45,
981        ] {
982            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
983            let soln = method
984                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
985                    &mut problem,
986                    &t_eval,
987                )
988                .unwrap();
989            for (i, &t) in t_eval.iter().enumerate() {
990                assert_close(
991                    soln.y_sens[0].get_index(0, i),
992                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
993                    5e-4,
994                    &format!("lu fwd_sens[{i}]"),
995                );
996            }
997        }
998
999        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1000        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
1001        let (soln, checkpoint) = OdeSolverType::Bdf
1002            .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1003                &mut problem,
1004                &adjoint_t_eval,
1005                &[2.0],
1006                LinearSolverType::Lu,
1007            )
1008            .unwrap();
1009        let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1010            problem.eqn.nout(),
1011            soln.ts.len(),
1012            problem.context().to_owned(),
1013        );
1014        let gradient = OdeSolverType::TrBdf2
1015            .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1016                &mut problem,
1017                checkpoint.as_ref(),
1018                &dgdu,
1019                &soln.ts,
1020            )
1021            .unwrap();
1022        assert_eq!(gradient.len(), 1);
1023        assert!(gradient[0].get_index(0).is_finite());
1024    }
1025
1026    #[cfg(feature = "diffsl-llvm")]
1027    fn test_direct_hybrid_sensitivity_restart_paths() {
1028        let t_eval = [0.5, 1.0, 2.5, 3.0, 4.5];
1029        for method in [
1030            OdeSolverType::Esdirk34,
1031            OdeSolverType::TrBdf2,
1032            OdeSolverType::Tsit45,
1033        ] {
1034            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
1035            let soln = method
1036                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1037                    &mut problem,
1038                    &t_eval,
1039                )
1040                .unwrap();
1041            for (i, &t) in t_eval.iter().enumerate() {
1042                assert_close(
1043                    soln.ys.get_index(0, i),
1044                    hybrid_logistic_state(2.0, t),
1045                    5e-4,
1046                    &format!("direct hybrid value[{i}]"),
1047                );
1048                assert_close(
1049                    soln.y_sens[0].get_index(0, i),
1050                    hybrid_logistic_state_dr(2.0, t),
1051                    5e-4,
1052                    &format!("direct hybrid fwd sens[{i}]"),
1053                );
1054            }
1055        }
1056    }
1057
1058    #[cfg(feature = "diffsl-llvm")]
1059    fn test_adjoint_backwards_methods_and_klu_branch() {
1060        for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
1061            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1062            let t_eval = [0.0, 0.25, 0.5, 1.0];
1063            let (soln, checkpoint) = OdeSolverType::Bdf
1064                .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1065                    &mut problem,
1066                    &t_eval,
1067                    &[2.0],
1068                    LinearSolverType::Default,
1069                )
1070                .unwrap();
1071            let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1072                problem.eqn.nout(),
1073                soln.ts.len(),
1074                problem.context().to_owned(),
1075            );
1076            let gradient = backwards_method
1077                .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as crate::valid_linear_solver::KluValidator<M>>::LS>(
1078                    &mut problem,
1079                    checkpoint.as_ref(),
1080                    &dgdu,
1081                    &soln.ts,
1082                )
1083                .unwrap();
1084            assert_eq!(gradient.len(), 1);
1085            assert!(gradient[0].get_index(0).is_finite());
1086        }
1087
1088        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1089        let t_eval = [0.0, 0.25, 0.5, 1.0];
1090        let (soln, checkpoint) = OdeSolverType::Tsit45
1091            .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1092                &mut problem,
1093                &t_eval,
1094                &[2.0],
1095                LinearSolverType::Default,
1096            )
1097            .unwrap();
1098        let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1099            problem.eqn.nout(),
1100            soln.ts.len(),
1101            problem.context().to_owned(),
1102        );
1103        let gradient = OdeSolverType::Bdf
1104            .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1105                &mut problem,
1106                checkpoint.as_ref(),
1107                &dgdu,
1108                &soln.ts,
1109            )
1110            .unwrap();
1111        assert_eq!(gradient.len(), 1);
1112        assert!(gradient[0].get_index(0).is_finite());
1113    }
1114
1115    #[cfg(feature = "diffsl-llvm")]
1116    fn test_all_adjoint_solver_variants() {
1117        let t_eval = [0.0, 0.25, 0.5, 1.0];
1118        for method in [
1119            OdeSolverType::Bdf,
1120            OdeSolverType::Esdirk34,
1121            OdeSolverType::TrBdf2,
1122        ] {
1123            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1124            let (soln, checkpoint) = method
1125                .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1126                    &mut problem,
1127                    &t_eval,
1128                    &[2.0],
1129                    LinearSolverType::Default,
1130                )
1131                .unwrap();
1132            let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1133                problem.eqn.nout(),
1134                soln.ts.len(),
1135                problem.context().to_owned(),
1136            );
1137            let gradient = OdeSolverType::Bdf
1138                .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1139                    &mut problem,
1140                    checkpoint.as_ref(),
1141                    &dgdu,
1142                    &soln.ts,
1143                )
1144                .unwrap();
1145            assert_eq!(gradient.len(), 1);
1146            assert!(gradient[0].get_index(0).is_finite());
1147        }
1148    }
1149
1150    #[cfg(feature = "diffsl-cranelift")]
1151    #[test]
1152    fn runtime_dispatch_solves_all_variants_for_cranelift() {
1153        test_all_solver_variants::<diffsol::CraneliftJitModule>();
1154        test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1155    }
1156
1157    #[cfg(feature = "diffsl-cranelift")]
1158    #[test]
1159    fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1160        test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1161        test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1162        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1163            OdeSolverType::Esdirk34,
1164        );
1165        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1166            OdeSolverType::TrBdf2,
1167        );
1168        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1169            OdeSolverType::Tsit45,
1170        );
1171    }
1172
1173    #[cfg(feature = "diffsl-llvm")]
1174    #[test]
1175    fn runtime_dispatch_solves_all_variants_for_llvm() {
1176        test_all_solver_variants::<diffsol::LlvmModule>();
1177        test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1178    }
1179
1180    #[cfg(feature = "diffsl-llvm")]
1181    #[test]
1182    fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1183        test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1184        test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1185        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1186            OdeSolverType::Esdirk34,
1187        );
1188        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1189        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1190    }
1191
1192    #[cfg(feature = "diffsl-llvm")]
1193    #[test]
1194    fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1195        test_all_sensitivity_solver_variants();
1196        test_lu_sensitivity_and_adjoint_solver_variants();
1197        test_direct_hybrid_sensitivity_restart_paths();
1198    }
1199
1200    #[cfg(feature = "diffsl-llvm")]
1201    #[test]
1202    fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1203        test_all_adjoint_solver_variants();
1204        test_adjoint_backwards_methods_and_klu_branch();
1205    }
1206}