Skip to main content

diffsol_c/
ode_solver_type.rs

1// Solver method Python enum. This is used to select the overarching solver
2// stragegy like bdf or esdirk34 in diffsol.
3
4use diffsol::error::{DiffsolError, OdeSolverError};
5use diffsol::ode_equations::OdeEquationsImplicitSensWithReset;
6use diffsol::{
7    matrix::MatrixRef, DefaultDenseMatrix, DiffSl, LinearSolver, Matrix, OdeSolverMethod,
8    OdeSolverProblem, OdeSolverState, Vector, VectorHost, VectorRef,
9};
10use diffsol::{
11    ode_solver_error, AdjointOdeSolverMethod, Checkpointing, CodegenModule, DefaultSolver,
12    DenseMatrix, MatrixCommon, OdeEquations, OdeSolverStopReason, Op, SensitivitiesOdeSolverMethod,
13    Solution, VectorViewMut,
14};
15use ndarray::ArrayView2;
16use num_traits::{FromPrimitive, Zero}; // for generic nums in _solve_sum_squares_adj
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20use crate::scalar_type::Scalar;
21use crate::utils::is_sens_available;
22use crate::{
23    linear_solver_type::LinearSolverType,
24    valid_linear_solver::{KluValidator, LuValidator},
25};
26
27/// Enumerates the possible ODE solver methods for diffsol. See the solver descriptions in the diffsol documentation (https://github.com/martinjrobins/diffsol) for more details.
28///
29/// :attr bdf: Backward Differentiation Formula (BDF) method for stiff ODEs and singular mass matrices
30/// :attr esdirk34: Explicit Singly Diagonally Implicit Runge-Kutta (ESDIRK) method for moderately stiff ODEs and singular mass matrices.
31/// :attr tr_bdf2: Trapezoidal Backward Differentiation Formula of order 2 (TR-BDF2) method for moderately stiff ODEs and singular mass matrices.
32/// :attr tsit45: Tsitouras 4/5th order Explicit Runge-Kutta (TSIT45) method for non-stiff ODEs. This is an explicit method, it cannot handle singular mass matrices and does not require a linear solver.
33#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
34#[serde(rename_all = "snake_case")]
35pub enum OdeSolverType {
36    Bdf,
37    Esdirk34,
38    TrBdf2,
39    Tsit45,
40}
41
42fn apply_state_reset<Eqn, S>(
43    problem: &OdeSolverProblem<Eqn>,
44    state: &mut S,
45) -> Result<(), DiffsolError>
46where
47    Eqn: OdeEquations,
48    S: OdeSolverState<Eqn::V>,
49{
50    let eqn = &problem.eqn;
51    if let Some(reset_fn) = eqn.reset() {
52        state.state_mut_op(eqn, &reset_fn)?;
53    }
54    Ok(())
55}
56
57fn apply_state_reset_with_sens<Eqn, S>(
58    problem: &OdeSolverProblem<Eqn>,
59    state: &mut S,
60    root_idx: usize,
61) -> Result<(), DiffsolError>
62where
63    Eqn: OdeEquationsImplicitSensWithReset,
64    S: OdeSolverState<Eqn::V>,
65{
66    let eqn = &problem.eqn;
67    match (eqn.reset(), eqn.root()) {
68        (None, _) => Ok(()),
69        (Some(_), None) => Err(ode_solver_error!(ResetRequiresRootOperator)),
70        (Some(reset_fn), Some(root_fn)) => {
71            state.state_mut_op_with_sens_and_reset(eqn, &reset_fn, &root_fn, root_idx)?;
72            Ok(())
73        }
74    }
75}
76
77impl OdeSolverType {
78    pub(crate) fn solve<M, CG, LS>(
79        &self,
80        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
81        final_time: M::T,
82    ) -> Result<Solution<M::V>, DiffsolError>
83    where
84        M: Matrix<T: Scalar>,
85        CG: CodegenModule,
86        M::V: VectorHost + DefaultDenseMatrix,
87        LS: LinearSolver<M>,
88        for<'b> &'b M::V: VectorRef<M::V>,
89        for<'b> &'b M: MatrixRef<M>,
90    {
91        match self {
92            OdeSolverType::Bdf => {
93                let solver = problem.bdf::<LS>()?;
94                let mut soln = Solution::new(final_time);
95                solver.solve_soln(&mut soln)?;
96                Ok(soln)
97            }
98            OdeSolverType::Esdirk34 => {
99                let solver = problem.esdirk34::<LS>()?;
100                let mut soln = Solution::new(final_time);
101                solver.solve_soln(&mut soln)?;
102                Ok(soln)
103            }
104            OdeSolverType::TrBdf2 => {
105                let solver = problem.tr_bdf2::<LS>()?;
106                let mut soln = Solution::new(final_time);
107                solver.solve_soln(&mut soln)?;
108                Ok(soln)
109            }
110            OdeSolverType::Tsit45 => {
111                let solver = problem.tsit45()?;
112                let mut soln = Solution::new(final_time);
113                solver.solve_soln(&mut soln)?;
114                Ok(soln)
115            }
116        }
117    }
118
119    pub(crate) fn solve_dense<M, CG, LS>(
120        &self,
121        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
122        t_eval: &[M::T],
123    ) -> Result<Solution<M::V>, DiffsolError>
124    where
125        M: Matrix<T: Scalar>,
126        CG: CodegenModule,
127        M::V: VectorHost + DefaultDenseMatrix,
128        LS: LinearSolver<M>,
129        for<'b> &'b M::V: VectorRef<M::V>,
130        for<'b> &'b M: MatrixRef<M>,
131    {
132        match self {
133            OdeSolverType::Bdf => {
134                let solver = problem.bdf::<LS>()?;
135                let mut soln = Solution::new_dense(t_eval.to_vec())?;
136                solver.solve_soln(&mut soln)?;
137                Ok(soln)
138            }
139            OdeSolverType::Esdirk34 => {
140                let solver = problem.esdirk34::<LS>()?;
141                let mut soln = Solution::new_dense(t_eval.to_vec())?;
142                solver.solve_soln(&mut soln)?;
143                Ok(soln)
144            }
145            OdeSolverType::TrBdf2 => {
146                let solver = problem.tr_bdf2::<LS>()?;
147                let mut soln = Solution::new_dense(t_eval.to_vec())?;
148                solver.solve_soln(&mut soln)?;
149                Ok(soln)
150            }
151            OdeSolverType::Tsit45 => {
152                let solver = problem.tsit45()?;
153                let mut soln = Solution::new_dense(t_eval.to_vec())?;
154                solver.solve_soln(&mut soln)?;
155                Ok(soln)
156            }
157        }
158    }
159
160    pub(crate) fn solve_hybrid<M, CG, LS>(
161        &self,
162        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
163        final_time: M::T,
164    ) -> Result<Solution<M::V>, DiffsolError>
165    where
166        M: Matrix<T: Scalar>,
167        CG: CodegenModule,
168        M::V: VectorHost + DefaultDenseMatrix,
169        LS: LinearSolver<M>,
170        for<'b> &'b M::V: VectorRef<M::V>,
171        for<'b> &'b M: MatrixRef<M>,
172    {
173        match self {
174            OdeSolverType::Bdf => {
175                let mut soln = Solution::new(final_time);
176                let mut solver = problem.bdf::<LS>()?;
177                while !soln.is_complete() {
178                    solver = solver.solve_soln(&mut soln)?;
179                    let root_idx = match soln.stop_reason {
180                        Some(OdeSolverStopReason::RootFound(_, root_idx))
181                            if !soln.is_complete() =>
182                        {
183                            root_idx
184                        }
185                        _ => continue,
186                    };
187                    let mut state = solver.into_state();
188                    problem.eqn.set_model_index(root_idx);
189                    apply_state_reset(problem, &mut state)?;
190                    solver = problem.bdf_solver::<LS>(state)?;
191                }
192                Ok(soln)
193            }
194            OdeSolverType::Esdirk34 => {
195                let mut soln = Solution::new(final_time);
196                let mut solver = problem.esdirk34::<LS>()?;
197                while !soln.is_complete() {
198                    solver = solver.solve_soln(&mut soln)?;
199                    let root_idx = match soln.stop_reason {
200                        Some(OdeSolverStopReason::RootFound(_, root_idx))
201                            if !soln.is_complete() =>
202                        {
203                            root_idx
204                        }
205                        _ => continue,
206                    };
207                    let mut state = solver.into_state();
208                    problem.eqn.set_model_index(root_idx);
209                    apply_state_reset(problem, &mut state)?;
210                    solver = problem.esdirk34_solver::<LS>(state)?;
211                }
212                Ok(soln)
213            }
214            OdeSolverType::TrBdf2 => {
215                let mut soln = Solution::new(final_time);
216                let mut solver = problem.tr_bdf2::<LS>()?;
217                while !soln.is_complete() {
218                    solver = solver.solve_soln(&mut soln)?;
219                    let root_idx = match soln.stop_reason {
220                        Some(OdeSolverStopReason::RootFound(_, root_idx))
221                            if !soln.is_complete() =>
222                        {
223                            root_idx
224                        }
225                        _ => continue,
226                    };
227                    let mut state = solver.into_state();
228                    problem.eqn.set_model_index(root_idx);
229                    apply_state_reset(problem, &mut state)?;
230                    solver = problem.tr_bdf2_solver::<LS>(state)?;
231                }
232                Ok(soln)
233            }
234            OdeSolverType::Tsit45 => {
235                let mut soln = Solution::new(final_time);
236                let mut solver = problem.tsit45()?;
237                while !soln.is_complete() {
238                    solver = solver.solve_soln(&mut soln)?;
239                    let root_idx = match soln.stop_reason {
240                        Some(OdeSolverStopReason::RootFound(_, root_idx))
241                            if !soln.is_complete() =>
242                        {
243                            root_idx
244                        }
245                        _ => continue,
246                    };
247                    let mut state = solver.into_state();
248                    problem.eqn.set_model_index(root_idx);
249                    apply_state_reset(problem, &mut state)?;
250                    solver = problem.tsit45_solver(state)?;
251                }
252                Ok(soln)
253            }
254        }
255    }
256
257    pub(crate) fn solve_hybrid_dense<M, CG, LS>(
258        &self,
259        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
260        t_eval: &[M::T],
261    ) -> Result<Solution<M::V>, DiffsolError>
262    where
263        M: Matrix<T: Scalar>,
264        CG: CodegenModule,
265        M::V: VectorHost + DefaultDenseMatrix,
266        LS: LinearSolver<M>,
267        for<'b> &'b M::V: VectorRef<M::V>,
268        for<'b> &'b M: MatrixRef<M>,
269    {
270        match self {
271            OdeSolverType::Bdf => {
272                let mut soln = Solution::new_dense(t_eval.to_vec())?;
273                let mut solver = problem.bdf::<LS>()?;
274                while !soln.is_complete() {
275                    solver = solver.solve_soln(&mut soln)?;
276                    let root_idx = match soln.stop_reason {
277                        Some(OdeSolverStopReason::RootFound(_, root_idx))
278                            if !soln.is_complete() =>
279                        {
280                            root_idx
281                        }
282                        _ => continue,
283                    };
284                    let mut state = solver.into_state();
285                    problem.eqn.set_model_index(root_idx);
286                    apply_state_reset(problem, &mut state)?;
287                    solver = problem.bdf_solver::<LS>(state)?;
288                }
289                Ok(soln)
290            }
291            OdeSolverType::Esdirk34 => {
292                let mut soln = Solution::new_dense(t_eval.to_vec())?;
293                let mut solver = problem.esdirk34::<LS>()?;
294                while !soln.is_complete() {
295                    solver = solver.solve_soln(&mut soln)?;
296                    let root_idx = match soln.stop_reason {
297                        Some(OdeSolverStopReason::RootFound(_, root_idx))
298                            if !soln.is_complete() =>
299                        {
300                            root_idx
301                        }
302                        _ => continue,
303                    };
304                    let mut state = solver.into_state();
305                    problem.eqn.set_model_index(root_idx);
306                    apply_state_reset(problem, &mut state)?;
307                    solver = problem.esdirk34_solver::<LS>(state)?;
308                }
309                Ok(soln)
310            }
311            OdeSolverType::TrBdf2 => {
312                let mut soln = Solution::new_dense(t_eval.to_vec())?;
313                let mut solver = problem.tr_bdf2::<LS>()?;
314                while !soln.is_complete() {
315                    solver = solver.solve_soln(&mut soln)?;
316                    let root_idx = match soln.stop_reason {
317                        Some(OdeSolverStopReason::RootFound(_, root_idx))
318                            if !soln.is_complete() =>
319                        {
320                            root_idx
321                        }
322                        _ => continue,
323                    };
324                    let mut state = solver.into_state();
325                    problem.eqn.set_model_index(root_idx);
326                    apply_state_reset(problem, &mut state)?;
327                    solver = problem.tr_bdf2_solver::<LS>(state)?;
328                }
329                Ok(soln)
330            }
331            OdeSolverType::Tsit45 => {
332                let mut soln = Solution::new_dense(t_eval.to_vec())?;
333                let mut solver = problem.tsit45()?;
334                while !soln.is_complete() {
335                    solver = solver.solve_soln(&mut soln)?;
336                    let root_idx = match soln.stop_reason {
337                        Some(OdeSolverStopReason::RootFound(_, root_idx))
338                            if !soln.is_complete() =>
339                        {
340                            root_idx
341                        }
342                        _ => continue,
343                    };
344                    let mut state = solver.into_state();
345                    problem.eqn.set_model_index(root_idx);
346                    apply_state_reset(problem, &mut state)?;
347                    solver = problem.tsit45_solver(state)?;
348                }
349                Ok(soln)
350            }
351        }
352    }
353
354    fn check_sens_available() -> Result<(), DiffsolError> {
355        if !is_sens_available() {
356            return Err(DiffsolError::Other(
357                "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
358            ));
359        }
360        Ok(())
361    }
362
363    #[allow(clippy::type_complexity)]
364    pub(crate) fn solve_fwd_sens<M, CG, LS>(
365        &self,
366        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
367        t_eval: &[M::T],
368    ) -> Result<Solution<M::V>, DiffsolError>
369    where
370        M: Matrix<T: Scalar> + DefaultSolver,
371        CG: CodegenModule,
372        M::V: VectorHost + DefaultDenseMatrix,
373        LS: LinearSolver<M>,
374        for<'b> &'b M::V: VectorRef<M::V>,
375        for<'b> &'b M: MatrixRef<M>,
376    {
377        Self::check_sens_available()?;
378        match self {
379            OdeSolverType::Bdf => {
380                let solver = problem.bdf_sens::<LS>()?;
381                let mut soln = Solution::new_dense(t_eval.to_vec())?;
382                solver.solve_soln_sensitivities(&mut soln)?;
383                Ok(soln)
384            }
385            OdeSolverType::Esdirk34 => {
386                let solver = problem.esdirk34_sens::<LS>()?;
387                let mut soln = Solution::new_dense(t_eval.to_vec())?;
388                solver.solve_soln_sensitivities(&mut soln)?;
389                Ok(soln)
390            }
391            OdeSolverType::TrBdf2 => {
392                let solver = problem.tr_bdf2_sens::<LS>()?;
393                let mut soln = Solution::new_dense(t_eval.to_vec())?;
394                solver.solve_soln_sensitivities(&mut soln)?;
395                Ok(soln)
396            }
397            OdeSolverType::Tsit45 => {
398                let solver = problem.tsit45_sens()?;
399                let mut soln = Solution::new_dense(t_eval.to_vec())?;
400                solver.solve_soln_sensitivities(&mut soln)?;
401                Ok(soln)
402            }
403        }
404    }
405
406    #[allow(clippy::type_complexity)]
407    pub(crate) fn solve_hybrid_fwd_sens<M, CG, LS>(
408        &self,
409        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
410        t_eval: &[M::T],
411    ) -> Result<Solution<M::V>, DiffsolError>
412    where
413        M: Matrix<T: Scalar> + DefaultSolver,
414        CG: CodegenModule,
415        M::V: VectorHost + DefaultDenseMatrix,
416        LS: LinearSolver<M>,
417        for<'b> &'b M::V: VectorRef<M::V>,
418        for<'b> &'b M: MatrixRef<M>,
419    {
420        Self::check_sens_available()?;
421        match self {
422            OdeSolverType::Bdf => {
423                let mut soln = Solution::new_dense(t_eval.to_vec())?;
424                let mut solver = problem.bdf_sens::<LS>()?;
425                while !soln.is_complete() {
426                    solver = solver.solve_soln_sensitivities(&mut soln)?;
427                    let root_idx = match soln.stop_reason {
428                        Some(OdeSolverStopReason::RootFound(_, root_idx))
429                            if !soln.is_complete() =>
430                        {
431                            root_idx
432                        }
433                        _ => continue,
434                    };
435                    let mut state = solver.into_state();
436                    problem.eqn.set_model_index(root_idx);
437                    apply_state_reset_with_sens(problem, &mut state, root_idx)?;
438                    solver = problem.bdf_solver_sens::<LS>(state)?;
439                }
440                Ok(soln)
441            }
442            OdeSolverType::Esdirk34 => {
443                let mut soln = Solution::new_dense(t_eval.to_vec())?;
444                let mut solver = problem.esdirk34_sens::<LS>()?;
445                while !soln.is_complete() {
446                    solver = solver.solve_soln_sensitivities(&mut soln)?;
447                    let root_idx = match soln.stop_reason {
448                        Some(OdeSolverStopReason::RootFound(_, root_idx))
449                            if !soln.is_complete() =>
450                        {
451                            root_idx
452                        }
453                        _ => continue,
454                    };
455                    let mut state = solver.into_state();
456                    problem.eqn.set_model_index(root_idx);
457                    apply_state_reset_with_sens(problem, &mut state, root_idx)?;
458                    solver = problem.esdirk34_solver_sens::<LS>(state)?;
459                }
460                Ok(soln)
461            }
462            OdeSolverType::TrBdf2 => {
463                let mut soln = Solution::new_dense(t_eval.to_vec())?;
464                let mut solver = problem.tr_bdf2_sens::<LS>()?;
465                while !soln.is_complete() {
466                    solver = solver.solve_soln_sensitivities(&mut soln)?;
467                    let root_idx = match soln.stop_reason {
468                        Some(OdeSolverStopReason::RootFound(_, root_idx))
469                            if !soln.is_complete() =>
470                        {
471                            root_idx
472                        }
473                        _ => continue,
474                    };
475                    let mut state = solver.into_state();
476                    problem.eqn.set_model_index(root_idx);
477                    apply_state_reset_with_sens(problem, &mut state, root_idx)?;
478                    solver = problem.tr_bdf2_solver_sens::<LS>(state)?;
479                }
480                Ok(soln)
481            }
482            OdeSolverType::Tsit45 => {
483                let mut soln = Solution::new_dense(t_eval.to_vec())?;
484                let mut solver = problem.tsit45_sens()?;
485                while !soln.is_complete() {
486                    solver = solver.solve_soln_sensitivities(&mut soln)?;
487                    let root_idx = match soln.stop_reason {
488                        Some(OdeSolverStopReason::RootFound(_, root_idx))
489                            if !soln.is_complete() =>
490                        {
491                            root_idx
492                        }
493                        _ => continue,
494                    };
495                    let mut state = solver.into_state();
496                    problem.eqn.set_model_index(root_idx);
497                    apply_state_reset_with_sens(problem, &mut state, root_idx)?;
498                    solver = problem.tsit45_solver_sens(state)?;
499                }
500                Ok(soln)
501            }
502        }
503    }
504
505    pub(crate) fn solve_sum_squares_adj<'a, M, CG, LS>(
506        &self,
507        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
508        data: ArrayView2<'a, M::T>,
509        t_eval: &[M::T],
510        backwards_method: OdeSolverType,
511        backwards_linear_solver: LinearSolverType,
512    ) -> Result<(M::T, M::V), DiffsolError>
513    where
514        M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
515        CG: CodegenModule,
516        M::V: VectorHost + DefaultDenseMatrix,
517        LS: LinearSolver<M>,
518        for<'b> &'b M::V: VectorRef<M::V>,
519        for<'b> &'b M: MatrixRef<M>,
520    {
521        Self::check_sens_available()?;
522        match self {
523            OdeSolverType::Bdf => self._solve_sum_squares_adj(
524                problem.bdf::<LS>()?,
525                data,
526                t_eval,
527                backwards_method,
528                backwards_linear_solver,
529            ),
530            OdeSolverType::Esdirk34 => self._solve_sum_squares_adj(
531                problem.esdirk34::<LS>()?,
532                data,
533                t_eval,
534                backwards_method,
535                backwards_linear_solver,
536            ),
537            OdeSolverType::TrBdf2 => self._solve_sum_squares_adj(
538                problem.tr_bdf2::<LS>()?,
539                data,
540                t_eval,
541                backwards_method,
542                backwards_linear_solver,
543            ),
544            OdeSolverType::Tsit45 => self._solve_sum_squares_adj(
545                problem.tsit45()?,
546                data,
547                t_eval,
548                backwards_method,
549                backwards_linear_solver,
550            ),
551        }
552    }
553
554    pub(crate) fn _solve_sum_squares_adj<'data, 'solver, M, CG, S>(
555        &self,
556        mut solver: S,
557        data: ArrayView2<'data, M::T>,
558        t_eval: &[M::T],
559        backwards_method: OdeSolverType,
560        backwards_linear_solver: LinearSolverType,
561    ) -> Result<(M::T, M::V), DiffsolError>
562    where
563        M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
564        CG: CodegenModule,
565        M::V: VectorHost + DefaultDenseMatrix,
566        S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
567        for<'b> &'b M::V: VectorRef<M::V>,
568        for<'b> &'b M: MatrixRef<M>,
569    {
570        let (chk, ys, stop_reason) = solver.solve_dense_with_checkpointing(t_eval, None)?;
571        let eqn = solver.problem().eqn();
572        let ctx = eqn.context();
573        let mut g_m = <M::V as DefaultDenseMatrix>::M::zeros(eqn.nout(), t_eval.len(), ctx.clone());
574        let mut y = M::T::zero();
575        for j in 0..g_m.ncols() {
576            let ys_col = ys.column(j);
577            // TODO: can we avoid this allocation? (I can't see how right now)
578            let mut tmp = M::V::from_slice(data.column(j).as_slice().unwrap(), ctx.clone());
579            // tmp = 2 * ys_col - 2 * tmp
580            tmp.axpy_v(
581                M::T::from_f64(2.0).unwrap(),
582                &ys_col,
583                M::T::from_f64(-2.0).unwrap(),
584            );
585            g_m.column_mut(j).copy_from(&tmp);
586
587            // y = (1/4) * dot(tmp, tmp) + y
588            let norm = tmp.norm(2);
589            y += M::T::from_f64(1.0 / 4.0).unwrap() * norm * norm;
590        }
591        let mut y_sens = match backwards_linear_solver {
592            LinearSolverType::Default => backwards_method
593                .solve_adjoint_backwards::<M, CG, <M as DefaultSolver>::LS, S>(
594                    solver.problem(),
595                    chk,
596                    stop_reason,
597                    &g_m,
598                    t_eval,
599                    Some(1),
600                )?,
601            LinearSolverType::Lu => backwards_method
602                .solve_adjoint_backwards::<M, CG, <M as LuValidator<M>>::LS, S>(
603                    solver.problem(),
604                    chk,
605                    stop_reason,
606                    &g_m,
607                    t_eval,
608                    Some(1),
609                )?,
610            LinearSolverType::Klu => backwards_method
611                .solve_adjoint_backwards::<M, CG, <M as KluValidator<M>>::LS, S>(
612                    solver.problem(),
613                    chk,
614                    stop_reason,
615                    &g_m,
616                    t_eval,
617                    Some(1),
618                )?,
619        };
620        Ok((y, y_sens.pop().unwrap()))
621    }
622
623    pub(crate) fn solve_adjoint_backwards<'solver, M, CG, LS, S>(
624        &self,
625        problem: &'solver OdeSolverProblem<DiffSl<M, CG>>,
626        checkpointing: Checkpointing<'solver, DiffSl<M, CG>, S>,
627        _stop_reason: OdeSolverStopReason<M::T>,
628        g_m: &<M::V as DefaultDenseMatrix>::M,
629        t_eval: &[M::T],
630        nout_override: Option<usize>,
631    ) -> Result<Vec<M::V>, DiffsolError>
632    where
633        M: Matrix<T: Scalar> + DefaultSolver,
634        CG: CodegenModule,
635        M::V: VectorHost + DefaultDenseMatrix,
636        S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
637        LS: LinearSolver<M>,
638        for<'b> &'b M::V: VectorRef<M::V>,
639        for<'b> &'b M: MatrixRef<M>,
640    {
641        match self {
642            OdeSolverType::Bdf => problem
643                .bdf_solver_adjoint::<LS, _>(checkpointing, nout_override)?
644                .solve_adjoint_backwards_pass(None, t_eval, &[g_m])
645                .map(|res| res.into_common().sg),
646            OdeSolverType::Esdirk34 => problem
647                .esdirk34_solver_adjoint::<LS, _>(checkpointing, nout_override)?
648                .solve_adjoint_backwards_pass(None, t_eval, &[g_m])
649                .map(|res| res.into_common().sg),
650            OdeSolverType::TrBdf2 => problem
651                .tr_bdf2_solver_adjoint::<LS, _>(checkpointing, nout_override)?
652                .solve_adjoint_backwards_pass(None, t_eval, &[g_m])
653                .map(|res| res.into_common().sg),
654            OdeSolverType::Tsit45 => Err(DiffsolError::Other(
655                "Tsit45 solver does not support adjoint sensitivity analysis.".to_string(),
656            )),
657        }
658    }
659}
660
661#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
662mod tests {
663    use diffsol::{
664        CodegenModuleCompile, CodegenModuleJit, DefaultSolver, DenseMatrix, OdeBuilder,
665        OdeSolverProblem, Vector,
666    };
667
668    #[cfg(feature = "diffsl-llvm")]
669    use crate::linear_solver_type::LinearSolverType;
670    use crate::test_support::{
671        assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
672        logistic_state, LOGISTIC_X0,
673    };
674    #[cfg(feature = "diffsl-llvm")]
675    use crate::test_support::{hybrid_logistic_state_dr, logistic_integral, logistic_state_dr};
676    use crate::valid_linear_solver::LuValidator;
677    #[cfg(feature = "diffsl-llvm")]
678    use ndarray::Array2;
679
680    use super::OdeSolverType;
681
682    type M = diffsol::NalgebraMat<f64>;
683
684    fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
685    where
686        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
687    {
688        OdeBuilder::<M>::new()
689            .p([2.0])
690            .rtol(1e-6)
691            .atol([1e-6])
692            .build_from_diffsl::<CG>(code)
693            .unwrap()
694    }
695
696    fn assert_dense_solution_matches_expected(
697        soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
698        t_eval: &[f64],
699        expected: impl Fn(f64) -> f64,
700    ) {
701        assert_eq!(soln.ts, t_eval);
702        for (i, &t) in t_eval.iter().enumerate() {
703            assert_close(
704                soln.ys.get_index(0, i),
705                expected(t),
706                5e-4,
707                &format!("solution[{i}]"),
708            );
709        }
710    }
711
712    fn test_all_solver_variants<CG>()
713    where
714        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
715    {
716        let t_eval = [0.25, 0.5, 1.0];
717        for method in [
718            OdeSolverType::Bdf,
719            OdeSolverType::Esdirk34,
720            OdeSolverType::TrBdf2,
721            OdeSolverType::Tsit45,
722        ] {
723            let mut problem = build_problem::<CG>(logistic_diffsl_code());
724            let soln = method
725                .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
726                .unwrap();
727            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
728            assert_close(
729                soln.ys.get_index(0, soln.ts.len() - 1),
730                logistic_state(LOGISTIC_X0, 2.0, 1.0),
731                5e-4,
732                "solve final value",
733            );
734
735            let mut problem = build_problem::<CG>(logistic_diffsl_code());
736            let soln = method
737                .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
738                .unwrap();
739            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
740                logistic_state(LOGISTIC_X0, 2.0, t)
741            });
742        }
743    }
744
745    fn test_all_hybrid_solver_variants<CG>()
746    where
747        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
748    {
749        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
750        for method in [
751            OdeSolverType::Bdf,
752            OdeSolverType::Esdirk34,
753            OdeSolverType::TrBdf2,
754            OdeSolverType::Tsit45,
755        ] {
756            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
757            let soln = method
758                .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
759                .unwrap();
760            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
761            assert_close(
762                soln.ys.get_index(0, soln.ts.len() - 1),
763                hybrid_logistic_state(2.0, 2.0),
764                5e-4,
765                "hybrid final value",
766            );
767
768            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
769            let soln = method
770                .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
771                .unwrap();
772            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
773                hybrid_logistic_state(2.0, t)
774            });
775        }
776    }
777
778    fn test_all_solver_variants_with_lu<CG>()
779    where
780        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
781    {
782        let t_eval = [0.25, 0.5, 1.0];
783        for method in [
784            OdeSolverType::Bdf,
785            OdeSolverType::Esdirk34,
786            OdeSolverType::TrBdf2,
787            OdeSolverType::Tsit45,
788        ] {
789            let mut problem = build_problem::<CG>(logistic_diffsl_code());
790            let soln = method
791                .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
792                .unwrap();
793            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
794
795            let mut problem = build_problem::<CG>(logistic_diffsl_code());
796            let soln = method
797                .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
798                .unwrap();
799            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
800                logistic_state(LOGISTIC_X0, 2.0, t)
801            });
802        }
803    }
804
805    fn test_all_hybrid_solver_variants_with_lu<CG>()
806    where
807        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
808    {
809        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
810        for method in [
811            OdeSolverType::Bdf,
812            OdeSolverType::Esdirk34,
813            OdeSolverType::TrBdf2,
814            OdeSolverType::Tsit45,
815        ] {
816            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
817            let soln = method
818                .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
819                .unwrap();
820            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
821
822            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
823            let soln = method
824                .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
825                .unwrap();
826            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
827                hybrid_logistic_state(2.0, t)
828            });
829        }
830    }
831
832    fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
833    where
834        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
835    {
836        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
837
838        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
839        let soln = method
840            .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
841            .unwrap();
842        assert_close(
843            *soln.ts.last().unwrap(),
844            2.0,
845            5e-4,
846            "direct hybrid restart final time",
847        );
848        assert_close(
849            soln.ys.get_index(0, soln.ts.len() - 1),
850            hybrid_logistic_state(2.0, 2.0),
851            5e-4,
852            "direct hybrid restart final value",
853        );
854
855        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
856        let soln = method
857            .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
858            .unwrap();
859        assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
860    }
861
862    #[cfg(feature = "diffsl-llvm")]
863    fn test_all_sensitivity_solver_variants() {
864        let t_eval = [0.25, 0.5, 1.0];
865        for method in [
866            OdeSolverType::Bdf,
867            OdeSolverType::Esdirk34,
868            OdeSolverType::TrBdf2,
869            OdeSolverType::Tsit45,
870        ] {
871            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
872            let soln = method
873                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
874                    &mut problem,
875                    &t_eval,
876                )
877                .unwrap();
878            for (i, &t) in t_eval.iter().enumerate() {
879                assert_close(
880                    soln.y_sens[0].get_index(0, i),
881                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
882                    5e-4,
883                    &format!("fwd_sens[{i}]"),
884                );
885            }
886
887            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
888            let soln = method
889                .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
890                    &mut problem,
891                    &t_eval,
892                )
893                .unwrap();
894            for (i, &t) in t_eval.iter().enumerate() {
895                assert_close(
896                    soln.y_sens[0].get_index(0, i),
897                    hybrid_logistic_state_dr(2.0, t),
898                    5e-4,
899                    &format!("hybrid_fwd_sens[{i}]"),
900                );
901            }
902        }
903    }
904
905    #[cfg(feature = "diffsl-llvm")]
906    fn test_lu_sensitivity_and_adjoint_solver_variants() {
907        let t_eval = [0.25, 0.5, 1.0];
908        for method in [
909            OdeSolverType::Bdf,
910            OdeSolverType::Esdirk34,
911            OdeSolverType::TrBdf2,
912            OdeSolverType::Tsit45,
913        ] {
914            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
915            let soln = method
916                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
917                    &mut problem,
918                    &t_eval,
919                )
920                .unwrap();
921            for (i, &t) in t_eval.iter().enumerate() {
922                assert_close(
923                    soln.y_sens[0].get_index(0, i),
924                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
925                    5e-4,
926                    &format!("lu fwd_sens[{i}]"),
927                );
928            }
929        }
930
931        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
932        let data = Array2::from_shape_vec(
933            (1, adjoint_t_eval.len()),
934            adjoint_t_eval
935                .iter()
936                .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
937                .collect(),
938        )
939        .unwrap();
940
941        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
942        let (objective, gradient) = OdeSolverType::Bdf
943            .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
944                &mut problem,
945                data.view(),
946                &adjoint_t_eval,
947                OdeSolverType::TrBdf2,
948                LinearSolverType::Lu,
949            )
950            .unwrap();
951        assert!(objective.is_finite());
952        assert_eq!(gradient.len(), 1);
953        assert!(gradient.get_index(0).is_finite());
954    }
955
956    #[cfg(feature = "diffsl-llvm")]
957    fn test_direct_hybrid_sensitivity_restart_paths() {
958        let t_eval = [0.5, 1.0, 2.5, 3.0, 4.5];
959        for method in [
960            OdeSolverType::Esdirk34,
961            OdeSolverType::TrBdf2,
962            OdeSolverType::Tsit45,
963        ] {
964            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
965            let soln = method
966                .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
967                    &mut problem,
968                    &t_eval,
969                )
970                .unwrap();
971            for (i, &t) in t_eval.iter().enumerate() {
972                assert_close(
973                    soln.ys.get_index(0, i),
974                    hybrid_logistic_state(2.0, t),
975                    5e-4,
976                    &format!("direct hybrid value[{i}]"),
977                );
978                assert_close(
979                    soln.y_sens[0].get_index(0, i),
980                    hybrid_logistic_state_dr(2.0, t),
981                    5e-4,
982                    &format!("direct hybrid fwd sens[{i}]"),
983                );
984            }
985        }
986    }
987
988    #[cfg(feature = "diffsl-llvm")]
989    fn test_adjoint_backwards_methods_and_klu_branch() {
990        let t_eval = [0.0, 0.25, 0.5, 1.0];
991        let data = Array2::from_shape_vec(
992            (1, t_eval.len()),
993            t_eval
994                .iter()
995                .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
996                .collect(),
997        )
998        .unwrap();
999
1000        for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
1001            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1002            let (objective, gradient) = OdeSolverType::Bdf
1003                .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1004                    &mut problem,
1005                    data.view(),
1006                    &t_eval,
1007                    backwards_method,
1008                    LinearSolverType::Klu,
1009                )
1010                .unwrap();
1011            assert!(objective.is_finite());
1012            assert_eq!(gradient.len(), 1);
1013            assert!(gradient.get_index(0).is_finite());
1014        }
1015
1016        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1017        let err = OdeSolverType::Bdf
1018            .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1019                &mut problem,
1020                data.view(),
1021                &t_eval,
1022                OdeSolverType::Tsit45,
1023                LinearSolverType::Default,
1024            )
1025            .unwrap_err();
1026        assert!(err
1027            .to_string()
1028            .contains("Tsit45 solver does not support adjoint sensitivity analysis"));
1029    }
1030
1031    #[cfg(feature = "diffsl-llvm")]
1032    fn test_all_adjoint_solver_variants() {
1033        let t_eval = [0.0, 0.25, 0.5, 1.0];
1034        let data = Array2::from_shape_vec(
1035            (1, t_eval.len()),
1036            t_eval
1037                .iter()
1038                .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
1039                .collect(),
1040        )
1041        .unwrap();
1042
1043        for method in [
1044            OdeSolverType::Bdf,
1045            OdeSolverType::Esdirk34,
1046            OdeSolverType::TrBdf2,
1047            OdeSolverType::Tsit45,
1048        ] {
1049            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1050            let (objective, gradient) = method
1051                .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1052                    &mut problem,
1053                    data.view(),
1054                    &t_eval,
1055                    OdeSolverType::Bdf,
1056                    crate::linear_solver_type::LinearSolverType::Default,
1057                )
1058                .unwrap();
1059            assert!(objective.is_finite());
1060            assert_eq!(gradient.len(), 1);
1061            assert!(gradient.get_index(0).is_finite());
1062        }
1063    }
1064
1065    #[cfg(feature = "diffsl-cranelift")]
1066    #[test]
1067    fn runtime_dispatch_solves_all_variants_for_cranelift() {
1068        test_all_solver_variants::<diffsol::CraneliftJitModule>();
1069        test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1070    }
1071
1072    #[cfg(feature = "diffsl-cranelift")]
1073    #[test]
1074    fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1075        test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1076        test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1077        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1078            OdeSolverType::Esdirk34,
1079        );
1080        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1081            OdeSolverType::TrBdf2,
1082        );
1083        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1084            OdeSolverType::Tsit45,
1085        );
1086    }
1087
1088    #[cfg(feature = "diffsl-llvm")]
1089    #[test]
1090    fn runtime_dispatch_solves_all_variants_for_llvm() {
1091        test_all_solver_variants::<diffsol::LlvmModule>();
1092        test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1093    }
1094
1095    #[cfg(feature = "diffsl-llvm")]
1096    #[test]
1097    fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1098        test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1099        test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1100        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1101            OdeSolverType::Esdirk34,
1102        );
1103        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1104        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1105    }
1106
1107    #[cfg(feature = "diffsl-llvm")]
1108    #[test]
1109    fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1110        test_all_sensitivity_solver_variants();
1111        test_lu_sensitivity_and_adjoint_solver_variants();
1112        test_direct_hybrid_sensitivity_restart_paths();
1113    }
1114
1115    #[cfg(feature = "diffsl-llvm")]
1116    #[test]
1117    fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1118        test_all_adjoint_solver_variants();
1119        test_adjoint_backwards_methods_and_klu_branch();
1120    }
1121}