1use diffsol::error::{DiffsolError, OdeSolverError};
5use diffsol::ode_equations::OdeEquationsImplicitSens;
6use diffsol::{
7 matrix::MatrixRef, DefaultDenseMatrix, DenseMatrix, DiffSl, LinearSolver, Matrix,
8 OdeSolverProblem, OdeSolverState, VectorHost, VectorRef, VectorView,
9};
10use diffsol::{
11 ode_solver_error, AdjointOdeSolverMethod, CheckpointingPath, CodegenModule, DefaultSolver,
12 OdeEquations, OdeSolverMethod, OdeSolverStopReason, SensitivitiesOdeSolverMethod, Solution,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17use crate::adjoint_checkpoint::{AdjointCheckpoint, AdjointCheckpointData};
18use crate::ode_solver_tag::{BdfTag, Esdirk34Tag, OdeSolverMethodTag, TrBdf2Tag, Tsit45Tag};
19use crate::scalar_type::Scalar;
20use crate::utils::is_sens_available;
21use crate::{
22 linear_solver_type::LinearSolverType,
23 valid_linear_solver::{KluValidator, LuValidator},
24};
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
33#[serde(rename_all = "snake_case")]
34pub enum OdeSolverType {
35 Bdf,
36 Esdirk34,
37 TrBdf2,
38 Tsit45,
39}
40
41fn solve_with_tag<M, CG, LS, Tag>(
42 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
43 mut soln: Solution<M::V>,
44) -> Result<Solution<M::V>, DiffsolError>
45where
46 M: Matrix<T: Scalar>,
47 CG: CodegenModule,
48 M::V: VectorHost + DefaultDenseMatrix,
49 LS: LinearSolver<M>,
50 Tag: OdeSolverMethodTag<M, CG>,
51 for<'b> &'b M::V: VectorRef<M::V>,
52 for<'b> &'b M: MatrixRef<M>,
53{
54 let mut solver = Tag::solver::<LS>(problem)?;
55 while !soln.is_complete() {
56 solver = solver.solve_soln(&mut soln)?;
57 let root_idx = match soln.stop_reason {
58 Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
59 _ => continue,
60 };
61 if problem.eqn.reset().is_none() {
62 soln.truncate(problem, solver.state())?;
63 return Ok(soln);
64 }
65 let mut state = solver.into_state();
66 problem.eqn.set_model_index(root_idx);
67 state.as_mut().apply_reset(&problem.eqn)?;
68 solver = Tag::solver_with_state::<LS>(problem, state)?;
69 }
70 Ok(soln)
71}
72
73fn solve_fwd_sens_with_tag<M, CG, LS, Tag>(
74 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
75 t_eval: &[M::T],
76) -> Result<Solution<M::V>, DiffsolError>
77where
78 M: Matrix<T: Scalar>,
79 CG: CodegenModule,
80 M::V: VectorHost + DefaultDenseMatrix,
81 LS: LinearSolver<M>,
82 Tag: OdeSolverMethodTag<M, CG>,
83 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
84 for<'b> &'b M::V: VectorRef<M::V>,
85 for<'b> &'b M: MatrixRef<M>,
86{
87 let mut soln = Solution::new_dense(t_eval.to_vec())?;
88 let mut solver = Tag::solver_sens::<LS>(problem)?;
89 while !soln.is_complete() {
90 solver = solver.solve_soln_sensitivities(&mut soln)?;
91 let root_idx = match soln.stop_reason {
92 Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
93 _ => continue,
94 };
95 if problem.eqn.reset().is_none() {
96 soln.truncate_sens(problem, solver.state())?;
97 return Ok(soln);
98 }
99 let mut state = solver.into_state();
100 problem.eqn.set_model_index(root_idx);
101 state
102 .as_mut()
103 .apply_reset_with_sens(&problem.eqn, root_idx)?;
104 solver = Tag::solver_sens_with_state::<LS>(problem, state)?;
105 }
106 Ok(soln)
107}
108
109#[allow(clippy::type_complexity)]
110fn solve_with_checkpointing_with_tag<M, CG, LS, Tag>(
111 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
112 mut soln: Solution<M::V>,
113) -> Result<(Solution<M::V>, CheckpointingPath<DiffSl<M, CG>, Tag::State>), DiffsolError>
114where
115 M: Matrix<T: Scalar>,
116 CG: CodegenModule,
117 M::V: VectorHost + DefaultDenseMatrix,
118 LS: LinearSolver<M>,
119 Tag: OdeSolverMethodTag<M, CG>,
120 for<'b> &'b M::V: VectorRef<M::V>,
121 for<'b> &'b M: MatrixRef<M>,
122{
123 let mut solver = Tag::solver::<LS>(problem)?;
124 let mut checkpointing = Vec::new();
125 while !soln.is_complete() {
126 solver = solver.solve_soln_with_checkpointing(&mut soln, &mut checkpointing, None)?;
127 let root_idx = match soln.stop_reason {
128 Some(OdeSolverStopReason::RootFound(_, root_idx)) if !soln.is_complete() => root_idx,
129 _ => continue,
130 };
131 if problem.eqn.reset().is_none() {
132 soln.truncate(problem, solver.state())?;
133 return Ok((soln, checkpointing));
134 }
135 let mut state = solver.into_state();
136 problem.eqn.set_model_index(root_idx);
137 state.as_mut().apply_reset(&problem.eqn)?;
138 solver = Tag::solver_with_state::<LS>(problem, state)?;
139 }
140 Ok((soln, checkpointing))
141}
142
143fn integral_from_soln<V>(soln: &Solution<V>) -> Result<V, DiffsolError>
144where
145 V: DefaultDenseMatrix,
146{
147 if soln.ts.is_empty() {
148 return Err(ode_solver_error!(
149 Other,
150 "Continuous adjoint solve returned no integral samples"
151 ));
152 }
153 Ok(soln.ys.column(soln.ts.len() - 1).into_owned())
154}
155
156#[allow(clippy::type_complexity)]
157fn solve_adjoint_fwd_with_tag<M, CG, LS, Tag>(
158 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
159 t_eval: &[M::T],
160 params: &[f64],
161 method: OdeSolverType,
162 linear_solver: LinearSolverType,
163) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
164where
165 M: Matrix<T: Scalar> + 'static,
166 CG: CodegenModule + 'static,
167 M::V: VectorHost + DefaultDenseMatrix,
168 LS: LinearSolver<M>,
169 DiffSl<M, CG>: OdeEquations<M = M, T = M::T, V = M::V, C = M::C>,
170 Tag: OdeSolverMethodTag<M, CG> + 'static,
171 for<'b> &'b M::V: VectorRef<M::V>,
172 for<'b> &'b M: MatrixRef<M>,
173{
174 let soln = Solution::new_dense(t_eval.to_vec())?;
175 let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
176 Ok((
177 soln,
178 Box::new(AdjointCheckpointData::<M, CG, Tag>::new(
179 checkpointing,
180 params.to_vec(),
181 method,
182 linear_solver,
183 )),
184 ))
185}
186
187fn solve_continuous_adjoint_with_tag<M, CG, LS, Tag>(
188 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
189 final_time: M::T,
190 method: OdeSolverType,
191) -> Result<(M::V, Vec<M::V>), DiffsolError>
192where
193 M: Matrix<T: Scalar> + DefaultSolver + 'static,
194 CG: CodegenModule + 'static,
195 M::V: VectorHost + DefaultDenseMatrix,
196 LS: LinearSolver<M>,
197 Tag: OdeSolverMethodTag<M, CG> + 'static,
198 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
199 + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
200 for<'b> &'b M::V: VectorRef<M::V>,
201 for<'b> &'b M: MatrixRef<M>,
202{
203 let soln = Solution::new(final_time);
204 let (soln, checkpointing) = solve_with_checkpointing_with_tag::<M, CG, LS, Tag>(problem, soln)?;
205 let integral = integral_from_soln(&soln)?;
206 let sg = match method {
207 OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<M, CG, LS, LS, Tag, BdfTag>(
208 problem,
209 &soln,
210 checkpointing,
211 &[],
212 None,
213 ),
214 OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
215 M,
216 CG,
217 LS,
218 LS,
219 Tag,
220 Esdirk34Tag,
221 >(problem, &soln, checkpointing, &[], None),
222 OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
223 M,
224 CG,
225 LS,
226 LS,
227 Tag,
228 TrBdf2Tag,
229 >(problem, &soln, checkpointing, &[], None),
230 OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
231 M,
232 CG,
233 LS,
234 LS,
235 Tag,
236 Tsit45Tag,
237 >(problem, &soln, checkpointing, &[], None),
238 }?;
239 Ok((integral, sg))
240}
241
242fn solve_adjoint_bkwds_with_fwd_tag<M, CG, FwdLS, BwdLS, Tag>(
243 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
244 checkpoint: &AdjointCheckpointData<M, CG, Tag>,
245 backwards_method: OdeSolverType,
246 dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
247 t_eval: &[M::T],
248) -> Result<Vec<M::V>, DiffsolError>
249where
250 M: Matrix<T: Scalar> + DefaultSolver + 'static,
251 CG: CodegenModule + 'static,
252 M::V: VectorHost + DefaultDenseMatrix,
253 FwdLS: LinearSolver<M>,
254 BwdLS: LinearSolver<M>,
255 Tag: OdeSolverMethodTag<M, CG> + 'static,
256 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
257 + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
258 for<'b> &'b M::V: VectorRef<M::V>,
259 for<'b> &'b M: MatrixRef<M>,
260{
261 let checkpointing = checkpoint.checkpointing.clone();
264 let soln = Solution::new_dense(t_eval.to_vec())?;
265
266 let dgdu_eval = [dgdu_eval];
268 match backwards_method {
269 OdeSolverType::Bdf => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
270 M,
271 CG,
272 FwdLS,
273 BwdLS,
274 Tag,
275 BdfTag,
276 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
277 OdeSolverType::Esdirk34 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
278 M,
279 CG,
280 FwdLS,
281 BwdLS,
282 Tag,
283 Esdirk34Tag,
284 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
285 OdeSolverType::TrBdf2 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
286 M,
287 CG,
288 FwdLS,
289 BwdLS,
290 Tag,
291 TrBdf2Tag,
292 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
293 OdeSolverType::Tsit45 => solve_adjoint_bkwds_with_fwd_bkwd_tag::<
294 M,
295 CG,
296 FwdLS,
297 BwdLS,
298 Tag,
299 Tsit45Tag,
300 >(problem, &soln, checkpointing, &dgdu_eval, Some(1)),
301 }
302}
303
304fn solve_adjoint_bkwds_with_fwd_bkwd_tag<'solver, M, CG, FwdLS, BwdLS, FwdTag, BwdTag>(
305 problem: &'solver mut OdeSolverProblem<DiffSl<M, CG>>,
306 soln: &Solution<M::V>,
307 mut checkpointing: CheckpointingPath<DiffSl<M, CG>, FwdTag::State>,
308 dgdu_eval: &[&<M::V as DefaultDenseMatrix>::M],
309 nout_override: Option<usize>,
310) -> Result<Vec<M::V>, DiffsolError>
311where
312 M: Matrix<T: Scalar> + DefaultSolver + 'solver,
313 CG: CodegenModule + 'solver,
314 M::V: VectorHost + DefaultDenseMatrix,
315 FwdLS: LinearSolver<M>,
316 BwdLS: LinearSolver<M>,
317 FwdTag: OdeSolverMethodTag<M, CG>,
318 BwdTag: OdeSolverMethodTag<M, CG>,
319 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
320 + diffsol::OdeEquationsImplicitAdjoint<M = M, T = M::T, V = M::V, C = M::C>,
321 for<'b> &'b M::V: VectorRef<M::V>,
322 for<'b> &'b M: MatrixRef<M>,
323{
324 let checkpointing_len = checkpointing.len();
325 if checkpointing_len == 0 {
326 return Err(ode_solver_error!(
327 Other,
328 "Adjoint backward pass requires at least one checkpointing segment"
329 ));
330 }
331
332 let t_eval = if dgdu_eval.is_empty() {
333 &[]
334 } else {
335 soln.ts.as_slice()
336 };
337
338 let current_checkpointing = checkpointing
339 .pop()
340 .ok_or_else(|| ode_solver_error!(Other, "Adjoint backward pass returned no state"))?;
341 let model_index = checkpointing
342 .last()
343 .map(|segment| {
344 segment
345 .terminal_reset_root_idx()
346 .expect("Missing reset root index")
347 })
348 .unwrap_or(0);
349 problem.eqn_mut().set_model_index(model_index);
350 let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
351 let mut adjoint = BwdTag::solver_adjoint::<BwdLS, _>(
352 &*problem,
353 vec![current_checkpointing],
354 Some(fwd_solver),
355 nout_override,
356 )?;
357 loop {
358 let (mut state, adjoint_checkpointing) =
359 adjoint.solve_adjoint_backwards_pass(t_eval, dgdu_eval)?;
360 let Some(previous_checkpointing) = checkpointing.pop() else {
361 return Ok(state.into_common().sg);
362 };
363 let model_index = checkpointing
364 .last()
365 .map(|segment| {
366 segment
367 .terminal_reset_root_idx()
368 .expect("Missing reset root index")
369 })
370 .unwrap_or(0);
371 let fwd_state_minus = previous_checkpointing.last_checkpoint();
372 let fwd_state_plus = adjoint_checkpointing
373 .first()
374 .ok_or_else(|| {
375 ode_solver_error!(Other, "Adjoint backward pass returned no checkpointing")
376 })?
377 .first_checkpoint();
378 state.as_mut().apply_reset_with_adjoint(
379 problem.eqn(),
380 previous_checkpointing.terminal_reset_root_idx().unwrap(),
381 fwd_state_minus.as_ref(),
382 fwd_state_plus.as_ref(),
383 problem.integrate_out,
384 )?;
385 problem.eqn_mut().set_model_index(model_index);
386 let fwd_solver = FwdTag::uninitialised_solver::<FwdLS>(&*problem)?;
387 let adjoint_eqn = problem.adjoint_equations(
389 vec![previous_checkpointing],
390 Some(fwd_solver),
391 nout_override,
392 );
393
394 adjoint = BwdTag::solver_adjoint_from_state::<BwdLS, _>(&*problem, state, adjoint_eqn)?;
395 }
396}
397
398impl OdeSolverType {
399 pub(crate) fn solve<M, CG, LS>(
400 &self,
401 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
402 final_time: M::T,
403 ) -> Result<Solution<M::V>, DiffsolError>
404 where
405 M: Matrix<T: Scalar>,
406 CG: CodegenModule,
407 M::V: VectorHost + DefaultDenseMatrix,
408 LS: LinearSolver<M>,
409 for<'b> &'b M::V: VectorRef<M::V>,
410 for<'b> &'b M: MatrixRef<M>,
411 {
412 match self {
413 OdeSolverType::Bdf => {
414 solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new(final_time))
415 }
416 OdeSolverType::Esdirk34 => {
417 solve_with_tag::<M, CG, LS, Esdirk34Tag>(problem, Solution::new(final_time))
418 }
419 OdeSolverType::TrBdf2 => {
420 solve_with_tag::<M, CG, LS, TrBdf2Tag>(problem, Solution::new(final_time))
421 }
422 OdeSolverType::Tsit45 => {
423 solve_with_tag::<M, CG, LS, Tsit45Tag>(problem, Solution::new(final_time))
424 }
425 }
426 }
427
428 pub(crate) fn solve_dense<M, CG, LS>(
429 &self,
430 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
431 t_eval: &[M::T],
432 ) -> Result<Solution<M::V>, DiffsolError>
433 where
434 M: Matrix<T: Scalar>,
435 CG: CodegenModule,
436 M::V: VectorHost + DefaultDenseMatrix,
437 LS: LinearSolver<M>,
438 for<'b> &'b M::V: VectorRef<M::V>,
439 for<'b> &'b M: MatrixRef<M>,
440 {
441 match self {
442 OdeSolverType::Bdf => {
443 solve_with_tag::<M, CG, LS, BdfTag>(problem, Solution::new_dense(t_eval.to_vec())?)
444 }
445 OdeSolverType::Esdirk34 => solve_with_tag::<M, CG, LS, Esdirk34Tag>(
446 problem,
447 Solution::new_dense(t_eval.to_vec())?,
448 ),
449 OdeSolverType::TrBdf2 => solve_with_tag::<M, CG, LS, TrBdf2Tag>(
450 problem,
451 Solution::new_dense(t_eval.to_vec())?,
452 ),
453 OdeSolverType::Tsit45 => solve_with_tag::<M, CG, LS, Tsit45Tag>(
454 problem,
455 Solution::new_dense(t_eval.to_vec())?,
456 ),
457 }
458 }
459
460 fn check_sens_available() -> Result<(), DiffsolError> {
461 if !is_sens_available() {
462 return Err(DiffsolError::Other(
463 "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
464 ));
465 }
466 Ok(())
467 }
468
469 #[allow(clippy::type_complexity)]
470 pub(crate) fn solve_fwd_sens<M, CG, LS>(
471 &self,
472 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
473 t_eval: &[M::T],
474 ) -> Result<Solution<M::V>, DiffsolError>
475 where
476 M: Matrix<T: Scalar> + DefaultSolver,
477 CG: CodegenModule,
478 M::V: VectorHost + DefaultDenseMatrix,
479 LS: LinearSolver<M>,
480 for<'b> &'b M::V: VectorRef<M::V>,
481 for<'b> &'b M: MatrixRef<M>,
482 {
483 Self::check_sens_available()?;
484 match self {
485 OdeSolverType::Bdf => solve_fwd_sens_with_tag::<M, CG, LS, BdfTag>(problem, t_eval),
486 OdeSolverType::Esdirk34 => {
487 solve_fwd_sens_with_tag::<M, CG, LS, Esdirk34Tag>(problem, t_eval)
488 }
489 OdeSolverType::TrBdf2 => {
490 solve_fwd_sens_with_tag::<M, CG, LS, TrBdf2Tag>(problem, t_eval)
491 }
492 OdeSolverType::Tsit45 => {
493 solve_fwd_sens_with_tag::<M, CG, LS, Tsit45Tag>(problem, t_eval)
494 }
495 }
496 }
497
498 #[allow(clippy::type_complexity)]
499 pub(crate) fn solve_adjoint_fwd<M, CG, LS>(
500 &self,
501 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
502 t_eval: &[M::T],
503 params: &[f64],
504 linear_solver: LinearSolverType,
505 ) -> Result<(Solution<M::V>, Box<dyn AdjointCheckpoint>), DiffsolError>
506 where
507 M: Matrix<T: Scalar> + DefaultSolver + 'static,
508 CG: CodegenModule + 'static,
509 M::V: VectorHost + DefaultDenseMatrix,
510 LS: LinearSolver<M>,
511 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>,
512 for<'b> &'b M::V: VectorRef<M::V>,
513 for<'b> &'b M: MatrixRef<M>,
514 {
515 Self::check_sens_available()?;
516 match self {
517 OdeSolverType::Bdf => solve_adjoint_fwd_with_tag::<M, CG, LS, BdfTag>(
518 problem,
519 t_eval,
520 params,
521 *self,
522 linear_solver,
523 ),
524 OdeSolverType::Esdirk34 => solve_adjoint_fwd_with_tag::<M, CG, LS, Esdirk34Tag>(
525 problem,
526 t_eval,
527 params,
528 *self,
529 linear_solver,
530 ),
531 OdeSolverType::TrBdf2 => solve_adjoint_fwd_with_tag::<M, CG, LS, TrBdf2Tag>(
532 problem,
533 t_eval,
534 params,
535 *self,
536 linear_solver,
537 ),
538 OdeSolverType::Tsit45 => solve_adjoint_fwd_with_tag::<M, CG, LS, Tsit45Tag>(
539 problem,
540 t_eval,
541 params,
542 *self,
543 linear_solver,
544 ),
545 }
546 }
547
548 pub(crate) fn solve_continuous_adjoint<M, CG, LS>(
549 &self,
550 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
551 final_time: M::T,
552 ) -> Result<(M::V, Vec<M::V>), DiffsolError>
553 where
554 M: Matrix<T: Scalar> + DefaultSolver + 'static,
555 CG: CodegenModule + 'static,
556 M::V: VectorHost + DefaultDenseMatrix,
557 LS: LinearSolver<M>,
558 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
559 + diffsol::OdeEquationsImplicitAdjoint,
560 for<'b> &'b M::V: VectorRef<M::V>,
561 for<'b> &'b M: MatrixRef<M>,
562 {
563 Self::check_sens_available()?;
564 match self {
565 OdeSolverType::Bdf => {
566 solve_continuous_adjoint_with_tag::<M, CG, LS, BdfTag>(problem, final_time, *self)
567 }
568 OdeSolverType::Esdirk34 => solve_continuous_adjoint_with_tag::<M, CG, LS, Esdirk34Tag>(
569 problem, final_time, *self,
570 ),
571 OdeSolverType::TrBdf2 => solve_continuous_adjoint_with_tag::<M, CG, LS, TrBdf2Tag>(
572 problem, final_time, *self,
573 ),
574 OdeSolverType::Tsit45 => solve_continuous_adjoint_with_tag::<M, CG, LS, Tsit45Tag>(
575 problem, final_time, *self,
576 ),
577 }
578 }
579
580 pub(crate) fn solve_adjoint_bkwd<M, CG, BwdLS>(
581 &self,
582 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
583 checkpoint: &dyn AdjointCheckpoint,
584 dgdu_eval: &<M::V as DefaultDenseMatrix>::M,
585 t_eval: &[M::T],
586 ) -> Result<Vec<M::V>, DiffsolError>
587 where
588 M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M> + 'static,
589 CG: CodegenModule + 'static,
590 M::V: VectorHost + DefaultDenseMatrix,
591 BwdLS: LinearSolver<M>,
592 DiffSl<M, CG>: OdeEquationsImplicitSens<M = M, T = M::T, V = M::V, C = M::C>
593 + diffsol::OdeEquationsImplicitAdjoint,
594 for<'b> &'b M::V: VectorRef<M::V>,
595 for<'b> &'b M: MatrixRef<M>,
596 {
597 Self::check_sens_available()?;
598 match checkpoint.method() {
599 OdeSolverType::Bdf => {
600 let data = checkpoint.data::<M, CG, BdfTag>()?;
601 match data.linear_solver() {
602 LinearSolverType::Default => {
603 solve_adjoint_bkwds_with_fwd_tag::<
604 M,
605 CG,
606 <M as DefaultSolver>::LS,
607 BwdLS,
608 BdfTag,
609 >(problem, data, *self, dgdu_eval, t_eval)
610 }
611 LinearSolverType::Lu => {
612 solve_adjoint_bkwds_with_fwd_tag::<
613 M,
614 CG,
615 <M as LuValidator<M>>::LS,
616 BwdLS,
617 BdfTag,
618 >(problem, data, *self, dgdu_eval, t_eval)
619 }
620 LinearSolverType::Klu => {
621 solve_adjoint_bkwds_with_fwd_tag::<
622 M,
623 CG,
624 <M as KluValidator<M>>::LS,
625 BwdLS,
626 BdfTag,
627 >(problem, data, *self, dgdu_eval, t_eval)
628 }
629 }
630 }
631 OdeSolverType::Esdirk34 => {
632 let data = checkpoint.data::<M, CG, Esdirk34Tag>()?;
633 match data.linear_solver() {
634 LinearSolverType::Default => {
635 solve_adjoint_bkwds_with_fwd_tag::<
636 M,
637 CG,
638 <M as DefaultSolver>::LS,
639 BwdLS,
640 Esdirk34Tag,
641 >(problem, data, *self, dgdu_eval, t_eval)
642 }
643 LinearSolverType::Lu => {
644 solve_adjoint_bkwds_with_fwd_tag::<
645 M,
646 CG,
647 <M as LuValidator<M>>::LS,
648 BwdLS,
649 Esdirk34Tag,
650 >(problem, data, *self, dgdu_eval, t_eval)
651 }
652 LinearSolverType::Klu => {
653 solve_adjoint_bkwds_with_fwd_tag::<
654 M,
655 CG,
656 <M as KluValidator<M>>::LS,
657 BwdLS,
658 Esdirk34Tag,
659 >(problem, data, *self, dgdu_eval, t_eval)
660 }
661 }
662 }
663 OdeSolverType::TrBdf2 => {
664 let data = checkpoint.data::<M, CG, TrBdf2Tag>()?;
665 match data.linear_solver() {
666 LinearSolverType::Default => {
667 solve_adjoint_bkwds_with_fwd_tag::<
668 M,
669 CG,
670 <M as DefaultSolver>::LS,
671 BwdLS,
672 TrBdf2Tag,
673 >(problem, data, *self, dgdu_eval, t_eval)
674 }
675 LinearSolverType::Lu => {
676 solve_adjoint_bkwds_with_fwd_tag::<
677 M,
678 CG,
679 <M as LuValidator<M>>::LS,
680 BwdLS,
681 TrBdf2Tag,
682 >(problem, data, *self, dgdu_eval, t_eval)
683 }
684 LinearSolverType::Klu => {
685 solve_adjoint_bkwds_with_fwd_tag::<
686 M,
687 CG,
688 <M as KluValidator<M>>::LS,
689 BwdLS,
690 TrBdf2Tag,
691 >(problem, data, *self, dgdu_eval, t_eval)
692 }
693 }
694 }
695 OdeSolverType::Tsit45 => {
696 let data = checkpoint.data::<M, CG, Tsit45Tag>()?;
697 match data.linear_solver() {
698 LinearSolverType::Default => {
699 solve_adjoint_bkwds_with_fwd_tag::<
700 M,
701 CG,
702 <M as DefaultSolver>::LS,
703 BwdLS,
704 Tsit45Tag,
705 >(problem, data, *self, dgdu_eval, t_eval)
706 }
707 LinearSolverType::Lu => {
708 solve_adjoint_bkwds_with_fwd_tag::<
709 M,
710 CG,
711 <M as LuValidator<M>>::LS,
712 BwdLS,
713 Tsit45Tag,
714 >(problem, data, *self, dgdu_eval, t_eval)
715 }
716 LinearSolverType::Klu => {
717 solve_adjoint_bkwds_with_fwd_tag::<
718 M,
719 CG,
720 <M as KluValidator<M>>::LS,
721 BwdLS,
722 Tsit45Tag,
723 >(problem, data, *self, dgdu_eval, t_eval)
724 }
725 }
726 }
727 }
728 }
729}
730
731#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
732mod tests {
733 use diffsol::{
734 CodegenModuleCompile, CodegenModuleJit, DefaultDenseMatrix, DefaultSolver, DenseMatrix,
735 Matrix, MatrixCommon, OdeBuilder, OdeSolverProblem, Op, Vector,
736 };
737
738 #[cfg(feature = "diffsl-llvm")]
739 use crate::linear_solver_type::LinearSolverType;
740 use crate::test_support::{
741 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
742 logistic_state, LOGISTIC_X0,
743 };
744 #[cfg(feature = "diffsl-llvm")]
745 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
746 use crate::valid_linear_solver::LuValidator;
747
748 use super::OdeSolverType;
749
750 type M = diffsol::NalgebraMat<f64>;
751
752 fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
753 where
754 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
755 {
756 OdeBuilder::<M>::new()
757 .p([2.0])
758 .rtol(1e-6)
759 .atol([1e-6])
760 .build_from_diffsl::<CG>(code)
761 .unwrap()
762 }
763
764 fn assert_dense_solution_matches_expected(
765 soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
766 t_eval: &[f64],
767 expected: impl Fn(f64) -> f64,
768 ) {
769 assert_eq!(soln.ts, t_eval);
770 for (i, &t) in t_eval.iter().enumerate() {
771 assert_close(
772 soln.ys.get_index(0, i),
773 expected(t),
774 5e-4,
775 &format!("solution[{i}]"),
776 );
777 }
778 }
779
780 fn test_all_solver_variants<CG>()
781 where
782 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
783 {
784 let t_eval = [0.25, 0.5, 1.0];
785 for method in [
786 OdeSolverType::Bdf,
787 OdeSolverType::Esdirk34,
788 OdeSolverType::TrBdf2,
789 OdeSolverType::Tsit45,
790 ] {
791 let mut problem = build_problem::<CG>(logistic_diffsl_code());
792 let soln = method
793 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
794 .unwrap();
795 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
796 assert_close(
797 soln.ys.get_index(0, soln.ts.len() - 1),
798 logistic_state(LOGISTIC_X0, 2.0, 1.0),
799 5e-4,
800 "solve final value",
801 );
802
803 let mut problem = build_problem::<CG>(logistic_diffsl_code());
804 let soln = method
805 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
806 .unwrap();
807 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
808 logistic_state(LOGISTIC_X0, 2.0, t)
809 });
810 }
811 }
812
813 fn test_all_hybrid_solver_variants<CG>()
814 where
815 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
816 {
817 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
818 for method in [
819 OdeSolverType::Bdf,
820 OdeSolverType::Esdirk34,
821 OdeSolverType::TrBdf2,
822 OdeSolverType::Tsit45,
823 ] {
824 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
825 let soln = method
826 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
827 .unwrap();
828 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
829 assert_close(
830 soln.ys.get_index(0, soln.ts.len() - 1),
831 hybrid_logistic_state(2.0, 2.0),
832 5e-4,
833 "hybrid final value",
834 );
835
836 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
837 let soln = method
838 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
839 .unwrap();
840 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
841 hybrid_logistic_state(2.0, t)
842 });
843 }
844 }
845
846 fn test_all_solver_variants_with_lu<CG>()
847 where
848 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
849 {
850 let t_eval = [0.25, 0.5, 1.0];
851 for method in [
852 OdeSolverType::Bdf,
853 OdeSolverType::Esdirk34,
854 OdeSolverType::TrBdf2,
855 OdeSolverType::Tsit45,
856 ] {
857 let mut problem = build_problem::<CG>(logistic_diffsl_code());
858 let soln = method
859 .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
860 .unwrap();
861 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
862
863 let mut problem = build_problem::<CG>(logistic_diffsl_code());
864 let soln = method
865 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
866 .unwrap();
867 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
868 logistic_state(LOGISTIC_X0, 2.0, t)
869 });
870 }
871 }
872
873 fn test_all_hybrid_solver_variants_with_lu<CG>()
874 where
875 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
876 {
877 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
878 for method in [
879 OdeSolverType::Bdf,
880 OdeSolverType::Esdirk34,
881 OdeSolverType::TrBdf2,
882 OdeSolverType::Tsit45,
883 ] {
884 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
885 let soln = method
886 .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
887 .unwrap();
888 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
889
890 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
891 let soln = method
892 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
893 .unwrap();
894 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
895 hybrid_logistic_state(2.0, t)
896 });
897 }
898 }
899
900 fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
901 where
902 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
903 {
904 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
905
906 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
907 let soln = method
908 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
909 .unwrap();
910 assert_close(
911 *soln.ts.last().unwrap(),
912 2.0,
913 5e-4,
914 "direct hybrid restart final time",
915 );
916 assert_close(
917 soln.ys.get_index(0, soln.ts.len() - 1),
918 hybrid_logistic_state(2.0, 2.0),
919 5e-4,
920 "direct hybrid restart final value",
921 );
922
923 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
924 let soln = method
925 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
926 .unwrap();
927 assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
928 }
929
930 #[cfg(feature = "diffsl-llvm")]
931 fn test_all_sensitivity_solver_variants() {
932 let t_eval = [0.25, 0.5, 1.0];
933 for method in [
934 OdeSolverType::Bdf,
935 OdeSolverType::Esdirk34,
936 OdeSolverType::TrBdf2,
937 OdeSolverType::Tsit45,
938 ] {
939 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
940 let soln = method
941 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
942 &mut problem,
943 &t_eval,
944 )
945 .unwrap();
946 for (i, &t) in t_eval.iter().enumerate() {
947 assert_close(
948 soln.y_sens[0].get_index(0, i),
949 logistic_state_dr(LOGISTIC_X0, 2.0, t),
950 5e-4,
951 &format!("fwd_sens[{i}]"),
952 );
953 }
954
955 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
956 let soln = method
957 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
958 &mut problem,
959 &t_eval,
960 )
961 .unwrap();
962 for (i, &t) in t_eval.iter().enumerate() {
963 assert_close(
964 soln.y_sens[0].get_index(0, i),
965 hybrid_logistic_state_dr(2.0, t),
966 5e-4,
967 &format!("hybrid_fwd_sens[{i}]"),
968 );
969 }
970 }
971 }
972
973 #[cfg(feature = "diffsl-llvm")]
974 fn test_lu_sensitivity_and_adjoint_solver_variants() {
975 let t_eval = [0.25, 0.5, 1.0];
976 for method in [
977 OdeSolverType::Bdf,
978 OdeSolverType::Esdirk34,
979 OdeSolverType::TrBdf2,
980 OdeSolverType::Tsit45,
981 ] {
982 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
983 let soln = method
984 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
985 &mut problem,
986 &t_eval,
987 )
988 .unwrap();
989 for (i, &t) in t_eval.iter().enumerate() {
990 assert_close(
991 soln.y_sens[0].get_index(0, i),
992 logistic_state_dr(LOGISTIC_X0, 2.0, t),
993 5e-4,
994 &format!("lu fwd_sens[{i}]"),
995 );
996 }
997 }
998
999 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1000 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
1001 let (soln, checkpoint) = OdeSolverType::Bdf
1002 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1003 &mut problem,
1004 &adjoint_t_eval,
1005 &[2.0],
1006 LinearSolverType::Lu,
1007 )
1008 .unwrap();
1009 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1010 problem.eqn.nout(),
1011 soln.ts.len(),
1012 problem.context().to_owned(),
1013 );
1014 let gradient = OdeSolverType::TrBdf2
1015 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
1016 &mut problem,
1017 checkpoint.as_ref(),
1018 &dgdu,
1019 &soln.ts,
1020 )
1021 .unwrap();
1022 assert_eq!(gradient.len(), 1);
1023 assert!(gradient[0].get_index(0).is_finite());
1024 }
1025
1026 #[cfg(feature = "diffsl-llvm")]
1027 fn test_direct_hybrid_sensitivity_restart_paths() {
1028 let t_eval = [0.5, 1.0, 2.5, 3.0, 4.5];
1029 for method in [
1030 OdeSolverType::Esdirk34,
1031 OdeSolverType::TrBdf2,
1032 OdeSolverType::Tsit45,
1033 ] {
1034 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
1035 let soln = method
1036 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1037 &mut problem,
1038 &t_eval,
1039 )
1040 .unwrap();
1041 for (i, &t) in t_eval.iter().enumerate() {
1042 assert_close(
1043 soln.ys.get_index(0, i),
1044 hybrid_logistic_state(2.0, t),
1045 5e-4,
1046 &format!("direct hybrid value[{i}]"),
1047 );
1048 assert_close(
1049 soln.y_sens[0].get_index(0, i),
1050 hybrid_logistic_state_dr(2.0, t),
1051 5e-4,
1052 &format!("direct hybrid fwd sens[{i}]"),
1053 );
1054 }
1055 }
1056 }
1057
1058 #[cfg(feature = "diffsl-llvm")]
1059 fn test_adjoint_backwards_methods_and_klu_branch() {
1060 for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
1061 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1062 let t_eval = [0.0, 0.25, 0.5, 1.0];
1063 let (soln, checkpoint) = OdeSolverType::Bdf
1064 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1065 &mut problem,
1066 &t_eval,
1067 &[2.0],
1068 LinearSolverType::Default,
1069 )
1070 .unwrap();
1071 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1072 problem.eqn.nout(),
1073 soln.ts.len(),
1074 problem.context().to_owned(),
1075 );
1076 let gradient = backwards_method
1077 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as crate::valid_linear_solver::KluValidator<M>>::LS>(
1078 &mut problem,
1079 checkpoint.as_ref(),
1080 &dgdu,
1081 &soln.ts,
1082 )
1083 .unwrap();
1084 assert_eq!(gradient.len(), 1);
1085 assert!(gradient[0].get_index(0).is_finite());
1086 }
1087
1088 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1089 let t_eval = [0.0, 0.25, 0.5, 1.0];
1090 let (soln, checkpoint) = OdeSolverType::Tsit45
1091 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1092 &mut problem,
1093 &t_eval,
1094 &[2.0],
1095 LinearSolverType::Default,
1096 )
1097 .unwrap();
1098 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1099 problem.eqn.nout(),
1100 soln.ts.len(),
1101 problem.context().to_owned(),
1102 );
1103 let gradient = OdeSolverType::Bdf
1104 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1105 &mut problem,
1106 checkpoint.as_ref(),
1107 &dgdu,
1108 &soln.ts,
1109 )
1110 .unwrap();
1111 assert_eq!(gradient.len(), 1);
1112 assert!(gradient[0].get_index(0).is_finite());
1113 }
1114
1115 #[cfg(feature = "diffsl-llvm")]
1116 fn test_all_adjoint_solver_variants() {
1117 let t_eval = [0.0, 0.25, 0.5, 1.0];
1118 for method in [
1119 OdeSolverType::Bdf,
1120 OdeSolverType::Esdirk34,
1121 OdeSolverType::TrBdf2,
1122 ] {
1123 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1124 let (soln, checkpoint) = method
1125 .solve_adjoint_fwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1126 &mut problem,
1127 &t_eval,
1128 &[2.0],
1129 LinearSolverType::Default,
1130 )
1131 .unwrap();
1132 let dgdu = <<M as MatrixCommon>::V as DefaultDenseMatrix>::M::zeros(
1133 problem.eqn.nout(),
1134 soln.ts.len(),
1135 problem.context().to_owned(),
1136 );
1137 let gradient = OdeSolverType::Bdf
1138 .solve_adjoint_bkwd::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1139 &mut problem,
1140 checkpoint.as_ref(),
1141 &dgdu,
1142 &soln.ts,
1143 )
1144 .unwrap();
1145 assert_eq!(gradient.len(), 1);
1146 assert!(gradient[0].get_index(0).is_finite());
1147 }
1148 }
1149
1150 #[cfg(feature = "diffsl-cranelift")]
1151 #[test]
1152 fn runtime_dispatch_solves_all_variants_for_cranelift() {
1153 test_all_solver_variants::<diffsol::CraneliftJitModule>();
1154 test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1155 }
1156
1157 #[cfg(feature = "diffsl-cranelift")]
1158 #[test]
1159 fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1160 test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1161 test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1162 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1163 OdeSolverType::Esdirk34,
1164 );
1165 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1166 OdeSolverType::TrBdf2,
1167 );
1168 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1169 OdeSolverType::Tsit45,
1170 );
1171 }
1172
1173 #[cfg(feature = "diffsl-llvm")]
1174 #[test]
1175 fn runtime_dispatch_solves_all_variants_for_llvm() {
1176 test_all_solver_variants::<diffsol::LlvmModule>();
1177 test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1178 }
1179
1180 #[cfg(feature = "diffsl-llvm")]
1181 #[test]
1182 fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1183 test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1184 test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1185 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1186 OdeSolverType::Esdirk34,
1187 );
1188 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1189 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1190 }
1191
1192 #[cfg(feature = "diffsl-llvm")]
1193 #[test]
1194 fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1195 test_all_sensitivity_solver_variants();
1196 test_lu_sensitivity_and_adjoint_solver_variants();
1197 test_direct_hybrid_sensitivity_restart_paths();
1198 }
1199
1200 #[cfg(feature = "diffsl-llvm")]
1201 #[test]
1202 fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1203 test_all_adjoint_solver_variants();
1204 test_adjoint_backwards_methods_and_klu_branch();
1205 }
1206}