Skip to main content

diffsol_c/
ode_solver_type.rs

1// Solver method Python enum. This is used to select the overarching solver
2// stragegy like bdf or esdirk34 in diffsol.
3
4use diffsol::error::{DiffsolError, OdeSolverError};
5use diffsol::ode_equations::{OdeEquationsImplicit, OdeEquationsImplicitSens};
6use diffsol::{
7    matrix::MatrixRef, DefaultDenseMatrix, DenseMatrix, DiffSl, LinearSolver, Matrix,
8    OdeSolverProblem, OdeSolverState, VectorHost, VectorRef, VectorView,
9};
10use diffsol::{
11    ode_solver_error, AdjointOdeSolverMethod, CheckpointingPath, CodegenModule, DefaultSolver,
12    OdeEquations, OdeSolverMethod, OdeSolverStopReason, SensitivitiesOdeSolverMethod, Solution,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17use crate::adjoint_checkpoint::{AdjointCheckpoint, AdjointCheckpointData};
18use crate::ode_solver_tag::{BdfTag, Esdirk34Tag, OdeSolverMethodTag, TrBdf2Tag, Tsit45Tag};
19use crate::scalar_type::Scalar;
20use crate::utils::is_sens_available;
21use crate::{
22    linear_solver_type::LinearSolverType,
23    valid_linear_solver::{KluValidator, LuValidator},
24};
25
26/// Enumerates the possible ODE solver methods for diffsol. See the solver descriptions in the diffsol documentation (<https://github.com/martinjrobins/diffsol>) for more details.
27///
28/// :attr bdf: Backward Differentiation Formula (BDF) method for stiff ODEs and singular mass matrices
29/// :attr esdirk34: Explicit Singly Diagonally Implicit Runge-Kutta (ESDIRK) method for moderately stiff ODEs and singular mass matrices.
30/// :attr tr_bdf2: Trapezoidal Backward Differentiation Formula of order 2 (TR-BDF2) method for moderately stiff ODEs and singular mass matrices.
31/// :attr tsit45: Tsitouras 4/5th order Explicit Runge-Kutta (TSIT45) method for non-stiff ODEs. This is an explicit method, it cannot handle singular mass matrices and does not require a linear solver.
32#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
33#[serde(rename_all = "snake_case")]
34pub enum OdeSolverType {
35    Bdf,
36    Esdirk34,
37    TrBdf2,
38    Tsit45,
39}
40
41fn solve_with_tag<M, CG, LS, Tag>(
42    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
43    mut soln: Solution<M::V>,
44) -> Result<Solution<M::V>, DiffsolError>
45where
46    M: Matrix<T: Scalar> + DefaultSolver,
47    CG: CodegenModule,
48    M::V: VectorHost + DefaultDenseMatrix,
49    LS: LinearSolver<M>,
50    Tag: OdeSolverMethodTag<M, CG>,
51    DiffSl<M, CG>: OdeEquationsImplicit<M = M, T = M::T, V = M::V, C = M::C>,
52    for<'b> &'b M::V: VectorRef<M::V>,
53    for<'b> &'b M: MatrixRef<M>,
54{
55    let mut solver = Tag::solver::<LS>(problem)?;
56    while !soln.is_complete() {
57        solver = solver.solve_soln(&mut soln)?;
58        let root_idx = match soln.stop_reason {
59            Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
60            _ => continue,
61        };
62        if problem.eqn.reset().is_none() {
63            soln.truncate(problem, solver.state())?;
64            return Ok(soln);
65        }
66        let mut state = solver.into_state();
67        problem.eqn.set_model_index(root_idx);
68        state.as_mut().apply_reset_with_mass::<M::LS, _>(problem)?;
69        solver = Tag::solver_with_state::<LS>(problem, state)?;
70    }
71    Ok(soln)
72}
73
74fn solve_fwd_sens_with_tag<M, CG, LS, Tag>(
75    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
76    t_eval: &[M::T],
77) -> Result<Solution<M::V>, DiffsolError>
78where
79    M: Matrix<T: Scalar> + DefaultSolver,
80    CG: CodegenModule,
81    M::V: VectorHost + DefaultDenseMatrix,
82    LS: LinearSolver<M>,
83    Tag: OdeSolverMethodTag<M, CG>,
84    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
85    for<'b> &'b M::V: VectorRef<M::V>,
86    for<'b> &'b M: MatrixRef<M>,
87{
88    let mut soln = Solution::new_dense(t_eval.to_vec())?;
89    let mut solver = Tag::solver_sens::<LS>(problem)?;
90    while !soln.is_complete() {
91        solver = solver.solve_soln_sensitivities(&mut soln)?;
92        let root_idx = match soln.stop_reason {
93            Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
94            _ => continue,
95        };
96        if problem.eqn.reset().is_none() {
97            soln.truncate_sens(problem, solver.state())?;
98            return Ok(soln);
99        }
100        let mut state = solver.into_state();
101        problem.eqn.set_model_index(root_idx);
102        state
103            .as_mut()
104            .apply_reset_with_sens_mass::<M::LS, _>(problem, root_idx)?;
105        solver = Tag::solver_sens_with_state::<LS>(problem, state)?;
106    }
107    Ok(soln)
108}
109
110#[allow(clippy::type_complexity)]
111fn solve_with_checkpointing_with_tag<M, CG, LS, Tag>(
112    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
113    mut soln: Solution<M::V>,
114) -> Result<(Solution<M::V>, CheckpointingPath<DiffSl<M, CG>, Tag::State>), DiffsolError>
115where
116    M: Matrix<T: Scalar> + DefaultSolver,
117    CG: CodegenModule,
118    M::V: VectorHost + DefaultDenseMatrix,
119    LS: LinearSolver<M>,
120    Tag: OdeSolverMethodTag<M, CG>,
121    DiffSl<M, CG>: OdeEquationsImplicit<M = M, T = M::T, V = M::V, C = M::C>,
122    for<'b> &'b M::V: VectorRef<M::V>,
123    for<'b> &'b M: MatrixRef<M>,
124{
125    let mut solver = Tag::solver::<LS>(problem)?;
126    let mut checkpointing = Vec::new();
127    while !soln.is_complete() {
128        solver = solver.solve_soln_with_checkpointing(&mut soln, &mut checkpointing, None)?;
129        let root_idx = match soln.stop_reason {
130            Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
131            _ => continue,
132        };
133        if problem.eqn.reset().is_none() {
134            soln.truncate(problem, solver.state())?;
135            return Ok((soln, checkpointing));
136        }
137        let mut state = solver.into_state();
138        problem.eqn.set_model_index(root_idx);
139        state.as_mut().apply_reset_with_mass::<M::LS, _>(problem)?;
140        solver = Tag::solver_with_state::<LS>(problem, state)?;
141    }
142    Ok((soln, checkpointing))
143}
144
145fn integral_from_soln<V>(soln: &Solution<V>) -> Result<V, DiffsolError>
146where
147    V: DefaultDenseMatrix,
148{
149    if soln.ts.is_empty() {
150        return Err(ode_solver_error!(
151            Other,
152            "Continuous adjoint solve returned no integral samples"
153        ));
154    }
155    Ok(soln.ys.column(soln.ts.len() - 1).into_owned())
156}
157
158#[allow(clippy::type_complexity)]
159fn solve_adjoint_fwd_with_tag<M, CG, LS, Tag>(
160    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
161    t_eval: &[M::T],
162    params: &[f64],
163    method: OdeSolverType,
164    linear_solver: LinearSolverType,
165) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
166where
167    M: Matrix<T: Scalar> + DefaultSolver + 'static,
168    CG: CodegenModule + 'static,
169    M::V: VectorHost + DefaultDenseMatrix,
170    LS: LinearSolver<M>,
171    DiffSl<M, CG>: OdeEquationsImplicit<M = M, T = M::T, V = M::V, C = M::C>,
172    Tag: OdeSolverMethodTag<M, CG> + 'static,
173    for<'b> &'b M::V: VectorRef<M::V>,
174    for<'b> &'b M: MatrixRef<M>,
175{
176    let soln = Solution::new_dense(t_eval.to_vec())?;
177    let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
178    Ok((
179        soln,
180        Box::new(AdjointCheckpointData::<M, CG, Tag>::new(
181            checkpointing,
182            params.to_vec(),
183            method,
184            linear_solver,
185        )),
186    ))
187}
188
189fn solve_continuous_adjoint_with_tag<M, CG, LS, Tag>(
190    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
191    final_time: M::T,
192    method: OdeSolverType,
193) -> Result<(M::V, Vec<M::V>), DiffsolError>
194where
195    M: Matrix<T: Scalar> + DefaultSolver + 'static,
196    CG: CodegenModule + 'static,
197    M::V: VectorHost + DefaultDenseMatrix,
198    LS: LinearSolver<M>,
199    Tag: OdeSolverMethodTag<M, CG> + 'static,
200    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
201        + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
202    for<'b> &'b M::V: VectorRef<M::V>,
203    for<'b> &'b M: MatrixRef<M>,
204{
205    let soln = Solution::new(final_time);
206    let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
207    let integral = integral_from_soln(&soln)?;
208    let sg = match method {
209        OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<M, CG, LS, LS, Tag, BdfTag>(
210            problem,
211            &soln,
212            checkpointing,
213            &[],
214            None,
215        ),
216        OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
217            M,
218            CG,
219            LS,
220            LS,
221            Tag,
222            Esdirk34Tag,
223        >(problem, &soln, checkpointing, &[], None),
224        OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
225            M,
226            CG,
227            LS,
228            LS,
229            Tag,
230            TrBdf2Tag,
231        >(problem, &soln, checkpointing, &[], None),
232        OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
233            M,
234            CG,
235            LS,
236            LS,
237            Tag,
238            Tsit45Tag,
239        >(problem, &soln, checkpointing, &[], None),
240    }?;
241    Ok((integral, sg))
242}
243
244fn solve_adjoint_bkwds_with_fwd_tag<M, CG, FwdLS, BwdLS, Tag>(
245    problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
246    checkpoint: &AdjointCheckpointData<M, CG, Tag>,
247    backwards_method: OdeSolverType,
248    dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
249    t_eval: &[M::T],
250) -> Result<Vec<M::V>, DiffsolError>
251where
252    M: Matrix<T: Scalar> + DefaultSolver + 'static,
253    CG: CodegenModule + 'static,
254    M::V: VectorHost + DefaultDenseMatrix,
255    FwdLS: LinearSolver<M>,
256    BwdLS: LinearSolver<M>,
257    Tag: OdeSolverMethodTag<M, CG> + 'static,
258    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
259        + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
260    for<'b> &'b M::V: VectorRef<M::V>,
261    for<'b> &'b M: MatrixRef<M>,
262{
263    // TODO: can we avoid cloning here? Adjoint equations require ownership of the checkpointing segments so maybe not
264    // unless we can change the adjoint equations to take references to the checkpointing segments instead
265    let checkpointing = checkpoint.checkpointing.clone();
266    let soln = Solution::new_dense(t_eval.to_vec())?;
267
268    // we will only consider a single output g for now, so nout_override is 1
269    let dgdu_eval = [dgdu_eval];
270    match backwards_method {
271        OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
272            M,
273            CG,
274            FwdLS,
275            BwdLS,
276            Tag,
277            BdfTag,
278        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
279        OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
280            M,
281            CG,
282            FwdLS,
283            BwdLS,
284            Tag,
285            Esdirk34Tag,
286        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
287        OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
288            M,
289            CG,
290            FwdLS,
291            BwdLS,
292            Tag,
293            TrBdf2Tag,
294        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
295        OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
296            M,
297            CG,
298            FwdLS,
299            BwdLS,
300            Tag,
301            Tsit45Tag,
302        >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
303    }
304}
305
306fn solve_adjoint_bkwds_with_fwd_bkwd_tag<'solver, M, CG, FwdLS, BwdLS, FwdTag, BwdTag>(
307    problem: &'solver mut OdeSolverProblem<DiffSl<M, CG>>,
308    soln: &Solution<M::V>,
309    mut checkpointing: CheckpointingPath<DiffSl<M, CG>, FwdTag::State>,
310    dgdu_eval: &[&<M::V as DefaultDenseMatrix>::M],
311    nout_override: Option<usize>,
312) -> Result<Vec<M::V>, DiffsolError>
313where
314    M: Matrix<T: Scalar> + DefaultSolver + 'solver,
315    CG: CodegenModule + 'solver,
316    M::V: VectorHost + DefaultDenseMatrix,
317    FwdLS: LinearSolver<M>,
318    BwdLS: LinearSolver<M>,
319    FwdTag: OdeSolverMethodTag<M, CG>,
320    BwdTag: OdeSolverMethodTag<M, CG>,
321    DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
322        + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
323    for<'b> &'b M::V: VectorRef<M::V>,
324    for<'b> &'b M: MatrixRef<M>,
325{
326    let checkpointing_len = checkpointing.len();
327    if checkpointing_len == 0 {
328        return Err(ode_solver_error!(
329            Other,
330            "Adjoint backward pass requires at least one checkpointing segment"
331        ));
332    }
333
334    let t_eval = if dgdu_eval.is_empty() {
335        &[]
336    } else {
337        soln.ts.as_slice()
338    };
339
340    let current_checkpointing = checkpointing
341        .pop()
342        .ok_or_else(|| ode_solver_error!(Other, "Adjoint backward pass returned no state"))?;
343    let model_index = checkpointing
344        .last()
345        .map(|segment| {
346            segment
347                .terminal_reset_root_idx()
348                .expect("Missing reset root index")
349        })
350        .unwrap_or(0);
351    problem.eqn_mut().set_model_index(model_index);
352    let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
353    let mut adjoint = BwdTag::solver_adjoint::<BwdLS, _>(
354        &*problem,
355        vec![current_checkpointing],
356        Some(fwd_solver),
357        nout_override,
358    )?;
359    loop {
360        let (mut state, adjoint_checkpointing) =
361            adjoint.solve_adjoint_backwards_pass(t_eval, dgdu_eval)?;
362        let Some(previous_checkpointing) = checkpointing.pop() else {
363            return Ok(state.into_common().sg);
364        };
365        let model_index = checkpointing
366            .last()
367            .map(|segment| {
368                segment
369                    .terminal_reset_root_idx()
370                    .expect("Missing reset root index")
371            })
372            .unwrap_or(0);
373        let fwd_state_minus = previous_checkpointing.last_checkpoint();
374        let fwd_state_plus = adjoint_checkpointing
375            .first()
376            .ok_or_else(|| {
377                ode_solver_error!(Other, "Adjoint backward pass returned no checkpointing")
378            })?
379            .first_checkpoint();
380        state.as_mut().apply_reset_with_adjoint(
381            problem.eqn(),
382            previous_checkpointing.terminal_reset_root_idx().unwrap(),
383            fwd_state_minus.as_ref(),
384            fwd_state_plus.as_ref(),
385            problem.integrate_out,
386        )?;
387        problem.eqn_mut().set_model_index(model_index);
388        let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
389        // TODO: remove clone here
390        let adjoint_eqn = problem.adjoint_equations(
391            vec![previous_checkpointing],
392            Some(fwd_solver),
393            nout_override,
394        );
395
396        adjoint = BwdTag::solver_adjoint_from_state::<BwdLS, _>(&*problem, state, adjoint_eqn)?;
397    }
398}
399
400impl OdeSolverType {
401    pub(crate) fn solve<M, CG, LS>(
402        &self,
403        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
404        final_time: M::T,
405    ) -> Result<Solution<M::V>, DiffsolError>
406    where
407        M: Matrix<T: Scalar> + DefaultSolver,
408        CG: CodegenModule,
409        M::V: VectorHost + DefaultDenseMatrix,
410        LS: LinearSolver<M>,
411        for<'b> &'b M::V: VectorRef<M::V>,
412        for<'b> &'b M: MatrixRef<M>,
413        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
414    {
415        match self {
416            OdeSolverType::Bdf => {
417                solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new(final_time))
418            }
419            OdeSolverType::Esdirk34 => {
420                solve_with_tag::<M, CG, LS, Esdirk34Tag>(problem, Solution::new(final_time))
421            }
422            OdeSolverType::TrBdf2 => {
423                solve_with_tag::<M, CG, LS, TrBdf2Tag>(problem, Solution::new(final_time))
424            }
425            OdeSolverType::Tsit45 => {
426                solve_with_tag::<M, CG, LS, Tsit45Tag>(problem, Solution::new(final_time))
427            }
428        }
429    }
430
431    pub(crate) fn solve_dense<M, CG, LS>(
432        &self,
433        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
434        t_eval: &[M::T],
435    ) -> Result<Solution<M::V>, DiffsolError>
436    where
437        M: Matrix<T: Scalar> + DefaultSolver,
438        CG: CodegenModule,
439        M::V: VectorHost + DefaultDenseMatrix,
440        LS: LinearSolver<M>,
441        for<'b> &'b M::V: VectorRef<M::V>,
442        for<'b> &'b M: MatrixRef<M>,
443        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
444    {
445        match self {
446            OdeSolverType::Bdf => {
447                solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new_dense(t_eval.to_vec())?)
448            }
449            OdeSolverType::Esdirk34 => solve_with_tag::<M, CG, LS, Esdirk34Tag>(
450                problem,
451                Solution::new_dense(t_eval.to_vec())?,
452            ),
453            OdeSolverType::TrBdf2 => solve_with_tag::<M, CG, LS, TrBdf2Tag>(
454                problem,
455                Solution::new_dense(t_eval.to_vec())?,
456            ),
457            OdeSolverType::Tsit45 => solve_with_tag::<M, CG, LS, Tsit45Tag>(
458                problem,
459                Solution::new_dense(t_eval.to_vec())?,
460            ),
461        }
462    }
463
464    fn check_sens_available() -> Result<(), DiffsolError> {
465        if !is_sens_available() {
466            return Err(DiffsolError::Other(
467                "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
468            ));
469        }
470        Ok(())
471    }
472
473    #[allow(clippy::type_complexity)]
474    pub(crate) fn solve_fwd_sens<M, CG, LS>(
475        &self,
476        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
477        t_eval: &[M::T],
478    ) -> Result<Solution<M::V>, DiffsolError>
479    where
480        M: Matrix<T: Scalar> + DefaultSolver,
481        CG: CodegenModule,
482        M::V: VectorHost + DefaultDenseMatrix,
483        LS: LinearSolver<M>,
484        for<'b> &'b M::V: VectorRef<M::V>,
485        for<'b> &'b M: MatrixRef<M>,
486    {
487        Self::check_sens_available()?;
488        match self {
489            OdeSolverType::Bdf => solve_fwd_sens_with_tag::<M, CG, LS, BdfTag>(problem, t_eval),
490            OdeSolverType::Esdirk34 => {
491                solve_fwd_sens_with_tag::<M, CG, LS, Esdirk34Tag>(problem, t_eval)
492            }
493            OdeSolverType::TrBdf2 => {
494                solve_fwd_sens_with_tag::<M, CG, LS, TrBdf2Tag>(problem, t_eval)
495            }
496            OdeSolverType::Tsit45 => {
497                solve_fwd_sens_with_tag::<M, CG, LS, Tsit45Tag>(problem, t_eval)
498            }
499        }
500    }
501
502    #[allow(clippy::type_complexity)]
503    pub(crate) fn solve_adjoint_fwd<M, CG, LS>(
504        &self,
505        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
506        t_eval: &[M::T],
507        params: &[f64],
508        linear_solver: LinearSolverType,
509    ) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
510    where
511        M: Matrix<T: Scalar> + DefaultSolver + 'static,
512        CG: CodegenModule + 'static,
513        M::V: VectorHost + DefaultDenseMatrix,
514        LS: LinearSolver<M>,
515        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
516        for<'b> &'b M::V: VectorRef<M::V>,
517        for<'b> &'b M: MatrixRef<M>,
518    {
519        Self::check_sens_available()?;
520        match self {
521            OdeSolverType::Bdf => solve_adjoint_fwd_with_tag::<M, CG, LS, BdfTag>(
522                problem,
523                t_eval,
524                params,
525                *self,
526                linear_solver,
527            ),
528            OdeSolverType::Esdirk34 => solve_adjoint_fwd_with_tag::<M, CG, LS, Esdirk34Tag>(
529                problem,
530                t_eval,
531                params,
532                *self,
533                linear_solver,
534            ),
535            OdeSolverType::TrBdf2 => solve_adjoint_fwd_with_tag::<M, CG, LS, TrBdf2Tag>(
536                problem,
537                t_eval,
538                params,
539                *self,
540                linear_solver,
541            ),
542            OdeSolverType::Tsit45 => solve_adjoint_fwd_with_tag::<M, CG, LS, Tsit45Tag>(
543                problem,
544                t_eval,
545                params,
546                *self,
547                linear_solver,
548            ),
549        }
550    }
551
552    pub(crate) fn solve_continuous_adjoint<M, CG, LS>(
553        &self,
554        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
555        final_time: M::T,
556    ) -> Result<(M::V, Vec<M::V>), DiffsolError>
557    where
558        M: Matrix<T: Scalar> + DefaultSolver + 'static,
559        CG: CodegenModule + 'static,
560        M::V: VectorHost + DefaultDenseMatrix,
561        LS: LinearSolver<M>,
562        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
563            + diffsol::OdeEquationsImplicitAdjoint,
564        for<'b> &'b M::V: VectorRef<M::V>,
565        for<'b> &'b M: MatrixRef<M>,
566    {
567        Self::check_sens_available()?;
568        match self {
569            OdeSolverType::Bdf => {
570                solve_continuous_adjoint_with_tag::<M, CG, LS, BdfTag>(problem, final_time, *self)
571            }
572            OdeSolverType::Esdirk34 => solve_continuous_adjoint_with_tag::<M, CG, LS, Esdirk34Tag>(
573                problem, final_time, *self,
574            ),
575            OdeSolverType::TrBdf2 => solve_continuous_adjoint_with_tag::<M, CG, LS, TrBdf2Tag>(
576                problem, final_time, *self,
577            ),
578            OdeSolverType::Tsit45 => solve_continuous_adjoint_with_tag::<M, CG, LS, Tsit45Tag>(
579                problem, final_time, *self,
580            ),
581        }
582    }
583
584    pub(crate) fn solve_adjoint_bkwd<M, CG, BwdLS>(
585        &self,
586        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
587        checkpoint: &dyn AdjointCheckpoint,
588        dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
589        t_eval: &[M::T],
590    ) -> Result<Vec<M::V>, DiffsolError>
591    where
592        M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M> + 'static,
593        CG: CodegenModule + 'static,
594        M::V: VectorHost + DefaultDenseMatrix,
595        BwdLS: LinearSolver<M>,
596        DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
597            + diffsol::OdeEquationsImplicitAdjoint,
598        for<'b> &'b M::V: VectorRef<M::V>,
599        for<'b> &'b M: MatrixRef<M>,
600    {
601        Self::check_sens_available()?;
602        match checkpoint.method() {
603            OdeSolverType::Bdf => {
604                let data = checkpoint.data::<M, CG, BdfTag>()?;
605                match data.linear_solver() {
606                    LinearSolverType::Default => {
607                        solve_adjoint_bkwds_with_fwd_tag::<
608                            M,
609                            CG,
610                            <M as DefaultSolver>::LS,
611                            BwdLS,
612                            BdfTag,
613                        >(problem, data, *self, dgdu_eval, t_eval)
614                    }
615                    LinearSolverType::Lu => {
616                        solve_adjoint_bkwds_with_fwd_tag::<
617                            M,
618                            CG,
619                            <M as LuValidator<M>>::LS,
620                            BwdLS,
621                            BdfTag,
622                        >(problem, data, *self, dgdu_eval, t_eval)
623                    }
624                    LinearSolverType::Klu => {
625                        solve_adjoint_bkwds_with_fwd_tag::<
626                            M,
627                            CG,
628                            <M as KluValidator<M>>::LS,
629                            BwdLS,
630                            BdfTag,
631                        >(problem, data, *self, dgdu_eval, t_eval)
632                    }
633                }
634            }
635            OdeSolverType::Esdirk34 => {
636                let data = checkpoint.data::<M, CG, Esdirk34Tag>()?;
637                match data.linear_solver() {
638                    LinearSolverType::Default => {
639                        solve_adjoint_bkwds_with_fwd_tag::<
640                            M,
641                            CG,
642                            <M as DefaultSolver>::LS,
643                            BwdLS,
644                            Esdirk34Tag,
645                        >(problem, data, *self, dgdu_eval, t_eval)
646                    }
647                    LinearSolverType::Lu => {
648                        solve_adjoint_bkwds_with_fwd_tag::<
649                            M,
650                            CG,
651                            <M as LuValidator<M>>::LS,
652                            BwdLS,
653                            Esdirk34Tag,
654                        >(problem, data, *self, dgdu_eval, t_eval)
655                    }
656                    LinearSolverType::Klu => {
657                        solve_adjoint_bkwds_with_fwd_tag::<
658                            M,
659                            CG,
660                            <M as KluValidator<M>>::LS,
661                            BwdLS,
662                            Esdirk34Tag,
663                        >(problem, data, *self, dgdu_eval, t_eval)
664                    }
665                }
666            }
667            OdeSolverType::TrBdf2 => {
668                let data = checkpoint.data::<M, CG, TrBdf2Tag>()?;
669                match data.linear_solver() {
670                    LinearSolverType::Default => {
671                        solve_adjoint_bkwds_with_fwd_tag::<
672                            M,
673                            CG,
674                            <M as DefaultSolver>::LS,
675                            BwdLS,
676                            TrBdf2Tag,
677                        >(problem, data, *self, dgdu_eval, t_eval)
678                    }
679                    LinearSolverType::Lu => {
680                        solve_adjoint_bkwds_with_fwd_tag::<
681                            M,
682                            CG,
683                            <M as LuValidator<M>>::LS,
684                            BwdLS,
685                            TrBdf2Tag,
686                        >(problem, data, *self, dgdu_eval, t_eval)
687                    }
688                    LinearSolverType::Klu => {
689                        solve_adjoint_bkwds_with_fwd_tag::<
690                            M,
691                            CG,
692                            <M as KluValidator<M>>::LS,
693                            BwdLS,
694                            TrBdf2Tag,
695                        >(problem, data, *self, dgdu_eval, t_eval)
696                    }
697                }
698            }
699            OdeSolverType::Tsit45 => {
700                let data = checkpoint.data::<M, CG, Tsit45Tag>()?;
701                match data.linear_solver() {
702                    LinearSolverType::Default => {
703                        solve_adjoint_bkwds_with_fwd_tag::<
704                            M,
705                            CG,
706                            <M as DefaultSolver>::LS,
707                            BwdLS,
708                            Tsit45Tag,
709                        >(problem, data, *self, dgdu_eval, t_eval)
710                    }
711                    LinearSolverType::Lu => {
712                        solve_adjoint_bkwds_with_fwd_tag::<
713                            M,
714                            CG,
715                            <M as LuValidator<M>>::LS,
716                            BwdLS,
717                            Tsit45Tag,
718                        >(problem, data, *self, dgdu_eval, t_eval)
719                    }
720                    LinearSolverType::Klu => {
721                        solve_adjoint_bkwds_with_fwd_tag::<
722                            M,
723                            CG,
724                            <M as KluValidator<M>>::LS,
725                            BwdLS,
726                            Tsit45Tag,
727                        >(problem, data, *self, dgdu_eval, t_eval)
728                    }
729                }
730            }
731        }
732    }
733}
734
735#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
736mod tests {
737    use diffsol::{
738        CodegenModuleCompile, CodegenModuleJit, DefaultDenseMatrix, DefaultSolver, DenseMatrix,
739        Matrix, MatrixCommon, OdeBuilder, OdeSolverProblem, Op, Vector,
740    };
741
742    #[cfg(feature = "diffsl-llvm")]
743    use crate::linear_solver_type::LinearSolverType;
744    use crate::test_support::{
745        assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
746        logistic_state, LOGISTIC_X0,
747    };
748    #[cfg(feature = "diffsl-llvm")]
749    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
750    use crate::valid_linear_solver::LuValidator;
751
752    use super::OdeSolverType;
753
754    type M = diffsol::NalgebraMat<f64>;
755
756    fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
757    where
758        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
759    {
760        OdeBuilder::<M>::new()
761            .p([2.0])
762            .rtol(1e-6)
763            .atol([1e-6])
764            .build_from_diffsl::<CG>(code)
765            .unwrap()
766    }
767
768    fn assert_dense_solution_matches_expected(
769        soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
770        t_eval: &[f64],
771        expected: impl Fn(f64) -> f64,
772    ) {
773        assert_eq!(soln.ts, t_eval);
774        for (i, &t) in t_eval.iter().enumerate() {
775            assert_close(
776                soln.ys.get_index(0, i),
777                expected(t),
778                5e-4,
779                &format!("solution[{i}]"),
780            );
781        }
782    }
783
784    fn test_all_solver_variants<CG>()
785    where
786        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
787    {
788        let t_eval = [0.25, 0.5, 1.0];
789        for method in [
790            OdeSolverType::Bdf,
791            OdeSolverType::Esdirk34,
792            OdeSolverType::TrBdf2,
793            OdeSolverType::Tsit45,
794        ] {
795            let mut problem = build_problem::<CG>(logistic_diffsl_code());
796            let soln = method
797                .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
798                .unwrap();
799            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
800            assert_close(
801                soln.ys.get_index(0, soln.ts.len() - 1),
802                logistic_state(LOGISTIC_X0, 2.0, 1.0),
803                5e-4,
804                "solve final value",
805            );
806
807            let mut problem = build_problem::<CG>(logistic_diffsl_code());
808            let soln = method
809                .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
810                .unwrap();
811            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
812                logistic_state(LOGISTIC_X0, 2.0, t)
813            });
814        }
815    }
816
817    fn test_all_hybrid_solver_variants<CG>()
818    where
819        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
820    {
821        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
822        for method in [
823            OdeSolverType::Bdf,
824            OdeSolverType::Esdirk34,
825            OdeSolverType::TrBdf2,
826            OdeSolverType::Tsit45,
827        ] {
828            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
829            let soln = method
830                .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
831                .unwrap();
832            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
833            assert_close(
834                soln.ys.get_index(0, soln.ts.len() - 1),
835                hybrid_logistic_state(2.0, 2.0),
836                5e-4,
837                "hybrid final value",
838            );
839
840            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
841            let soln = method
842                .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
843                .unwrap();
844            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
845                hybrid_logistic_state(2.0, t)
846            });
847        }
848    }
849
850    fn test_all_solver_variants_with_lu<CG>()
851    where
852        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
853    {
854        let t_eval = [0.25, 0.5, 1.0];
855        for method in [
856            OdeSolverType::Bdf,
857            OdeSolverType::Esdirk34,
858            OdeSolverType::TrBdf2,
859            OdeSolverType::Tsit45,
860        ] {
861            let mut problem = build_problem::<CG>(logistic_diffsl_code());
862            let soln = method
863                .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
864                .unwrap();
865            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
866
867            let mut problem = build_problem::<CG>(logistic_diffsl_code());
868            let soln = method
869                .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
870                .unwrap();
871            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
872                logistic_state(LOGISTIC_X0, 2.0, t)
873            });
874        }
875    }
876
877    fn test_all_hybrid_solver_variants_with_lu<CG>()
878    where
879        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
880    {
881        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
882        for method in [
883            OdeSolverType::Bdf,
884            OdeSolverType::Esdirk34,
885            OdeSolverType::TrBdf2,
886            OdeSolverType::Tsit45,
887        ] {
888            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
889            let soln = method
890                .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
891                .unwrap();
892            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
893
894            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
895            let soln = method
896                .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
897                .unwrap();
898            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
899                hybrid_logistic_state(2.0, t)
900            });
901        }
902    }
903
904    fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
905    where
906        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
907    {
908        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
909
910        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
911        let soln = method
912            .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
913            .unwrap();
914        assert_close(
915            *soln.ts.last().unwrap(),
916            2.0,
917            5e-4,
918            "direct hybrid restart final time",
919        );
920        assert_close(
921            soln.ys.get_index(0, soln.ts.len() - 1),
922            hybrid_logistic_state(2.0, 2.0),
923            5e-4,
924            "direct hybrid restart final value",
925        );
926
927        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
928        let soln = method
929            .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
930            .unwrap();
931        assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
932    }
933
934    #[cfg(feature = "diffsl-llvm")]
935    fn test_all_sensitivity_solver_variants() {
936        let t_eval = [0.25, 0.5, 1.0];
937        for method in [
938            OdeSolverType::Bdf,
939            OdeSolverType::Esdirk34,
940            OdeSolverType::TrBdf2,
941            OdeSolverType::Tsit45,
942        ] {
943            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
944            let soln = method
945                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
946                    &mut problem,
947                    &t_eval,
948                )
949                .unwrap();
950            for (i, &t) in t_eval.iter().enumerate() {
951                assert_close(
952                    soln.y_sens[0].get_index(0, i),
953                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
954                    5e-4,
955                    &format!("fwd_sens[{i}]"),
956                );
957            }
958
959            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
960            let soln = method
961                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
962                    &mut problem,
963                    &t_eval,
964                )
965                .unwrap();
966            for (i, &t) in t_eval.iter().enumerate() {
967                assert_close(
968                    soln.y_sens[0].get_index(0, i),
969                    hybrid_logistic_state_dr(2.0, t),
970                    5e-4,
971                    &format!("hybrid_fwd_sens[{i}]"),
972                );
973            }
974        }
975    }
976
977    #[cfg(feature = "diffsl-llvm")]
978    fn test_lu_sensitivity_and_adjoint_solver_variants() {
979        let t_eval = [0.25, 0.5, 1.0];
980        for method in [
981            OdeSolverType::Bdf,
982            OdeSolverType::Esdirk34,
983            OdeSolverType::TrBdf2,
984            OdeSolverType::Tsit45,
985        ] {
986            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
987            let soln = method
988                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
989                    &mut problem,
990                    &t_eval,
991                )
992                .unwrap();
993            for (i, &t) in t_eval.iter().enumerate() {
994                assert_close(
995                    soln.y_sens[0].get_index(0, i),
996                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
997                    5e-4,
998                    &format!("lu fwd_sens[{i}]"),
999                );
1000            }
1001        }
1002
1003        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1004        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
1005        let (soln, checkpoint) = OdeSolverType::Bdf
1006            .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1007                &mut problem,
1008                &adjoint_t_eval,
1009                &[2.0],
1010                LinearSolverType::Lu,
1011            )
1012            .unwrap();
1013        let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1014            problem.eqn.nout(),
1015            soln.ts.len(),
1016            problem.context().to_owned(),
1017        );
1018        let gradient = OdeSolverType::TrBdf2
1019            .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1020                &mut problem,
1021                checkpoint.as_ref(),
1022                &dgdu,
1023                &soln.ts,
1024            )
1025            .unwrap();
1026        assert_eq!(gradient.len(), 1);
1027        assert!(gradient[0].get_index(0).is_finite());
1028    }
1029
1030    #[cfg(feature = "diffsl-llvm")]
1031    fn test_direct_hybrid_sensitivity_restart_paths() {
1032        let t_eval = [0.5, 1.0, 2.5, 3.0, 4.5];
1033        for method in [
1034            OdeSolverType::Esdirk34,
1035            OdeSolverType::TrBdf2,
1036            OdeSolverType::Tsit45,
1037        ] {
1038            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
1039            let soln = method
1040                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1041                    &mut problem,
1042                    &t_eval,
1043                )
1044                .unwrap();
1045            for (i, &t) in t_eval.iter().enumerate() {
1046                assert_close(
1047                    soln.ys.get_index(0, i),
1048                    hybrid_logistic_state(2.0, t),
1049                    5e-4,
1050                    &format!("direct hybrid value[{i}]"),
1051                );
1052                assert_close(
1053                    soln.y_sens[0].get_index(0, i),
1054                    hybrid_logistic_state_dr(2.0, t),
1055                    5e-4,
1056                    &format!("direct hybrid fwd sens[{i}]"),
1057                );
1058            }
1059        }
1060    }
1061
1062    #[cfg(feature = "diffsl-llvm")]
1063    fn test_adjoint_backwards_methods_and_klu_branch() {
1064        for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
1065            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1066            let t_eval = [0.0, 0.25, 0.5, 1.0];
1067            let (soln, checkpoint) = OdeSolverType::Bdf
1068                .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1069                    &mut problem,
1070                    &t_eval,
1071                    &[2.0],
1072                    LinearSolverType::Default,
1073                )
1074                .unwrap();
1075            let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1076                problem.eqn.nout(),
1077                soln.ts.len(),
1078                problem.context().to_owned(),
1079            );
1080            let gradient = backwards_method
1081                .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as crate::valid_linear_solver::KluValidator<M>>::LS>(
1082                    &mut problem,
1083                    checkpoint.as_ref(),
1084                    &dgdu,
1085                    &soln.ts,
1086                )
1087                .unwrap();
1088            assert_eq!(gradient.len(), 1);
1089            assert!(gradient[0].get_index(0).is_finite());
1090        }
1091
1092        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1093        let t_eval = [0.0, 0.25, 0.5, 1.0];
1094        let (soln, checkpoint) = OdeSolverType::Tsit45
1095            .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1096                &mut problem,
1097                &t_eval,
1098                &[2.0],
1099                LinearSolverType::Default,
1100            )
1101            .unwrap();
1102        let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1103            problem.eqn.nout(),
1104            soln.ts.len(),
1105            problem.context().to_owned(),
1106        );
1107        let gradient = OdeSolverType::Bdf
1108            .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1109                &mut problem,
1110                checkpoint.as_ref(),
1111                &dgdu,
1112                &soln.ts,
1113            )
1114            .unwrap();
1115        assert_eq!(gradient.len(), 1);
1116        assert!(gradient[0].get_index(0).is_finite());
1117    }
1118
1119    #[cfg(feature = "diffsl-llvm")]
1120    fn test_all_adjoint_solver_variants() {
1121        let t_eval = [0.0, 0.25, 0.5, 1.0];
1122        for method in [
1123            OdeSolverType::Bdf,
1124            OdeSolverType::Esdirk34,
1125            OdeSolverType::TrBdf2,
1126        ] {
1127            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1128            let (soln, checkpoint) = method
1129                .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1130                    &mut problem,
1131                    &t_eval,
1132                    &[2.0],
1133                    LinearSolverType::Default,
1134                )
1135                .unwrap();
1136            let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1137                problem.eqn.nout(),
1138                soln.ts.len(),
1139                problem.context().to_owned(),
1140            );
1141            let gradient = OdeSolverType::Bdf
1142                .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1143                    &mut problem,
1144                    checkpoint.as_ref(),
1145                    &dgdu,
1146                    &soln.ts,
1147                )
1148                .unwrap();
1149            assert_eq!(gradient.len(), 1);
1150            assert!(gradient[0].get_index(0).is_finite());
1151        }
1152    }
1153
1154    #[cfg(feature = "diffsl-cranelift")]
1155    #[test]
1156    fn runtime_dispatch_solves_all_variants_for_cranelift() {
1157        test_all_solver_variants::<diffsol::CraneliftJitModule>();
1158        test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1159    }
1160
1161    #[cfg(feature = "diffsl-cranelift")]
1162    #[test]
1163    fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1164        test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1165        test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1166        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1167            OdeSolverType::Esdirk34,
1168        );
1169        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1170            OdeSolverType::TrBdf2,
1171        );
1172        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1173            OdeSolverType::Tsit45,
1174        );
1175    }
1176
1177    #[cfg(feature = "diffsl-llvm")]
1178    #[test]
1179    fn runtime_dispatch_solves_all_variants_for_llvm() {
1180        test_all_solver_variants::<diffsol::LlvmModule>();
1181        test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1182    }
1183
1184    #[cfg(feature = "diffsl-llvm")]
1185    #[test]
1186    fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1187        test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1188        test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1189        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1190            OdeSolverType::Esdirk34,
1191        );
1192        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1193        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1194    }
1195
1196    #[cfg(feature = "diffsl-llvm")]
1197    #[test]
1198    fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1199        test_all_sensitivity_solver_variants();
1200        test_lu_sensitivity_and_adjoint_solver_variants();
1201        test_direct_hybrid_sensitivity_restart_paths();
1202    }
1203
1204    #[cfg(feature = "diffsl-llvm")]
1205    #[test]
1206    fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1207        test_all_adjoint_solver_variants();
1208        test_adjoint_backwards_methods_and_klu_branch();
1209    }
1210}