diffsol/ode_solver/
mod.rs

1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod explicit_rk;
7pub mod jacobian_update;
8pub mod method;
9pub mod problem;
10pub mod runge_kutta;
11pub mod sde;
12pub mod sdirk;
13pub mod sdirk_state;
14pub mod sensitivities;
15pub mod state;
16pub mod tableau;
17
18#[cfg(test)]
19mod tests {
20    use std::rc::Rc;
21
22    use self::problem::OdeSolverSolution;
23    use nalgebra::ComplexField;
24
25    use super::*;
26    use crate::matrix::Matrix;
27    use crate::op::unit::UnitCallable;
28    use crate::op::ParameterisedOp;
29    use crate::{
30        op::OpStatistics, AdjointOdeSolverMethod, Context, DenseMatrix, MatrixCommon, MatrixRef,
31        NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeEquationsImplicitAdjoint,
32        OdeEquationsRef, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason,
33        Scale, VectorRef, VectorView, VectorViewMut,
34    };
35    use crate::{
36        ConstantOp, DefaultDenseMatrix, DefaultSolver, LinearSolver, NonLinearOp, Op, Vector,
37    };
38    use num_traits::{One, Zero};
39
40    pub fn test_ode_solver<'a, M, Eqn, Method>(
41        method: &mut Method,
42        solution: OdeSolverSolution<M::V>,
43        override_tol: Option<M::T>,
44        use_tstop: bool,
45        solve_for_sensitivities: bool,
46    ) -> Eqn::V
47    where
48        M: Matrix,
49        Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
50        Method: OdeSolverMethod<'a, Eqn>,
51    {
52        let have_root = method.problem().eqn.root().is_some();
53        for (i, point) in solution.solution_points.iter().enumerate() {
54            let (soln, sens_soln) = if use_tstop {
55                match method.set_stop_time(point.t) {
56                    Ok(_) => loop {
57                        match method.step() {
58                            Ok(OdeSolverStopReason::RootFound(_)) => {
59                                assert!(have_root);
60                                return method.state().y.clone();
61                            }
62                            Ok(OdeSolverStopReason::TstopReached) => {
63                                break (method.state().y.clone(), method.state().s.to_vec());
64                            }
65                            _ => (),
66                        }
67                    },
68                    Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
69                }
70            } else {
71                while method.state().t.abs() < point.t.abs() {
72                    if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() {
73                        assert!(have_root);
74                        return method.interpolate(t).unwrap();
75                    }
76                }
77                let soln = method.interpolate(point.t).unwrap();
78                let sens_soln = method.interpolate_sens(point.t).unwrap();
79                (soln, sens_soln)
80            };
81            let soln = if let Some(out) = method.problem().eqn.out() {
82                out.call(&soln, point.t)
83            } else {
84                soln
85            };
86            assert_eq!(
87                soln.len(),
88                point.state.len(),
89                "soln.len() != point.state.len()"
90            );
91            if let Some(override_tol) = override_tol {
92                soln.assert_eq_st(&point.state, override_tol);
93            } else {
94                let (rtol, atol) = if method.problem().eqn.out().is_some() {
95                    // problem rtol and atol is on the state, so just use solution tolerance here
96                    (solution.rtol, &solution.atol)
97                } else {
98                    (method.problem().rtol, &method.problem().atol)
99                };
100                let error = soln.clone() - &point.state;
101                let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
102                assert!(
103                    error_norm < M::T::from(15.0),
104                    "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
105                    error_norm,
106                    point.t,
107                    soln,
108                    point.state
109                );
110                if solve_for_sensitivities {
111                    if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
112                        for (j, sens_points) in sens_soln_points.iter().enumerate() {
113                            let sens_point = &sens_points[i];
114                            let sens_soln = &sens_soln[j];
115                            let error = sens_soln.clone() - &sens_point.state;
116                            let error_norm =
117                                error.squared_norm(&sens_point.state, atol, rtol).sqrt();
118                            assert!(
119                                error_norm < M::T::from(29.0),
120                                "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
121                                point.t,
122                                sens_point.state
123                            );
124                        }
125                    }
126                }
127            }
128        }
129        method.state().y.clone()
130    }
131
132    pub fn setup_test_adjoint<'a, LS, Eqn>(
133        problem: &'a mut OdeSolverProblem<Eqn>,
134        soln: OdeSolverSolution<Eqn::V>,
135    ) -> <Eqn::V as DefaultDenseMatrix>::M
136    where
137        Eqn: OdeEquationsImplicitAdjoint + 'a,
138        LS: LinearSolver<Eqn::M>,
139        Eqn::V: DefaultDenseMatrix,
140        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
141        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
142    {
143        let nparams = problem.eqn.nparams();
144        let nout = problem.eqn.nout();
145        let ctx = problem.eqn.context();
146        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
147        let final_time = soln.solution_points.last().unwrap().t;
148        let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
149        problem.eqn.get_params(&mut p_0);
150        let h_base = Eqn::T::from(1e-10);
151        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
152        h.axpy(h_base, &p_0, Eqn::T::one());
153        let p_base = p_0.clone();
154        for i in 0..nparams {
155            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
156            problem.eqn.set_params(&p_0);
157            let mut s = problem.bdf::<LS>().unwrap();
158            s.set_stop_time(final_time).unwrap();
159            while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
160            let g_pos = s.state().g.clone();
161
162            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
163            problem.eqn.set_params(&p_0);
164            let mut s = problem.bdf::<LS>().unwrap();
165            s.set_stop_time(final_time).unwrap();
166            while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
167            let g_neg = s.state().g.clone();
168            p_0.set_index(i, p_base.get_index(i));
169
170            let delta = (g_pos - g_neg) / Scale(Eqn::T::from(2.) * h.get_index(i));
171            for j in 0..nout {
172                dgdp.set_index(i, j, delta.get_index(j));
173            }
174        }
175        problem.eqn.set_params(&p_base);
176        dgdp
177    }
178
179    /// sum_i^n (soln_i - data_i)^2
180    /// sum_i^n (soln_i - data_i)^4
181    pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
182    where
183        DM: DenseMatrix,
184    {
185        let mut ret = DM::V::zeros(2, soln.context().clone());
186        for j in 0..soln.ncols() {
187            let soln_j = soln.column(j);
188            let data_j = data.column(j);
189            let delta = soln_j - data_j;
190            ret.set_index(0, ret.get_index(0) + delta.norm(2).powi(2));
191            ret.set_index(1, ret.get_index(1) + delta.norm(4).powi(4));
192        }
193        ret
194    }
195
196    /// sum_i^n 2 * (soln_i - data_i)
197    /// sum_i^n 4 * (soln_i - data_i)^3
198    pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
199    where
200        DM: DenseMatrix,
201    {
202        let delta = soln.clone() - data;
203        let mut delta3 = delta.clone();
204        for j in 0..delta3.ncols() {
205            let delta_col = delta.column(j).into_owned();
206
207            let mut delta3_col = delta_col.clone();
208            delta3_col.component_mul_assign(&delta_col);
209            delta3_col.component_mul_assign(&delta_col);
210
211            delta3.column_mut(j).copy_from(&delta3_col);
212        }
213        let ret = vec![
214            delta * Scale(DM::T::from(2.)),
215            delta3 * Scale(DM::T::from(4.)),
216        ];
217        ret
218    }
219
220    pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
221        problem: &'a mut OdeSolverProblem<Eqn>,
222        times: &[Eqn::T],
223    ) -> (
224        <Eqn::V as DefaultDenseMatrix>::M,
225        <Eqn::V as DefaultDenseMatrix>::M,
226    )
227    where
228        Eqn: OdeEquationsImplicitAdjoint + 'a,
229        LS: LinearSolver<Eqn::M>,
230        Eqn::V: DefaultDenseMatrix,
231        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
232        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
233    {
234        let nparams = problem.eqn.nparams();
235        let nout = 2;
236        let ctx = problem.eqn.context();
237        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
238
239        let mut p_0 = ctx.vector_zeros(nparams);
240        problem.eqn.get_params(&mut p_0);
241        let h_base = Eqn::T::from(1e-10);
242        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
243        h.axpy(h_base, &p_0, Eqn::T::one());
244        let mut p_data = p_0.clone();
245        p_data.axpy(Eqn::T::from(0.1), &p_0, Eqn::T::one());
246        let p_base = p_0.clone();
247
248        problem.eqn.set_params(&p_data);
249        let mut s = problem.bdf::<LS>().unwrap();
250        let data = s.solve_dense(times).unwrap();
251
252        for i in 0..nparams {
253            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
254            problem.eqn.set_params(&p_0);
255            let mut s = problem.bdf::<LS>().unwrap();
256            let v = s.solve_dense(times).unwrap();
257            let g_pos = sum_squares(&v, &data);
258
259            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
260            problem.eqn.set_params(&p_0);
261            let mut s = problem.bdf::<LS>().unwrap();
262            let v = s.solve_dense(times).unwrap();
263            let g_neg = sum_squares(&v, &data);
264
265            p_0.set_index(i, p_base.get_index(i));
266
267            let delta = (g_pos - g_neg) / Scale(Eqn::T::from(2.) * h.get_index(i));
268            for j in 0..nout {
269                dgdp.set_index(i, j, delta.get_index(j));
270            }
271        }
272        problem.eqn.set_params(&p_base);
273        (dgdp, data)
274    }
275
276    pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
277        backwards_solver: SolverB,
278        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
279        forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
280        data: <Eqn::V as DefaultDenseMatrix>::M,
281        times: &[Eqn::T],
282    ) where
283        SolverF: OdeSolverMethod<'a, Eqn>,
284        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
285        Eqn: OdeEquationsImplicitAdjoint + 'a,
286        Eqn::V: DefaultDenseMatrix,
287        Eqn::M: DefaultSolver,
288    {
289        let nparams = dgdp_check.nrows();
290        let dgdu = dsum_squaresdp(&forwards_soln, &data);
291
292        let atol = Eqn::V::from_element(nparams, Eqn::T::from(1e-6), data.context().clone());
293        let rtol = Eqn::T::from(1e-6);
294        let state = backwards_solver
295            .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
296            .unwrap();
297        let gs_adj = state.into_common().sg;
298        #[allow(clippy::needless_range_loop)]
299        for j in 0..dgdp_check.ncols() {
300            gs_adj[j].assert_eq_norm(
301                &dgdp_check.column(j).into_owned(),
302                &atol,
303                rtol,
304                Eqn::T::from(66.),
305            );
306        }
307    }
308
309    pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
310        backwards_solver: SolverB,
311        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
312    ) where
313        SolverF: OdeSolverMethod<'a, Eqn>,
314        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
315        Eqn: OdeEquationsImplicitAdjoint + 'a,
316        Eqn::V: DefaultDenseMatrix,
317        Eqn::M: DefaultSolver,
318    {
319        let nout = backwards_solver.problem().eqn.nout();
320        let atol = Eqn::V::from_element(nout, Eqn::T::from(1e-6), dgdp_check.context().clone());
321        let rtol = Eqn::T::from(1e-6);
322        let state = backwards_solver
323            .solve_adjoint_backwards_pass(&[], &[])
324            .unwrap();
325        let gs_adj = state.into_common().sg;
326        #[allow(clippy::needless_range_loop)]
327        for j in 0..dgdp_check.ncols() {
328            gs_adj[j].assert_eq_norm(
329                &dgdp_check.column(j).into_owned(),
330                &atol,
331                rtol,
332                Eqn::T::from(33.),
333            );
334        }
335    }
336
337    pub struct TestEqnInit<M: Matrix> {
338        ctx: M::C,
339    }
340
341    impl<M: Matrix> Op for TestEqnInit<M> {
342        type T = M::T;
343        type V = M::V;
344        type M = M;
345        type C = M::C;
346
347        fn nout(&self) -> usize {
348            1
349        }
350        fn nparams(&self) -> usize {
351            0
352        }
353        fn nstates(&self) -> usize {
354            1
355        }
356        fn context(&self) -> &Self::C {
357            &self.ctx
358        }
359    }
360
361    impl<M: Matrix> ConstantOp for TestEqnInit<M> {
362        fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
363            y.fill(M::T::one());
364        }
365    }
366
367    pub struct TestEqnRhs<M: Matrix> {
368        ctx: M::C,
369    }
370
371    impl<M: Matrix> Op for TestEqnRhs<M> {
372        type T = M::T;
373        type V = M::V;
374        type M = M;
375        type C = M::C;
376
377        fn nout(&self) -> usize {
378            1
379        }
380        fn nparams(&self) -> usize {
381            0
382        }
383        fn nstates(&self) -> usize {
384            1
385        }
386        fn context(&self) -> &Self::C {
387            &self.ctx
388        }
389    }
390
391    impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
392        fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
393            y.fill(M::T::zero());
394        }
395    }
396
397    impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
398        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
399            y.fill(M::T::zero());
400        }
401    }
402
403    pub struct TestEqn<M: Matrix> {
404        rhs: Rc<TestEqnRhs<M>>,
405        init: Rc<TestEqnInit<M>>,
406        ctx: M::C,
407    }
408
409    impl<M: Matrix> TestEqn<M> {
410        pub fn new() -> Self {
411            let ctx = M::C::default();
412            Self {
413                rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
414                init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
415                ctx,
416            }
417        }
418    }
419
420    impl<M: Matrix> Op for TestEqn<M> {
421        type T = M::T;
422        type V = M::V;
423        type M = M;
424        type C = M::C;
425        fn nout(&self) -> usize {
426            1
427        }
428        fn nparams(&self) -> usize {
429            0
430        }
431        fn nstates(&self) -> usize {
432            1
433        }
434        fn statistics(&self) -> crate::op::OpStatistics {
435            OpStatistics::default()
436        }
437        fn context(&self) -> &Self::C {
438            &self.ctx
439        }
440    }
441
442    impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
443        type Rhs = &'a TestEqnRhs<M>;
444        type Mass = ParameterisedOp<'a, UnitCallable<M>>;
445        type Root = ParameterisedOp<'a, UnitCallable<M>>;
446        type Init = &'a TestEqnInit<M>;
447        type Out = ParameterisedOp<'a, UnitCallable<M>>;
448    }
449
450    impl<M: Matrix> OdeEquations for TestEqn<M> {
451        fn rhs(&self) -> &TestEqnRhs<M> {
452            &self.rhs
453        }
454
455        fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
456            None
457        }
458
459        fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
460            None
461        }
462
463        fn init(&self) -> &TestEqnInit<M> {
464            &self.init
465        }
466
467        fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
468            None
469        }
470        fn set_params(&mut self, _p: &Self::V) {
471            unimplemented!()
472        }
473        fn get_params(&self, _p: &mut Self::V) {
474            unimplemented!()
475        }
476    }
477
478    pub fn test_problem<M: Matrix>() -> OdeSolverProblem<TestEqn<M>> {
479        let eqn = TestEqn::<M>::new();
480        let atol = eqn.context().vector_from_element(1, M::T::from(1e-6));
481        OdeSolverProblem::new(
482            eqn,
483            M::T::from(1e-6),
484            atol,
485            None,
486            None,
487            None,
488            None,
489            None,
490            None,
491            M::T::zero(),
492            M::T::one(),
493            false,
494        )
495        .unwrap()
496    }
497
498    pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
499        let state = s.checkpoint();
500        let t0 = state.as_ref().t;
501        let t1 = t0 + M::T::from(1e6);
502        s.interpolate(t0)
503            .unwrap()
504            .assert_eq_st(state.as_ref().y, M::T::from(1e-9));
505        assert!(s.interpolate(t1).is_err());
506        s.step().unwrap();
507        assert!(s.interpolate(s.state().t).is_ok());
508        assert!(s.interpolate(s.state().t + t1).is_err());
509    }
510
511    pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
512        let state = s.checkpoint();
513        let state2 = s.state();
514        state2.y.assert_eq_st(state.as_ref().y, M::T::from(1e-9));
515        s.state_mut()
516            .y
517            .set_index(0, M::T::from(std::f64::consts::PI));
518        assert_eq!(
519            s.state_mut().y.get_index(0),
520            M::T::from(std::f64::consts::PI)
521        );
522    }
523
524    #[cfg(feature = "diffsl-cranelift")]
525    pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
526    ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
527        crate::OdeBuilder::<M>::new()
528            .build_from_diffsl(
529                "
530            g { 9.81 } h { 10.0 }
531            u_i {
532                x = h,
533                v = 0,
534            }
535            F_i {
536                v,
537                -g,
538            }
539            stop {
540                x,
541            }
542        ",
543            )
544            .unwrap()
545    }
546
547    #[cfg(feature = "diffsl-cranelift")]
548    pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
549    where
550        M: crate::MatrixHost<T = f64>,
551        M: DefaultSolver<T = f64>,
552        M::V: DefaultDenseMatrix<T = f64>,
553        Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
554    {
555        let e = 0.8;
556
557        let final_time = 2.5;
558
559        // solve and apply the remaining doses
560        solver.set_stop_time(final_time).unwrap();
561        loop {
562            match solver.step() {
563                Ok(OdeSolverStopReason::InternalTimestep) => (),
564                Ok(OdeSolverStopReason::RootFound(t)) => {
565                    // get the state when the event occurred
566                    let mut y = solver.interpolate(t).unwrap();
567
568                    // update the velocity of the ball
569                    y.set_index(1, y.get_index(1) * -e);
570
571                    // make sure the ball is above the ground
572                    y.set_index(0, y.get_index(0).max(f64::EPSILON));
573
574                    // set the state to the updated state
575                    solver.state_mut().y.copy_from(&y);
576                    solver.state_mut().dy.set_index(0, y.get_index(1));
577                    *solver.state_mut().t = t;
578
579                    break;
580                }
581                Ok(OdeSolverStopReason::TstopReached) => break,
582                Err(_) => panic!("unexpected solver error"),
583            }
584        }
585        // do three more steps after the 1st bound and many sure they are correct
586        let mut x = vec![];
587        let mut v = vec![];
588        let mut t = vec![];
589        for _ in 0..3 {
590            let ret = solver.step();
591            x.push(solver.state().y.get_index(0));
592            v.push(solver.state().y.get_index(1));
593            t.push(solver.state().t);
594            match ret {
595                Ok(OdeSolverStopReason::InternalTimestep) => (),
596                Ok(OdeSolverStopReason::RootFound(_)) => {
597                    panic!("should be an internal timestep but found a root")
598                }
599                Ok(OdeSolverStopReason::TstopReached) => break,
600                _ => panic!("should be an internal timestep"),
601            }
602        }
603        (x, v, t)
604    }
605
606    pub fn test_checkpointing<'a, M, Method, Eqn>(
607        soln: OdeSolverSolution<M::V>,
608        mut solver1: Method,
609        mut solver2: Method,
610    ) where
611        M: Matrix + DefaultSolver,
612        Method: OdeSolverMethod<'a, Eqn>,
613        Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
614    {
615        let half_i = soln.solution_points.len() / 2;
616        let half_t = soln.solution_points[half_i].t;
617        while solver1.state().t <= half_t {
618            solver1.step().unwrap();
619        }
620        let checkpoint = solver1.checkpoint();
621        let checkpoint_t = checkpoint.as_ref().t;
622        solver2.set_state(checkpoint);
623
624        // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution)
625        for point in soln.solution_points.iter().skip(half_i + 1) {
626            // point should be past checkpoint
627            if point.t < checkpoint_t {
628                continue;
629            }
630            while solver2.state().t < point.t {
631                solver1.step().unwrap();
632                solver2.step().unwrap();
633                let time_error = (solver1.state().t - solver2.state().t).abs()
634                    / (solver1.state().t.abs() * solver1.problem().rtol
635                        + solver1.problem().atol.get_index(0));
636                assert!(
637                    time_error < M::T::from(20.0),
638                    "time_error: {} at t = {}",
639                    time_error,
640                    solver1.state().t
641                );
642                solver1.state().y.assert_eq_norm(
643                    solver2.state().y,
644                    &solver1.problem().atol,
645                    solver1.problem().rtol,
646                    M::T::from(20.0),
647                );
648            }
649            let soln = solver1.interpolate(point.t).unwrap();
650            soln.assert_eq_norm(
651                &point.state,
652                &solver1.problem().atol,
653                solver1.problem().rtol,
654                M::T::from(15.0),
655            );
656            let soln = solver2.interpolate(point.t).unwrap();
657            soln.assert_eq_norm(
658                &point.state,
659                &solver1.problem().atol,
660                solver1.problem().rtol,
661                M::T::from(15.0),
662            );
663        }
664    }
665
666    pub fn test_state_mut_on_problem<'a, Eqn, Method>(
667        mut s: Method,
668        soln: OdeSolverSolution<Eqn::V>,
669    ) where
670        Eqn: OdeEquationsImplicit + 'a,
671        Method: OdeSolverMethod<'a, Eqn>,
672        Eqn::V: DefaultDenseMatrix,
673    {
674        // save state and solve for a little bit
675        let state = s.checkpoint();
676        s.solve(Eqn::T::from(1.0)).unwrap();
677
678        // reinit using state_mut
679        s.state_mut().y.copy_from(state.as_ref().y);
680        s.state_mut().dy.copy_from(state.as_ref().dy);
681        *s.state_mut().t = state.as_ref().t;
682
683        // solve and check against solution
684        for point in soln.solution_points.iter() {
685            while s.state().t < point.t {
686                s.step().unwrap();
687            }
688            let soln = s.interpolate(point.t).unwrap();
689            let error = soln.clone() - &point.state;
690            let error_norm = error
691                .squared_norm(&error, &s.problem().atol, s.problem().rtol)
692                .sqrt();
693            assert!(
694                error_norm < Eqn::T::from(19.0),
695                "error_norm: {} at t = {}",
696                error_norm,
697                point.t
698            );
699        }
700    }
701}