Skip to main content

diffsol_c/
ode_solver_type.rs

1// Solver method Python enum. This is used to select the overarching solver
2// stragegy like bdf or esdirk34 in diffsol.
3
4use diffsol::error::DiffsolError;
5use diffsol::{
6    AdjointOdeSolverMethod, Checkpointing, CodegenModule, DefaultSolver, DenseMatrix, MatrixCommon,
7    OdeEquations, OdeSolverState, OdeSolverStopReason, Op, SensitivitiesOdeSolverMethod, Solution,
8    VectorViewMut,
9};
10use diffsol::{
11    DefaultDenseMatrix, DiffSl, LinearSolver, Matrix, OdeSolverMethod, OdeSolverProblem, Vector,
12    VectorHost, VectorRef, matrix::MatrixRef,
13};
14use ndarray::ArrayView2;
15use num_traits::{FromPrimitive, Zero}; // for generic nums in _solve_sum_squares_adj
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18
19use crate::scalar_type::Scalar;
20use crate::utils::is_sens_available;
21use crate::{
22    linear_solver_type::LinearSolverType,
23    valid_linear_solver::{KluValidator, LuValidator},
24};
25
26/// Enumerates the possible ODE solver methods for diffsol. See the solver descriptions in the diffsol documentation (https://github.com/martinjrobins/diffsol) for more details.
27///
28/// :attr bdf: Backward Differentiation Formula (BDF) method for stiff ODEs and singular mass matrices
29/// :attr esdirk34: Explicit Singly Diagonally Implicit Runge-Kutta (ESDIRK) method for moderately stiff ODEs and singular mass matrices.
30/// :attr tr_bdf2: Trapezoidal Backward Differentiation Formula of order 2 (TR-BDF2) method for moderately stiff ODEs and singular mass matrices.
31/// :attr tsit45: Tsitouras 4/5th order Explicit Runge-Kutta (TSIT45) method for non-stiff ODEs. This is an explicit method, it cannot handle singular mass matrices and does not require a linear solver.
32#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
33#[serde(rename_all = "snake_case")]
34pub enum OdeSolverType {
35    Bdf,
36    Esdirk34,
37    TrBdf2,
38    Tsit45,
39}
40
41impl OdeSolverType {
42    pub(crate) fn solve<M, CG, LS>(
43        &self,
44        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
45        final_time: M::T,
46    ) -> Result<Solution<M::V>, DiffsolError>
47    where
48        M: Matrix<T: Scalar>,
49        CG: CodegenModule,
50        M::V: VectorHost + DefaultDenseMatrix,
51        LS: LinearSolver<M>,
52        for<'b> &'b M::V: VectorRef<M::V>,
53        for<'b> &'b M: MatrixRef<M>,
54    {
55        match self {
56            OdeSolverType::Bdf => {
57                let solver = problem.bdf::<LS>()?;
58                let mut soln = Solution::new(final_time);
59                solver.solve_soln(&mut soln)?;
60                Ok(soln)
61            }
62            OdeSolverType::Esdirk34 => {
63                let solver = problem.esdirk34::<LS>()?;
64                let mut soln = Solution::new(final_time);
65                solver.solve_soln(&mut soln)?;
66                Ok(soln)
67            }
68            OdeSolverType::TrBdf2 => {
69                let solver = problem.tr_bdf2::<LS>()?;
70                let mut soln = Solution::new(final_time);
71                solver.solve_soln(&mut soln)?;
72                Ok(soln)
73            }
74            OdeSolverType::Tsit45 => {
75                let solver = problem.tsit45()?;
76                let mut soln = Solution::new(final_time);
77                solver.solve_soln(&mut soln)?;
78                Ok(soln)
79            }
80        }
81    }
82
83    pub(crate) fn solve_dense<M, CG, LS>(
84        &self,
85        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
86        t_eval: &[M::T],
87    ) -> Result<Solution<M::V>, DiffsolError>
88    where
89        M: Matrix<T: Scalar>,
90        CG: CodegenModule,
91        M::V: VectorHost + DefaultDenseMatrix,
92        LS: LinearSolver<M>,
93        for<'b> &'b M::V: VectorRef<M::V>,
94        for<'b> &'b M: MatrixRef<M>,
95    {
96        match self {
97            OdeSolverType::Bdf => {
98                let solver = problem.bdf::<LS>()?;
99                let mut soln = Solution::new_dense(t_eval.to_vec())?;
100                solver.solve_soln(&mut soln)?;
101                Ok(soln)
102            }
103            OdeSolverType::Esdirk34 => {
104                let solver = problem.esdirk34::<LS>()?;
105                let mut soln = Solution::new_dense(t_eval.to_vec())?;
106                solver.solve_soln(&mut soln)?;
107                Ok(soln)
108            }
109            OdeSolverType::TrBdf2 => {
110                let solver = problem.tr_bdf2::<LS>()?;
111                let mut soln = Solution::new_dense(t_eval.to_vec())?;
112                solver.solve_soln(&mut soln)?;
113                Ok(soln)
114            }
115            OdeSolverType::Tsit45 => {
116                let solver = problem.tsit45()?;
117                let mut soln = Solution::new_dense(t_eval.to_vec())?;
118                solver.solve_soln(&mut soln)?;
119                Ok(soln)
120            }
121        }
122    }
123
124    pub(crate) fn solve_hybrid<M, CG, LS>(
125        &self,
126        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
127        final_time: M::T,
128    ) -> Result<Solution<M::V>, DiffsolError>
129    where
130        M: Matrix<T: Scalar>,
131        CG: CodegenModule,
132        M::V: VectorHost + DefaultDenseMatrix,
133        LS: LinearSolver<M>,
134        for<'b> &'b M::V: VectorRef<M::V>,
135        for<'b> &'b M: MatrixRef<M>,
136    {
137        match self {
138            OdeSolverType::Bdf => {
139                let mut soln = Solution::new(final_time);
140                let mut solver = problem.bdf::<LS>()?;
141                while !soln.is_complete() {
142                    solver = solver.solve_soln(&mut soln)?;
143                    let root_idx = match soln.stop_reason {
144                        Some(OdeSolverStopReason::RootFound(_, root_idx))
145                            if !soln.is_complete() =>
146                        {
147                            root_idx
148                        }
149                        _ => continue,
150                    };
151                    let state = solver.into_state();
152                    problem.eqn.set_model_index(root_idx);
153                    let mut restarted_solver = problem.bdf_solver::<LS>(state)?;
154                    restarted_solver.reset()?;
155                    solver = restarted_solver;
156                }
157                Ok(soln)
158            }
159            OdeSolverType::Esdirk34 => {
160                let mut soln = Solution::new(final_time);
161                let mut solver = problem.esdirk34::<LS>()?;
162                while !soln.is_complete() {
163                    solver = solver.solve_soln(&mut soln)?;
164                    let root_idx = match soln.stop_reason {
165                        Some(OdeSolverStopReason::RootFound(_, root_idx))
166                            if !soln.is_complete() =>
167                        {
168                            root_idx
169                        }
170                        _ => continue,
171                    };
172                    let state = solver.into_state();
173                    problem.eqn.set_model_index(root_idx);
174                    let mut restarted_solver = problem.esdirk34_solver::<LS>(state)?;
175                    restarted_solver.reset()?;
176                    solver = restarted_solver;
177                }
178                Ok(soln)
179            }
180            OdeSolverType::TrBdf2 => {
181                let mut soln = Solution::new(final_time);
182                let mut solver = problem.tr_bdf2::<LS>()?;
183                while !soln.is_complete() {
184                    solver = solver.solve_soln(&mut soln)?;
185                    let root_idx = match soln.stop_reason {
186                        Some(OdeSolverStopReason::RootFound(_, root_idx))
187                            if !soln.is_complete() =>
188                        {
189                            root_idx
190                        }
191                        _ => continue,
192                    };
193                    let state = solver.into_state();
194                    problem.eqn.set_model_index(root_idx);
195                    let mut restarted_solver = problem.tr_bdf2_solver::<LS>(state)?;
196                    restarted_solver.reset()?;
197                    solver = restarted_solver;
198                }
199                Ok(soln)
200            }
201            OdeSolverType::Tsit45 => {
202                let mut soln = Solution::new(final_time);
203                let mut solver = problem.tsit45()?;
204                while !soln.is_complete() {
205                    solver = solver.solve_soln(&mut soln)?;
206                    let root_idx = match soln.stop_reason {
207                        Some(OdeSolverStopReason::RootFound(_, root_idx))
208                            if !soln.is_complete() =>
209                        {
210                            root_idx
211                        }
212                        _ => continue,
213                    };
214                    let state = solver.into_state();
215                    problem.eqn.set_model_index(root_idx);
216                    let mut restarted_solver = problem.tsit45_solver(state)?;
217                    restarted_solver.reset()?;
218                    solver = restarted_solver;
219                }
220                Ok(soln)
221            }
222        }
223    }
224
225    pub(crate) fn solve_hybrid_dense<M, CG, LS>(
226        &self,
227        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
228        t_eval: &[M::T],
229    ) -> Result<Solution<M::V>, DiffsolError>
230    where
231        M: Matrix<T: Scalar>,
232        CG: CodegenModule,
233        M::V: VectorHost + DefaultDenseMatrix,
234        LS: LinearSolver<M>,
235        for<'b> &'b M::V: VectorRef<M::V>,
236        for<'b> &'b M: MatrixRef<M>,
237    {
238        match self {
239            OdeSolverType::Bdf => {
240                let mut soln = Solution::new_dense(t_eval.to_vec())?;
241                let mut solver = problem.bdf::<LS>()?;
242                while !soln.is_complete() {
243                    solver = solver.solve_soln(&mut soln)?;
244                    let root_idx = match soln.stop_reason {
245                        Some(OdeSolverStopReason::RootFound(_, root_idx))
246                            if !soln.is_complete() =>
247                        {
248                            root_idx
249                        }
250                        _ => continue,
251                    };
252                    let state = solver.into_state();
253                    problem.eqn.set_model_index(root_idx);
254                    let mut restarted_solver = problem.bdf_solver::<LS>(state)?;
255                    restarted_solver.reset()?;
256                    solver = restarted_solver;
257                }
258                Ok(soln)
259            }
260            OdeSolverType::Esdirk34 => {
261                let mut soln = Solution::new_dense(t_eval.to_vec())?;
262                let mut solver = problem.esdirk34::<LS>()?;
263                while !soln.is_complete() {
264                    solver = solver.solve_soln(&mut soln)?;
265                    let root_idx = match soln.stop_reason {
266                        Some(OdeSolverStopReason::RootFound(_, root_idx))
267                            if !soln.is_complete() =>
268                        {
269                            root_idx
270                        }
271                        _ => continue,
272                    };
273                    let state = solver.into_state();
274                    problem.eqn.set_model_index(root_idx);
275                    let mut restarted_solver = problem.esdirk34_solver::<LS>(state)?;
276                    restarted_solver.reset()?;
277                    solver = restarted_solver;
278                }
279                Ok(soln)
280            }
281            OdeSolverType::TrBdf2 => {
282                let mut soln = Solution::new_dense(t_eval.to_vec())?;
283                let mut solver = problem.tr_bdf2::<LS>()?;
284                while !soln.is_complete() {
285                    solver = solver.solve_soln(&mut soln)?;
286                    let root_idx = match soln.stop_reason {
287                        Some(OdeSolverStopReason::RootFound(_, root_idx))
288                            if !soln.is_complete() =>
289                        {
290                            root_idx
291                        }
292                        _ => continue,
293                    };
294                    let state = solver.into_state();
295                    problem.eqn.set_model_index(root_idx);
296                    let mut restarted_solver = problem.tr_bdf2_solver::<LS>(state)?;
297                    restarted_solver.reset()?;
298                    solver = restarted_solver;
299                }
300                Ok(soln)
301            }
302            OdeSolverType::Tsit45 => {
303                let mut soln = Solution::new_dense(t_eval.to_vec())?;
304                let mut solver = problem.tsit45()?;
305                while !soln.is_complete() {
306                    solver = solver.solve_soln(&mut soln)?;
307                    let root_idx = match soln.stop_reason {
308                        Some(OdeSolverStopReason::RootFound(_, root_idx))
309                            if !soln.is_complete() =>
310                        {
311                            root_idx
312                        }
313                        _ => continue,
314                    };
315                    let state = solver.into_state();
316                    problem.eqn.set_model_index(root_idx);
317                    let mut restarted_solver = problem.tsit45_solver(state)?;
318                    restarted_solver.reset()?;
319                    solver = restarted_solver;
320                }
321                Ok(soln)
322            }
323        }
324    }
325
326    fn check_sens_available() -> Result<(), DiffsolError> {
327        if !is_sens_available() {
328            return Err(DiffsolError::Other(
329                "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
330            ));
331        }
332        Ok(())
333    }
334
335    #[allow(clippy::type_complexity)]
336    pub(crate) fn solve_fwd_sens<M, CG, LS>(
337        &self,
338        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
339        t_eval: &[M::T],
340    ) -> Result<Solution<M::V>, DiffsolError>
341    where
342        M: Matrix<T: Scalar> + DefaultSolver,
343        CG: CodegenModule,
344        M::V: VectorHost + DefaultDenseMatrix,
345        LS: LinearSolver<M>,
346        for<'b> &'b M::V: VectorRef<M::V>,
347        for<'b> &'b M: MatrixRef<M>,
348    {
349        Self::check_sens_available()?;
350        match self {
351            OdeSolverType::Bdf => {
352                let solver = problem.bdf_sens::<LS>()?;
353                let mut soln = Solution::new_dense(t_eval.to_vec())?;
354                solver.solve_soln_sensitivities(&mut soln)?;
355                Ok(soln)
356            }
357            OdeSolverType::Esdirk34 => {
358                let solver = problem.esdirk34_sens::<LS>()?;
359                let mut soln = Solution::new_dense(t_eval.to_vec())?;
360                solver.solve_soln_sensitivities(&mut soln)?;
361                Ok(soln)
362            }
363            OdeSolverType::TrBdf2 => {
364                let solver = problem.tr_bdf2_sens::<LS>()?;
365                let mut soln = Solution::new_dense(t_eval.to_vec())?;
366                solver.solve_soln_sensitivities(&mut soln)?;
367                Ok(soln)
368            }
369            OdeSolverType::Tsit45 => {
370                let solver = problem.tsit45_sens()?;
371                let mut soln = Solution::new_dense(t_eval.to_vec())?;
372                solver.solve_soln_sensitivities(&mut soln)?;
373                Ok(soln)
374            }
375        }
376    }
377
378    #[allow(clippy::type_complexity)]
379    pub(crate) fn solve_hybrid_fwd_sens<M, CG, LS>(
380        &self,
381        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
382        t_eval: &[M::T],
383    ) -> Result<Solution<M::V>, DiffsolError>
384    where
385        M: Matrix<T: Scalar> + DefaultSolver,
386        CG: CodegenModule,
387        M::V: VectorHost + DefaultDenseMatrix,
388        LS: LinearSolver<M>,
389        for<'b> &'b M::V: VectorRef<M::V>,
390        for<'b> &'b M: MatrixRef<M>,
391    {
392        Self::check_sens_available()?;
393        match self {
394            OdeSolverType::Bdf => {
395                let mut soln = Solution::new_dense(t_eval.to_vec())?;
396                let mut solver = problem.bdf_sens::<LS>()?;
397                while !soln.is_complete() {
398                    solver = solver.solve_soln_sensitivities(&mut soln)?;
399                    let root_idx = match soln.stop_reason {
400                        Some(OdeSolverStopReason::RootFound(_, root_idx))
401                            if !soln.is_complete() =>
402                        {
403                            root_idx
404                        }
405                        _ => continue,
406                    };
407                    let state = solver.into_state();
408                    problem.eqn.set_model_index(root_idx);
409                    let mut restarted_solver = problem.bdf_solver_sens::<LS>(state)?;
410                    restarted_solver.reset_with_sens()?;
411                    solver = restarted_solver;
412                }
413                Ok(soln)
414            }
415            OdeSolverType::Esdirk34 => {
416                let mut soln = Solution::new_dense(t_eval.to_vec())?;
417                let mut solver = problem.esdirk34_sens::<LS>()?;
418                while !soln.is_complete() {
419                    solver = solver.solve_soln_sensitivities(&mut soln)?;
420                    let root_idx = match soln.stop_reason {
421                        Some(OdeSolverStopReason::RootFound(_, root_idx))
422                            if !soln.is_complete() =>
423                        {
424                            root_idx
425                        }
426                        _ => continue,
427                    };
428                    let state = solver.into_state();
429                    problem.eqn.set_model_index(root_idx);
430                    let mut restarted_solver = problem.esdirk34_solver_sens::<LS>(state)?;
431                    restarted_solver.reset_with_sens()?;
432                    solver = restarted_solver;
433                }
434                Ok(soln)
435            }
436            OdeSolverType::TrBdf2 => {
437                let mut soln = Solution::new_dense(t_eval.to_vec())?;
438                let mut solver = problem.tr_bdf2_sens::<LS>()?;
439                while !soln.is_complete() {
440                    solver = solver.solve_soln_sensitivities(&mut soln)?;
441                    let root_idx = match soln.stop_reason {
442                        Some(OdeSolverStopReason::RootFound(_, root_idx))
443                            if !soln.is_complete() =>
444                        {
445                            root_idx
446                        }
447                        _ => continue,
448                    };
449                    let state = solver.into_state();
450                    problem.eqn.set_model_index(root_idx);
451                    let mut restarted_solver = problem.tr_bdf2_solver_sens::<LS>(state)?;
452                    restarted_solver.reset_with_sens()?;
453                    solver = restarted_solver;
454                }
455                Ok(soln)
456            }
457            OdeSolverType::Tsit45 => {
458                let mut soln = Solution::new_dense(t_eval.to_vec())?;
459                let mut solver = problem.tsit45_sens()?;
460                while !soln.is_complete() {
461                    solver = solver.solve_soln_sensitivities(&mut soln)?;
462                    let root_idx = match soln.stop_reason {
463                        Some(OdeSolverStopReason::RootFound(_, root_idx))
464                            if !soln.is_complete() =>
465                        {
466                            root_idx
467                        }
468                        _ => continue,
469                    };
470                    let state = solver.into_state();
471                    problem.eqn.set_model_index(root_idx);
472                    let mut restarted_solver = problem.tsit45_solver_sens(state)?;
473                    restarted_solver.reset_with_sens()?;
474                    solver = restarted_solver;
475                }
476                Ok(soln)
477            }
478        }
479    }
480
481    pub(crate) fn solve_sum_squares_adj<'a, M, CG, LS>(
482        &self,
483        problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
484        data: ArrayView2<'a, M::T>,
485        t_eval: &[M::T],
486        backwards_method: OdeSolverType,
487        backwards_linear_solver: LinearSolverType,
488    ) -> Result<(M::T, M::V), DiffsolError>
489    where
490        M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
491        CG: CodegenModule,
492        M::V: VectorHost + DefaultDenseMatrix,
493        LS: LinearSolver<M>,
494        for<'b> &'b M::V: VectorRef<M::V>,
495        for<'b> &'b M: MatrixRef<M>,
496    {
497        Self::check_sens_available()?;
498        match self {
499            OdeSolverType::Bdf => self._solve_sum_squares_adj(
500                problem.bdf::<LS>()?,
501                data,
502                t_eval,
503                backwards_method,
504                backwards_linear_solver,
505            ),
506            OdeSolverType::Esdirk34 => self._solve_sum_squares_adj(
507                problem.esdirk34::<LS>()?,
508                data,
509                t_eval,
510                backwards_method,
511                backwards_linear_solver,
512            ),
513            OdeSolverType::TrBdf2 => self._solve_sum_squares_adj(
514                problem.tr_bdf2::<LS>()?,
515                data,
516                t_eval,
517                backwards_method,
518                backwards_linear_solver,
519            ),
520            OdeSolverType::Tsit45 => self._solve_sum_squares_adj(
521                problem.tsit45()?,
522                data,
523                t_eval,
524                backwards_method,
525                backwards_linear_solver,
526            ),
527        }
528    }
529
530    pub(crate) fn _solve_sum_squares_adj<'data, 'solver, M, CG, S>(
531        &self,
532        mut solver: S,
533        data: ArrayView2<'data, M::T>,
534        t_eval: &[M::T],
535        backwards_method: OdeSolverType,
536        backwards_linear_solver: LinearSolverType,
537    ) -> Result<(M::T, M::V), DiffsolError>
538    where
539        M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
540        CG: CodegenModule,
541        M::V: VectorHost + DefaultDenseMatrix,
542        S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
543        for<'b> &'b M::V: VectorRef<M::V>,
544        for<'b> &'b M: MatrixRef<M>,
545    {
546        let (chk, ys) = solver.solve_dense_with_checkpointing(t_eval, None)?;
547        let eqn = solver.problem().eqn();
548        let ctx = eqn.context();
549        let mut g_m = <M::V as DefaultDenseMatrix>::M::zeros(eqn.nout(), t_eval.len(), ctx.clone());
550        let mut y = M::T::zero();
551        for j in 0..g_m.ncols() {
552            let ys_col = ys.column(j);
553            // TODO: can we avoid this allocation? (I can't see how right now)
554            let mut tmp = M::V::from_slice(data.column(j).as_slice().unwrap(), ctx.clone());
555            // tmp = 2 * ys_col - 2 * tmp
556            tmp.axpy_v(
557                M::T::from_f64(2.0).unwrap(),
558                &ys_col,
559                M::T::from_f64(-2.0).unwrap(),
560            );
561            g_m.column_mut(j).copy_from(&tmp);
562
563            // y = (1/4) * dot(tmp, tmp) + y
564            let norm = tmp.norm(2);
565            y += M::T::from_f64(1.0 / 4.0).unwrap() * norm * norm;
566        }
567        let mut y_sens = match backwards_linear_solver {
568            LinearSolverType::Default => backwards_method
569                .solve_adjoint_backwards::<M, CG, <M as DefaultSolver>::LS, S>(
570                    solver.problem(),
571                    chk,
572                    &g_m,
573                    t_eval,
574                    Some(1),
575                )?,
576            LinearSolverType::Lu => backwards_method
577                .solve_adjoint_backwards::<M, CG, <M as LuValidator<M>>::LS, S>(
578                    solver.problem(),
579                    chk,
580                    &g_m,
581                    t_eval,
582                    Some(1),
583                )?,
584            LinearSolverType::Klu => backwards_method
585                .solve_adjoint_backwards::<M, CG, <M as KluValidator<M>>::LS, S>(
586                    solver.problem(),
587                    chk,
588                    &g_m,
589                    t_eval,
590                    Some(1),
591                )?,
592        };
593        Ok((y, y_sens.pop().unwrap()))
594    }
595
596    pub(crate) fn solve_adjoint_backwards<'solver, M, CG, LS, S>(
597        &self,
598        problem: &'solver OdeSolverProblem<DiffSl<M, CG>>,
599        checkpointing: Checkpointing<'solver, DiffSl<M, CG>, S>,
600        g_m: &<M::V as DefaultDenseMatrix>::M,
601        t_eval: &[M::T],
602        nout_override: Option<usize>,
603    ) -> Result<Vec<M::V>, DiffsolError>
604    where
605        M: Matrix<T: Scalar> + DefaultSolver,
606        CG: CodegenModule,
607        M::V: VectorHost + DefaultDenseMatrix,
608        S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
609        LS: LinearSolver<M>,
610        for<'b> &'b M::V: VectorRef<M::V>,
611        for<'b> &'b M: MatrixRef<M>,
612    {
613        match self {
614            OdeSolverType::Bdf => problem
615                .bdf_solver_adjoint::<LS, _>(checkpointing, nout_override)?
616                .solve_adjoint_backwards_pass(t_eval, &[g_m])
617                .map(|res| res.into_common().sg),
618            OdeSolverType::Esdirk34 => problem
619                .esdirk34_solver_adjoint::<LS, _>(checkpointing, nout_override)?
620                .solve_adjoint_backwards_pass(t_eval, &[g_m])
621                .map(|res| res.into_common().sg),
622            OdeSolverType::TrBdf2 => problem
623                .tr_bdf2_solver_adjoint::<LS, _>(checkpointing, nout_override)?
624                .solve_adjoint_backwards_pass(t_eval, &[g_m])
625                .map(|res| res.into_common().sg),
626            OdeSolverType::Tsit45 => Err(DiffsolError::Other(
627                "Tsit45 solver does not support adjoint sensitivity analysis.".to_string(),
628            )),
629        }
630    }
631}
632
633#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
634mod tests {
635    use diffsol::Vector;
636    use diffsol::{
637        CodegenModuleCompile, CodegenModuleJit, DefaultSolver, DenseMatrix, OdeBuilder,
638        OdeSolverProblem,
639    };
640    use ndarray::Array2;
641
642    use crate::linear_solver_type::LinearSolverType;
643    use crate::test_support::{
644        LOGISTIC_X0, assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state,
645        hybrid_logistic_state_dr, logistic_diffsl_code, logistic_integral, logistic_state,
646        logistic_state_dr,
647    };
648    use crate::valid_linear_solver::LuValidator;
649
650    use super::OdeSolverType;
651
652    type M = diffsol::NalgebraMat<f64>;
653
654    fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
655    where
656        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
657    {
658        OdeBuilder::<M>::new()
659            .p([2.0])
660            .rtol(1e-6)
661            .atol([1e-6])
662            .build_from_diffsl::<CG>(code)
663            .unwrap()
664    }
665
666    fn assert_dense_solution_matches_expected(
667        soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
668        t_eval: &[f64],
669        expected: impl Fn(f64) -> f64,
670    ) {
671        assert_eq!(soln.ts, t_eval);
672        for (i, &t) in t_eval.iter().enumerate() {
673            assert_close(
674                soln.ys.get_index(0, i),
675                expected(t),
676                5e-4,
677                &format!("solution[{i}]"),
678            );
679        }
680    }
681
682    fn test_all_solver_variants<CG>()
683    where
684        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
685    {
686        let t_eval = [0.25, 0.5, 1.0];
687        for method in [
688            OdeSolverType::Bdf,
689            OdeSolverType::Esdirk34,
690            OdeSolverType::TrBdf2,
691            OdeSolverType::Tsit45,
692        ] {
693            let mut problem = build_problem::<CG>(logistic_diffsl_code());
694            let soln = method
695                .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
696                .unwrap();
697            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
698            assert_close(
699                soln.ys.get_index(0, soln.ts.len() - 1),
700                logistic_state(LOGISTIC_X0, 2.0, 1.0),
701                5e-4,
702                "solve final value",
703            );
704
705            let mut problem = build_problem::<CG>(logistic_diffsl_code());
706            let soln = method
707                .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
708                .unwrap();
709            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
710                logistic_state(LOGISTIC_X0, 2.0, t)
711            });
712        }
713    }
714
715    fn test_all_hybrid_solver_variants<CG>()
716    where
717        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
718    {
719        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
720        for method in [
721            OdeSolverType::Bdf,
722            OdeSolverType::Esdirk34,
723            OdeSolverType::TrBdf2,
724            OdeSolverType::Tsit45,
725        ] {
726            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
727            let soln = method
728                .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
729                .unwrap();
730            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
731            assert_close(
732                soln.ys.get_index(0, soln.ts.len() - 1),
733                hybrid_logistic_state(2.0, 2.0),
734                5e-4,
735                "hybrid final value",
736            );
737
738            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
739            let soln = method
740                .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
741                .unwrap();
742            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
743                hybrid_logistic_state(2.0, t)
744            });
745        }
746    }
747
748    fn test_all_solver_variants_with_lu<CG>()
749    where
750        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
751    {
752        let t_eval = [0.25, 0.5, 1.0];
753        for method in [
754            OdeSolverType::Bdf,
755            OdeSolverType::Esdirk34,
756            OdeSolverType::TrBdf2,
757            OdeSolverType::Tsit45,
758        ] {
759            let mut problem = build_problem::<CG>(logistic_diffsl_code());
760            let soln = method
761                .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
762                .unwrap();
763            assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
764
765            let mut problem = build_problem::<CG>(logistic_diffsl_code());
766            let soln = method
767                .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
768                .unwrap();
769            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
770                logistic_state(LOGISTIC_X0, 2.0, t)
771            });
772        }
773    }
774
775    fn test_all_hybrid_solver_variants_with_lu<CG>()
776    where
777        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
778    {
779        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
780        for method in [
781            OdeSolverType::Bdf,
782            OdeSolverType::Esdirk34,
783            OdeSolverType::TrBdf2,
784            OdeSolverType::Tsit45,
785        ] {
786            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
787            let soln = method
788                .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
789                .unwrap();
790            assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
791
792            let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
793            let soln = method
794                .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
795                .unwrap();
796            assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
797                hybrid_logistic_state(2.0, t)
798            });
799        }
800    }
801
802    fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
803    where
804        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
805    {
806        let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
807
808        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
809        let soln = method
810            .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
811            .unwrap();
812        assert_close(
813            *soln.ts.last().unwrap(),
814            2.0,
815            5e-4,
816            "direct hybrid restart final time",
817        );
818        assert_close(
819            soln.ys.get_index(0, soln.ts.len() - 1),
820            hybrid_logistic_state(2.0, 2.0),
821            5e-4,
822            "direct hybrid restart final value",
823        );
824
825        let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
826        let soln = method
827            .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
828            .unwrap();
829        assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
830    }
831
832    #[cfg(feature = "diffsl-llvm")]
833    fn test_all_sensitivity_solver_variants() {
834        let t_eval = [0.25, 0.5, 1.0];
835        for method in [
836            OdeSolverType::Bdf,
837            OdeSolverType::Esdirk34,
838            OdeSolverType::TrBdf2,
839            OdeSolverType::Tsit45,
840        ] {
841            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
842            let soln = method
843                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
844                    &mut problem,
845                    &t_eval,
846                )
847                .unwrap();
848            for (i, &t) in t_eval.iter().enumerate() {
849                assert_close(
850                    soln.y_sens[0].get_index(0, i),
851                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
852                    5e-4,
853                    &format!("fwd_sens[{i}]"),
854                );
855            }
856
857            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
858            let soln = method
859                .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
860                    &mut problem,
861                    &t_eval,
862                )
863                .unwrap();
864            for (i, &t) in t_eval.iter().enumerate() {
865                assert_close(
866                    soln.y_sens[0].get_index(0, i),
867                    hybrid_logistic_state_dr(2.0, t),
868                    5e-4,
869                    &format!("hybrid_fwd_sens[{i}]"),
870                );
871            }
872        }
873    }
874
875    #[cfg(feature = "diffsl-llvm")]
876    fn test_lu_sensitivity_and_adjoint_solver_variants() {
877        let t_eval = [0.25, 0.5, 1.0];
878        for method in [
879            OdeSolverType::Bdf,
880            OdeSolverType::Esdirk34,
881            OdeSolverType::TrBdf2,
882            OdeSolverType::Tsit45,
883        ] {
884            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
885            let soln = method
886                .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
887                    &mut problem,
888                    &t_eval,
889                )
890                .unwrap();
891            for (i, &t) in t_eval.iter().enumerate() {
892                assert_close(
893                    soln.y_sens[0].get_index(0, i),
894                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
895                    5e-4,
896                    &format!("lu fwd_sens[{i}]"),
897                );
898            }
899        }
900
901        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
902        let data = Array2::from_shape_vec(
903            (1, adjoint_t_eval.len()),
904            adjoint_t_eval
905                .iter()
906                .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
907                .collect(),
908        )
909        .unwrap();
910
911        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
912        let (objective, gradient) = OdeSolverType::Bdf
913            .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
914                &mut problem,
915                data.view(),
916                &adjoint_t_eval,
917                OdeSolverType::TrBdf2,
918                LinearSolverType::Lu,
919            )
920            .unwrap();
921        assert!(objective.is_finite());
922        assert_eq!(gradient.len(), 1);
923        assert!(gradient.get_index(0).is_finite());
924    }
925
926    #[cfg(feature = "diffsl-llvm")]
927    fn test_direct_hybrid_sensitivity_restart_paths() {
928        let t_eval = [0.25, 0.5, 1.0];
929        for method in [
930            OdeSolverType::Esdirk34,
931            OdeSolverType::TrBdf2,
932            OdeSolverType::Tsit45,
933        ] {
934            let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
935            let soln = method
936                .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
937                    &mut problem,
938                    &t_eval,
939                )
940                .unwrap();
941            for (i, &t) in t_eval.iter().enumerate() {
942                assert_close(
943                    soln.y_sens[0].get_index(0, i),
944                    hybrid_logistic_state_dr(2.0, t),
945                    5e-4,
946                    &format!("direct hybrid fwd sens[{i}]"),
947                );
948            }
949        }
950    }
951
952    #[cfg(feature = "diffsl-llvm")]
953    fn test_adjoint_backwards_methods_and_klu_branch() {
954        let t_eval = [0.0, 0.25, 0.5, 1.0];
955        let data = Array2::from_shape_vec(
956            (1, t_eval.len()),
957            t_eval
958                .iter()
959                .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
960                .collect(),
961        )
962        .unwrap();
963
964        for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
965            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
966            let (objective, gradient) = OdeSolverType::Bdf
967                .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
968                    &mut problem,
969                    data.view(),
970                    &t_eval,
971                    backwards_method,
972                    LinearSolverType::Klu,
973                )
974                .unwrap();
975            assert!(objective.is_finite());
976            assert_eq!(gradient.len(), 1);
977            assert!(gradient.get_index(0).is_finite());
978        }
979
980        let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
981        let err = OdeSolverType::Bdf
982            .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
983                &mut problem,
984                data.view(),
985                &t_eval,
986                OdeSolverType::Tsit45,
987                LinearSolverType::Default,
988            )
989            .unwrap_err();
990        assert!(
991            err.to_string()
992                .contains("Tsit45 solver does not support adjoint sensitivity analysis")
993        );
994    }
995
996    #[cfg(feature = "diffsl-llvm")]
997    fn test_all_adjoint_solver_variants() {
998        let t_eval = [0.0, 0.25, 0.5, 1.0];
999        let data = Array2::from_shape_vec(
1000            (1, t_eval.len()),
1001            t_eval
1002                .iter()
1003                .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
1004                .collect(),
1005        )
1006        .unwrap();
1007
1008        for method in [
1009            OdeSolverType::Bdf,
1010            OdeSolverType::Esdirk34,
1011            OdeSolverType::TrBdf2,
1012            OdeSolverType::Tsit45,
1013        ] {
1014            let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1015            let (objective, gradient) = method
1016                .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1017                    &mut problem,
1018                    data.view(),
1019                    &t_eval,
1020                    OdeSolverType::Bdf,
1021                    crate::linear_solver_type::LinearSolverType::Default,
1022                )
1023                .unwrap();
1024            assert!(objective.is_finite());
1025            assert_eq!(gradient.len(), 1);
1026            assert!(gradient.get_index(0).is_finite());
1027        }
1028    }
1029
1030    #[cfg(feature = "diffsl-cranelift")]
1031    #[test]
1032    fn runtime_dispatch_solves_all_variants_for_cranelift() {
1033        test_all_solver_variants::<diffsol::CraneliftJitModule>();
1034        test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1035    }
1036
1037    #[cfg(feature = "diffsl-cranelift")]
1038    #[test]
1039    fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1040        test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1041        test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1042        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1043            OdeSolverType::Esdirk34,
1044        );
1045        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1046            OdeSolverType::TrBdf2,
1047        );
1048        assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1049            OdeSolverType::Tsit45,
1050        );
1051    }
1052
1053    #[cfg(feature = "diffsl-llvm")]
1054    #[test]
1055    fn runtime_dispatch_solves_all_variants_for_llvm() {
1056        test_all_solver_variants::<diffsol::LlvmModule>();
1057        test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1058    }
1059
1060    #[cfg(feature = "diffsl-llvm")]
1061    #[test]
1062    fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1063        test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1064        test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1065        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1066            OdeSolverType::Esdirk34,
1067        );
1068        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1069        assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1070    }
1071
1072    #[cfg(feature = "diffsl-llvm")]
1073    #[test]
1074    fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1075        test_all_sensitivity_solver_variants();
1076        test_lu_sensitivity_and_adjoint_solver_variants();
1077        test_direct_hybrid_sensitivity_restart_paths();
1078    }
1079
1080    #[cfg(feature = "diffsl-llvm")]
1081    #[test]
1082    fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1083        test_all_adjoint_solver_variants();
1084        test_adjoint_backwards_methods_and_klu_branch();
1085    }
1086}