1use diffsol::error::{DiffsolError, OdeSolverError};
5use diffsol::ode_equations::{OdeEquationsImplicit, OdeEquationsImplicitSens};
6use diffsol::{
7 matrix::MatrixRef, DefaultDenseMatrix, DenseMatrix, DiffSl, LinearSolver, Matrix,
8 OdeSolverProblem, OdeSolverState, VectorHost, VectorRef, VectorView,
9};
10use diffsol::{
11 ode_solver_error, AdjointOdeSolverMethod, CheckpointingPath, CodegenModule, DefaultSolver,
12 OdeEquations, OdeSolverMethod, OdeSolverStopReason, SensitivitiesOdeSolverMethod, Solution,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17use crate::adjoint_checkpoint::{AdjointCheckpoint, AdjointCheckpointData};
18use crate::ode_solver_tag::{BdfTag, Esdirk34Tag, OdeSolverMethodTag, TrBdf2Tag, Tsit45Tag};
19use crate::scalar_type::Scalar;
20use crate::utils::is_sens_available;
21use crate::{
22 linear_solver_type::LinearSolverType,
23 valid_linear_solver::{KluValidator, LuValidator},
24};
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
33#[serde(rename_all = "snake_case")]
34pub enum OdeSolverType {
35 Bdf,
36 Esdirk34,
37 TrBdf2,
38 Tsit45,
39}
40
41fn solve_with_tag<M, CG, LS, Tag>(
42 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
43 mut soln: Solution<M::V>,
44) -> Result<Solution<M::V>, DiffsolError>
45where
46 M: Matrix<T: Scalar> + DefaultSolver,
47 CG: CodegenModule,
48 M::V: VectorHost + DefaultDenseMatrix,
49 LS: LinearSolver<M>,
50 Tag: OdeSolverMethodTag<M, CG>,
51 DiffSl<M, CG>: OdeEquationsImplicit<M = M, T = M::T, V = M::V, C = M::C>,
52 for<'b> &'b M::V: VectorRef<M::V>,
53 for<'b> &'b M: MatrixRef<M>,
54{
55 let mut solver = Tag::solver::<LS>(problem)?;
56 while !soln.is_complete() {
57 solver = solver.solve_soln(&mut soln)?;
58 let root_idx = match soln.stop_reason {
59 Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
60 _ => continue,
61 };
62 if problem.eqn.reset().is_none() {
63 soln.truncate(problem, solver.state())?;
64 return Ok(soln);
65 }
66 let mut state = solver.into_state();
67 problem.eqn.set_model_index(root_idx);
68 state.as_mut().apply_reset_with_mass::<M::LS, _>(problem)?;
69 solver = Tag::solver_with_state::<LS>(problem, state)?;
70 }
71 Ok(soln)
72}
73
74fn solve_fwd_sens_with_tag<M, CG, LS, Tag>(
75 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
76 t_eval: &[M::T],
77) -> Result<Solution<M::V>, DiffsolError>
78where
79 M: Matrix<T: Scalar> + DefaultSolver,
80 CG: CodegenModule,
81 M::V: VectorHost + DefaultDenseMatrix,
82 LS: LinearSolver<M>,
83 Tag: OdeSolverMethodTag<M, CG>,
84 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
85 for<'b> &'b M::V: VectorRef<M::V>,
86 for<'b> &'b M: MatrixRef<M>,
87{
88 let mut soln = Solution::new_dense(t_eval.to_vec())?;
89 let mut solver = Tag::solver_sens::<LS>(problem)?;
90 while !soln.is_complete() {
91 solver = solver.solve_soln_sensitivities(&mut soln)?;
92 let root_idx = match soln.stop_reason {
93 Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
94 _ => continue,
95 };
96 if problem.eqn.reset().is_none() {
97 soln.truncate_sens(problem, solver.state())?;
98 return Ok(soln);
99 }
100 let mut state = solver.into_state();
101 problem.eqn.set_model_index(root_idx);
102 state
103 .as_mut()
104 .apply_reset_with_sens_mass::<M::LS, _>(problem, root_idx)?;
105 solver = Tag::solver_sens_with_state::<LS>(problem, state)?;
106 }
107 Ok(soln)
108}
109
110#[allow(clippy::type_complexity)]
111fn solve_with_checkpointing_with_tag<M, CG, LS, Tag>(
112 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
113 mut soln: Solution<M::V>,
114) -> Result<(Solution<M::V>, CheckpointingPath<DiffSl<M, CG>, Tag::State>), DiffsolError>
115where
116 M: Matrix<T: Scalar> + DefaultSolver,
117 CG: CodegenModule,
118 M::V: VectorHost + DefaultDenseMatrix,
119 LS: LinearSolver<M>,
120 Tag: OdeSolverMethodTag<M, CG>,
121 DiffSl<M, CG>: OdeEquationsImplicit<M = M, T = M::T, V = M::V, C = M::C>,
122 for<'b> &'b M::V: VectorRef<M::V>,
123 for<'b> &'b M: MatrixRef<M>,
124{
125 let mut solver = Tag::solver::<LS>(problem)?;
126 let mut checkpointing = Vec::new();
127 while !soln.is_complete() {
128 solver = solver.solve_soln_with_checkpointing(&mut soln, &mut checkpointing, None)?;
129 let root_idx = match soln.stop_reason {
130 Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
131 _ => continue,
132 };
133 if problem.eqn.reset().is_none() {
134 soln.truncate(problem, solver.state())?;
135 return Ok((soln, checkpointing));
136 }
137 let mut state = solver.into_state();
138 problem.eqn.set_model_index(root_idx);
139 state.as_mut().apply_reset_with_mass::<M::LS, _>(problem)?;
140 solver = Tag::solver_with_state::<LS>(problem, state)?;
141 }
142 Ok((soln, checkpointing))
143}
144
145fn integral_from_soln<V>(soln: &Solution<V>) -> Result<V, DiffsolError>
146where
147 V: DefaultDenseMatrix,
148{
149 if soln.ts.is_empty() {
150 return Err(ode_solver_error!(
151 Other,
152 "Continuous adjoint solve returned no integral samples"
153 ));
154 }
155 Ok(soln.ys.column(soln.ts.len() - 1).into_owned())
156}
157
158#[allow(clippy::type_complexity)]
159fn solve_adjoint_fwd_with_tag<M, CG, LS, Tag>(
160 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
161 t_eval: &[M::T],
162 params: &[f64],
163 method: OdeSolverType,
164 linear_solver: LinearSolverType,
165) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
166where
167 M: Matrix<T: Scalar> + DefaultSolver + 'static,
168 CG: CodegenModule + 'static,
169 M::V: VectorHost + DefaultDenseMatrix,
170 LS: LinearSolver<M>,
171 DiffSl<M, CG>: OdeEquationsImplicit<M = M, T = M::T, V = M::V, C = M::C>,
172 Tag: OdeSolverMethodTag<M, CG> + 'static,
173 for<'b> &'b M::V: VectorRef<M::V>,
174 for<'b> &'b M: MatrixRef<M>,
175{
176 let soln = Solution::new_dense(t_eval.to_vec())?;
177 let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
178 Ok((
179 soln,
180 Box::new(AdjointCheckpointData::<M, CG, Tag>::new(
181 checkpointing,
182 params.to_vec(),
183 method,
184 linear_solver,
185 )),
186 ))
187}
188
189fn solve_continuous_adjoint_with_tag<M, CG, LS, Tag>(
190 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
191 final_time: M::T,
192 method: OdeSolverType,
193) -> Result<(M::V, Vec<M::V>), DiffsolError>
194where
195 M: Matrix<T: Scalar> + DefaultSolver + 'static,
196 CG: CodegenModule + 'static,
197 M::V: VectorHost + DefaultDenseMatrix,
198 LS: LinearSolver<M>,
199 Tag: OdeSolverMethodTag<M, CG> + 'static,
200 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
201 + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
202 for<'b> &'b M::V: VectorRef<M::V>,
203 for<'b> &'b M: MatrixRef<M>,
204{
205 let soln = Solution::new(final_time);
206 let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
207 let integral = integral_from_soln(&soln)?;
208 let sg = match method {
209 OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<M, CG, LS, LS, Tag, BdfTag>(
210 problem,
211 &soln,
212 checkpointing,
213 &[],
214 None,
215 ),
216 OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
217 M,
218 CG,
219 LS,
220 LS,
221 Tag,
222 Esdirk34Tag,
223 >(problem, &soln, checkpointing, &[], None),
224 OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
225 M,
226 CG,
227 LS,
228 LS,
229 Tag,
230 TrBdf2Tag,
231 >(problem, &soln, checkpointing, &[], None),
232 OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
233 M,
234 CG,
235 LS,
236 LS,
237 Tag,
238 Tsit45Tag,
239 >(problem, &soln, checkpointing, &[], None),
240 }?;
241 Ok((integral, sg))
242}
243
244fn solve_adjoint_bkwds_with_fwd_tag<M, CG, FwdLS, BwdLS, Tag>(
245 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
246 checkpoint: &AdjointCheckpointData<M, CG, Tag>,
247 backwards_method: OdeSolverType,
248 dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
249 t_eval: &[M::T],
250) -> Result<Vec<M::V>, DiffsolError>
251where
252 M: Matrix<T: Scalar> + DefaultSolver + 'static,
253 CG: CodegenModule + 'static,
254 M::V: VectorHost + DefaultDenseMatrix,
255 FwdLS: LinearSolver<M>,
256 BwdLS: LinearSolver<M>,
257 Tag: OdeSolverMethodTag<M, CG> + 'static,
258 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
259 + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
260 for<'b> &'b M::V: VectorRef<M::V>,
261 for<'b> &'b M: MatrixRef<M>,
262{
263 let checkpointing = checkpoint.checkpointing.clone();
266 let soln = Solution::new_dense(t_eval.to_vec())?;
267
268 let dgdu_eval = [dgdu_eval];
270 match backwards_method {
271 OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
272 M,
273 CG,
274 FwdLS,
275 BwdLS,
276 Tag,
277 BdfTag,
278 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
279 OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
280 M,
281 CG,
282 FwdLS,
283 BwdLS,
284 Tag,
285 Esdirk34Tag,
286 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
287 OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
288 M,
289 CG,
290 FwdLS,
291 BwdLS,
292 Tag,
293 TrBdf2Tag,
294 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
295 OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
296 M,
297 CG,
298 FwdLS,
299 BwdLS,
300 Tag,
301 Tsit45Tag,
302 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
303 }
304}
305
306fn solve_adjoint_bkwds_with_fwd_bkwd_tag<'solver, M, CG, FwdLS, BwdLS, FwdTag, BwdTag>(
307 problem: &'solver mut OdeSolverProblem<DiffSl<M, CG>>,
308 soln: &Solution<M::V>,
309 mut checkpointing: CheckpointingPath<DiffSl<M, CG>, FwdTag::State>,
310 dgdu_eval: &[&<M::V as DefaultDenseMatrix>::M],
311 nout_override: Option<usize>,
312) -> Result<Vec<M::V>, DiffsolError>
313where
314 M: Matrix<T: Scalar> + DefaultSolver + 'solver,
315 CG: CodegenModule + 'solver,
316 M::V: VectorHost + DefaultDenseMatrix,
317 FwdLS: LinearSolver<M>,
318 BwdLS: LinearSolver<M>,
319 FwdTag: OdeSolverMethodTag<M, CG>,
320 BwdTag: OdeSolverMethodTag<M, CG>,
321 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
322 + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
323 for<'b> &'b M::V: VectorRef<M::V>,
324 for<'b> &'b M: MatrixRef<M>,
325{
326 let checkpointing_len = checkpointing.len();
327 if checkpointing_len == 0 {
328 return Err(ode_solver_error!(
329 Other,
330 "Adjoint backward pass requires at least one checkpointing segment"
331 ));
332 }
333
334 let t_eval = if dgdu_eval.is_empty() {
335 &[]
336 } else {
337 soln.ts.as_slice()
338 };
339
340 let current_checkpointing = checkpointing
341 .pop()
342 .ok_or_else(|| ode_solver_error!(Other, "Adjoint backward pass returned no state"))?;
343 let model_index = checkpointing
344 .last()
345 .map(|segment| {
346 segment
347 .terminal_reset_root_idx()
348 .expect("Missing reset root index")
349 })
350 .unwrap_or(0);
351 problem.eqn_mut().set_model_index(model_index);
352 let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
353 let mut adjoint = BwdTag::solver_adjoint::<BwdLS, _>(
354 &*problem,
355 vec![current_checkpointing],
356 Some(fwd_solver),
357 nout_override,
358 )?;
359 loop {
360 let (mut state, adjoint_checkpointing) =
361 adjoint.solve_adjoint_backwards_pass(t_eval, dgdu_eval)?;
362 let Some(previous_checkpointing) = checkpointing.pop() else {
363 return Ok(state.into_common().sg);
364 };
365 let model_index = checkpointing
366 .last()
367 .map(|segment| {
368 segment
369 .terminal_reset_root_idx()
370 .expect("Missing reset root index")
371 })
372 .unwrap_or(0);
373 let fwd_state_minus = previous_checkpointing.last_checkpoint();
374 let fwd_state_plus = adjoint_checkpointing
375 .first()
376 .ok_or_else(|| {
377 ode_solver_error!(Other, "Adjoint backward pass returned no checkpointing")
378 })?
379 .first_checkpoint();
380 state.as_mut().apply_reset_with_adjoint(
381 problem.eqn(),
382 previous_checkpointing.terminal_reset_root_idx().unwrap(),
383 fwd_state_minus.as_ref(),
384 fwd_state_plus.as_ref(),
385 problem.integrate_out,
386 )?;
387 problem.eqn_mut().set_model_index(model_index);
388 let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
389 let adjoint_eqn = problem.adjoint_equations(
391 vec![previous_checkpointing],
392 Some(fwd_solver),
393 nout_override,
394 );
395
396 adjoint = BwdTag::solver_adjoint_from_state::<BwdLS, _>(&*problem, state, adjoint_eqn)?;
397 }
398}
399
400impl OdeSolverType {
401 pub(crate) fn solve<M, CG, LS>(
402 &self,
403 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
404 final_time: M::T,
405 ) -> Result<Solution<M::V>, DiffsolError>
406 where
407 M: Matrix<T: Scalar> + DefaultSolver,
408 CG: CodegenModule,
409 M::V: VectorHost + DefaultDenseMatrix,
410 LS: LinearSolver<M>,
411 for<'b> &'b M::V: VectorRef<M::V>,
412 for<'b> &'b M: MatrixRef<M>,
413 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
414 {
415 match self {
416 OdeSolverType::Bdf => {
417 solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new(final_time))
418 }
419 OdeSolverType::Esdirk34 => {
420 solve_with_tag::<M, CG, LS, Esdirk34Tag>(problem, Solution::new(final_time))
421 }
422 OdeSolverType::TrBdf2 => {
423 solve_with_tag::<M, CG, LS, TrBdf2Tag>(problem, Solution::new(final_time))
424 }
425 OdeSolverType::Tsit45 => {
426 solve_with_tag::<M, CG, LS, Tsit45Tag>(problem, Solution::new(final_time))
427 }
428 }
429 }
430
431 pub(crate) fn solve_dense<M, CG, LS>(
432 &self,
433 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
434 t_eval: &[M::T],
435 ) -> Result<Solution<M::V>, DiffsolError>
436 where
437 M: Matrix<T: Scalar> + DefaultSolver,
438 CG: CodegenModule,
439 M::V: VectorHost + DefaultDenseMatrix,
440 LS: LinearSolver<M>,
441 for<'b> &'b M::V: VectorRef<M::V>,
442 for<'b> &'b M: MatrixRef<M>,
443 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
444 {
445 match self {
446 OdeSolverType::Bdf => {
447 solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new_dense(t_eval.to_vec())?)
448 }
449 OdeSolverType::Esdirk34 => solve_with_tag::<M, CG, LS, Esdirk34Tag>(
450 problem,
451 Solution::new_dense(t_eval.to_vec())?,
452 ),
453 OdeSolverType::TrBdf2 => solve_with_tag::<M, CG, LS, TrBdf2Tag>(
454 problem,
455 Solution::new_dense(t_eval.to_vec())?,
456 ),
457 OdeSolverType::Tsit45 => solve_with_tag::<M, CG, LS, Tsit45Tag>(
458 problem,
459 Solution::new_dense(t_eval.to_vec())?,
460 ),
461 }
462 }
463
464 fn check_sens_available() -> Result<(), DiffsolError> {
465 if !is_sens_available() {
466 return Err(DiffsolError::Other(
467 "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
468 ));
469 }
470 Ok(())
471 }
472
473 #[allow(clippy::type_complexity)]
474 pub(crate) fn solve_fwd_sens<M, CG, LS>(
475 &self,
476 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
477 t_eval: &[M::T],
478 ) -> Result<Solution<M::V>, DiffsolError>
479 where
480 M: Matrix<T: Scalar> + DefaultSolver,
481 CG: CodegenModule,
482 M::V: VectorHost + DefaultDenseMatrix,
483 LS: LinearSolver<M>,
484 for<'b> &'b M::V: VectorRef<M::V>,
485 for<'b> &'b M: MatrixRef<M>,
486 {
487 Self::check_sens_available()?;
488 match self {
489 OdeSolverType::Bdf => solve_fwd_sens_with_tag::<M, CG, LS, BdfTag>(problem, t_eval),
490 OdeSolverType::Esdirk34 => {
491 solve_fwd_sens_with_tag::<M, CG, LS, Esdirk34Tag>(problem, t_eval)
492 }
493 OdeSolverType::TrBdf2 => {
494 solve_fwd_sens_with_tag::<M, CG, LS, TrBdf2Tag>(problem, t_eval)
495 }
496 OdeSolverType::Tsit45 => {
497 solve_fwd_sens_with_tag::<M, CG, LS, Tsit45Tag>(problem, t_eval)
498 }
499 }
500 }
501
502 #[allow(clippy::type_complexity)]
503 pub(crate) fn solve_adjoint_fwd<M, CG, LS>(
504 &self,
505 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
506 t_eval: &[M::T],
507 params: &[f64],
508 linear_solver: LinearSolverType,
509 ) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
510 where
511 M: Matrix<T: Scalar> + DefaultSolver + 'static,
512 CG: CodegenModule + 'static,
513 M::V: VectorHost + DefaultDenseMatrix,
514 LS: LinearSolver<M>,
515 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
516 for<'b> &'b M::V: VectorRef<M::V>,
517 for<'b> &'b M: MatrixRef<M>,
518 {
519 Self::check_sens_available()?;
520 match self {
521 OdeSolverType::Bdf => solve_adjoint_fwd_with_tag::<M, CG, LS, BdfTag>(
522 problem,
523 t_eval,
524 params,
525 *self,
526 linear_solver,
527 ),
528 OdeSolverType::Esdirk34 => solve_adjoint_fwd_with_tag::<M, CG, LS, Esdirk34Tag>(
529 problem,
530 t_eval,
531 params,
532 *self,
533 linear_solver,
534 ),
535 OdeSolverType::TrBdf2 => solve_adjoint_fwd_with_tag::<M, CG, LS, TrBdf2Tag>(
536 problem,
537 t_eval,
538 params,
539 *self,
540 linear_solver,
541 ),
542 OdeSolverType::Tsit45 => solve_adjoint_fwd_with_tag::<M, CG, LS, Tsit45Tag>(
543 problem,
544 t_eval,
545 params,
546 *self,
547 linear_solver,
548 ),
549 }
550 }
551
552 pub(crate) fn solve_continuous_adjoint<M, CG, LS>(
553 &self,
554 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
555 final_time: M::T,
556 ) -> Result<(M::V, Vec<M::V>), DiffsolError>
557 where
558 M: Matrix<T: Scalar> + DefaultSolver + 'static,
559 CG: CodegenModule + 'static,
560 M::V: VectorHost + DefaultDenseMatrix,
561 LS: LinearSolver<M>,
562 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
563 + diffsol::OdeEquationsImplicitAdjoint,
564 for<'b> &'b M::V: VectorRef<M::V>,
565 for<'b> &'b M: MatrixRef<M>,
566 {
567 Self::check_sens_available()?;
568 match self {
569 OdeSolverType::Bdf => {
570 solve_continuous_adjoint_with_tag::<M, CG, LS, BdfTag>(problem, final_time, *self)
571 }
572 OdeSolverType::Esdirk34 => solve_continuous_adjoint_with_tag::<M, CG, LS, Esdirk34Tag>(
573 problem, final_time, *self,
574 ),
575 OdeSolverType::TrBdf2 => solve_continuous_adjoint_with_tag::<M, CG, LS, TrBdf2Tag>(
576 problem, final_time, *self,
577 ),
578 OdeSolverType::Tsit45 => solve_continuous_adjoint_with_tag::<M, CG, LS, Tsit45Tag>(
579 problem, final_time, *self,
580 ),
581 }
582 }
583
584 pub(crate) fn solve_adjoint_bkwd<M, CG, BwdLS>(
585 &self,
586 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
587 checkpoint: &dyn AdjointCheckpoint,
588 dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
589 t_eval: &[M::T],
590 ) -> Result<Vec<M::V>, DiffsolError>
591 where
592 M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M> + 'static,
593 CG: CodegenModule + 'static,
594 M::V: VectorHost + DefaultDenseMatrix,
595 BwdLS: LinearSolver<M>,
596 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
597 + diffsol::OdeEquationsImplicitAdjoint,
598 for<'b> &'b M::V: VectorRef<M::V>,
599 for<'b> &'b M: MatrixRef<M>,
600 {
601 Self::check_sens_available()?;
602 match checkpoint.method() {
603 OdeSolverType::Bdf => {
604 let data = checkpoint.data::<M, CG, BdfTag>()?;
605 match data.linear_solver() {
606 LinearSolverType::Default => {
607 solve_adjoint_bkwds_with_fwd_tag::<
608 M,
609 CG,
610 <M as DefaultSolver>::LS,
611 BwdLS,
612 BdfTag,
613 >(problem, data, *self, dgdu_eval, t_eval)
614 }
615 LinearSolverType::Lu => {
616 solve_adjoint_bkwds_with_fwd_tag::<
617 M,
618 CG,
619 <M as LuValidator<M>>::LS,
620 BwdLS,
621 BdfTag,
622 >(problem, data, *self, dgdu_eval, t_eval)
623 }
624 LinearSolverType::Klu => {
625 solve_adjoint_bkwds_with_fwd_tag::<
626 M,
627 CG,
628 <M as KluValidator<M>>::LS,
629 BwdLS,
630 BdfTag,
631 >(problem, data, *self, dgdu_eval, t_eval)
632 }
633 }
634 }
635 OdeSolverType::Esdirk34 => {
636 let data = checkpoint.data::<M, CG, Esdirk34Tag>()?;
637 match data.linear_solver() {
638 LinearSolverType::Default => {
639 solve_adjoint_bkwds_with_fwd_tag::<
640 M,
641 CG,
642 <M as DefaultSolver>::LS,
643 BwdLS,
644 Esdirk34Tag,
645 >(problem, data, *self, dgdu_eval, t_eval)
646 }
647 LinearSolverType::Lu => {
648 solve_adjoint_bkwds_with_fwd_tag::<
649 M,
650 CG,
651 <M as LuValidator<M>>::LS,
652 BwdLS,
653 Esdirk34Tag,
654 >(problem, data, *self, dgdu_eval, t_eval)
655 }
656 LinearSolverType::Klu => {
657 solve_adjoint_bkwds_with_fwd_tag::<
658 M,
659 CG,
660 <M as KluValidator<M>>::LS,
661 BwdLS,
662 Esdirk34Tag,
663 >(problem, data, *self, dgdu_eval, t_eval)
664 }
665 }
666 }
667 OdeSolverType::TrBdf2 => {
668 let data = checkpoint.data::<M, CG, TrBdf2Tag>()?;
669 match data.linear_solver() {
670 LinearSolverType::Default => {
671 solve_adjoint_bkwds_with_fwd_tag::<
672 M,
673 CG,
674 <M as DefaultSolver>::LS,
675 BwdLS,
676 TrBdf2Tag,
677 >(problem, data, *self, dgdu_eval, t_eval)
678 }
679 LinearSolverType::Lu => {
680 solve_adjoint_bkwds_with_fwd_tag::<
681 M,
682 CG,
683 <M as LuValidator<M>>::LS,
684 BwdLS,
685 TrBdf2Tag,
686 >(problem, data, *self, dgdu_eval, t_eval)
687 }
688 LinearSolverType::Klu => {
689 solve_adjoint_bkwds_with_fwd_tag::<
690 M,
691 CG,
692 <M as KluValidator<M>>::LS,
693 BwdLS,
694 TrBdf2Tag,
695 >(problem, data, *self, dgdu_eval, t_eval)
696 }
697 }
698 }
699 OdeSolverType::Tsit45 => {
700 let data = checkpoint.data::<M, CG, Tsit45Tag>()?;
701 match data.linear_solver() {
702 LinearSolverType::Default => {
703 solve_adjoint_bkwds_with_fwd_tag::<
704 M,
705 CG,
706 <M as DefaultSolver>::LS,
707 BwdLS,
708 Tsit45Tag,
709 >(problem, data, *self, dgdu_eval, t_eval)
710 }
711 LinearSolverType::Lu => {
712 solve_adjoint_bkwds_with_fwd_tag::<
713 M,
714 CG,
715 <M as LuValidator<M>>::LS,
716 BwdLS,
717 Tsit45Tag,
718 >(problem, data, *self, dgdu_eval, t_eval)
719 }
720 LinearSolverType::Klu => {
721 solve_adjoint_bkwds_with_fwd_tag::<
722 M,
723 CG,
724 <M as KluValidator<M>>::LS,
725 BwdLS,
726 Tsit45Tag,
727 >(problem, data, *self, dgdu_eval, t_eval)
728 }
729 }
730 }
731 }
732 }
733}
734
735#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
736mod tests {
737 use diffsol::{
738 CodegenModuleCompile, CodegenModuleJit, DefaultDenseMatrix, DefaultSolver, DenseMatrix,
739 Matrix, MatrixCommon, OdeBuilder, OdeSolverProblem, Op, Vector,
740 };
741
742 #[cfg(feature = "diffsl-llvm")]
743 use crate::linear_solver_type::LinearSolverType;
744 use crate::test_support::{
745 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
746 logistic_state, LOGISTIC_X0,
747 };
748 #[cfg(feature = "diffsl-llvm")]
749 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
750 use crate::valid_linear_solver::LuValidator;
751
752 use super::OdeSolverType;
753
754 type M = diffsol::NalgebraMat<f64>;
755
756 fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
757 where
758 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
759 {
760 OdeBuilder::<M>::new()
761 .p([2.0])
762 .rtol(1e-6)
763 .atol([1e-6])
764 .build_from_diffsl::<CG>(code)
765 .unwrap()
766 }
767
768 fn assert_dense_solution_matches_expected(
769 soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
770 t_eval: &[f64],
771 expected: impl Fn(f64) -> f64,
772 ) {
773 assert_eq!(soln.ts, t_eval);
774 for (i, &t) in t_eval.iter().enumerate() {
775 assert_close(
776 soln.ys.get_index(0, i),
777 expected(t),
778 5e-4,
779 &format!("solution[{i}]"),
780 );
781 }
782 }
783
784 fn test_all_solver_variants<CG>()
785 where
786 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
787 {
788 let t_eval = [0.25, 0.5, 1.0];
789 for method in [
790 OdeSolverType::Bdf,
791 OdeSolverType::Esdirk34,
792 OdeSolverType::TrBdf2,
793 OdeSolverType::Tsit45,
794 ] {
795 let mut problem = build_problem::<CG>(logistic_diffsl_code());
796 let soln = method
797 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
798 .unwrap();
799 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
800 assert_close(
801 soln.ys.get_index(0, soln.ts.len() - 1),
802 logistic_state(LOGISTIC_X0, 2.0, 1.0),
803 5e-4,
804 "solve final value",
805 );
806
807 let mut problem = build_problem::<CG>(logistic_diffsl_code());
808 let soln = method
809 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
810 .unwrap();
811 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
812 logistic_state(LOGISTIC_X0, 2.0, t)
813 });
814 }
815 }
816
817 fn test_all_hybrid_solver_variants<CG>()
818 where
819 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
820 {
821 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
822 for method in [
823 OdeSolverType::Bdf,
824 OdeSolverType::Esdirk34,
825 OdeSolverType::TrBdf2,
826 OdeSolverType::Tsit45,
827 ] {
828 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
829 let soln = method
830 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
831 .unwrap();
832 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
833 assert_close(
834 soln.ys.get_index(0, soln.ts.len() - 1),
835 hybrid_logistic_state(2.0, 2.0),
836 5e-4,
837 "hybrid final value",
838 );
839
840 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
841 let soln = method
842 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
843 .unwrap();
844 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
845 hybrid_logistic_state(2.0, t)
846 });
847 }
848 }
849
850 fn test_all_solver_variants_with_lu<CG>()
851 where
852 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
853 {
854 let t_eval = [0.25, 0.5, 1.0];
855 for method in [
856 OdeSolverType::Bdf,
857 OdeSolverType::Esdirk34,
858 OdeSolverType::TrBdf2,
859 OdeSolverType::Tsit45,
860 ] {
861 let mut problem = build_problem::<CG>(logistic_diffsl_code());
862 let soln = method
863 .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
864 .unwrap();
865 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
866
867 let mut problem = build_problem::<CG>(logistic_diffsl_code());
868 let soln = method
869 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
870 .unwrap();
871 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
872 logistic_state(LOGISTIC_X0, 2.0, t)
873 });
874 }
875 }
876
877 fn test_all_hybrid_solver_variants_with_lu<CG>()
878 where
879 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
880 {
881 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
882 for method in [
883 OdeSolverType::Bdf,
884 OdeSolverType::Esdirk34,
885 OdeSolverType::TrBdf2,
886 OdeSolverType::Tsit45,
887 ] {
888 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
889 let soln = method
890 .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
891 .unwrap();
892 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
893
894 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
895 let soln = method
896 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
897 .unwrap();
898 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
899 hybrid_logistic_state(2.0, t)
900 });
901 }
902 }
903
904 fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
905 where
906 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
907 {
908 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
909
910 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
911 let soln = method
912 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
913 .unwrap();
914 assert_close(
915 *soln.ts.last().unwrap(),
916 2.0,
917 5e-4,
918 "direct hybrid restart final time",
919 );
920 assert_close(
921 soln.ys.get_index(0, soln.ts.len() - 1),
922 hybrid_logistic_state(2.0, 2.0),
923 5e-4,
924 "direct hybrid restart final value",
925 );
926
927 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
928 let soln = method
929 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
930 .unwrap();
931 assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
932 }
933
934 #[cfg(feature = "diffsl-llvm")]
935 fn test_all_sensitivity_solver_variants() {
936 let t_eval = [0.25, 0.5, 1.0];
937 for method in [
938 OdeSolverType::Bdf,
939 OdeSolverType::Esdirk34,
940 OdeSolverType::TrBdf2,
941 OdeSolverType::Tsit45,
942 ] {
943 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
944 let soln = method
945 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
946 &mut problem,
947 &t_eval,
948 )
949 .unwrap();
950 for (i, &t) in t_eval.iter().enumerate() {
951 assert_close(
952 soln.y_sens[0].get_index(0, i),
953 logistic_state_dr(LOGISTIC_X0, 2.0, t),
954 5e-4,
955 &format!("fwd_sens[{i}]"),
956 );
957 }
958
959 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
960 let soln = method
961 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
962 &mut problem,
963 &t_eval,
964 )
965 .unwrap();
966 for (i, &t) in t_eval.iter().enumerate() {
967 assert_close(
968 soln.y_sens[0].get_index(0, i),
969 hybrid_logistic_state_dr(2.0, t),
970 5e-4,
971 &format!("hybrid_fwd_sens[{i}]"),
972 );
973 }
974 }
975 }
976
977 #[cfg(feature = "diffsl-llvm")]
978 fn test_lu_sensitivity_and_adjoint_solver_variants() {
979 let t_eval = [0.25, 0.5, 1.0];
980 for method in [
981 OdeSolverType::Bdf,
982 OdeSolverType::Esdirk34,
983 OdeSolverType::TrBdf2,
984 OdeSolverType::Tsit45,
985 ] {
986 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
987 let soln = method
988 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
989 &mut problem,
990 &t_eval,
991 )
992 .unwrap();
993 for (i, &t) in t_eval.iter().enumerate() {
994 assert_close(
995 soln.y_sens[0].get_index(0, i),
996 logistic_state_dr(LOGISTIC_X0, 2.0, t),
997 5e-4,
998 &format!("lu fwd_sens[{i}]"),
999 );
1000 }
1001 }
1002
1003 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1004 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
1005 let (soln, checkpoint) = OdeSolverType::Bdf
1006 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1007 &mut problem,
1008 &adjoint_t_eval,
1009 &[2.0],
1010 LinearSolverType::Lu,
1011 )
1012 .unwrap();
1013 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1014 problem.eqn.nout(),
1015 soln.ts.len(),
1016 problem.context().to_owned(),
1017 );
1018 let gradient = OdeSolverType::TrBdf2
1019 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1020 &mut problem,
1021 checkpoint.as_ref(),
1022 &dgdu,
1023 &soln.ts,
1024 )
1025 .unwrap();
1026 assert_eq!(gradient.len(), 1);
1027 assert!(gradient[0].get_index(0).is_finite());
1028 }
1029
1030 #[cfg(feature = "diffsl-llvm")]
1031 fn test_direct_hybrid_sensitivity_restart_paths() {
1032 let t_eval = [0.5, 1.0, 2.5, 3.0, 4.5];
1033 for method in [
1034 OdeSolverType::Esdirk34,
1035 OdeSolverType::TrBdf2,
1036 OdeSolverType::Tsit45,
1037 ] {
1038 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
1039 let soln = method
1040 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1041 &mut problem,
1042 &t_eval,
1043 )
1044 .unwrap();
1045 for (i, &t) in t_eval.iter().enumerate() {
1046 assert_close(
1047 soln.ys.get_index(0, i),
1048 hybrid_logistic_state(2.0, t),
1049 5e-4,
1050 &format!("direct hybrid value[{i}]"),
1051 );
1052 assert_close(
1053 soln.y_sens[0].get_index(0, i),
1054 hybrid_logistic_state_dr(2.0, t),
1055 5e-4,
1056 &format!("direct hybrid fwd sens[{i}]"),
1057 );
1058 }
1059 }
1060 }
1061
1062 #[cfg(feature = "diffsl-llvm")]
1063 fn test_adjoint_backwards_methods_and_klu_branch() {
1064 for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
1065 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1066 let t_eval = [0.0, 0.25, 0.5, 1.0];
1067 let (soln, checkpoint) = OdeSolverType::Bdf
1068 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1069 &mut problem,
1070 &t_eval,
1071 &[2.0],
1072 LinearSolverType::Default,
1073 )
1074 .unwrap();
1075 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1076 problem.eqn.nout(),
1077 soln.ts.len(),
1078 problem.context().to_owned(),
1079 );
1080 let gradient = backwards_method
1081 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as crate::valid_linear_solver::KluValidator<M>>::LS>(
1082 &mut problem,
1083 checkpoint.as_ref(),
1084 &dgdu,
1085 &soln.ts,
1086 )
1087 .unwrap();
1088 assert_eq!(gradient.len(), 1);
1089 assert!(gradient[0].get_index(0).is_finite());
1090 }
1091
1092 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1093 let t_eval = [0.0, 0.25, 0.5, 1.0];
1094 let (soln, checkpoint) = OdeSolverType::Tsit45
1095 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1096 &mut problem,
1097 &t_eval,
1098 &[2.0],
1099 LinearSolverType::Default,
1100 )
1101 .unwrap();
1102 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1103 problem.eqn.nout(),
1104 soln.ts.len(),
1105 problem.context().to_owned(),
1106 );
1107 let gradient = OdeSolverType::Bdf
1108 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1109 &mut problem,
1110 checkpoint.as_ref(),
1111 &dgdu,
1112 &soln.ts,
1113 )
1114 .unwrap();
1115 assert_eq!(gradient.len(), 1);
1116 assert!(gradient[0].get_index(0).is_finite());
1117 }
1118
1119 #[cfg(feature = "diffsl-llvm")]
1120 fn test_all_adjoint_solver_variants() {
1121 let t_eval = [0.0, 0.25, 0.5, 1.0];
1122 for method in [
1123 OdeSolverType::Bdf,
1124 OdeSolverType::Esdirk34,
1125 OdeSolverType::TrBdf2,
1126 ] {
1127 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1128 let (soln, checkpoint) = method
1129 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1130 &mut problem,
1131 &t_eval,
1132 &[2.0],
1133 LinearSolverType::Default,
1134 )
1135 .unwrap();
1136 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1137 problem.eqn.nout(),
1138 soln.ts.len(),
1139 problem.context().to_owned(),
1140 );
1141 let gradient = OdeSolverType::Bdf
1142 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1143 &mut problem,
1144 checkpoint.as_ref(),
1145 &dgdu,
1146 &soln.ts,
1147 )
1148 .unwrap();
1149 assert_eq!(gradient.len(), 1);
1150 assert!(gradient[0].get_index(0).is_finite());
1151 }
1152 }
1153
1154 #[cfg(feature = "diffsl-cranelift")]
1155 #[test]
1156 fn runtime_dispatch_solves_all_variants_for_cranelift() {
1157 test_all_solver_variants::<diffsol::CraneliftJitModule>();
1158 test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1159 }
1160
1161 #[cfg(feature = "diffsl-cranelift")]
1162 #[test]
1163 fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1164 test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1165 test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1166 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1167 OdeSolverType::Esdirk34,
1168 );
1169 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1170 OdeSolverType::TrBdf2,
1171 );
1172 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1173 OdeSolverType::Tsit45,
1174 );
1175 }
1176
1177 #[cfg(feature = "diffsl-llvm")]
1178 #[test]
1179 fn runtime_dispatch_solves_all_variants_for_llvm() {
1180 test_all_solver_variants::<diffsol::LlvmModule>();
1181 test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1182 }
1183
1184 #[cfg(feature = "diffsl-llvm")]
1185 #[test]
1186 fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1187 test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1188 test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1189 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1190 OdeSolverType::Esdirk34,
1191 );
1192 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1193 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1194 }
1195
1196 #[cfg(feature = "diffsl-llvm")]
1197 #[test]
1198 fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1199 test_all_sensitivity_solver_variants();
1200 test_lu_sensitivity_and_adjoint_solver_variants();
1201 test_direct_hybrid_sensitivity_restart_paths();
1202 }
1203
1204 #[cfg(feature = "diffsl-llvm")]
1205 #[test]
1206 fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1207 test_all_adjoint_solver_variants();
1208 test_adjoint_backwards_methods_and_klu_branch();
1209 }
1210}