Skip to main content

diffsol/ode_solver/
mod.rs

1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod no_checkpointing_solver;
11pub mod problem;
12pub mod runge_kutta;
13pub mod sde;
14pub mod sdirk;
15pub mod sdirk_state;
16pub mod sensitivities;
17pub mod solution;
18pub mod state;
19pub mod tableau;
20
21#[cfg(test)]
22mod tests {
23    use std::rc::Rc;
24
25    use self::problem::OdeSolverSolution;
26
27    use super::*;
28    use crate::error::{DiffsolError, OdeSolverError};
29    use crate::matrix::Matrix;
30    use crate::ode_solver::sensitivities::SensitivitiesOdeSolverMethod;
31    use crate::ode_solver::solution::Solution;
32    use crate::op::unit::UnitCallable;
33    use crate::op::ParameterisedOp;
34    use crate::Scalar;
35    use crate::{
36        op::OpStatistics, AdjointEquations, AdjointOdeSolverMethod, Context, DenseMatrix,
37        MatrixCommon, MatrixRef, NonLinearOp, NonLinearOpJacobian, OdeEquations,
38        OdeEquationsImplicit, OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens,
39        OdeEquationsRef, OdeSolverConfig, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
40        OdeSolverStopReason, Scale, VectorRef, VectorView, VectorViewMut,
41    };
42    use crate::{
43        ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver,
44        NonLinearOpSens, Op, Vector,
45    };
46    use num_traits::{FromPrimitive, One, Signed, ToPrimitive, Zero};
47
48    pub fn test_ode_solver<'a, M, Eqn, Method>(
49        method: &mut Method,
50        solution: OdeSolverSolution<M::V>,
51        override_tol: Option<M::T>,
52        use_tstop: bool,
53        solve_for_sensitivities: bool,
54    ) -> Eqn::V
55    where
56        M: Matrix,
57        Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
58        Method: OdeSolverMethod<'a, Eqn>,
59    {
60        let have_root = method.problem().eqn.root().is_some();
61        for (i, point) in solution.solution_points.iter().enumerate() {
62            let (soln, sens_soln) = if use_tstop {
63                match method.set_stop_time(point.t) {
64                    Ok(_) => loop {
65                        match method.step() {
66                            Ok(OdeSolverStopReason::RootFound(_, _)) => {
67                                assert!(have_root);
68                                return method.state().y.clone();
69                            }
70                            Ok(OdeSolverStopReason::TstopReached) => {
71                                break (method.state().y.clone(), method.state().s.to_vec());
72                            }
73                            _ => (),
74                        }
75                    },
76                    Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
77                }
78            } else {
79                while method.state().t.abs() < point.t.abs() {
80                    if let OdeSolverStopReason::RootFound(t, _) = method.step().unwrap() {
81                        assert!(have_root);
82                        return method.interpolate(t).unwrap();
83                    }
84                }
85                let soln = method.interpolate(point.t).unwrap();
86                let sens_soln = method.interpolate_sens(point.t).unwrap();
87                (soln, sens_soln)
88            };
89            let soln = if let Some(out) = method.problem().eqn.out() {
90                out.call(&soln, point.t)
91            } else {
92                soln
93            };
94            assert_eq!(
95                soln.len(),
96                point.state.len(),
97                "soln.len() != point.state.len()"
98            );
99            if let Some(override_tol) = override_tol {
100                soln.assert_eq_st(&point.state, override_tol);
101            } else {
102                let (rtol, atol) = if method.problem().eqn.out().is_some() {
103                    // problem rtol and atol is on the state, so just use solution tolerance here
104                    (solution.rtol, &solution.atol)
105                } else {
106                    (method.problem().rtol, &method.problem().atol)
107                };
108                let error = soln.clone() - &point.state;
109                let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
110                assert!(
111                    error_norm < M::T::from_f64(20.0).unwrap(),
112                    "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
113                    error_norm,
114                    point.t,
115                    soln,
116                    point.state
117                );
118                if solve_for_sensitivities {
119                    if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
120                        for (j, sens_points) in sens_soln_points.iter().enumerate() {
121                            let sens_point = &sens_points[i];
122                            let sens_soln = &sens_soln[j];
123                            let error = sens_soln.clone() - &sens_point.state;
124                            let error_norm =
125                                error.squared_norm(&sens_point.state, atol, rtol).sqrt();
126                            assert!(
127                                error_norm < M::T::from_f64(29.0).unwrap(),
128                                "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
129                                point.t,
130                                sens_point.state
131                            );
132                        }
133                    }
134                }
135            }
136        }
137        method.state().y.clone()
138    }
139
140    pub fn setup_test_adjoint<'a, LS, Eqn>(
141        problem: &'a mut OdeSolverProblem<Eqn>,
142        soln: OdeSolverSolution<Eqn::V>,
143    ) -> <Eqn::V as DefaultDenseMatrix>::M
144    where
145        Eqn: OdeEquationsImplicitAdjoint + 'a,
146        LS: LinearSolver<Eqn::M>,
147        Eqn::M: DefaultSolver,
148        Eqn::V: DefaultDenseMatrix,
149        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
150        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
151    {
152        let nparams = problem.eqn.nparams();
153        let nout = problem.eqn.nout();
154        let ctx = problem.eqn.context();
155        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
156        let final_time = soln.solution_points.last().unwrap().t;
157        let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
158        problem.eqn.get_params(&mut p_0);
159        let h_base = Eqn::T::from_f64(1e-10).unwrap();
160        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
161        h.axpy(h_base, &p_0, Eqn::T::one());
162        let p_base = p_0.clone();
163        for i in 0..nparams {
164            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
165            problem.eqn.set_params(&p_0);
166            let g_pos = {
167                let mut s = problem.bdf::<LS>().unwrap();
168                s.set_stop_time(final_time).unwrap();
169                while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
170                s.state().g.clone()
171            };
172
173            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
174            problem.eqn.set_params(&p_0);
175            let g_neg = {
176                let mut s = problem.bdf::<LS>().unwrap();
177                s.set_stop_time(final_time).unwrap();
178                while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
179                s.state().g.clone()
180            };
181            p_0.set_index(i, p_base.get_index(i));
182
183            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
184            for j in 0..nout {
185                dgdp.set_index(i, j, delta.get_index(j));
186            }
187        }
188        problem.eqn.set_params(&p_base);
189        dgdp
190    }
191
192    /// sum_i^n (soln_i - data_i)^2
193    /// sum_i^n (soln_i - data_i)^4
194    pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
195    where
196        DM: DenseMatrix,
197    {
198        let mut ret = DM::V::zeros(2, soln.context().clone());
199        for j in 0..soln.ncols() {
200            let soln_j = soln.column(j);
201            let data_j = data.column(j);
202            let delta = soln_j - data_j;
203            let norm2 = delta.norm(2);
204            ret.set_index(0, ret.get_index(0) + norm2 * norm2);
205            let norm4 = delta.norm(4);
206            let norm4_sq = norm4 * norm4;
207            ret.set_index(1, ret.get_index(1) + norm4_sq * norm4_sq);
208        }
209        ret
210    }
211
212    /// sum_i^n 2 * (soln_i - data_i)
213    /// sum_i^n 4 * (soln_i - data_i)^3
214    pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
215    where
216        DM: DenseMatrix,
217    {
218        let delta = soln.clone() - data;
219        let mut delta3 = delta.clone();
220        for j in 0..delta3.ncols() {
221            let delta_col = delta.column(j).into_owned();
222
223            let mut delta3_col = delta_col.clone();
224            delta3_col.component_mul_assign(&delta_col);
225            delta3_col.component_mul_assign(&delta_col);
226
227            delta3.column_mut(j).copy_from(&delta3_col);
228        }
229        let ret = vec![
230            delta * Scale(DM::T::from_f64(2.).unwrap()),
231            delta3 * Scale(DM::T::from_f64(4.).unwrap()),
232        ];
233        ret
234    }
235
236    pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
237        problem: &'a mut OdeSolverProblem<Eqn>,
238        times: &[Eqn::T],
239    ) -> (
240        <Eqn::V as DefaultDenseMatrix>::M,
241        <Eqn::V as DefaultDenseMatrix>::M,
242    )
243    where
244        Eqn: OdeEquationsImplicitAdjoint + 'a,
245        LS: LinearSolver<Eqn::M>,
246        Eqn::M: DefaultSolver,
247        Eqn::V: DefaultDenseMatrix,
248        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
249        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
250    {
251        let nparams = problem.eqn.nparams();
252        let nout = 2;
253        let ctx = problem.eqn.context();
254        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
255
256        let mut p_0 = ctx.vector_zeros(nparams);
257        problem.eqn.get_params(&mut p_0);
258        let h_base = Eqn::T::from_f64(1e-10).unwrap();
259        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
260        h.axpy(h_base, &p_0, Eqn::T::one());
261        let mut p_data = p_0.clone();
262        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
263        let p_base = p_0.clone();
264
265        problem.eqn.set_params(&p_data);
266        let data = {
267            let mut s = problem.bdf::<LS>().unwrap();
268            s.solve_dense(times).unwrap().0
269        };
270
271        for i in 0..nparams {
272            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
273            problem.eqn.set_params(&p_0);
274            let g_pos = {
275                let mut s = problem.bdf::<LS>().unwrap();
276                let v = s.solve_dense(times).unwrap().0;
277                sum_squares(&v, &data)
278            };
279
280            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
281            problem.eqn.set_params(&p_0);
282            let g_neg = {
283                let mut s = problem.bdf::<LS>().unwrap();
284                let v = s.solve_dense(times).unwrap().0;
285                sum_squares(&v, &data)
286            };
287
288            p_0.set_index(i, p_base.get_index(i));
289
290            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
291            for j in 0..nout {
292                dgdp.set_index(i, j, delta.get_index(j));
293            }
294        }
295        problem.eqn.set_params(&p_base);
296        (dgdp, data)
297    }
298
299    pub fn single_reset_root_discrete_times<T: Scalar>(t_stop: T) -> Vec<T> {
300        let t_root = t_stop / T::from_f64(2.0).unwrap();
301        [0.25, 0.75, 1.25, 1.75]
302            .into_iter()
303            .map(|factor| t_root * T::from_f64(factor).unwrap())
304            .collect()
305    }
306
307    fn solve_dense_with_single_reset_root<'a, Eqn, Method, BuildForward>(
308        build_forward: BuildForward,
309        times: &[Eqn::T],
310    ) -> <Eqn::V as DefaultDenseMatrix>::M
311    where
312        Eqn: OdeEquationsImplicitAdjoint + 'a,
313        Eqn::M: DefaultSolver,
314        Eqn::V: DefaultDenseMatrix,
315        Method: OdeSolverMethod<'a, Eqn>,
316        BuildForward: Fn(Option<Method::State>) -> Result<Method, DiffsolError>,
317    {
318        let mut soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
319        let first_forward_solver = build_forward(None).unwrap().solve_soln(&mut soln).unwrap();
320        match soln.stop_reason {
321            Some(OdeSolverStopReason::RootFound(_, 0)) => {}
322            Some(OdeSolverStopReason::RootFound(_, idx)) => {
323                panic!("expected first solve_soln() segment to stop on root 0, got root {idx}")
324            }
325            Some(OdeSolverStopReason::TstopReached) => {
326                panic!("expected first solve_soln() segment to stop on the interior root")
327            }
328            Some(OdeSolverStopReason::InternalTimestep) | None => {
329                panic!("first solve_soln() segment did not finish with a terminal stop reason")
330            }
331        }
332
333        let mut state_after_reset = first_forward_solver.state_clone();
334        {
335            let problem = first_forward_solver.problem();
336            state_after_reset
337                .as_mut()
338                .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
339                .unwrap();
340        }
341
342        build_forward(Some(state_after_reset))
343            .unwrap()
344            .solve_soln(&mut soln)
345            .unwrap();
346        assert!(
347            soln.is_complete(),
348            "expected stitched solve_soln() output to cover all requested observation times",
349        );
350        soln.ys
351    }
352
353    fn state_after_manual_reset<'a, Eqn, Method>(solver: &Method) -> Method::State
354    where
355        Eqn: OdeEquationsImplicitAdjoint + 'a,
356        Eqn::M: DefaultSolver,
357        Method: OdeSolverMethod<'a, Eqn>,
358    {
359        let mut state_after_reset = solver.state_clone();
360        {
361            let problem = solver.problem();
362            state_after_reset
363                .as_mut()
364                .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
365                .unwrap();
366        }
367        state_after_reset
368    }
369
370    pub fn setup_test_adjoint_sum_squares_with_single_reset_root<'a, LS, Eqn>(
371        problem: &'a mut OdeSolverProblem<Eqn>,
372        times: &[Eqn::T],
373    ) -> (
374        <Eqn::V as DefaultDenseMatrix>::M,
375        <Eqn::V as DefaultDenseMatrix>::M,
376    )
377    where
378        Eqn: OdeEquationsImplicitAdjoint + 'a,
379        LS: LinearSolver<Eqn::M>,
380        Eqn::M: DefaultSolver,
381        Eqn::V: DefaultDenseMatrix,
382        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
383        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
384    {
385        let nparams = problem.eqn.nparams();
386        let nout = 2;
387        let ctx = problem.eqn.context();
388        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
389
390        let mut p_0 = ctx.vector_zeros(nparams);
391        problem.eqn.get_params(&mut p_0);
392        let h_base = Eqn::T::from_f64(1e-10).unwrap();
393        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
394        h.axpy(h_base, &p_0, Eqn::T::one());
395        let mut p_data = p_0.clone();
396        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
397        let p_base = p_0.clone();
398
399        problem.eqn.set_params(&p_data);
400        let data = solve_dense_with_single_reset_root::<Eqn, _, _>(
401            |state| match state {
402                Some(state) => problem.bdf_solver(state),
403                None => problem.bdf::<LS>(),
404            },
405            times,
406        );
407
408        for i in 0..nparams {
409            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
410            problem.eqn.set_params(&p_0);
411            let g_pos = {
412                let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
413                    |state| match state {
414                        Some(state) => problem.bdf_solver(state),
415                        None => problem.bdf::<LS>(),
416                    },
417                    times,
418                );
419                sum_squares(&v, &data)
420            };
421
422            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
423            problem.eqn.set_params(&p_0);
424            let g_neg = {
425                let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
426                    |state| match state {
427                        Some(state) => problem.bdf_solver(state),
428                        None => problem.bdf::<LS>(),
429                    },
430                    times,
431                );
432                sum_squares(&v, &data)
433            };
434
435            p_0.set_index(i, p_base.get_index(i));
436
437            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
438            for j in 0..nout {
439                dgdp.set_index(i, j, delta.get_index(j));
440            }
441        }
442        problem.eqn.set_params(&p_base);
443        (dgdp, data)
444    }
445
446    pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
447        backwards_solver: SolverB,
448        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
449        forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
450        data: <Eqn::V as DefaultDenseMatrix>::M,
451        times: &[Eqn::T],
452    ) where
453        SolverF: OdeSolverMethod<'a, Eqn>,
454        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
455        Eqn: OdeEquationsImplicitAdjoint + 'a,
456        Eqn::V: DefaultDenseMatrix,
457        Eqn::M: DefaultSolver,
458    {
459        let nparams = dgdp_check.nrows();
460        let dgdu = dsum_squaresdp(&forwards_soln, &data);
461
462        let atol = Eqn::V::from_element(
463            nparams,
464            Eqn::T::from_f64(1e-6).unwrap(),
465            data.context().clone(),
466        );
467        let rtol = Eqn::T::from_f64(1e-6).unwrap();
468        let (state, _) = backwards_solver
469            .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
470            .unwrap();
471        let gs_adj = state.into_common().sg;
472        #[allow(clippy::needless_range_loop)]
473        for j in 0..dgdp_check.ncols() {
474            gs_adj[j].assert_eq_norm(
475                &dgdp_check.column(j).into_owned(),
476                &atol,
477                rtol,
478                Eqn::T::from_f64(260.).unwrap(),
479            );
480        }
481    }
482
483    pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
484        backwards_solver: SolverB,
485        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
486    ) where
487        SolverF: OdeSolverMethod<'a, Eqn>,
488        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
489        Eqn: OdeEquationsImplicitAdjoint + 'a,
490        Eqn::V: DefaultDenseMatrix,
491        Eqn::M: DefaultSolver,
492    {
493        let nout = backwards_solver.problem().eqn.nout();
494        let atol = Eqn::V::from_element(
495            nout,
496            Eqn::T::from_f64(1e-6).unwrap(),
497            dgdp_check.context().clone(),
498        );
499        let rtol = Eqn::T::from_f64(1e-6).unwrap();
500        let (state, _) = backwards_solver
501            .solve_adjoint_backwards_pass(&[], &[])
502            .unwrap();
503        let gs_adj = state.into_common().sg;
504        #[allow(clippy::needless_range_loop)]
505        for j in 0..dgdp_check.ncols() {
506            gs_adj[j].assert_eq_norm(
507                &dgdp_check.column(j).into_owned(),
508                &atol,
509                rtol,
510                Eqn::T::from_f64(40.).unwrap(),
511            );
512        }
513    }
514
515    pub struct TestEqnInit<M: Matrix> {
516        ctx: M::C,
517    }
518
519    impl<M: Matrix> Op for TestEqnInit<M> {
520        type T = M::T;
521        type V = M::V;
522        type M = M;
523        type C = M::C;
524
525        fn nout(&self) -> usize {
526            1
527        }
528        fn nparams(&self) -> usize {
529            1
530        }
531        fn nstates(&self) -> usize {
532            1
533        }
534        fn context(&self) -> &Self::C {
535            &self.ctx
536        }
537    }
538
539    impl<M: Matrix> ConstantOp for TestEqnInit<M> {
540        fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
541            y.fill(M::T::one());
542        }
543    }
544
545    impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
546        fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
547            sens.fill(M::T::zero());
548        }
549    }
550
551    pub struct TestEqnRhs<M: Matrix> {
552        ctx: M::C,
553    }
554
555    impl<M: Matrix> Op for TestEqnRhs<M> {
556        type T = M::T;
557        type V = M::V;
558        type M = M;
559        type C = M::C;
560
561        fn nout(&self) -> usize {
562            1
563        }
564        fn nparams(&self) -> usize {
565            1
566        }
567        fn nstates(&self) -> usize {
568            1
569        }
570        fn context(&self) -> &Self::C {
571            &self.ctx
572        }
573    }
574
575    impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
576        fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
577            y.fill(M::T::zero());
578        }
579    }
580
581    impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
582        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
583            y.fill(M::T::zero());
584        }
585    }
586
587    impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
588        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
589            sens.fill(M::T::zero());
590        }
591    }
592
593    pub struct TestEqnOut<M: Matrix> {
594        ctx: M::C,
595    }
596
597    impl<M: Matrix> Op for TestEqnOut<M> {
598        type T = M::T;
599        type V = M::V;
600        type M = M;
601        type C = M::C;
602
603        fn nout(&self) -> usize {
604            1
605        }
606        fn nparams(&self) -> usize {
607            1
608        }
609        fn nstates(&self) -> usize {
610            1
611        }
612        fn context(&self) -> &Self::C {
613            &self.ctx
614        }
615    }
616
617    impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
618        fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
619            y.copy_from(x);
620        }
621    }
622
623    impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
624        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
625            y.copy_from(v);
626        }
627    }
628
629    impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
630        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
631            sens.fill(M::T::zero());
632        }
633    }
634
635    pub struct TestEqn<M: Matrix> {
636        rhs: Rc<TestEqnRhs<M>>,
637        init: Rc<TestEqnInit<M>>,
638        out: Rc<TestEqnOut<M>>,
639        ctx: M::C,
640    }
641
642    impl<M: Matrix> TestEqn<M> {
643        pub fn new() -> Self {
644            let ctx = M::C::default();
645            Self {
646                rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
647                init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
648                out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
649                ctx,
650            }
651        }
652    }
653
654    impl<M: Matrix> Op for TestEqn<M> {
655        type T = M::T;
656        type V = M::V;
657        type M = M;
658        type C = M::C;
659        fn nout(&self) -> usize {
660            1
661        }
662        fn nparams(&self) -> usize {
663            1
664        }
665        fn nstates(&self) -> usize {
666            1
667        }
668        fn statistics(&self) -> crate::op::OpStatistics {
669            OpStatistics::default()
670        }
671        fn context(&self) -> &Self::C {
672            &self.ctx
673        }
674    }
675
676    impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
677        type Rhs = &'a TestEqnRhs<M>;
678        type Mass = ParameterisedOp<'a, UnitCallable<M>>;
679        type Root = ParameterisedOp<'a, UnitCallable<M>>;
680        type Init = &'a TestEqnInit<M>;
681        type Out = &'a TestEqnOut<M>;
682        type Reset = ParameterisedOp<'a, UnitCallable<M>>;
683    }
684
685    impl<M: Matrix> OdeEquations for TestEqn<M> {
686        fn rhs(&self) -> &TestEqnRhs<M> {
687            &self.rhs
688        }
689
690        fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
691            None
692        }
693
694        fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
695            None
696        }
697
698        fn init(&self) -> &TestEqnInit<M> {
699            &self.init
700        }
701
702        fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
703            Some(&self.out)
704        }
705        fn set_params(&mut self, _p: &Self::V) {
706            unimplemented!()
707        }
708        fn get_params(&self, _p: &mut Self::V) {
709            unimplemented!()
710        }
711    }
712
713    pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
714        let eqn = TestEqn::<M>::new();
715        let atol = eqn
716            .context()
717            .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
718        OdeSolverProblem::new(
719            eqn,
720            M::T::from_f64(1e-6).unwrap(),
721            atol,
722            None,
723            None,
724            None,
725            None,
726            None,
727            None,
728            M::T::zero(),
729            M::T::one(),
730            integrate_out,
731            Default::default(),
732            Default::default(),
733        )
734        .unwrap()
735    }
736
737    pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
738        let state = s.checkpoint();
739        let integrating_sens = !s.state().s.is_empty();
740        let integrating_out = s.problem().integrate_out;
741        let t0 = state.as_ref().t;
742        let t1 = t0 + M::T::from_f64(1e6).unwrap();
743        s.interpolate(t0)
744            .unwrap()
745            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
746        assert!(s.interpolate(t1).is_err());
747        assert!(s.interpolate_out(t1).is_err());
748        if integrating_sens {
749            assert!(s.interpolate_sens(t1).is_err());
750        } else {
751            assert!(s.interpolate_sens(t0).is_ok());
752        }
753        s.step().unwrap();
754        let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
755        assert!(s.interpolate(s.state().t).is_ok());
756        assert!(s.interpolate(tmid).is_ok());
757        if integrating_out {
758            assert!(s.interpolate_out(s.state().t).is_ok());
759        } else {
760            assert!(s.interpolate_out(s.state().t).is_err());
761        }
762        assert!(s.interpolate_sens(s.state().t).is_ok());
763        assert!(s.interpolate(s.state().t + t1).is_err());
764        assert!(s.interpolate_out(s.state().t + t1).is_err());
765        if integrating_sens {
766            assert!(s.interpolate_sens(s.state().t + t1).is_err());
767        } else {
768            assert!(s.interpolate_sens(s.state().t + t1).is_ok());
769        }
770
771        let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
772        assert!(s
773            .interpolate_inplace(s.state().t, &mut y_wrong_length)
774            .is_err());
775        let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
776        assert!(s
777            .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
778            .is_err());
779        let mut s_wrong_length = vec![
780            M::V::zeros(1, s.problem().context().clone()),
781            M::V::zeros(1, s.problem().context().clone()),
782        ];
783        assert!(s
784            .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
785            .is_err());
786        let mut s_wrong_vec_length = if integrating_sens {
787            vec![M::V::zeros(2, s.problem().context().clone())]
788        } else {
789            vec![]
790        };
791        if integrating_sens {
792            assert!(s
793                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
794                .is_err());
795        } else {
796            assert!(s
797                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
798                .is_ok());
799        }
800
801        s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
802        assert!(s.interpolate(s.state().t).is_ok());
803        if integrating_out {
804            assert!(s.interpolate_out(s.state().t).is_ok());
805        }
806        if integrating_sens {
807            assert!(s.interpolate_sens(s.state().t).is_ok());
808        }
809        assert!(s.interpolate(tmid).is_err());
810        assert!(s.interpolate_out(tmid).is_err());
811        if integrating_sens {
812            assert!(s.interpolate_sens(tmid).is_err());
813        } else {
814            assert!(s.interpolate_sens(tmid).is_ok());
815        }
816    }
817
818    pub fn test_interpolate_dy<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(
819        mut s: Method,
820    ) {
821        // Error before first step: t is in the future
822        let t_future = s.state().t + M::T::from_f64(1e6).unwrap();
823        assert!(s.interpolate_dy(t_future).is_err());
824
825        let t0 = s.state().t;
826        s.step().unwrap();
827        let t1 = s.state().t;
828        let dt = t1 - t0;
829        let tmid = t0 + dt / M::T::from_f64(2.0).unwrap();
830
831        // Wrong vector length should return error
832        let mut dy_wrong = M::V::zeros(2, s.problem().context().clone());
833        assert!(s.interpolate_dy_inplace(t1, &mut dy_wrong).is_err());
834
835        // t after current time should return error
836        assert!(s.interpolate_dy(t1 + M::T::from_f64(1e6).unwrap()).is_err());
837
838        // interpolate_dy should be consistent with finite-difference of interpolate (step 1)
839        let eps = dt.abs() * M::T::from_f64(1e-5).unwrap();
840        let y_plus = s.interpolate(tmid + eps).unwrap();
841        let y_minus = s.interpolate(tmid - eps).unwrap();
842        let fd_dy = (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps));
843        let dy = s.interpolate_dy(tmid).unwrap();
844        dy.assert_eq_norm(
845            &fd_dy,
846            &s.problem().atol,
847            s.problem().rtol,
848            M::T::from_f64(1e3).unwrap(),
849        );
850
851        // take a second step and check consistency again
852        let t1 = s.state().t;
853        s.step().unwrap();
854        let t2 = s.state().t;
855        let dt2 = t2 - t1;
856        let tmid2 = t1 + dt2 / M::T::from_f64(2.0).unwrap();
857        let eps2 = dt2.abs() * M::T::from_f64(1e-5).unwrap();
858        let y_plus = s.interpolate(tmid2 + eps2).unwrap();
859        let y_minus = s.interpolate(tmid2 - eps2).unwrap();
860        let fd_dy2 =
861            (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps2));
862        let dy2 = s.interpolate_dy(tmid2).unwrap();
863        dy2.assert_eq_norm(
864            &fd_dy2,
865            &s.problem().atol,
866            s.problem().rtol,
867            M::T::from_f64(1e3).unwrap(),
868        );
869    }
870
871    pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
872        mut s: Method,
873    ) {
874        *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
875        assert_eq!(
876            *s.config().as_base_ref().minimum_timestep,
877            Eqn::T::from_f64(1.0e8).unwrap()
878        );
879        // force a step size reduction
880        *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
881
882        let mut failed = false;
883        for _ in 0..10 {
884            if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
885                s.step()
886            {
887                failed = true;
888                break;
889            }
890        }
891        assert!(failed);
892    }
893
894    pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
895        let state = s.checkpoint();
896        let state2 = s.state();
897        state2
898            .y
899            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
900        s.state_mut()
901            .y
902            .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
903        assert_eq!(
904            s.state_mut().y.get_index(0),
905            M::T::from_f64(std::f64::consts::PI).unwrap()
906        );
907    }
908
909    #[cfg(feature = "diffsl-cranelift")]
910    pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
911    ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
912        crate::OdeBuilder::<M>::new()
913            .build_from_diffsl(
914                "
915            g { 9.81 } h { 10.0 }
916            u_i {
917                x = h,
918                v = 0,
919            }
920            F_i {
921                v,
922                -g,
923            }
924            stop {
925                x,
926            }
927        ",
928            )
929            .unwrap()
930    }
931
932    #[cfg(feature = "diffsl-cranelift")]
933    pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
934    where
935        M: crate::MatrixHost<T = f64>,
936        M: DefaultSolver<T = f64>,
937        M::V: DefaultDenseMatrix<T = f64>,
938        Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
939    {
940        let e = 0.8;
941
942        let final_time = 2.5;
943
944        // solve and apply the remaining doses
945        solver.set_stop_time(final_time).unwrap();
946        loop {
947            match solver.step() {
948                Ok(OdeSolverStopReason::InternalTimestep) => (),
949                Ok(OdeSolverStopReason::RootFound(t, _)) => {
950                    // get the state when the event occurred
951                    let mut y = solver.interpolate(t).unwrap();
952
953                    // update the velocity of the ball
954                    y.set_index(1, y.get_index(1) * -e);
955
956                    // make sure the ball is above the ground
957                    y.set_index(0, y.get_index(0).max(f64::EPSILON));
958
959                    // set the state to the updated state
960                    solver.state_mut().y.copy_from(&y);
961                    solver.state_mut().dy.set_index(0, y.get_index(1));
962                    *solver.state_mut().t = t;
963
964                    break;
965                }
966                Ok(OdeSolverStopReason::TstopReached) => break,
967                Err(_) => panic!("unexpected solver error"),
968            }
969        }
970        // do three more steps after the 1st bound and many sure they are correct
971        let mut x = vec![];
972        let mut v = vec![];
973        let mut t = vec![];
974        for _ in 0..3 {
975            let ret = solver.step();
976            x.push(solver.state().y.get_index(0));
977            v.push(solver.state().y.get_index(1));
978            t.push(solver.state().t);
979            match ret {
980                Ok(OdeSolverStopReason::InternalTimestep) => (),
981                Ok(OdeSolverStopReason::RootFound(_, _)) => {
982                    panic!("should be an internal timestep but found a root")
983                }
984                Ok(OdeSolverStopReason::TstopReached) => break,
985                _ => panic!("should be an internal timestep"),
986            }
987        }
988        (x, v, t)
989    }
990
991    pub fn test_checkpointing<'a, M, Method, Eqn>(
992        soln: OdeSolverSolution<M::V>,
993        mut solver1: Method,
994        mut solver2: Method,
995    ) where
996        M: Matrix + DefaultSolver,
997        Method: OdeSolverMethod<'a, Eqn>,
998        Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
999    {
1000        let half_i = soln.solution_points.len() / 2;
1001        let half_t = soln.solution_points[half_i].t;
1002        while solver1.state().t <= half_t {
1003            solver1.step().unwrap();
1004        }
1005        let checkpoint = solver1.checkpoint();
1006        let checkpoint_t = checkpoint.as_ref().t;
1007        solver2.set_state(checkpoint);
1008
1009        // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution)
1010        for point in soln.solution_points.iter().skip(half_i + 1) {
1011            // point should be past checkpoint
1012            if point.t < checkpoint_t {
1013                continue;
1014            }
1015            while solver2.state().t < point.t {
1016                solver1.step().unwrap();
1017                solver2.step().unwrap();
1018                let time_error = (solver1.state().t - solver2.state().t).abs()
1019                    / (solver1.state().t.abs() * solver1.problem().rtol
1020                        + solver1.problem().atol.get_index(0));
1021                assert!(
1022                    time_error < M::T::from_f64(20.0).unwrap(),
1023                    "time_error: {} at t = {}",
1024                    time_error,
1025                    solver1.state().t
1026                );
1027                solver1.state().y.assert_eq_norm(
1028                    solver2.state().y,
1029                    &solver1.problem().atol,
1030                    solver1.problem().rtol,
1031                    M::T::from_f64(20.0).unwrap(),
1032                );
1033            }
1034            let soln = solver1.interpolate(point.t).unwrap();
1035            soln.assert_eq_norm(
1036                &point.state,
1037                &solver1.problem().atol,
1038                solver1.problem().rtol,
1039                M::T::from_f64(15.0).unwrap(),
1040            );
1041            let soln = solver2.interpolate(point.t).unwrap();
1042            soln.assert_eq_norm(
1043                &point.state,
1044                &solver1.problem().atol,
1045                solver1.problem().rtol,
1046                M::T::from_f64(15.0).unwrap(),
1047            );
1048        }
1049    }
1050
1051    pub fn test_state_mut_on_problem<'a, Eqn, Method>(
1052        mut s: Method,
1053        soln: OdeSolverSolution<Eqn::V>,
1054    ) where
1055        Eqn: OdeEquationsImplicit + 'a,
1056        Eqn::M: DefaultSolver,
1057        Method: OdeSolverMethod<'a, Eqn>,
1058        Eqn::V: DefaultDenseMatrix,
1059    {
1060        // save state and solve for a little bit
1061        let state = s.checkpoint();
1062        s.solve(Eqn::T::one()).unwrap();
1063
1064        // reinit using state_mut
1065        s.state_mut().y.copy_from(state.as_ref().y);
1066        s.state_mut().dy.copy_from(state.as_ref().dy);
1067        *s.state_mut().t = state.as_ref().t;
1068
1069        // solve and check against solution
1070        for point in soln.solution_points.iter() {
1071            while s.state().t < point.t {
1072                s.step().unwrap();
1073            }
1074            let soln = s.interpolate(point.t).unwrap();
1075            let error = soln.clone() - &point.state;
1076            let error_norm = error
1077                .squared_norm(&error, &s.problem().atol, s.problem().rtol)
1078                .sqrt();
1079            assert!(
1080                error_norm < Eqn::T::from_f64(19.0).unwrap(),
1081                "error_norm: {} at t = {}",
1082                error_norm,
1083                point.t
1084            );
1085        }
1086    }
1087
1088    /// Test that `step()` returns `RootFound(t, index)` with the correct root index.
1089    ///
1090    /// The problem must have a root function with **two** outputs and **no** Reset:
1091    ///   - Root 0 fires first (at `t ≈ 5.108`, `y[0] ≈ 0.6` for the exponential-decay test model)
1092    ///   - Root 1 fires second
1093    ///
1094    /// The test asserts that the first `RootFound` reports index 0 and the time
1095    /// matches `t_root_0_expected` within `tol`.
1096    pub fn test_root_found_index<'a, Eqn, Method>(
1097        mut solver: Method,
1098        soln: &OdeSolverSolution<Eqn::V>,
1099        expected_root_index: usize,
1100        tol: Eqn::T,
1101    ) where
1102        Eqn: OdeEquations + 'a,
1103        Method: OdeSolverMethod<'a, Eqn>,
1104    {
1105        let t_root_expected = soln.solution_points[0].t;
1106        solver
1107            .set_stop_time(Eqn::T::from_f64(100.0).unwrap())
1108            .unwrap();
1109        loop {
1110            match solver.step().unwrap() {
1111                // RED: `RootFound` currently has one field; adding `index` makes this fail.
1112                OdeSolverStopReason::RootFound(t, index) => {
1113                    assert_eq!(
1114                        index, expected_root_index,
1115                        "expected root index {expected_root_index} but got {index}",
1116                    );
1117                    assert!(
1118                        (t - t_root_expected).abs() < tol,
1119                        "expected t ≈ {t_root_expected:?}, got {t:?}",
1120                    );
1121                    break;
1122                }
1123                OdeSolverStopReason::TstopReached => {
1124                    panic!("reached tstop without finding a root")
1125                }
1126                OdeSolverStopReason::InternalTimestep => {}
1127            }
1128        }
1129    }
1130
1131    /// Test that `solve()` automatically applies resets at roots and continues
1132    /// integrating until `final_time`.
1133    pub fn test_solve_with_reset<'a, Eqn, Method>(
1134        mut solver: Method,
1135        soln: &OdeSolverSolution<Eqn::V>,
1136    ) where
1137        Eqn: OdeEquationsImplicit + 'a,
1138        Eqn::M: DefaultSolver,
1139        Eqn::V: DefaultDenseMatrix,
1140        Method: OdeSolverMethod<'a, Eqn>,
1141    {
1142        let final_time = Eqn::T::from_f64(100.0).unwrap();
1143        let (ys, ts, stop_reason) = solver.solve(final_time).unwrap();
1144        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1145        let t_last = *ts.last().unwrap();
1146        let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1147        assert!(
1148            (t_last - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1149            "expected solve() to reach final_time ≈ {:?}, got {:?}",
1150            final_time,
1151            t_last,
1152        );
1153        assert!(
1154            (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1155            "expected solver state at final_time ≈ {:?}, got {:?}",
1156            final_time,
1157            solver.state().t,
1158        );
1159
1160        let expected = &soln.solution_points[0];
1161        let root_time_tol = soln.rtol * expected.t.abs() + soln.atol.get_index(0);
1162        let root_col = ts
1163            .iter()
1164            .position(|&t| (t - expected.t).abs() < Eqn::T::from_f64(30.0).unwrap() * root_time_tol)
1165            .expect("expected solve() output to include the second-root/reset time");
1166        let root_expected = Eqn::V::from_element(
1167            expected.state.len(),
1168            Eqn::T::from_f64(0.4).unwrap(),
1169            expected.state.context().clone(),
1170        );
1171        let root_state = ys.column(root_col).into_owned();
1172        let root_error = root_state - &root_expected;
1173        let root_error_norm = root_error
1174            .squared_norm(&root_expected, &soln.atol, soln.rtol)
1175            .sqrt();
1176        let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1177        assert!(
1178            root_error_norm < error_threshold,
1179            "expected reset state y=0.4 at second-root time; WRMS error norm {root_error_norm:?} ≥ {error_threshold:?}",
1180        );
1181
1182        let reset_value = Eqn::T::from_f64(0.4).unwrap();
1183        let reset_tol = Eqn::T::from_f64(30.0).unwrap()
1184            * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1185        let last_reset_col = (0..ts.len())
1186            .rev()
1187            .find(|&i| (ys.get_index(0, i) - reset_value).abs() < reset_tol)
1188            .expect("expected solve() output to include at least one reset state");
1189        let final_time_f64 = final_time.to_f64().unwrap();
1190        let last_reset_time_f64 = ts[last_reset_col].to_f64().unwrap();
1191        let expected_final_value =
1192            Eqn::T::from_f64(0.4 * (-0.1 * (final_time_f64 - last_reset_time_f64)).exp()).unwrap();
1193        let expected_final = Eqn::V::from_element(
1194            expected.state.len(),
1195            expected_final_value,
1196            expected.state.context().clone(),
1197        );
1198        let final_state = ys.column(ts.len() - 1).into_owned();
1199        let final_error = final_state - &expected_final;
1200        let final_error_norm = final_error
1201            .squared_norm(&expected_final, &soln.atol, soln.rtol)
1202            .sqrt();
1203        assert!(
1204            final_error_norm < error_threshold,
1205            "final state mismatch after automatic reset continuation: WRMS error norm {final_error_norm:?} ≥ {error_threshold:?}",
1206        );
1207    }
1208
1209    /// Test that `solve_dense()` automatically applies resets at roots and
1210    /// continues filling the requested evaluation times.
1211    pub fn test_solve_dense_with_reset<'a, Eqn, Method>(
1212        mut solver: Method,
1213        soln: &OdeSolverSolution<Eqn::V>,
1214    ) where
1215        Eqn: OdeEquationsImplicit + 'a,
1216        Eqn::M: DefaultSolver,
1217        Eqn::V: DefaultDenseMatrix,
1218        Method: OdeSolverMethod<'a, Eqn>,
1219    {
1220        let t_stop = soln.solution_points[0].t;
1221        let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1222        let mut probe_solver = solver.clone();
1223        let (probe_ys, probe_ts, probe_stop_reason) = probe_solver.solve(final_time).unwrap();
1224        assert_eq!(probe_stop_reason, OdeSolverStopReason::TstopReached);
1225
1226        let reset_time_tol =
1227            Eqn::T::from_f64(30.0).unwrap() * (soln.rtol * t_stop.abs() + soln.atol.get_index(0));
1228        let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1229        let reset_value = Eqn::T::from_f64(0.4).unwrap();
1230        let reset_value_tol = Eqn::T::from_f64(30.0).unwrap()
1231            * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1232        let reset_col = (0..probe_ts.len())
1233            .find(|&i| {
1234                (probe_ts[i] - t_stop).abs() < reset_time_tol
1235                    && (probe_ys.get_index(0, i) - reset_value).abs() < reset_value_tol
1236            })
1237            .expect("expected solve() probe output to contain the second-root reset state");
1238        let t_event = probe_ts[reset_col];
1239        let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, final_time];
1240
1241        let (ret, stop_reason) = solver.solve_dense(&t_eval).unwrap();
1242        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1243        assert!(
1244            ret.ncols() == t_eval.len(),
1245            "expected solve_dense() to fill all requested evaluation times"
1246        );
1247        let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1248        assert!(
1249            (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1250            "expected solver state at final_time ≈ {:?}, got {:?}",
1251            final_time,
1252            solver.state().t,
1253        );
1254
1255        let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1256        let pre_reset_state = ret.column(1).into_owned();
1257        let pre_reset_error = pre_reset_state - &soln.solution_points[0].state;
1258        let pre_reset_error_norm = pre_reset_error
1259            .squared_norm(&soln.solution_points[0].state, &soln.atol, soln.rtol)
1260            .sqrt();
1261        assert!(
1262            pre_reset_error_norm < error_threshold,
1263            "expected pre-reset state at event time; WRMS norm {pre_reset_error_norm:?} >= {error_threshold:?}",
1264        );
1265
1266        let expected_post_reset_value =
1267            reset_value * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1268        let expected_post_reset = Eqn::V::from_element(
1269            soln.solution_points[0].state.len(),
1270            expected_post_reset_value,
1271            soln.solution_points[0].state.context().clone(),
1272        );
1273        let post_reset_state = ret.column(2).into_owned();
1274        let post_reset_error = post_reset_state - &expected_post_reset;
1275        let post_reset_error_norm = post_reset_error
1276            .squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1277            .sqrt();
1278        assert!(
1279            post_reset_error_norm < error_threshold,
1280            "expected reset state just after event time; WRMS norm {post_reset_error_norm:?} >= {error_threshold:?}",
1281        );
1282    }
1283
1284    /// Test that `solve_dense_sensitivities()` applies root-aware resets and
1285    /// continues filling the requested evaluation times.
1286    pub fn test_solve_dense_sensitivities_with_reset<'a, Eqn, Method>(
1287        mut solver: Method,
1288        soln: &OdeSolverSolution<Eqn::V>,
1289    ) where
1290        Eqn: OdeEquationsImplicitSens + 'a,
1291        Eqn::V: DefaultDenseMatrix,
1292        Eqn::M: DefaultSolver,
1293        Method: SensitivitiesOdeSolverMethod<'a, Eqn>,
1294    {
1295        let t_stop = soln.solution_points[0].t;
1296        let t_event = Eqn::T::from_f64(10.0 * (5.0_f64 / 3.0_f64).ln()).unwrap();
1297
1298        let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1299        let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, t_stop];
1300        let (ret, ret_sens, stop_reason) = solver.solve_dense_sensitivities(&t_eval).unwrap();
1301        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1302        assert_eq!(ret.ncols(), t_eval.len());
1303        for ret_sens_j in &ret_sens {
1304            assert_eq!(ret_sens_j.ncols(), t_eval.len());
1305        }
1306
1307        let error_threshold = Eqn::T::from_f64(100.0).unwrap();
1308        let ctx = soln.solution_points[0].state.context().clone();
1309        let nstates = soln.solution_points[0].state.len();
1310
1311        let post_reset_y = Eqn::T::from_f64(2.6).unwrap()
1312            * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1313        let post_reset_t = t_event + post_event_dt;
1314        let expected_post_reset = Eqn::V::from_element(nstates, post_reset_y, ctx.clone());
1315        let expected_post_reset_sk =
1316            Eqn::V::from_element(nstates, -post_reset_y * post_reset_t, ctx.clone());
1317        let expected_post_reset_sy0 = Eqn::V::from_element(nstates, post_reset_y, ctx);
1318
1319        let col = 2;
1320        let ey = ret.column(col).into_owned() - &expected_post_reset;
1321        let esk = ret_sens[0].column(col).into_owned() - &expected_post_reset_sk;
1322        let esy0 = ret_sens[1].column(col).into_owned() - &expected_post_reset_sy0;
1323        let norm = (ey.squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1324            + esk.squared_norm(&expected_post_reset_sk, &soln.atol, soln.rtol)
1325            + esy0.squared_norm(&expected_post_reset_sy0, &soln.atol, soln.rtol))
1326        .sqrt();
1327        assert!(
1328            norm < error_threshold,
1329            "dense sensitivity mismatch just after reset; combined WRMS {norm:?} >= {error_threshold:?}",
1330        );
1331    }
1332
1333    pub fn test_solve_adjoint_with_single_reset_root<
1334        'a,
1335        Eqn,
1336        MethodF,
1337        MethodB,
1338        BuildForward,
1339        BuildAdjointState,
1340        BuildAdjointFromState,
1341    >(
1342        build_forward: BuildForward,
1343        soln: &OdeSolverSolution<Eqn::V>,
1344        build_adjoint_state: BuildAdjointState,
1345        build_adjoint_from_state: BuildAdjointFromState,
1346        use_replay_solver: bool,
1347    ) where
1348        Eqn: OdeEquationsImplicitAdjoint + 'a,
1349        Eqn::M: DefaultSolver,
1350        Eqn::V: DefaultDenseMatrix,
1351        MethodF: OdeSolverMethod<'a, Eqn>,
1352        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1353        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1354        BuildAdjointState:
1355            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1356        BuildAdjointFromState:
1357            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1358    {
1359        let expected_out = &soln.solution_points[0];
1360        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1361
1362        let mut forward_solver = build_forward(None).unwrap();
1363        let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1364            .solve_with_checkpointing(forward_stop_time, None)
1365            .unwrap();
1366        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1367        assert!(
1368            checkpointers.len() >= 3,
1369            "expected checkpointing path to include the two reset events"
1370        );
1371        let problem = forward_solver.problem();
1372        let post_reset_solver = forward_solver.clone();
1373        let post_reset_root_idx = checkpointers[1]
1374            .terminal_reset_root_idx()
1375            .expect("second reset segment should record its terminal root index");
1376        let final_forward_state = checkpointers[1].last_checkpoint().clone();
1377        let t_second_root = final_forward_state.as_ref().t;
1378
1379        let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1380        let out_norm = out_error
1381            .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1382            .sqrt();
1383        assert!(
1384            out_norm < Eqn::T::from_f64(50.0).unwrap(),
1385            "forward integrated output mismatch at second root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1386            final_forward_state.as_ref().g,
1387            expected_out.state,
1388        );
1389        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1390        assert!(
1391            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1392            "expected second root time ≈ {:?}, got {:?}",
1393            expected_out.t,
1394            t_second_root,
1395        );
1396
1397        let adjoint_checkpointers = checkpointers.into_iter().take(2).collect::<Vec<_>>();
1398
1399        // make a broken adjoint that is missing the reset root metadata on the first segment, which should cause an error
1400        let mut missing_metadata_checkpointers = adjoint_checkpointers.clone();
1401        missing_metadata_checkpointers[0].clear_terminal_reset_root_idx();
1402        let missing_metadata_solver = use_replay_solver.then(|| post_reset_solver.clone());
1403        let mut missing_metadata_adjoint_eqn = problem.adjoint_equations(
1404            missing_metadata_checkpointers,
1405            missing_metadata_solver,
1406            None,
1407        );
1408        let mut missing_metadata_adjoint_state =
1409            build_adjoint_state(&mut missing_metadata_adjoint_eqn).unwrap();
1410        missing_metadata_adjoint_state
1411            .as_mut()
1412            .state_mut_adjoint_terminal_root(
1413                &problem.eqn,
1414                post_reset_root_idx,
1415                &final_forward_state,
1416                problem.integrate_out,
1417            )
1418            .unwrap();
1419        let missing_metadata_adjoint =
1420            build_adjoint_from_state(missing_metadata_adjoint_state, missing_metadata_adjoint_eqn)
1421                .unwrap();
1422        let missing_metadata_err =
1423            match missing_metadata_adjoint.solve_adjoint_backwards_pass(&[], &[]) {
1424                Ok(_) => panic!("expected missing reset metadata error"),
1425                Err(err) => err,
1426            };
1427        assert!(
1428            format!("{missing_metadata_err:?}").contains("Missing reset root metadata"),
1429            "expected missing reset metadata error, got {missing_metadata_err:?}",
1430        );
1431
1432        // now build the correct adjoint and check that it produces the correct gradient
1433        let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1434        let mut adjoint_eqn =
1435            problem.adjoint_equations(adjoint_checkpointers, adjoint_solver, None);
1436        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1437        adjoint_state
1438            .as_mut()
1439            .state_mut_adjoint_terminal_root(
1440                &problem.eqn,
1441                post_reset_root_idx,
1442                &final_forward_state,
1443                problem.integrate_out,
1444            )
1445            .unwrap();
1446        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1447        let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1448
1449        let t0 = problem.t0;
1450        let ctx = problem.context().clone();
1451
1452        let sens_points = soln.sens_solution_points.as_ref().unwrap();
1453        let expected_grad = Eqn::V::from_vec(
1454            sens_points
1455                .iter()
1456                .map(|pts| pts[0].state.get_index(0))
1457                .collect(),
1458            ctx.clone(),
1459        );
1460        let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1461        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1462        assert!(
1463            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1464            "expected adjoint final time {:?}, got {:?}",
1465            t0,
1466            adjoint_state.as_ref().t,
1467        );
1468        adjoint_state.as_ref().sg[0].assert_eq_norm(
1469            &expected_grad,
1470            &atol,
1471            Eqn::T::from_f64(1e-6).unwrap(),
1472            Eqn::T::from_f64(60.0).unwrap(),
1473        );
1474    }
1475
1476    #[allow(clippy::too_many_arguments)]
1477    pub fn test_solve_adjoint_sum_squares_with_single_reset_root<
1478        'a,
1479        Eqn,
1480        MethodF,
1481        MethodB,
1482        BuildForward,
1483        BuildAdjointState,
1484        BuildAdjointFromState,
1485    >(
1486        build_forward: BuildForward,
1487        soln: &OdeSolverSolution<Eqn::V>,
1488        build_adjoint_state: BuildAdjointState,
1489        build_adjoint_from_state: BuildAdjointFromState,
1490        use_replay_solver: bool,
1491        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1492        data: <Eqn::V as DefaultDenseMatrix>::M,
1493        times: &[Eqn::T],
1494    ) where
1495        Eqn: OdeEquationsImplicitAdjoint + 'a,
1496        Eqn::M: DefaultSolver,
1497        Eqn::V: DefaultDenseMatrix,
1498        MethodF: OdeSolverMethod<'a, Eqn>,
1499        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1500        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1501        BuildAdjointState:
1502            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1503        BuildAdjointFromState:
1504            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1505    {
1506        let expected_out = &soln.solution_points[0];
1507        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1508        let forwards_soln =
1509            solve_dense_with_single_reset_root::<Eqn, MethodF, _>(&build_forward, times);
1510        assert_eq!(
1511            forwards_soln.ncols(),
1512            times.len(),
1513            "expected stitched forward samples to cover every requested observation time",
1514        );
1515        let dgdu = dsum_squaresdp(&forwards_soln, &data);
1516        let dgdu_refs = dgdu.iter().collect::<Vec<_>>();
1517
1518        let mut forward_solver = build_forward(None).unwrap();
1519        let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1520            .solve_with_checkpointing(forward_stop_time, None)
1521            .unwrap();
1522        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1523        assert!(
1524            checkpointers.len() >= 3,
1525            "expected checkpointing path to include the two reset events"
1526        );
1527        let problem = forward_solver.problem();
1528        let post_reset_solver = forward_solver.clone();
1529        let post_reset_root_idx = checkpointers[1]
1530            .terminal_reset_root_idx()
1531            .expect("second reset segment should record its terminal root index");
1532        let final_forward_state = checkpointers[1].last_checkpoint().clone();
1533        let t_second_root = final_forward_state.as_ref().t;
1534
1535        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1536        assert!(
1537            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1538            "expected second root time ≈ {:?}, got {:?}",
1539            expected_out.t,
1540            t_second_root,
1541        );
1542
1543        let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1544        let mut adjoint_eqn = problem.adjoint_equations(
1545            checkpointers.into_iter().take(2).collect(),
1546            adjoint_solver,
1547            Some(dgdu.len()),
1548        );
1549        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1550        adjoint_state
1551            .as_mut()
1552            .state_mut_adjoint_terminal_root(
1553                &problem.eqn,
1554                post_reset_root_idx,
1555                &final_forward_state,
1556                problem.integrate_out,
1557            )
1558            .unwrap();
1559        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1560        let (adjoint_state, _) = adjoint
1561            .solve_adjoint_backwards_pass(times, dgdu_refs.as_slice())
1562            .unwrap();
1563
1564        let t0 = problem.t0;
1565        let ctx = problem.context().clone();
1566
1567        let nparams = dgdp_check.nrows();
1568        let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1569        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1570        assert!(
1571            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1572            "expected adjoint final time {:?}, got {:?}",
1573            t0,
1574            adjoint_state.as_ref().t,
1575        );
1576        #[allow(clippy::needless_range_loop)]
1577        for j in 0..dgdp_check.ncols() {
1578            adjoint_state.as_ref().sg[j].assert_eq_norm(
1579                &dgdp_check.column(j).into_owned(),
1580                &atol,
1581                Eqn::T::from_f64(1e-6).unwrap(),
1582                Eqn::T::from_f64(260.0).unwrap(),
1583            );
1584        }
1585    }
1586
1587    pub fn test_solve_soln_adjoint_with_single_reset_root<
1588        'a,
1589        Eqn,
1590        MethodF,
1591        MethodB,
1592        BuildForward,
1593        BuildAdjointState,
1594        BuildAdjointFromState,
1595    >(
1596        build_forward: BuildForward,
1597        soln: &OdeSolverSolution<Eqn::V>,
1598        build_adjoint_state: BuildAdjointState,
1599        build_adjoint_from_state: BuildAdjointFromState,
1600        use_replay_solver: bool,
1601    ) where
1602        Eqn: OdeEquationsImplicitAdjoint + 'a,
1603        Eqn::M: DefaultSolver,
1604        Eqn::V: DefaultDenseMatrix,
1605        MethodF: OdeSolverMethod<'a, Eqn>,
1606        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1607        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1608        BuildAdjointState:
1609            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1610        BuildAdjointFromState:
1611            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1612    {
1613        let expected_out = &soln.solution_points[0];
1614        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1615        let mut forward_soln = Solution::<Eqn::V>::new(forward_stop_time);
1616        let mut checkpointers = Vec::new();
1617
1618        let first_forward_solver = build_forward(None)
1619            .unwrap()
1620            .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1621            .unwrap();
1622        let first_root_idx = match forward_soln.stop_reason {
1623            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1624            Some(reason) => {
1625                panic!("expected first staged solve to stop at reset root, got {reason:?}")
1626            }
1627            None => panic!("first staged solve did not set a stop reason"),
1628        };
1629        assert_eq!(checkpointers.len(), 1);
1630        assert_eq!(
1631            checkpointers[0].terminal_reset_root_idx(),
1632            Some(first_root_idx)
1633        );
1634
1635        let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1636        let terminal_forward_solver = build_forward(Some(state_after_reset))
1637            .unwrap()
1638            .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1639            .unwrap();
1640        let terminal_root_idx = match forward_soln.stop_reason {
1641            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1642            Some(reason) => {
1643                panic!("expected second staged solve to stop at terminal root, got {reason:?}")
1644            }
1645            None => panic!("second staged solve did not set a stop reason"),
1646        };
1647        assert_eq!(checkpointers.len(), 2);
1648        assert_eq!(
1649            checkpointers[1].terminal_reset_root_idx(),
1650            Some(terminal_root_idx)
1651        );
1652
1653        let problem = terminal_forward_solver.problem();
1654        let final_forward_state = terminal_forward_solver.state_clone();
1655        let t_second_root = final_forward_state.as_ref().t;
1656        let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1657        let out_norm = out_error
1658            .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1659            .sqrt();
1660        assert!(
1661            out_norm < Eqn::T::from_f64(50.0).unwrap(),
1662            "forward integrated output mismatch at terminal root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1663            final_forward_state.as_ref().g,
1664            expected_out.state,
1665        );
1666        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1667        assert!(
1668            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1669            "expected terminal root time ≈ {:?}, got {:?}",
1670            expected_out.t,
1671            t_second_root,
1672        );
1673
1674        let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1675        let mut adjoint_eqn = problem.adjoint_equations(checkpointers, adjoint_solver, None);
1676        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1677        adjoint_state
1678            .as_mut()
1679            .state_mut_adjoint_terminal_root(
1680                &problem.eqn,
1681                terminal_root_idx,
1682                &final_forward_state,
1683                problem.integrate_out,
1684            )
1685            .unwrap();
1686        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1687        let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1688
1689        let t0 = problem.t0;
1690        let ctx = problem.context().clone();
1691        let sens_points = soln.sens_solution_points.as_ref().unwrap();
1692        let expected_grad = Eqn::V::from_vec(
1693            sens_points
1694                .iter()
1695                .map(|pts| pts[0].state.get_index(0))
1696                .collect(),
1697            ctx.clone(),
1698        );
1699        let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1700        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1701        assert!(
1702            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1703            "expected adjoint final time {:?}, got {:?}",
1704            t0,
1705            adjoint_state.as_ref().t,
1706        );
1707        adjoint_state.as_ref().sg[0].assert_eq_norm(
1708            &expected_grad,
1709            &atol,
1710            Eqn::T::from_f64(1e-6).unwrap(),
1711            Eqn::T::from_f64(60.0).unwrap(),
1712        );
1713    }
1714
1715    #[allow(clippy::too_many_arguments)]
1716    pub fn test_solve_soln_adjoint_sum_squares_with_single_reset_root<
1717        'a,
1718        Eqn,
1719        MethodF,
1720        MethodB,
1721        BuildForward,
1722        BuildAdjointState,
1723        BuildAdjointFromState,
1724    >(
1725        build_forward: BuildForward,
1726        soln: &OdeSolverSolution<Eqn::V>,
1727        build_adjoint_state: BuildAdjointState,
1728        build_adjoint_from_state: BuildAdjointFromState,
1729        use_replay_solver: bool,
1730        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1731        data: <Eqn::V as DefaultDenseMatrix>::M,
1732        times: &[Eqn::T],
1733    ) where
1734        Eqn: OdeEquationsImplicitAdjoint + 'a,
1735        Eqn::M: DefaultSolver,
1736        Eqn::V: DefaultDenseMatrix,
1737        MethodF: OdeSolverMethod<'a, Eqn>,
1738        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1739        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1740        BuildAdjointState:
1741            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1742        BuildAdjointFromState:
1743            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1744    {
1745        let expected_out = &soln.solution_points[0];
1746        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1747        let mut forward_soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
1748        let mut checkpointers = Vec::new();
1749
1750        let first_forward_solver = build_forward(None)
1751            .unwrap()
1752            .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1753            .unwrap();
1754        let first_root_idx = match forward_soln.stop_reason {
1755            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1756            Some(reason) => {
1757                panic!("expected first staged solve to stop at reset root, got {reason:?}")
1758            }
1759            None => panic!("first staged solve did not set a stop reason"),
1760        };
1761        assert_eq!(checkpointers.len(), 1);
1762        assert_eq!(
1763            checkpointers[0].terminal_reset_root_idx(),
1764            Some(first_root_idx)
1765        );
1766
1767        let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1768        build_forward(Some(state_after_reset.clone()))
1769            .unwrap()
1770            .solve_soln(&mut forward_soln)
1771            .unwrap();
1772        assert!(forward_soln.is_complete());
1773        assert_eq!(
1774            forward_soln.stop_reason,
1775            Some(OdeSolverStopReason::TstopReached)
1776        );
1777
1778        let mut terminal_soln = Solution::<Eqn::V>::new(forward_stop_time);
1779        let terminal_forward_solver = build_forward(Some(state_after_reset))
1780            .unwrap()
1781            .solve_soln_with_checkpointing(&mut terminal_soln, &mut checkpointers, None)
1782            .unwrap();
1783        let terminal_root_idx = match terminal_soln.stop_reason {
1784            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1785            Some(reason) => {
1786                panic!("expected terminal staged solve to stop at root, got {reason:?}")
1787            }
1788            None => panic!("terminal staged solve did not set a stop reason"),
1789        };
1790        assert_eq!(checkpointers.len(), 2);
1791        assert_eq!(
1792            checkpointers.last().unwrap().terminal_reset_root_idx(),
1793            Some(terminal_root_idx)
1794        );
1795
1796        let dgdu_eval = dsum_squaresdp(&forward_soln.ys, &data);
1797        let dgdu_eval_refs = dgdu_eval.iter().collect::<Vec<_>>();
1798        let problem = terminal_forward_solver.problem();
1799        let final_forward_state = terminal_forward_solver.state_clone();
1800        let t_second_root = final_forward_state.as_ref().t;
1801        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1802        assert!(
1803            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1804            "expected terminal root time ≈ {:?}, got {:?}",
1805            expected_out.t,
1806            t_second_root,
1807        );
1808
1809        let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1810        let mut adjoint_eqn =
1811            problem.adjoint_equations(checkpointers, adjoint_solver, Some(dgdu_eval_refs.len()));
1812        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1813        adjoint_state
1814            .as_mut()
1815            .state_mut_adjoint_terminal_root(
1816                &problem.eqn,
1817                terminal_root_idx,
1818                &final_forward_state,
1819                problem.integrate_out,
1820            )
1821            .unwrap();
1822        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1823        let (adjoint_state, _) = adjoint
1824            .solve_adjoint_backwards_pass(times, dgdu_eval_refs.as_slice())
1825            .unwrap();
1826
1827        let t0 = problem.t0;
1828        let ctx = problem.context().clone();
1829        let nparams = dgdp_check.nrows();
1830        let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1831        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1832        assert!(
1833            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1834            "expected adjoint final time {:?}, got {:?}",
1835            t0,
1836            adjoint_state.as_ref().t,
1837        );
1838        #[allow(clippy::needless_range_loop)]
1839        for j in 0..dgdp_check.ncols() {
1840            adjoint_state.as_ref().sg[j].assert_eq_norm(
1841                &dgdp_check.column(j).into_owned(),
1842                &atol,
1843                Eqn::T::from_f64(1e-6).unwrap(),
1844                Eqn::T::from_f64(260.0).unwrap(),
1845            );
1846        }
1847    }
1848}