diffsol/ode_solver/
mod.rs

1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod problem;
11pub mod runge_kutta;
12pub mod sde;
13pub mod sdirk;
14pub mod sdirk_state;
15pub mod sensitivities;
16pub mod state;
17pub mod tableau;
18
19#[cfg(test)]
20mod tests {
21    use std::rc::Rc;
22
23    use self::problem::OdeSolverSolution;
24    use nalgebra::ComplexField;
25
26    use super::*;
27    use crate::error::{DiffsolError, OdeSolverError};
28    use crate::matrix::Matrix;
29    use crate::op::unit::UnitCallable;
30    use crate::op::ParameterisedOp;
31    use crate::{
32        op::OpStatistics, AdjointOdeSolverMethod, Context, DenseMatrix, MatrixCommon, MatrixRef,
33        NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeEquationsImplicitAdjoint,
34        OdeEquationsRef, OdeSolverConfig, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
35        OdeSolverStopReason, Scale, VectorRef, VectorView, VectorViewMut,
36    };
37    use crate::{
38        ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver, NonLinearOp,
39        NonLinearOpSens, Op, Vector,
40    };
41    use num_traits::{FromPrimitive, One, Zero};
42
43    pub fn test_ode_solver<'a, M, Eqn, Method>(
44        method: &mut Method,
45        solution: OdeSolverSolution<M::V>,
46        override_tol: Option<M::T>,
47        use_tstop: bool,
48        solve_for_sensitivities: bool,
49    ) -> Eqn::V
50    where
51        M: Matrix,
52        Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
53        Method: OdeSolverMethod<'a, Eqn>,
54    {
55        let have_root = method.problem().eqn.root().is_some();
56        for (i, point) in solution.solution_points.iter().enumerate() {
57            let (soln, sens_soln) = if use_tstop {
58                match method.set_stop_time(point.t) {
59                    Ok(_) => loop {
60                        match method.step() {
61                            Ok(OdeSolverStopReason::RootFound(_)) => {
62                                assert!(have_root);
63                                return method.state().y.clone();
64                            }
65                            Ok(OdeSolverStopReason::TstopReached) => {
66                                break (method.state().y.clone(), method.state().s.to_vec());
67                            }
68                            _ => (),
69                        }
70                    },
71                    Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
72                }
73            } else {
74                while method.state().t.abs() < point.t.abs() {
75                    if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() {
76                        assert!(have_root);
77                        return method.interpolate(t).unwrap();
78                    }
79                }
80                let soln = method.interpolate(point.t).unwrap();
81                let sens_soln = method.interpolate_sens(point.t).unwrap();
82                (soln, sens_soln)
83            };
84            let soln = if let Some(out) = method.problem().eqn.out() {
85                out.call(&soln, point.t)
86            } else {
87                soln
88            };
89            assert_eq!(
90                soln.len(),
91                point.state.len(),
92                "soln.len() != point.state.len()"
93            );
94            if let Some(override_tol) = override_tol {
95                soln.assert_eq_st(&point.state, override_tol);
96            } else {
97                let (rtol, atol) = if method.problem().eqn.out().is_some() {
98                    // problem rtol and atol is on the state, so just use solution tolerance here
99                    (solution.rtol, &solution.atol)
100                } else {
101                    (method.problem().rtol, &method.problem().atol)
102                };
103                let error = soln.clone() - &point.state;
104                let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
105                assert!(
106                    error_norm < M::T::from_f64(20.0).unwrap(),
107                    "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
108                    error_norm,
109                    point.t,
110                    soln,
111                    point.state
112                );
113                if solve_for_sensitivities {
114                    if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
115                        for (j, sens_points) in sens_soln_points.iter().enumerate() {
116                            let sens_point = &sens_points[i];
117                            let sens_soln = &sens_soln[j];
118                            let error = sens_soln.clone() - &sens_point.state;
119                            let error_norm =
120                                error.squared_norm(&sens_point.state, atol, rtol).sqrt();
121                            assert!(
122                                error_norm < M::T::from_f64(29.0).unwrap(),
123                                "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
124                                point.t,
125                                sens_point.state
126                            );
127                        }
128                    }
129                }
130            }
131        }
132        method.state().y.clone()
133    }
134
135    pub fn setup_test_adjoint<'a, LS, Eqn>(
136        problem: &'a mut OdeSolverProblem<Eqn>,
137        soln: OdeSolverSolution<Eqn::V>,
138    ) -> <Eqn::V as DefaultDenseMatrix>::M
139    where
140        Eqn: OdeEquationsImplicitAdjoint + 'a,
141        LS: LinearSolver<Eqn::M>,
142        Eqn::V: DefaultDenseMatrix,
143        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
144        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
145    {
146        let nparams = problem.eqn.nparams();
147        let nout = problem.eqn.nout();
148        let ctx = problem.eqn.context();
149        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
150        let final_time = soln.solution_points.last().unwrap().t;
151        let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
152        problem.eqn.get_params(&mut p_0);
153        let h_base = Eqn::T::from_f64(1e-10).unwrap();
154        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
155        h.axpy(h_base, &p_0, Eqn::T::one());
156        let p_base = p_0.clone();
157        for i in 0..nparams {
158            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
159            problem.eqn.set_params(&p_0);
160            let g_pos = {
161                let mut s = problem.bdf::<LS>().unwrap();
162                s.set_stop_time(final_time).unwrap();
163                while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
164                s.state().g.clone()
165            };
166
167            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
168            problem.eqn.set_params(&p_0);
169            let g_neg = {
170                let mut s = problem.bdf::<LS>().unwrap();
171                s.set_stop_time(final_time).unwrap();
172                while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
173                s.state().g.clone()
174            };
175            p_0.set_index(i, p_base.get_index(i));
176
177            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
178            for j in 0..nout {
179                dgdp.set_index(i, j, delta.get_index(j));
180            }
181        }
182        problem.eqn.set_params(&p_base);
183        dgdp
184    }
185
186    /// sum_i^n (soln_i - data_i)^2
187    /// sum_i^n (soln_i - data_i)^4
188    pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
189    where
190        DM: DenseMatrix,
191    {
192        let mut ret = DM::V::zeros(2, soln.context().clone());
193        for j in 0..soln.ncols() {
194            let soln_j = soln.column(j);
195            let data_j = data.column(j);
196            let delta = soln_j - data_j;
197            ret.set_index(0, ret.get_index(0) + delta.norm(2).powi(2));
198            ret.set_index(1, ret.get_index(1) + delta.norm(4).powi(4));
199        }
200        ret
201    }
202
203    /// sum_i^n 2 * (soln_i - data_i)
204    /// sum_i^n 4 * (soln_i - data_i)^3
205    pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
206    where
207        DM: DenseMatrix,
208    {
209        let delta = soln.clone() - data;
210        let mut delta3 = delta.clone();
211        for j in 0..delta3.ncols() {
212            let delta_col = delta.column(j).into_owned();
213
214            let mut delta3_col = delta_col.clone();
215            delta3_col.component_mul_assign(&delta_col);
216            delta3_col.component_mul_assign(&delta_col);
217
218            delta3.column_mut(j).copy_from(&delta3_col);
219        }
220        let ret = vec![
221            delta * Scale(DM::T::from_f64(2.).unwrap()),
222            delta3 * Scale(DM::T::from_f64(4.).unwrap()),
223        ];
224        ret
225    }
226
227    pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
228        problem: &'a mut OdeSolverProblem<Eqn>,
229        times: &[Eqn::T],
230    ) -> (
231        <Eqn::V as DefaultDenseMatrix>::M,
232        <Eqn::V as DefaultDenseMatrix>::M,
233    )
234    where
235        Eqn: OdeEquationsImplicitAdjoint + 'a,
236        LS: LinearSolver<Eqn::M>,
237        Eqn::V: DefaultDenseMatrix,
238        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
239        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
240    {
241        let nparams = problem.eqn.nparams();
242        let nout = 2;
243        let ctx = problem.eqn.context();
244        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
245
246        let mut p_0 = ctx.vector_zeros(nparams);
247        problem.eqn.get_params(&mut p_0);
248        let h_base = Eqn::T::from_f64(1e-10).unwrap();
249        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
250        h.axpy(h_base, &p_0, Eqn::T::one());
251        let mut p_data = p_0.clone();
252        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
253        let p_base = p_0.clone();
254
255        problem.eqn.set_params(&p_data);
256        let data = {
257            let mut s = problem.bdf::<LS>().unwrap();
258            s.solve_dense(times).unwrap()
259        };
260
261        for i in 0..nparams {
262            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
263            problem.eqn.set_params(&p_0);
264            let g_pos = {
265                let mut s = problem.bdf::<LS>().unwrap();
266                let v = s.solve_dense(times).unwrap();
267                sum_squares(&v, &data)
268            };
269
270            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
271            problem.eqn.set_params(&p_0);
272            let g_neg = {
273                let mut s = problem.bdf::<LS>().unwrap();
274                let v = s.solve_dense(times).unwrap();
275                sum_squares(&v, &data)
276            };
277
278            p_0.set_index(i, p_base.get_index(i));
279
280            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
281            for j in 0..nout {
282                dgdp.set_index(i, j, delta.get_index(j));
283            }
284        }
285        problem.eqn.set_params(&p_base);
286        (dgdp, data)
287    }
288
289    pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
290        backwards_solver: SolverB,
291        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
292        forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
293        data: <Eqn::V as DefaultDenseMatrix>::M,
294        times: &[Eqn::T],
295    ) where
296        SolverF: OdeSolverMethod<'a, Eqn>,
297        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
298        Eqn: OdeEquationsImplicitAdjoint + 'a,
299        Eqn::V: DefaultDenseMatrix,
300        Eqn::M: DefaultSolver,
301    {
302        let nparams = dgdp_check.nrows();
303        let dgdu = dsum_squaresdp(&forwards_soln, &data);
304
305        let atol = Eqn::V::from_element(
306            nparams,
307            Eqn::T::from_f64(1e-6).unwrap(),
308            data.context().clone(),
309        );
310        let rtol = Eqn::T::from_f64(1e-6).unwrap();
311        let state = backwards_solver
312            .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
313            .unwrap();
314        let gs_adj = state.into_common().sg;
315        #[allow(clippy::needless_range_loop)]
316        for j in 0..dgdp_check.ncols() {
317            gs_adj[j].assert_eq_norm(
318                &dgdp_check.column(j).into_owned(),
319                &atol,
320                rtol,
321                Eqn::T::from_f64(260.).unwrap(),
322            );
323        }
324    }
325
326    pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
327        backwards_solver: SolverB,
328        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
329    ) where
330        SolverF: OdeSolverMethod<'a, Eqn>,
331        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
332        Eqn: OdeEquationsImplicitAdjoint + 'a,
333        Eqn::V: DefaultDenseMatrix,
334        Eqn::M: DefaultSolver,
335    {
336        let nout = backwards_solver.problem().eqn.nout();
337        let atol = Eqn::V::from_element(
338            nout,
339            Eqn::T::from_f64(1e-6).unwrap(),
340            dgdp_check.context().clone(),
341        );
342        let rtol = Eqn::T::from_f64(1e-6).unwrap();
343        let state = backwards_solver
344            .solve_adjoint_backwards_pass(&[], &[])
345            .unwrap();
346        let gs_adj = state.into_common().sg;
347        #[allow(clippy::needless_range_loop)]
348        for j in 0..dgdp_check.ncols() {
349            gs_adj[j].assert_eq_norm(
350                &dgdp_check.column(j).into_owned(),
351                &atol,
352                rtol,
353                Eqn::T::from_f64(40.).unwrap(),
354            );
355        }
356    }
357
358    pub struct TestEqnInit<M: Matrix> {
359        ctx: M::C,
360    }
361
362    impl<M: Matrix> Op for TestEqnInit<M> {
363        type T = M::T;
364        type V = M::V;
365        type M = M;
366        type C = M::C;
367
368        fn nout(&self) -> usize {
369            1
370        }
371        fn nparams(&self) -> usize {
372            1
373        }
374        fn nstates(&self) -> usize {
375            1
376        }
377        fn context(&self) -> &Self::C {
378            &self.ctx
379        }
380    }
381
382    impl<M: Matrix> ConstantOp for TestEqnInit<M> {
383        fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
384            y.fill(M::T::one());
385        }
386    }
387
388    impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
389        fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
390            sens.fill(M::T::zero());
391        }
392    }
393
394    pub struct TestEqnRhs<M: Matrix> {
395        ctx: M::C,
396    }
397
398    impl<M: Matrix> Op for TestEqnRhs<M> {
399        type T = M::T;
400        type V = M::V;
401        type M = M;
402        type C = M::C;
403
404        fn nout(&self) -> usize {
405            1
406        }
407        fn nparams(&self) -> usize {
408            1
409        }
410        fn nstates(&self) -> usize {
411            1
412        }
413        fn context(&self) -> &Self::C {
414            &self.ctx
415        }
416    }
417
418    impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
419        fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
420            y.fill(M::T::zero());
421        }
422    }
423
424    impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
425        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
426            y.fill(M::T::zero());
427        }
428    }
429
430    impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
431        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
432            sens.fill(M::T::zero());
433        }
434    }
435
436    pub struct TestEqnOut<M: Matrix> {
437        ctx: M::C,
438    }
439
440    impl<M: Matrix> Op for TestEqnOut<M> {
441        type T = M::T;
442        type V = M::V;
443        type M = M;
444        type C = M::C;
445
446        fn nout(&self) -> usize {
447            1
448        }
449        fn nparams(&self) -> usize {
450            1
451        }
452        fn nstates(&self) -> usize {
453            1
454        }
455        fn context(&self) -> &Self::C {
456            &self.ctx
457        }
458    }
459
460    impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
461        fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
462            y.copy_from(x);
463        }
464    }
465
466    impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
467        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
468            y.copy_from(v);
469        }
470    }
471
472    impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
473        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
474            sens.fill(M::T::zero());
475        }
476    }
477
478    pub struct TestEqn<M: Matrix> {
479        rhs: Rc<TestEqnRhs<M>>,
480        init: Rc<TestEqnInit<M>>,
481        out: Rc<TestEqnOut<M>>,
482        ctx: M::C,
483    }
484
485    impl<M: Matrix> TestEqn<M> {
486        pub fn new() -> Self {
487            let ctx = M::C::default();
488            Self {
489                rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
490                init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
491                out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
492                ctx,
493            }
494        }
495    }
496
497    impl<M: Matrix> Op for TestEqn<M> {
498        type T = M::T;
499        type V = M::V;
500        type M = M;
501        type C = M::C;
502        fn nout(&self) -> usize {
503            1
504        }
505        fn nparams(&self) -> usize {
506            1
507        }
508        fn nstates(&self) -> usize {
509            1
510        }
511        fn statistics(&self) -> crate::op::OpStatistics {
512            OpStatistics::default()
513        }
514        fn context(&self) -> &Self::C {
515            &self.ctx
516        }
517    }
518
519    impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
520        type Rhs = &'a TestEqnRhs<M>;
521        type Mass = ParameterisedOp<'a, UnitCallable<M>>;
522        type Root = ParameterisedOp<'a, UnitCallable<M>>;
523        type Init = &'a TestEqnInit<M>;
524        type Out = &'a TestEqnOut<M>;
525    }
526
527    impl<M: Matrix> OdeEquations for TestEqn<M> {
528        fn rhs(&self) -> &TestEqnRhs<M> {
529            &self.rhs
530        }
531
532        fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
533            None
534        }
535
536        fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
537            None
538        }
539
540        fn init(&self) -> &TestEqnInit<M> {
541            &self.init
542        }
543
544        fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
545            Some(&self.out)
546        }
547        fn set_params(&mut self, _p: &Self::V) {
548            unimplemented!()
549        }
550        fn get_params(&self, _p: &mut Self::V) {
551            unimplemented!()
552        }
553    }
554
555    pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
556        let eqn = TestEqn::<M>::new();
557        let atol = eqn
558            .context()
559            .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
560        OdeSolverProblem::new(
561            eqn,
562            M::T::from_f64(1e-6).unwrap(),
563            atol,
564            None,
565            None,
566            None,
567            None,
568            None,
569            None,
570            M::T::zero(),
571            M::T::one(),
572            integrate_out,
573            Default::default(),
574            Default::default(),
575        )
576        .unwrap()
577    }
578
579    pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
580        let state = s.checkpoint();
581        let integrating_sens = !s.state().s.is_empty();
582        let integrating_out = s.problem().integrate_out;
583        let t0 = state.as_ref().t;
584        let t1 = t0 + M::T::from_f64(1e6).unwrap();
585        s.interpolate(t0)
586            .unwrap()
587            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
588        assert!(s.interpolate(t1).is_err());
589        assert!(s.interpolate_out(t1).is_err());
590        if integrating_sens {
591            assert!(s.interpolate_sens(t1).is_err());
592        } else {
593            assert!(s.interpolate_sens(t0).is_ok());
594        }
595        s.step().unwrap();
596        let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
597        assert!(s.interpolate(s.state().t).is_ok());
598        assert!(s.interpolate(tmid).is_ok());
599        if integrating_out {
600            assert!(s.interpolate_out(s.state().t).is_ok());
601        } else {
602            assert!(s.interpolate_out(s.state().t).is_err());
603        }
604        assert!(s.interpolate_sens(s.state().t).is_ok());
605        assert!(s.interpolate(s.state().t + t1).is_err());
606        assert!(s.interpolate_out(s.state().t + t1).is_err());
607        if integrating_sens {
608            assert!(s.interpolate_sens(s.state().t + t1).is_err());
609        } else {
610            assert!(s.interpolate_sens(s.state().t + t1).is_ok());
611        }
612
613        let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
614        assert!(s
615            .interpolate_inplace(s.state().t, &mut y_wrong_length)
616            .is_err());
617        let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
618        assert!(s
619            .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
620            .is_err());
621        let mut s_wrong_length = vec![
622            M::V::zeros(1, s.problem().context().clone()),
623            M::V::zeros(1, s.problem().context().clone()),
624        ];
625        assert!(s
626            .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
627            .is_err());
628        let mut s_wrong_vec_length = if integrating_sens {
629            vec![M::V::zeros(2, s.problem().context().clone())]
630        } else {
631            vec![]
632        };
633        if integrating_sens {
634            assert!(s
635                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
636                .is_err());
637        } else {
638            assert!(s
639                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
640                .is_ok());
641        }
642
643        s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
644        assert!(s.interpolate(s.state().t).is_ok());
645        if integrating_out {
646            assert!(s.interpolate_out(s.state().t).is_ok());
647        }
648        if integrating_sens {
649            assert!(s.interpolate_sens(s.state().t).is_ok());
650        }
651        assert!(s.interpolate(tmid).is_err());
652        assert!(s.interpolate_out(tmid).is_err());
653        if integrating_sens {
654            assert!(s.interpolate_sens(tmid).is_err());
655        } else {
656            assert!(s.interpolate_sens(tmid).is_ok());
657        }
658    }
659
660    pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
661        mut s: Method,
662    ) {
663        *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
664        assert_eq!(
665            *s.config().as_base_ref().minimum_timestep,
666            Eqn::T::from_f64(1.0e8).unwrap()
667        );
668        // force a step size reduction
669        *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
670
671        let mut failed = false;
672        for _ in 0..10 {
673            if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
674                s.step()
675            {
676                failed = true;
677                break;
678            }
679        }
680        assert!(failed);
681    }
682
683    pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
684        let state = s.checkpoint();
685        let state2 = s.state();
686        state2
687            .y
688            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
689        s.state_mut()
690            .y
691            .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
692        assert_eq!(
693            s.state_mut().y.get_index(0),
694            M::T::from_f64(std::f64::consts::PI).unwrap()
695        );
696    }
697
698    #[cfg(feature = "diffsl-cranelift")]
699    pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
700    ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
701        crate::OdeBuilder::<M>::new()
702            .build_from_diffsl(
703                "
704            g { 9.81 } h { 10.0 }
705            u_i {
706                x = h,
707                v = 0,
708            }
709            F_i {
710                v,
711                -g,
712            }
713            stop {
714                x,
715            }
716        ",
717            )
718            .unwrap()
719    }
720
721    #[cfg(feature = "diffsl-cranelift")]
722    pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
723    where
724        M: crate::MatrixHost<T = f64>,
725        M: DefaultSolver<T = f64>,
726        M::V: DefaultDenseMatrix<T = f64>,
727        Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
728    {
729        let e = 0.8;
730
731        let final_time = 2.5;
732
733        // solve and apply the remaining doses
734        solver.set_stop_time(final_time).unwrap();
735        loop {
736            match solver.step() {
737                Ok(OdeSolverStopReason::InternalTimestep) => (),
738                Ok(OdeSolverStopReason::RootFound(t)) => {
739                    // get the state when the event occurred
740                    let mut y = solver.interpolate(t).unwrap();
741
742                    // update the velocity of the ball
743                    y.set_index(1, y.get_index(1) * -e);
744
745                    // make sure the ball is above the ground
746                    y.set_index(0, y.get_index(0).max(f64::EPSILON));
747
748                    // set the state to the updated state
749                    solver.state_mut().y.copy_from(&y);
750                    solver.state_mut().dy.set_index(0, y.get_index(1));
751                    *solver.state_mut().t = t;
752
753                    break;
754                }
755                Ok(OdeSolverStopReason::TstopReached) => break,
756                Err(_) => panic!("unexpected solver error"),
757            }
758        }
759        // do three more steps after the 1st bound and many sure they are correct
760        let mut x = vec![];
761        let mut v = vec![];
762        let mut t = vec![];
763        for _ in 0..3 {
764            let ret = solver.step();
765            x.push(solver.state().y.get_index(0));
766            v.push(solver.state().y.get_index(1));
767            t.push(solver.state().t);
768            match ret {
769                Ok(OdeSolverStopReason::InternalTimestep) => (),
770                Ok(OdeSolverStopReason::RootFound(_)) => {
771                    panic!("should be an internal timestep but found a root")
772                }
773                Ok(OdeSolverStopReason::TstopReached) => break,
774                _ => panic!("should be an internal timestep"),
775            }
776        }
777        (x, v, t)
778    }
779
780    pub fn test_checkpointing<'a, M, Method, Eqn>(
781        soln: OdeSolverSolution<M::V>,
782        mut solver1: Method,
783        mut solver2: Method,
784    ) where
785        M: Matrix + DefaultSolver,
786        Method: OdeSolverMethod<'a, Eqn>,
787        Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
788    {
789        let half_i = soln.solution_points.len() / 2;
790        let half_t = soln.solution_points[half_i].t;
791        while solver1.state().t <= half_t {
792            solver1.step().unwrap();
793        }
794        let checkpoint = solver1.checkpoint();
795        let checkpoint_t = checkpoint.as_ref().t;
796        solver2.set_state(checkpoint);
797
798        // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution)
799        for point in soln.solution_points.iter().skip(half_i + 1) {
800            // point should be past checkpoint
801            if point.t < checkpoint_t {
802                continue;
803            }
804            while solver2.state().t < point.t {
805                solver1.step().unwrap();
806                solver2.step().unwrap();
807                let time_error = (solver1.state().t - solver2.state().t).abs()
808                    / (solver1.state().t.abs() * solver1.problem().rtol
809                        + solver1.problem().atol.get_index(0));
810                assert!(
811                    time_error < M::T::from_f64(20.0).unwrap(),
812                    "time_error: {} at t = {}",
813                    time_error,
814                    solver1.state().t
815                );
816                solver1.state().y.assert_eq_norm(
817                    solver2.state().y,
818                    &solver1.problem().atol,
819                    solver1.problem().rtol,
820                    M::T::from_f64(20.0).unwrap(),
821                );
822            }
823            let soln = solver1.interpolate(point.t).unwrap();
824            soln.assert_eq_norm(
825                &point.state,
826                &solver1.problem().atol,
827                solver1.problem().rtol,
828                M::T::from_f64(15.0).unwrap(),
829            );
830            let soln = solver2.interpolate(point.t).unwrap();
831            soln.assert_eq_norm(
832                &point.state,
833                &solver1.problem().atol,
834                solver1.problem().rtol,
835                M::T::from_f64(15.0).unwrap(),
836            );
837        }
838    }
839
840    pub fn test_state_mut_on_problem<'a, Eqn, Method>(
841        mut s: Method,
842        soln: OdeSolverSolution<Eqn::V>,
843    ) where
844        Eqn: OdeEquationsImplicit + 'a,
845        Method: OdeSolverMethod<'a, Eqn>,
846        Eqn::V: DefaultDenseMatrix,
847    {
848        // save state and solve for a little bit
849        let state = s.checkpoint();
850        s.solve(Eqn::T::one()).unwrap();
851
852        // reinit using state_mut
853        s.state_mut().y.copy_from(state.as_ref().y);
854        s.state_mut().dy.copy_from(state.as_ref().dy);
855        *s.state_mut().t = state.as_ref().t;
856
857        // solve and check against solution
858        for point in soln.solution_points.iter() {
859            while s.state().t < point.t {
860                s.step().unwrap();
861            }
862            let soln = s.interpolate(point.t).unwrap();
863            let error = soln.clone() - &point.state;
864            let error_norm = error
865                .squared_norm(&error, &s.problem().atol, s.problem().rtol)
866                .sqrt();
867            assert!(
868                error_norm < Eqn::T::from_f64(19.0).unwrap(),
869                "error_norm: {} at t = {}",
870                error_norm,
871                point.t
872            );
873        }
874    }
875}