Skip to main content

diffsol/ode_solver/
mod.rs

1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod no_checkpointing_solver;
11pub mod problem;
12pub mod runge_kutta;
13pub mod sde;
14pub mod sdirk;
15pub mod sdirk_state;
16pub mod sensitivities;
17pub mod solution;
18pub mod state;
19pub mod tableau;
20
21use serde::Serialize;
22use std::fmt::Display;
23
24use crate::ode_solver::jacobian_update::SolverState;
25
26/// Solver statistics shared by all ODE solver methods.
27#[derive(Clone, Debug, Serialize, Default)]
28pub struct OdeSolverStatistics {
29    /// Total Jacobian/LU setups (CVODE `nsetups`); sum of the per-cause counters below.
30    pub number_of_linear_solver_setups: usize,
31    /// Number of time steps taken by the solver.
32    pub number_of_steps: usize,
33    /// Number of local error test failures (steps rejected for excessive local error).
34    pub number_of_error_test_failures: usize,
35    /// Total number of nonlinear (Newton) solver iterations across all steps.
36    pub number_of_nonlinear_solver_iterations: usize,
37    /// Number of nonlinear (Newton) solver convergence failures.
38    pub number_of_nonlinear_solver_fails: usize,
39    /// Jacobian/LU setups triggered by checkpoint or reinitialisation.
40    pub number_of_linear_solver_setups_from_checkpoint: usize,
41    /// Jacobian/LU setups triggered by a first nonlinear convergence failure.
42    pub number_of_linear_solver_setups_from_first_convergence_fail: usize,
43    /// Jacobian/LU setups triggered by a second nonlinear convergence failure.
44    pub number_of_linear_solver_setups_from_second_convergence_fail: usize,
45    /// Jacobian/LU setups triggered by a local error test failure.
46    pub number_of_linear_solver_setups_from_error_test_fail: usize,
47    /// Jacobian/LU setups triggered by the normal step-success heuristic.
48    pub number_of_linear_solver_setups_from_step_success: usize,
49}
50
51impl OdeSolverStatistics {
52    /// Record a Jacobian/LU setup, incrementing the total and the per-cause counter.
53    pub(crate) fn record_linear_solver_setup(&mut self, cause: SolverState) {
54        self.number_of_linear_solver_setups += 1;
55        match cause {
56            SolverState::Checkpoint => self.number_of_linear_solver_setups_from_checkpoint += 1,
57            SolverState::FirstConvergenceFail => {
58                self.number_of_linear_solver_setups_from_first_convergence_fail += 1
59            }
60            SolverState::SecondConvergenceFail => {
61                self.number_of_linear_solver_setups_from_second_convergence_fail += 1
62            }
63            SolverState::ErrorTestFail => {
64                self.number_of_linear_solver_setups_from_error_test_fail += 1
65            }
66            SolverState::StepSuccess => self.number_of_linear_solver_setups_from_step_success += 1,
67        }
68    }
69}
70
71impl Display for OdeSolverStatistics {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}", serde_json::to_string_pretty(self).unwrap())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::rc::Rc;
80
81    use self::problem::OdeSolverSolution;
82
83    use super::*;
84    use crate::error::{DiffsolError, OdeSolverError};
85    use crate::matrix::Matrix;
86    use crate::ode_solver::sensitivities::SensitivitiesOdeSolverMethod;
87    use crate::ode_solver::solution::Solution;
88    use crate::op::unit::UnitCallable;
89    use crate::op::ParameterisedOp;
90    use crate::Scalar;
91    use crate::{
92        op::OpStatistics, AdjointEquations, AdjointOdeSolverMethod, Context, DenseMatrix,
93        MatrixCommon, MatrixRef, NonLinearOp, NonLinearOpJacobian, OdeEquations,
94        OdeEquationsImplicit, OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens,
95        OdeEquationsRef, OdeSolverConfig, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
96        OdeSolverStopReason, Scale, VectorRef, VectorView, VectorViewMut,
97    };
98    use crate::{
99        ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver,
100        NonLinearOpSens, Op, Vector,
101    };
102    use num_traits::{FromPrimitive, One, Signed, ToPrimitive, Zero};
103
104    pub fn test_ode_solver<'a, M, Eqn, Method>(
105        method: &mut Method,
106        solution: OdeSolverSolution<M::V>,
107        override_tol: Option<M::T>,
108        use_tstop: bool,
109        solve_for_sensitivities: bool,
110    ) -> Eqn::V
111    where
112        M: Matrix,
113        Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
114        Method: OdeSolverMethod<'a, Eqn>,
115    {
116        let have_root = method.problem().eqn.root().is_some();
117        for (i, point) in solution.solution_points.iter().enumerate() {
118            let (soln, sens_soln) = if use_tstop {
119                match method.set_stop_time(point.t) {
120                    Ok(_) => loop {
121                        match method.step() {
122                            Ok(OdeSolverStopReason::RootFound(_, _)) => {
123                                assert!(have_root);
124                                return method.state().y.clone();
125                            }
126                            Ok(OdeSolverStopReason::TstopReached) => {
127                                break (method.state().y.clone(), method.state().s.to_vec());
128                            }
129                            _ => (),
130                        }
131                    },
132                    Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
133                }
134            } else {
135                while method.state().t.abs() < point.t.abs() {
136                    if let OdeSolverStopReason::RootFound(t, _) = method.step().unwrap() {
137                        assert!(have_root);
138                        return method.interpolate(t).unwrap();
139                    }
140                }
141                let soln = method.interpolate(point.t).unwrap();
142                let sens_soln = method.interpolate_sens(point.t).unwrap();
143                (soln, sens_soln)
144            };
145            let soln = if let Some(out) = method.problem().eqn.out() {
146                out.call(&soln, point.t)
147            } else {
148                soln
149            };
150            assert_eq!(
151                soln.len(),
152                point.state.len(),
153                "soln.len() != point.state.len()"
154            );
155            if let Some(override_tol) = override_tol {
156                soln.assert_eq_st(&point.state, override_tol);
157            } else {
158                let (rtol, atol) = if method.problem().eqn.out().is_some() {
159                    // problem rtol and atol is on the state, so just use solution tolerance here
160                    (solution.rtol, &solution.atol)
161                } else {
162                    (method.problem().rtol, &method.problem().atol)
163                };
164                let error = soln.clone() - &point.state;
165                let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
166                assert!(
167                    error_norm < M::T::from_f64(20.0).unwrap(),
168                    "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
169                    error_norm,
170                    point.t,
171                    soln,
172                    point.state
173                );
174                if solve_for_sensitivities {
175                    if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
176                        for (j, sens_points) in sens_soln_points.iter().enumerate() {
177                            let sens_point = &sens_points[i];
178                            let sens_soln = &sens_soln[j];
179                            let error = sens_soln.clone() - &sens_point.state;
180                            let error_norm =
181                                error.squared_norm(&sens_point.state, atol, rtol).sqrt();
182                            assert!(
183                                error_norm < M::T::from_f64(29.0).unwrap(),
184                                "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
185                                point.t,
186                                sens_point.state
187                            );
188                        }
189                    }
190                }
191            }
192        }
193        method.state().y.clone()
194    }
195
196    pub fn setup_test_adjoint<'a, LS, Eqn>(
197        problem: &'a mut OdeSolverProblem<Eqn>,
198        soln: OdeSolverSolution<Eqn::V>,
199    ) -> <Eqn::V as DefaultDenseMatrix>::M
200    where
201        Eqn: OdeEquationsImplicitAdjoint + 'a,
202        LS: LinearSolver<Eqn::M>,
203        Eqn::M: DefaultSolver,
204        Eqn::V: DefaultDenseMatrix,
205        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
206        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
207    {
208        let nparams = problem.eqn.nparams();
209        let nout = problem.eqn.nout();
210        let ctx = problem.eqn.context();
211        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
212        let final_time = soln.solution_points.last().unwrap().t;
213        let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
214        problem.eqn.get_params(&mut p_0);
215        let nbatch = p_0.context().nbatch();
216        let h_base = Eqn::T::from_f64(1e-6).unwrap();
217        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
218        h.axpy(h_base, &p_0, Eqn::T::one());
219        let p_base = p_0.clone();
220        for i in 0..nparams {
221            for b in 0..nbatch {
222                let base = p_base.get_batch(b).get_index(i);
223                let hb = h.get_batch(b).get_index(i);
224                p_0.get_batch_mut(b).set_index(i, base + hb);
225            }
226            problem.eqn.set_params(&p_0);
227            let g_pos = {
228                let mut s = problem.bdf::<LS>().unwrap();
229                s.solve(final_time).unwrap();
230                s.state().g.clone()
231            };
232
233            for b in 0..nbatch {
234                let base = p_base.get_batch(b).get_index(i);
235                let hb = h.get_batch(b).get_index(i);
236                p_0.get_batch_mut(b).set_index(i, base - hb);
237            }
238            problem.eqn.set_params(&p_0);
239            let g_neg = {
240                let mut s = problem.bdf::<LS>().unwrap();
241                s.solve(final_time).unwrap();
242                s.state().g.clone()
243            };
244            for b in 0..nbatch {
245                let base = p_base.get_batch(b).get_index(i);
246                p_0.get_batch_mut(b).set_index(i, base);
247            }
248
249            let delta_full = g_pos - g_neg;
250            for b in 0..nbatch {
251                let hb = h.get_batch(b).get_index(i);
252                let denom = Eqn::T::from_f64(2.0).unwrap() * hb;
253                for j in 0..nout {
254                    let delta_val = delta_full.get_batch(b).get_index(j) / denom;
255                    dgdp.set_index(i, b * nout + j, delta_val);
256                }
257            }
258        }
259        problem.eqn.set_params(&p_base);
260        dgdp
261    }
262
263    /// sum_i^n (soln_i - data_i)^2
264    /// sum_i^n (soln_i - data_i)^4
265    pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
266    where
267        DM: DenseMatrix,
268    {
269        let nbatch = soln.context().nbatch();
270        let mut ret = DM::V::zeros(2, soln.context().clone());
271        for j in 0..soln.ncols() {
272            let soln_j = soln.column(j);
273            let data_j = data.column(j);
274            let delta = soln_j - data_j;
275            for b in 0..nbatch {
276                let delta_b = delta.get_batch(b).into_owned();
277                let norm2 = delta_b.norm(2);
278                let norm4 = delta_b.norm(4);
279                let cur0 = ret.get_batch(b).get_index(0);
280                let cur1 = ret.get_batch(b).get_index(1);
281                ret.get_batch_mut(b).set_index(0, cur0 + norm2 * norm2);
282                let norm4_sq = norm4 * norm4;
283                ret.get_batch_mut(b)
284                    .set_index(1, cur1 + norm4_sq * norm4_sq);
285            }
286        }
287        ret
288    }
289
290    /// sum_i^n 2 * (soln_i - data_i)
291    /// sum_i^n 4 * (soln_i - data_i)^3
292    pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
293    where
294        DM: DenseMatrix,
295    {
296        let delta = soln.clone() - data;
297        let mut delta3 = delta.clone();
298        for j in 0..delta3.ncols() {
299            let delta_col = delta.column(j).into_owned();
300
301            let mut delta3_col = delta_col.clone();
302            delta3_col.component_mul_assign(&delta_col);
303            delta3_col.component_mul_assign(&delta_col);
304
305            delta3.column_mut(j).copy_from(&delta3_col);
306        }
307        let ret = vec![
308            delta * Scale(DM::T::from_f64(2.).unwrap()),
309            delta3 * Scale(DM::T::from_f64(4.).unwrap()),
310        ];
311        ret
312    }
313
314    pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
315        problem: &'a mut OdeSolverProblem<Eqn>,
316        times: &[Eqn::T],
317    ) -> (
318        <Eqn::V as DefaultDenseMatrix>::M,
319        <Eqn::V as DefaultDenseMatrix>::M,
320    )
321    where
322        Eqn: OdeEquationsImplicitAdjoint + 'a,
323        LS: LinearSolver<Eqn::M>,
324        Eqn::M: DefaultSolver,
325        Eqn::V: DefaultDenseMatrix,
326        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
327        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
328    {
329        let nparams = problem.eqn.nparams();
330        let nout = 2;
331        let ctx = problem.eqn.context();
332        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
333
334        let mut p_0 = ctx.vector_zeros(nparams);
335        problem.eqn.get_params(&mut p_0);
336        let nbatch = p_0.context().nbatch();
337        let h_base = Eqn::T::from_f64(1e-6).unwrap();
338        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
339        h.axpy(h_base, &p_0, Eqn::T::one());
340        let mut p_data = p_0.clone();
341        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
342        let p_base = p_0.clone();
343
344        problem.eqn.set_params(&p_data);
345        let data = {
346            let mut s = problem.bdf::<LS>().unwrap();
347            s.solve_dense(times).unwrap().0
348        };
349
350        for i in 0..nparams {
351            for b in 0..nbatch {
352                let base = p_base.get_batch(b).get_index(i);
353                let hb = h.get_batch(b).get_index(i);
354                p_0.get_batch_mut(b).set_index(i, base + hb);
355            }
356            problem.eqn.set_params(&p_0);
357            let g_pos = {
358                let mut s = problem.bdf::<LS>().unwrap();
359                let v = s.solve_dense(times).unwrap().0;
360                sum_squares(&v, &data)
361            };
362
363            for b in 0..nbatch {
364                let base = p_base.get_batch(b).get_index(i);
365                let hb = h.get_batch(b).get_index(i);
366                p_0.get_batch_mut(b).set_index(i, base - hb);
367            }
368            problem.eqn.set_params(&p_0);
369            let g_neg = {
370                let mut s = problem.bdf::<LS>().unwrap();
371                let v = s.solve_dense(times).unwrap().0;
372                sum_squares(&v, &data)
373            };
374
375            for b in 0..nbatch {
376                let base = p_base.get_batch(b).get_index(i);
377                p_0.get_batch_mut(b).set_index(i, base);
378            }
379
380            let delta_full = g_pos - g_neg;
381            for b in 0..nbatch {
382                let hb = h.get_batch(b).get_index(i);
383                let denom = Eqn::T::from_f64(2.0).unwrap() * hb;
384                for j in 0..nout {
385                    let delta_val = delta_full.get_batch(b).get_index(j) / denom;
386                    dgdp.set_index(i, b * nout + j, delta_val);
387                }
388            }
389        }
390        problem.eqn.set_params(&p_base);
391        (dgdp, data)
392    }
393
394    pub fn single_reset_root_discrete_times<T: Scalar>(t_stop: T) -> Vec<T> {
395        let t_root = t_stop / T::from_f64(2.0).unwrap();
396        [0.25, 0.75, 1.25, 1.75]
397            .into_iter()
398            .map(|factor| t_root * T::from_f64(factor).unwrap())
399            .collect()
400    }
401
402    fn solve_dense_with_single_reset_root<'a, Eqn, Method, BuildForward>(
403        build_forward: BuildForward,
404        times: &[Eqn::T],
405    ) -> <Eqn::V as DefaultDenseMatrix>::M
406    where
407        Eqn: OdeEquationsImplicitAdjoint + 'a,
408        Eqn::M: DefaultSolver,
409        Eqn::V: DefaultDenseMatrix,
410        Method: OdeSolverMethod<'a, Eqn>,
411        BuildForward: Fn(Option<Method::State>) -> Result<Method, DiffsolError>,
412    {
413        let mut soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
414        let first_forward_solver = build_forward(None).unwrap().solve_soln(&mut soln).unwrap();
415        match soln.stop_reason {
416            Some(OdeSolverStopReason::RootFound(_, 0)) => {}
417            Some(OdeSolverStopReason::RootFound(_, idx)) => {
418                panic!("expected first solve_soln() segment to stop on root 0, got root {idx}")
419            }
420            Some(OdeSolverStopReason::TstopReached) => {
421                panic!("expected first solve_soln() segment to stop on the interior root")
422            }
423            Some(OdeSolverStopReason::InternalTimestep) | None => {
424                panic!("first solve_soln() segment did not finish with a terminal stop reason")
425            }
426        }
427
428        let mut state_after_reset = first_forward_solver.state_clone();
429        {
430            let problem = first_forward_solver.problem();
431            state_after_reset
432                .as_mut()
433                .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
434                .unwrap();
435        }
436
437        build_forward(Some(state_after_reset))
438            .unwrap()
439            .solve_soln(&mut soln)
440            .unwrap();
441        assert!(
442            soln.is_complete(),
443            "expected stitched solve_soln() output to cover all requested observation times",
444        );
445        soln.ys
446    }
447
448    fn state_after_manual_reset<'a, Eqn, Method>(solver: &Method) -> Method::State
449    where
450        Eqn: OdeEquationsImplicitAdjoint + 'a,
451        Eqn::M: DefaultSolver,
452        Method: OdeSolverMethod<'a, Eqn>,
453    {
454        let mut state_after_reset = solver.state_clone();
455        {
456            let problem = solver.problem();
457            state_after_reset
458                .as_mut()
459                .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
460                .unwrap();
461        }
462        state_after_reset
463    }
464
465    pub fn setup_test_adjoint_sum_squares_with_single_reset_root<'a, LS, Eqn>(
466        problem: &'a mut OdeSolverProblem<Eqn>,
467        times: &[Eqn::T],
468    ) -> (
469        <Eqn::V as DefaultDenseMatrix>::M,
470        <Eqn::V as DefaultDenseMatrix>::M,
471    )
472    where
473        Eqn: OdeEquationsImplicitAdjoint + 'a,
474        LS: LinearSolver<Eqn::M>,
475        Eqn::M: DefaultSolver,
476        Eqn::V: DefaultDenseMatrix,
477        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
478        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
479    {
480        let nparams = problem.eqn.nparams();
481        let nout = 2;
482        let ctx = problem.eqn.context();
483        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
484
485        let mut p_0 = ctx.vector_zeros(nparams);
486        problem.eqn.get_params(&mut p_0);
487        let h_base = Eqn::T::from_f64(1e-10).unwrap();
488        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
489        h.axpy(h_base, &p_0, Eqn::T::one());
490        let mut p_data = p_0.clone();
491        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
492        let p_base = p_0.clone();
493
494        problem.eqn.set_params(&p_data);
495        let data = solve_dense_with_single_reset_root::<Eqn, _, _>(
496            |state| match state {
497                Some(state) => problem.bdf_solver(state),
498                None => problem.bdf::<LS>(),
499            },
500            times,
501        );
502
503        for i in 0..nparams {
504            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
505            problem.eqn.set_params(&p_0);
506            let g_pos = {
507                let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
508                    |state| match state {
509                        Some(state) => problem.bdf_solver(state),
510                        None => problem.bdf::<LS>(),
511                    },
512                    times,
513                );
514                sum_squares(&v, &data)
515            };
516
517            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
518            problem.eqn.set_params(&p_0);
519            let g_neg = {
520                let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
521                    |state| match state {
522                        Some(state) => problem.bdf_solver(state),
523                        None => problem.bdf::<LS>(),
524                    },
525                    times,
526                );
527                sum_squares(&v, &data)
528            };
529
530            p_0.set_index(i, p_base.get_index(i));
531
532            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
533            for j in 0..nout {
534                dgdp.set_index(i, j, delta.get_index(j));
535            }
536        }
537        problem.eqn.set_params(&p_base);
538        (dgdp, data)
539    }
540
541    pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
542        backwards_solver: SolverB,
543        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
544        forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
545        data: <Eqn::V as DefaultDenseMatrix>::M,
546        times: &[Eqn::T],
547    ) where
548        SolverF: OdeSolverMethod<'a, Eqn>,
549        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
550        Eqn: OdeEquationsImplicitAdjoint + 'a,
551        Eqn::V: DefaultDenseMatrix,
552        Eqn::M: DefaultSolver,
553    {
554        let nparams = dgdp_check.nrows();
555        let dgdu = dsum_squaresdp(&forwards_soln, &data);
556
557        let atol = Eqn::V::from_element(
558            nparams,
559            Eqn::T::from_f64(1e-6).unwrap(),
560            data.context().clone(),
561        );
562        let rtol = Eqn::T::from_f64(1e-6).unwrap();
563        let (state, _) = backwards_solver
564            .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
565            .unwrap();
566        let gs_adj = state.into_common().sg;
567        #[allow(clippy::needless_range_loop)]
568        for j in 0..dgdp_check.ncols() {
569            gs_adj[j].assert_eq_norm(
570                &dgdp_check.column(j).into_owned(),
571                &atol,
572                rtol,
573                Eqn::T::from_f64(260.).unwrap(),
574            );
575        }
576    }
577
578    pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
579        backwards_solver: SolverB,
580        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
581    ) where
582        SolverF: OdeSolverMethod<'a, Eqn>,
583        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
584        Eqn: OdeEquationsImplicitAdjoint + 'a,
585        Eqn::V: DefaultDenseMatrix,
586        Eqn::M: DefaultSolver,
587    {
588        let nout = backwards_solver.problem().eqn.nout();
589        let atol = Eqn::V::from_element(
590            nout,
591            Eqn::T::from_f64(1e-6).unwrap(),
592            dgdp_check.context().clone(),
593        );
594        let rtol = Eqn::T::from_f64(1e-6).unwrap();
595        let (state, _) = backwards_solver
596            .solve_adjoint_backwards_pass(&[], &[])
597            .unwrap();
598        let gs_adj = state.into_common().sg;
599        #[allow(clippy::needless_range_loop)]
600        for j in 0..dgdp_check.ncols() {
601            gs_adj[j].assert_eq_norm(
602                &dgdp_check.column(j).into_owned(),
603                &atol,
604                rtol,
605                Eqn::T::from_f64(40.).unwrap(),
606            );
607        }
608    }
609
610    pub struct TestEqnInit<M: Matrix> {
611        ctx: M::C,
612    }
613
614    impl<M: Matrix> Op for TestEqnInit<M> {
615        type T = M::T;
616        type V = M::V;
617        type M = M;
618        type C = M::C;
619
620        fn nout(&self) -> usize {
621            1
622        }
623        fn nparams(&self) -> usize {
624            1
625        }
626        fn nstates(&self) -> usize {
627            1
628        }
629        fn context(&self) -> &Self::C {
630            &self.ctx
631        }
632    }
633
634    impl<M: Matrix> ConstantOp for TestEqnInit<M> {
635        fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
636            y.fill(M::T::one());
637        }
638    }
639
640    impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
641        fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
642            sens.fill(M::T::zero());
643        }
644    }
645
646    pub struct TestEqnRhs<M: Matrix> {
647        ctx: M::C,
648    }
649
650    impl<M: Matrix> Op for TestEqnRhs<M> {
651        type T = M::T;
652        type V = M::V;
653        type M = M;
654        type C = M::C;
655
656        fn nout(&self) -> usize {
657            1
658        }
659        fn nparams(&self) -> usize {
660            1
661        }
662        fn nstates(&self) -> usize {
663            1
664        }
665        fn context(&self) -> &Self::C {
666            &self.ctx
667        }
668    }
669
670    impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
671        fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
672            y.fill(M::T::zero());
673        }
674    }
675
676    impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
677        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
678            y.fill(M::T::zero());
679        }
680    }
681
682    impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
683        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
684            sens.fill(M::T::zero());
685        }
686    }
687
688    pub struct TestEqnOut<M: Matrix> {
689        ctx: M::C,
690    }
691
692    impl<M: Matrix> Op for TestEqnOut<M> {
693        type T = M::T;
694        type V = M::V;
695        type M = M;
696        type C = M::C;
697
698        fn nout(&self) -> usize {
699            1
700        }
701        fn nparams(&self) -> usize {
702            1
703        }
704        fn nstates(&self) -> usize {
705            1
706        }
707        fn context(&self) -> &Self::C {
708            &self.ctx
709        }
710    }
711
712    impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
713        fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
714            y.copy_from(x);
715        }
716    }
717
718    impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
719        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
720            y.copy_from(v);
721        }
722    }
723
724    impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
725        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
726            sens.fill(M::T::zero());
727        }
728    }
729
730    pub struct TestEqn<M: Matrix> {
731        rhs: Rc<TestEqnRhs<M>>,
732        init: Rc<TestEqnInit<M>>,
733        out: Rc<TestEqnOut<M>>,
734        ctx: M::C,
735    }
736
737    impl<M: Matrix> TestEqn<M> {
738        pub fn new() -> Self {
739            let ctx = M::C::default();
740            Self {
741                rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
742                init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
743                out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
744                ctx,
745            }
746        }
747    }
748
749    impl<M: Matrix> Op for TestEqn<M> {
750        type T = M::T;
751        type V = M::V;
752        type M = M;
753        type C = M::C;
754        fn nout(&self) -> usize {
755            1
756        }
757        fn nparams(&self) -> usize {
758            1
759        }
760        fn nstates(&self) -> usize {
761            1
762        }
763        fn statistics(&self) -> crate::op::OpStatistics {
764            OpStatistics::default()
765        }
766        fn context(&self) -> &Self::C {
767            &self.ctx
768        }
769    }
770
771    impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
772        type Rhs = &'a TestEqnRhs<M>;
773        type Mass = ParameterisedOp<'a, UnitCallable<M>>;
774        type Root = ParameterisedOp<'a, UnitCallable<M>>;
775        type Init = &'a TestEqnInit<M>;
776        type Out = &'a TestEqnOut<M>;
777        type Reset = ParameterisedOp<'a, UnitCallable<M>>;
778    }
779
780    impl<M: Matrix> OdeEquations for TestEqn<M> {
781        fn rhs(&self) -> &TestEqnRhs<M> {
782            &self.rhs
783        }
784
785        fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
786            None
787        }
788
789        fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
790            None
791        }
792
793        fn init(&self) -> &TestEqnInit<M> {
794            &self.init
795        }
796
797        fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
798            Some(&self.out)
799        }
800        fn set_params(&mut self, _p: &Self::V) {
801            unimplemented!()
802        }
803        fn get_params(&self, _p: &mut Self::V) {
804            unimplemented!()
805        }
806    }
807
808    pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
809        let eqn = TestEqn::<M>::new();
810        let atol = eqn
811            .context()
812            .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
813        OdeSolverProblem::new(
814            eqn,
815            M::T::from_f64(1e-6).unwrap(),
816            atol,
817            None,
818            None,
819            None,
820            None,
821            None,
822            None,
823            M::T::zero(),
824            M::T::one(),
825            integrate_out,
826            Default::default(),
827            Default::default(),
828        )
829        .unwrap()
830    }
831
832    pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
833        let state = s.checkpoint();
834        let integrating_sens = !s.state().s.is_empty();
835        let integrating_out = s.problem().integrate_out;
836        let t0 = state.as_ref().t;
837        let t1 = t0 + M::T::from_f64(1e6).unwrap();
838        s.interpolate(t0)
839            .unwrap()
840            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
841        assert!(s.interpolate(t1).is_err());
842        assert!(s.interpolate_out(t1).is_err());
843        if integrating_sens {
844            assert!(s.interpolate_sens(t1).is_err());
845        } else {
846            assert!(s.interpolate_sens(t0).is_ok());
847        }
848        s.step().unwrap();
849        let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
850        assert!(s.interpolate(s.state().t).is_ok());
851        assert!(s.interpolate(tmid).is_ok());
852        if integrating_out {
853            assert!(s.interpolate_out(s.state().t).is_ok());
854        } else {
855            assert!(s.interpolate_out(s.state().t).is_err());
856        }
857        assert!(s.interpolate_sens(s.state().t).is_ok());
858        assert!(s.interpolate(s.state().t + t1).is_err());
859        assert!(s.interpolate_out(s.state().t + t1).is_err());
860        if integrating_sens {
861            assert!(s.interpolate_sens(s.state().t + t1).is_err());
862        } else {
863            assert!(s.interpolate_sens(s.state().t + t1).is_ok());
864        }
865
866        let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
867        assert!(s
868            .interpolate_inplace(s.state().t, &mut y_wrong_length)
869            .is_err());
870        let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
871        assert!(s
872            .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
873            .is_err());
874        let mut s_wrong_length = vec![
875            M::V::zeros(1, s.problem().context().clone()),
876            M::V::zeros(1, s.problem().context().clone()),
877        ];
878        assert!(s
879            .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
880            .is_err());
881        let mut s_wrong_vec_length = if integrating_sens {
882            vec![M::V::zeros(2, s.problem().context().clone())]
883        } else {
884            vec![]
885        };
886        if integrating_sens {
887            assert!(s
888                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
889                .is_err());
890        } else {
891            assert!(s
892                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
893                .is_ok());
894        }
895
896        s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
897        assert!(s.interpolate(s.state().t).is_ok());
898        if integrating_out {
899            assert!(s.interpolate_out(s.state().t).is_ok());
900        }
901        if integrating_sens {
902            assert!(s.interpolate_sens(s.state().t).is_ok());
903        }
904        assert!(s.interpolate(tmid).is_err());
905        assert!(s.interpolate_out(tmid).is_err());
906        if integrating_sens {
907            assert!(s.interpolate_sens(tmid).is_err());
908        } else {
909            assert!(s.interpolate_sens(tmid).is_ok());
910        }
911    }
912
913    pub fn test_interpolate_dy<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(
914        mut s: Method,
915    ) {
916        // Error before first step: t is in the future
917        let t_future = s.state().t + M::T::from_f64(1e6).unwrap();
918        assert!(s.interpolate_dy(t_future).is_err());
919
920        let t0 = s.state().t;
921        s.step().unwrap();
922        let t1 = s.state().t;
923        let dt = t1 - t0;
924        let tmid = t0 + dt / M::T::from_f64(2.0).unwrap();
925
926        // Wrong vector length should return error
927        let mut dy_wrong = M::V::zeros(2, s.problem().context().clone());
928        assert!(s.interpolate_dy_inplace(t1, &mut dy_wrong).is_err());
929
930        // t after current time should return error
931        assert!(s.interpolate_dy(t1 + M::T::from_f64(1e6).unwrap()).is_err());
932
933        // interpolate_dy should be consistent with finite-difference of interpolate (step 1)
934        let eps = dt.abs() * M::T::from_f64(1e-5).unwrap();
935        let y_plus = s.interpolate(tmid + eps).unwrap();
936        let y_minus = s.interpolate(tmid - eps).unwrap();
937        let fd_dy = (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps));
938        let dy = s.interpolate_dy(tmid).unwrap();
939        dy.assert_eq_norm(
940            &fd_dy,
941            &s.problem().atol,
942            s.problem().rtol,
943            M::T::from_f64(1e3).unwrap(),
944        );
945
946        // take a second step and check consistency again
947        let t1 = s.state().t;
948        s.step().unwrap();
949        let t2 = s.state().t;
950        let dt2 = t2 - t1;
951        let tmid2 = t1 + dt2 / M::T::from_f64(2.0).unwrap();
952        let eps2 = dt2.abs() * M::T::from_f64(1e-5).unwrap();
953        let y_plus = s.interpolate(tmid2 + eps2).unwrap();
954        let y_minus = s.interpolate(tmid2 - eps2).unwrap();
955        let fd_dy2 =
956            (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps2));
957        let dy2 = s.interpolate_dy(tmid2).unwrap();
958        dy2.assert_eq_norm(
959            &fd_dy2,
960            &s.problem().atol,
961            s.problem().rtol,
962            M::T::from_f64(1e3).unwrap(),
963        );
964    }
965
966    pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
967        mut s: Method,
968    ) {
969        *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
970        assert_eq!(
971            *s.config().as_base_ref().minimum_timestep,
972            Eqn::T::from_f64(1.0e8).unwrap()
973        );
974        // force a step size reduction
975        *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
976
977        let mut failed = false;
978        for _ in 0..10 {
979            if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
980                s.step()
981            {
982                failed = true;
983                break;
984            }
985        }
986        assert!(failed);
987    }
988
989    pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
990        let state = s.checkpoint();
991        let state2 = s.state();
992        state2
993            .y
994            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
995        s.state_mut()
996            .y
997            .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
998        assert_eq!(
999            s.state_mut().y.get_index(0),
1000            M::T::from_f64(std::f64::consts::PI).unwrap()
1001        );
1002    }
1003
1004    #[cfg(feature = "diffsl-cranelift")]
1005    pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
1006    ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
1007        crate::OdeBuilder::<M>::new()
1008            .build_from_diffsl(
1009                "
1010            g { 9.81 } h { 10.0 }
1011            u_i {
1012                x = h,
1013                v = 0,
1014            }
1015            F_i {
1016                v,
1017                -g,
1018            }
1019            stop {
1020                x,
1021            }
1022        ",
1023            )
1024            .unwrap()
1025    }
1026
1027    #[cfg(feature = "diffsl-cranelift")]
1028    pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
1029    where
1030        M: crate::MatrixHost<T = f64>,
1031        M: DefaultSolver<T = f64>,
1032        M::V: DefaultDenseMatrix<T = f64>,
1033        Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
1034    {
1035        let e = 0.8;
1036
1037        let final_time = 2.5;
1038
1039        // solve and apply the remaining doses
1040        solver.set_stop_time(final_time).unwrap();
1041        loop {
1042            match solver.step() {
1043                Ok(OdeSolverStopReason::InternalTimestep) => (),
1044                Ok(OdeSolverStopReason::RootFound(t, _)) => {
1045                    // get the state when the event occurred
1046                    let mut y = solver.interpolate(t).unwrap();
1047
1048                    // update the velocity of the ball
1049                    y.set_index(1, y.get_index(1) * -e);
1050
1051                    // make sure the ball is above the ground
1052                    y.set_index(0, y.get_index(0).max(f64::EPSILON));
1053
1054                    // set the state to the updated state
1055                    solver.state_mut().y.copy_from(&y);
1056                    solver.state_mut().dy.set_index(0, y.get_index(1));
1057                    *solver.state_mut().t = t;
1058
1059                    break;
1060                }
1061                Ok(OdeSolverStopReason::TstopReached) => break,
1062                Err(_) => panic!("unexpected solver error"),
1063            }
1064        }
1065        // do three more steps after the 1st bound and many sure they are correct
1066        let mut x = vec![];
1067        let mut v = vec![];
1068        let mut t = vec![];
1069        for _ in 0..3 {
1070            let ret = solver.step();
1071            x.push(solver.state().y.get_index(0));
1072            v.push(solver.state().y.get_index(1));
1073            t.push(solver.state().t);
1074            match ret {
1075                Ok(OdeSolverStopReason::InternalTimestep) => (),
1076                Ok(OdeSolverStopReason::RootFound(_, _)) => {
1077                    panic!("should be an internal timestep but found a root")
1078                }
1079                Ok(OdeSolverStopReason::TstopReached) => break,
1080                _ => panic!("should be an internal timestep"),
1081            }
1082        }
1083        (x, v, t)
1084    }
1085
1086    pub fn test_checkpointing<'a, M, Method, Eqn>(
1087        soln: OdeSolverSolution<M::V>,
1088        mut solver1: Method,
1089        mut solver2: Method,
1090    ) where
1091        M: Matrix + DefaultSolver,
1092        Method: OdeSolverMethod<'a, Eqn>,
1093        Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
1094    {
1095        let half_i = soln.solution_points.len() / 2;
1096        let half_t = soln.solution_points[half_i].t;
1097        while solver1.state().t <= half_t {
1098            solver1.step().unwrap();
1099        }
1100        let checkpoint = solver1.checkpoint();
1101        let checkpoint_t = checkpoint.as_ref().t;
1102        solver2.set_state(checkpoint);
1103
1104        // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution)
1105        for point in soln.solution_points.iter().skip(half_i + 1) {
1106            // point should be past checkpoint
1107            if point.t < checkpoint_t {
1108                continue;
1109            }
1110            while solver2.state().t < point.t {
1111                solver1.step().unwrap();
1112                solver2.step().unwrap();
1113                let time_error = (solver1.state().t - solver2.state().t).abs()
1114                    / (solver1.state().t.abs() * solver1.problem().rtol
1115                        + solver1.problem().atol.get_index(0));
1116                assert!(
1117                    time_error < M::T::from_f64(20.0).unwrap(),
1118                    "time_error: {} at t = {}",
1119                    time_error,
1120                    solver1.state().t
1121                );
1122                solver1.state().y.assert_eq_norm(
1123                    solver2.state().y,
1124                    &solver1.problem().atol,
1125                    solver1.problem().rtol,
1126                    M::T::from_f64(20.0).unwrap(),
1127                );
1128            }
1129            let soln = solver1.interpolate(point.t).unwrap();
1130            soln.assert_eq_norm(
1131                &point.state,
1132                &solver1.problem().atol,
1133                solver1.problem().rtol,
1134                M::T::from_f64(15.0).unwrap(),
1135            );
1136            let soln = solver2.interpolate(point.t).unwrap();
1137            soln.assert_eq_norm(
1138                &point.state,
1139                &solver1.problem().atol,
1140                solver1.problem().rtol,
1141                M::T::from_f64(15.0).unwrap(),
1142            );
1143        }
1144    }
1145
1146    pub fn test_state_mut_on_problem<'a, Eqn, Method>(
1147        mut s: Method,
1148        soln: OdeSolverSolution<Eqn::V>,
1149    ) where
1150        Eqn: OdeEquationsImplicit + 'a,
1151        Eqn::M: DefaultSolver,
1152        Method: OdeSolverMethod<'a, Eqn>,
1153        Eqn::V: DefaultDenseMatrix,
1154    {
1155        // save state and solve for a little bit
1156        let state = s.checkpoint();
1157        s.solve(Eqn::T::one()).unwrap();
1158
1159        // reinit using state_mut
1160        s.state_mut().y.copy_from(state.as_ref().y);
1161        s.state_mut().dy.copy_from(state.as_ref().dy);
1162        *s.state_mut().t = state.as_ref().t;
1163
1164        // solve and check against solution
1165        for point in soln.solution_points.iter() {
1166            while s.state().t < point.t {
1167                s.step().unwrap();
1168            }
1169            let soln = s.interpolate(point.t).unwrap();
1170            let error = soln.clone() - &point.state;
1171            let error_norm = error
1172                .squared_norm(&error, &s.problem().atol, s.problem().rtol)
1173                .sqrt();
1174            assert!(
1175                error_norm < Eqn::T::from_f64(19.0).unwrap(),
1176                "error_norm: {} at t = {}",
1177                error_norm,
1178                point.t
1179            );
1180        }
1181    }
1182
1183    /// Test that `step()` returns `RootFound(t, index)` with the correct root index.
1184    ///
1185    /// The problem must have a root function with **two** outputs and **no** Reset:
1186    ///   - Root 0 fires first (at `t ≈ 5.108`, `y[0] ≈ 0.6` for the exponential-decay test model)
1187    ///   - Root 1 fires second
1188    ///
1189    /// The test asserts that the first `RootFound` reports index 0 and the time
1190    /// matches `t_root_0_expected` within `tol`.
1191    pub fn test_root_found_index<'a, Eqn, Method>(
1192        mut solver: Method,
1193        soln: &OdeSolverSolution<Eqn::V>,
1194        expected_root_index: usize,
1195        tol: Eqn::T,
1196    ) where
1197        Eqn: OdeEquations + 'a,
1198        Method: OdeSolverMethod<'a, Eqn>,
1199    {
1200        let t_root_expected = soln.solution_points[0].t;
1201        solver
1202            .set_stop_time(Eqn::T::from_f64(100.0).unwrap())
1203            .unwrap();
1204        loop {
1205            match solver.step().unwrap() {
1206                // RED: `RootFound` currently has one field; adding `index` makes this fail.
1207                OdeSolverStopReason::RootFound(t, index) => {
1208                    assert_eq!(
1209                        index, expected_root_index,
1210                        "expected root index {expected_root_index} but got {index}",
1211                    );
1212                    assert!(
1213                        (t - t_root_expected).abs() < tol,
1214                        "expected t ≈ {t_root_expected:?}, got {t:?}",
1215                    );
1216                    break;
1217                }
1218                OdeSolverStopReason::TstopReached => {
1219                    panic!("reached tstop without finding a root")
1220                }
1221                OdeSolverStopReason::InternalTimestep => {}
1222            }
1223        }
1224    }
1225
1226    /// Test that `solve()` automatically applies resets at roots and continues
1227    /// integrating until `final_time`.
1228    pub fn test_solve_with_reset<'a, Eqn, Method>(
1229        mut solver: Method,
1230        soln: &OdeSolverSolution<Eqn::V>,
1231        final_time: Eqn::T,
1232    ) where
1233        Eqn: OdeEquationsImplicit + 'a,
1234        Eqn::M: DefaultSolver,
1235        Eqn::V: DefaultDenseMatrix,
1236        Method: OdeSolverMethod<'a, Eqn>,
1237    {
1238        let (ys, ts, stop_reason) = solver.solve(final_time).unwrap();
1239        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1240        let t_last = *ts.last().unwrap();
1241        let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1242        assert!(
1243            (t_last - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1244            "expected solve() to reach final_time ≈ {:?}, got {:?}",
1245            final_time,
1246            t_last,
1247        );
1248        assert!(
1249            (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1250            "expected solver state at final_time ≈ {:?}, got {:?}",
1251            final_time,
1252            solver.state().t,
1253        );
1254
1255        let expected = &soln.solution_points[0];
1256        let root_time_tol = soln.rtol * expected.t.abs() + soln.atol.get_index(0);
1257        let root_col = ts
1258            .iter()
1259            .position(|&t| (t - expected.t).abs() < Eqn::T::from_f64(30.0).unwrap() * root_time_tol)
1260            .expect("expected solve() output to include the second-root/reset time");
1261        let root_expected = Eqn::V::from_element(
1262            expected.state.len(),
1263            Eqn::T::from_f64(0.4).unwrap(),
1264            expected.state.context().clone(),
1265        );
1266        let root_state = ys.column(root_col).into_owned();
1267        let root_error = root_state - &root_expected;
1268        let root_error_norm = root_error
1269            .squared_norm(&root_expected, &soln.atol, soln.rtol)
1270            .sqrt();
1271        let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1272        assert!(
1273            root_error_norm < error_threshold,
1274            "expected reset state y=0.4 at second-root time; WRMS error norm {root_error_norm:?} ≥ {error_threshold:?}",
1275        );
1276
1277        let reset_value = Eqn::T::from_f64(0.4).unwrap();
1278        let reset_tol = Eqn::T::from_f64(30.0).unwrap()
1279            * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1280        let last_reset_col = (0..ts.len())
1281            .rev()
1282            .find(|&i| (ys.get_index(0, i) - reset_value).abs() < reset_tol)
1283            .expect("expected solve() output to include at least one reset state");
1284        let final_time_f64 = final_time.to_f64().unwrap();
1285        let last_reset_time_f64 = ts[last_reset_col].to_f64().unwrap();
1286        let expected_final_value =
1287            Eqn::T::from_f64(0.4 * (-0.1 * (final_time_f64 - last_reset_time_f64)).exp()).unwrap();
1288        let expected_final = Eqn::V::from_element(
1289            expected.state.len(),
1290            expected_final_value,
1291            expected.state.context().clone(),
1292        );
1293        let final_state = ys.column(ts.len() - 1).into_owned();
1294        let final_error = final_state - &expected_final;
1295        let final_error_norm = final_error
1296            .squared_norm(&expected_final, &soln.atol, soln.rtol)
1297            .sqrt();
1298        assert!(
1299            final_error_norm < error_threshold,
1300            "final state mismatch after automatic reset continuation: WRMS error norm {final_error_norm:?} ≥ {error_threshold:?}",
1301        );
1302    }
1303
1304    /// Test that `solve_dense()` automatically applies resets at roots and
1305    /// continues filling the requested evaluation times.
1306    pub fn test_solve_dense_with_reset<'a, Eqn, Method>(
1307        mut solver: Method,
1308        soln: &OdeSolverSolution<Eqn::V>,
1309    ) where
1310        Eqn: OdeEquationsImplicit + 'a,
1311        Eqn::M: DefaultSolver,
1312        Eqn::V: DefaultDenseMatrix,
1313        Method: OdeSolverMethod<'a, Eqn>,
1314    {
1315        let t_stop = soln.solution_points[0].t;
1316        let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1317        let mut probe_solver = solver.clone();
1318        let (probe_ys, probe_ts, probe_stop_reason) = probe_solver.solve(final_time).unwrap();
1319        assert_eq!(probe_stop_reason, OdeSolverStopReason::TstopReached);
1320
1321        let reset_time_tol =
1322            Eqn::T::from_f64(30.0).unwrap() * (soln.rtol * t_stop.abs() + soln.atol.get_index(0));
1323        let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1324        let reset_value = Eqn::T::from_f64(0.4).unwrap();
1325        let reset_value_tol = Eqn::T::from_f64(30.0).unwrap()
1326            * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1327        let reset_col = (0..probe_ts.len())
1328            .find(|&i| {
1329                (probe_ts[i] - t_stop).abs() < reset_time_tol
1330                    && (probe_ys.get_index(0, i) - reset_value).abs() < reset_value_tol
1331            })
1332            .expect("expected solve() probe output to contain the second-root reset state");
1333        let t_event = probe_ts[reset_col];
1334        let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, final_time];
1335
1336        let (ret, stop_reason) = solver.solve_dense(&t_eval).unwrap();
1337        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1338        assert!(
1339            ret.ncols() == t_eval.len(),
1340            "expected solve_dense() to fill all requested evaluation times"
1341        );
1342        let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1343        assert!(
1344            (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1345            "expected solver state at final_time ≈ {:?}, got {:?}",
1346            final_time,
1347            solver.state().t,
1348        );
1349
1350        let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1351        let pre_reset_state = ret.column(1).into_owned();
1352        let pre_reset_error = pre_reset_state - &soln.solution_points[0].state;
1353        let pre_reset_error_norm = pre_reset_error
1354            .squared_norm(&soln.solution_points[0].state, &soln.atol, soln.rtol)
1355            .sqrt();
1356        assert!(
1357            pre_reset_error_norm < error_threshold,
1358            "expected pre-reset state at event time; WRMS norm {pre_reset_error_norm:?} >= {error_threshold:?}",
1359        );
1360
1361        let expected_post_reset_value =
1362            reset_value * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1363        let expected_post_reset = Eqn::V::from_element(
1364            soln.solution_points[0].state.len(),
1365            expected_post_reset_value,
1366            soln.solution_points[0].state.context().clone(),
1367        );
1368        let post_reset_state = ret.column(2).into_owned();
1369        let post_reset_error = post_reset_state - &expected_post_reset;
1370        let post_reset_error_norm = post_reset_error
1371            .squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1372            .sqrt();
1373        assert!(
1374            post_reset_error_norm < error_threshold,
1375            "expected reset state just after event time; WRMS norm {post_reset_error_norm:?} >= {error_threshold:?}",
1376        );
1377    }
1378
1379    /// Test that `solve_dense_sensitivities()` applies root-aware resets and
1380    /// continues filling the requested evaluation times.
1381    pub fn test_solve_dense_sensitivities_with_reset<'a, Eqn, Method>(
1382        mut solver: Method,
1383        soln: &OdeSolverSolution<Eqn::V>,
1384    ) where
1385        Eqn: OdeEquationsImplicitSens + 'a,
1386        Eqn::V: DefaultDenseMatrix,
1387        Eqn::M: DefaultSolver,
1388        Method: SensitivitiesOdeSolverMethod<'a, Eqn>,
1389    {
1390        let t_stop = soln.solution_points[0].t;
1391        let t_event = Eqn::T::from_f64(10.0 * (5.0_f64 / 3.0_f64).ln()).unwrap();
1392
1393        let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1394        let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, t_stop];
1395        let (ret, ret_sens, stop_reason) = solver.solve_dense_sensitivities(&t_eval).unwrap();
1396        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1397        assert_eq!(ret.ncols(), t_eval.len());
1398        for ret_sens_j in &ret_sens {
1399            assert_eq!(ret_sens_j.ncols(), t_eval.len());
1400        }
1401
1402        let error_threshold = Eqn::T::from_f64(100.0).unwrap();
1403        let ctx = soln.solution_points[0].state.context().clone();
1404        let nstates = soln.solution_points[0].state.len();
1405
1406        let post_reset_y = Eqn::T::from_f64(2.6).unwrap()
1407            * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1408        let post_reset_t = t_event + post_event_dt;
1409        let expected_post_reset = Eqn::V::from_element(nstates, post_reset_y, ctx.clone());
1410        let expected_post_reset_sk =
1411            Eqn::V::from_element(nstates, -post_reset_y * post_reset_t, ctx.clone());
1412        let expected_post_reset_sy0 = Eqn::V::from_element(nstates, post_reset_y, ctx);
1413
1414        let col = 2;
1415        let ey = ret.column(col).into_owned() - &expected_post_reset;
1416        let esk = ret_sens[0].column(col).into_owned() - &expected_post_reset_sk;
1417        let esy0 = ret_sens[1].column(col).into_owned() - &expected_post_reset_sy0;
1418        let norm = (ey.squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1419            + esk.squared_norm(&expected_post_reset_sk, &soln.atol, soln.rtol)
1420            + esy0.squared_norm(&expected_post_reset_sy0, &soln.atol, soln.rtol))
1421        .sqrt();
1422        assert!(
1423            norm < error_threshold,
1424            "dense sensitivity mismatch just after reset; combined WRMS {norm:?} >= {error_threshold:?}",
1425        );
1426    }
1427
1428    pub fn test_solve_adjoint_with_single_reset_root<
1429        'a,
1430        Eqn,
1431        MethodF,
1432        MethodB,
1433        BuildForward,
1434        BuildAdjointState,
1435        BuildAdjointFromState,
1436    >(
1437        build_forward: BuildForward,
1438        soln: &OdeSolverSolution<Eqn::V>,
1439        build_adjoint_state: BuildAdjointState,
1440        build_adjoint_from_state: BuildAdjointFromState,
1441        use_replay_solver: bool,
1442    ) where
1443        Eqn: OdeEquationsImplicitAdjoint + 'a,
1444        Eqn::M: DefaultSolver,
1445        Eqn::V: DefaultDenseMatrix,
1446        MethodF: OdeSolverMethod<'a, Eqn>,
1447        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1448        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1449        BuildAdjointState:
1450            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1451        BuildAdjointFromState:
1452            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1453    {
1454        let expected_out = &soln.solution_points[0];
1455        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1456
1457        let mut forward_solver = build_forward(None).unwrap();
1458        let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1459            .solve_with_checkpointing(forward_stop_time, None)
1460            .unwrap();
1461        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1462        assert!(
1463            checkpointers.len() >= 3,
1464            "expected checkpointing path to include the two reset events"
1465        );
1466        let problem = forward_solver.problem();
1467        let post_reset_solver = forward_solver.clone();
1468        let post_reset_root_idx = checkpointers[1]
1469            .terminal_reset_root_idx()
1470            .expect("second reset segment should record its terminal root index");
1471        let final_forward_state = checkpointers[1].last_checkpoint().clone();
1472        let t_second_root = final_forward_state.as_ref().t;
1473
1474        let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1475        let out_norm = out_error
1476            .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1477            .sqrt();
1478        assert!(
1479            out_norm < Eqn::T::from_f64(50.0).unwrap(),
1480            "forward integrated output mismatch at second root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1481            final_forward_state.as_ref().g,
1482            expected_out.state,
1483        );
1484        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1485        assert!(
1486            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1487            "expected second root time ≈ {:?}, got {:?}",
1488            expected_out.t,
1489            t_second_root,
1490        );
1491
1492        let adjoint_checkpointers = checkpointers.into_iter().take(2).collect::<Vec<_>>();
1493
1494        // make a broken adjoint that is missing the reset root metadata on the first segment, which should cause an error
1495        let mut missing_metadata_checkpointers = adjoint_checkpointers.clone();
1496        missing_metadata_checkpointers[0].clear_terminal_reset_root_idx();
1497        let missing_metadata_solver = use_replay_solver.then(|| post_reset_solver.clone());
1498        let mut missing_metadata_adjoint_eqn = problem.adjoint_equations(
1499            missing_metadata_checkpointers,
1500            missing_metadata_solver,
1501            None,
1502        );
1503        let mut missing_metadata_adjoint_state =
1504            build_adjoint_state(&mut missing_metadata_adjoint_eqn).unwrap();
1505        missing_metadata_adjoint_state
1506            .as_mut()
1507            .state_mut_adjoint_terminal_root(
1508                &problem.eqn,
1509                post_reset_root_idx,
1510                &final_forward_state,
1511                problem.integrate_out,
1512            )
1513            .unwrap();
1514        let missing_metadata_adjoint =
1515            build_adjoint_from_state(missing_metadata_adjoint_state, missing_metadata_adjoint_eqn)
1516                .unwrap();
1517        let missing_metadata_err =
1518            match missing_metadata_adjoint.solve_adjoint_backwards_pass(&[], &[]) {
1519                Ok(_) => panic!("expected missing reset metadata error"),
1520                Err(err) => err,
1521            };
1522        assert!(
1523            format!("{missing_metadata_err:?}").contains("Missing reset root metadata"),
1524            "expected missing reset metadata error, got {missing_metadata_err:?}",
1525        );
1526
1527        // now build the correct adjoint and check that it produces the correct gradient
1528        let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1529        let mut adjoint_eqn =
1530            problem.adjoint_equations(adjoint_checkpointers, adjoint_solver, None);
1531        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1532        adjoint_state
1533            .as_mut()
1534            .state_mut_adjoint_terminal_root(
1535                &problem.eqn,
1536                post_reset_root_idx,
1537                &final_forward_state,
1538                problem.integrate_out,
1539            )
1540            .unwrap();
1541        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1542        let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1543
1544        let t0 = problem.t0;
1545        let ctx = problem.context().clone();
1546
1547        let sens_points = soln.sens_solution_points.as_ref().unwrap();
1548        let expected_grad = Eqn::V::from_vec(
1549            sens_points
1550                .iter()
1551                .map(|pts| pts[0].state.get_index(0))
1552                .collect(),
1553            ctx.clone(),
1554        );
1555        let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1556        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1557        assert!(
1558            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1559            "expected adjoint final time {:?}, got {:?}",
1560            t0,
1561            adjoint_state.as_ref().t,
1562        );
1563        adjoint_state.as_ref().sg[0].assert_eq_norm(
1564            &expected_grad,
1565            &atol,
1566            Eqn::T::from_f64(1e-6).unwrap(),
1567            Eqn::T::from_f64(60.0).unwrap(),
1568        );
1569    }
1570
1571    #[allow(clippy::too_many_arguments)]
1572    pub fn test_solve_adjoint_sum_squares_with_single_reset_root<
1573        'a,
1574        Eqn,
1575        MethodF,
1576        MethodB,
1577        BuildForward,
1578        BuildAdjointState,
1579        BuildAdjointFromState,
1580    >(
1581        build_forward: BuildForward,
1582        soln: &OdeSolverSolution<Eqn::V>,
1583        build_adjoint_state: BuildAdjointState,
1584        build_adjoint_from_state: BuildAdjointFromState,
1585        use_replay_solver: bool,
1586        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1587        data: <Eqn::V as DefaultDenseMatrix>::M,
1588        times: &[Eqn::T],
1589    ) where
1590        Eqn: OdeEquationsImplicitAdjoint + 'a,
1591        Eqn::M: DefaultSolver,
1592        Eqn::V: DefaultDenseMatrix,
1593        MethodF: OdeSolverMethod<'a, Eqn>,
1594        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1595        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1596        BuildAdjointState:
1597            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1598        BuildAdjointFromState:
1599            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1600    {
1601        let expected_out = &soln.solution_points[0];
1602        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1603        let forwards_soln =
1604            solve_dense_with_single_reset_root::<Eqn, MethodF, _>(&build_forward, times);
1605        assert_eq!(
1606            forwards_soln.ncols(),
1607            times.len(),
1608            "expected stitched forward samples to cover every requested observation time",
1609        );
1610        let dgdu = dsum_squaresdp(&forwards_soln, &data);
1611        let dgdu_refs = dgdu.iter().collect::<Vec<_>>();
1612
1613        let mut forward_solver = build_forward(None).unwrap();
1614        let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1615            .solve_with_checkpointing(forward_stop_time, None)
1616            .unwrap();
1617        assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1618        assert!(
1619            checkpointers.len() >= 3,
1620            "expected checkpointing path to include the two reset events"
1621        );
1622        let problem = forward_solver.problem();
1623        let post_reset_solver = forward_solver.clone();
1624        let post_reset_root_idx = checkpointers[1]
1625            .terminal_reset_root_idx()
1626            .expect("second reset segment should record its terminal root index");
1627        let final_forward_state = checkpointers[1].last_checkpoint().clone();
1628        let t_second_root = final_forward_state.as_ref().t;
1629
1630        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1631        assert!(
1632            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1633            "expected second root time ≈ {:?}, got {:?}",
1634            expected_out.t,
1635            t_second_root,
1636        );
1637
1638        let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1639        let mut adjoint_eqn = problem.adjoint_equations(
1640            checkpointers.into_iter().take(2).collect(),
1641            adjoint_solver,
1642            Some(dgdu.len()),
1643        );
1644        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1645        adjoint_state
1646            .as_mut()
1647            .state_mut_adjoint_terminal_root(
1648                &problem.eqn,
1649                post_reset_root_idx,
1650                &final_forward_state,
1651                problem.integrate_out,
1652            )
1653            .unwrap();
1654        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1655        let (adjoint_state, _) = adjoint
1656            .solve_adjoint_backwards_pass(times, dgdu_refs.as_slice())
1657            .unwrap();
1658
1659        let t0 = problem.t0;
1660        let ctx = problem.context().clone();
1661
1662        let nparams = dgdp_check.nrows();
1663        let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1664        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1665        assert!(
1666            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1667            "expected adjoint final time {:?}, got {:?}",
1668            t0,
1669            adjoint_state.as_ref().t,
1670        );
1671        #[allow(clippy::needless_range_loop)]
1672        for j in 0..dgdp_check.ncols() {
1673            adjoint_state.as_ref().sg[j].assert_eq_norm(
1674                &dgdp_check.column(j).into_owned(),
1675                &atol,
1676                Eqn::T::from_f64(1e-6).unwrap(),
1677                Eqn::T::from_f64(260.0).unwrap(),
1678            );
1679        }
1680    }
1681
1682    pub fn test_solve_soln_adjoint_with_single_reset_root<
1683        'a,
1684        Eqn,
1685        MethodF,
1686        MethodB,
1687        BuildForward,
1688        BuildAdjointState,
1689        BuildAdjointFromState,
1690    >(
1691        build_forward: BuildForward,
1692        soln: &OdeSolverSolution<Eqn::V>,
1693        build_adjoint_state: BuildAdjointState,
1694        build_adjoint_from_state: BuildAdjointFromState,
1695        use_replay_solver: bool,
1696    ) where
1697        Eqn: OdeEquationsImplicitAdjoint + 'a,
1698        Eqn::M: DefaultSolver,
1699        Eqn::V: DefaultDenseMatrix,
1700        MethodF: OdeSolverMethod<'a, Eqn>,
1701        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1702        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1703        BuildAdjointState:
1704            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1705        BuildAdjointFromState:
1706            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1707    {
1708        let expected_out = &soln.solution_points[0];
1709        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1710        let mut forward_soln = Solution::<Eqn::V>::new(forward_stop_time);
1711        let mut checkpointers = Vec::new();
1712
1713        let first_forward_solver = build_forward(None)
1714            .unwrap()
1715            .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1716            .unwrap();
1717        let first_root_idx = match forward_soln.stop_reason {
1718            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1719            Some(reason) => {
1720                panic!("expected first staged solve to stop at reset root, got {reason:?}")
1721            }
1722            None => panic!("first staged solve did not set a stop reason"),
1723        };
1724        assert_eq!(checkpointers.len(), 1);
1725        assert_eq!(
1726            checkpointers[0].terminal_reset_root_idx(),
1727            Some(first_root_idx)
1728        );
1729
1730        let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1731        let terminal_forward_solver = build_forward(Some(state_after_reset))
1732            .unwrap()
1733            .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1734            .unwrap();
1735        let terminal_root_idx = match forward_soln.stop_reason {
1736            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1737            Some(reason) => {
1738                panic!("expected second staged solve to stop at terminal root, got {reason:?}")
1739            }
1740            None => panic!("second staged solve did not set a stop reason"),
1741        };
1742        assert_eq!(checkpointers.len(), 2);
1743        assert_eq!(
1744            checkpointers[1].terminal_reset_root_idx(),
1745            Some(terminal_root_idx)
1746        );
1747
1748        let problem = terminal_forward_solver.problem();
1749        let final_forward_state = terminal_forward_solver.state_clone();
1750        let t_second_root = final_forward_state.as_ref().t;
1751        let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1752        let out_norm = out_error
1753            .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1754            .sqrt();
1755        assert!(
1756            out_norm < Eqn::T::from_f64(50.0).unwrap(),
1757            "forward integrated output mismatch at terminal root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1758            final_forward_state.as_ref().g,
1759            expected_out.state,
1760        );
1761        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1762        assert!(
1763            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1764            "expected terminal root time ≈ {:?}, got {:?}",
1765            expected_out.t,
1766            t_second_root,
1767        );
1768
1769        let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1770        let mut adjoint_eqn = problem.adjoint_equations(checkpointers, adjoint_solver, None);
1771        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1772        adjoint_state
1773            .as_mut()
1774            .state_mut_adjoint_terminal_root(
1775                &problem.eqn,
1776                terminal_root_idx,
1777                &final_forward_state,
1778                problem.integrate_out,
1779            )
1780            .unwrap();
1781        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1782        let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1783
1784        let t0 = problem.t0;
1785        let ctx = problem.context().clone();
1786        let sens_points = soln.sens_solution_points.as_ref().unwrap();
1787        let expected_grad = Eqn::V::from_vec(
1788            sens_points
1789                .iter()
1790                .map(|pts| pts[0].state.get_index(0))
1791                .collect(),
1792            ctx.clone(),
1793        );
1794        let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1795        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1796        assert!(
1797            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1798            "expected adjoint final time {:?}, got {:?}",
1799            t0,
1800            adjoint_state.as_ref().t,
1801        );
1802        adjoint_state.as_ref().sg[0].assert_eq_norm(
1803            &expected_grad,
1804            &atol,
1805            Eqn::T::from_f64(1e-6).unwrap(),
1806            Eqn::T::from_f64(60.0).unwrap(),
1807        );
1808    }
1809
1810    #[allow(clippy::too_many_arguments)]
1811    pub fn test_solve_soln_adjoint_sum_squares_with_single_reset_root<
1812        'a,
1813        Eqn,
1814        MethodF,
1815        MethodB,
1816        BuildForward,
1817        BuildAdjointState,
1818        BuildAdjointFromState,
1819    >(
1820        build_forward: BuildForward,
1821        soln: &OdeSolverSolution<Eqn::V>,
1822        build_adjoint_state: BuildAdjointState,
1823        build_adjoint_from_state: BuildAdjointFromState,
1824        use_replay_solver: bool,
1825        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1826        data: <Eqn::V as DefaultDenseMatrix>::M,
1827        times: &[Eqn::T],
1828    ) where
1829        Eqn: OdeEquationsImplicitAdjoint + 'a,
1830        Eqn::M: DefaultSolver,
1831        Eqn::V: DefaultDenseMatrix,
1832        MethodF: OdeSolverMethod<'a, Eqn>,
1833        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1834        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1835        BuildAdjointState:
1836            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1837        BuildAdjointFromState:
1838            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1839    {
1840        let expected_out = &soln.solution_points[0];
1841        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1842        let mut forward_soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
1843        let mut checkpointers = Vec::new();
1844
1845        let first_forward_solver = build_forward(None)
1846            .unwrap()
1847            .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1848            .unwrap();
1849        let first_root_idx = match forward_soln.stop_reason {
1850            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1851            Some(reason) => {
1852                panic!("expected first staged solve to stop at reset root, got {reason:?}")
1853            }
1854            None => panic!("first staged solve did not set a stop reason"),
1855        };
1856        assert_eq!(checkpointers.len(), 1);
1857        assert_eq!(
1858            checkpointers[0].terminal_reset_root_idx(),
1859            Some(first_root_idx)
1860        );
1861
1862        let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1863        build_forward(Some(state_after_reset.clone()))
1864            .unwrap()
1865            .solve_soln(&mut forward_soln)
1866            .unwrap();
1867        assert!(forward_soln.is_complete());
1868        assert_eq!(
1869            forward_soln.stop_reason,
1870            Some(OdeSolverStopReason::TstopReached)
1871        );
1872
1873        let mut terminal_soln = Solution::<Eqn::V>::new(forward_stop_time);
1874        let terminal_forward_solver = build_forward(Some(state_after_reset))
1875            .unwrap()
1876            .solve_soln_with_checkpointing(&mut terminal_soln, &mut checkpointers, None)
1877            .unwrap();
1878        let terminal_root_idx = match terminal_soln.stop_reason {
1879            Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1880            Some(reason) => {
1881                panic!("expected terminal staged solve to stop at root, got {reason:?}")
1882            }
1883            None => panic!("terminal staged solve did not set a stop reason"),
1884        };
1885        assert_eq!(checkpointers.len(), 2);
1886        assert_eq!(
1887            checkpointers.last().unwrap().terminal_reset_root_idx(),
1888            Some(terminal_root_idx)
1889        );
1890
1891        let dgdu_eval = dsum_squaresdp(&forward_soln.ys, &data);
1892        let dgdu_eval_refs = dgdu_eval.iter().collect::<Vec<_>>();
1893        let problem = terminal_forward_solver.problem();
1894        let final_forward_state = terminal_forward_solver.state_clone();
1895        let t_second_root = final_forward_state.as_ref().t;
1896        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1897        assert!(
1898            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1899            "expected terminal root time ≈ {:?}, got {:?}",
1900            expected_out.t,
1901            t_second_root,
1902        );
1903
1904        let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1905        let mut adjoint_eqn =
1906            problem.adjoint_equations(checkpointers, adjoint_solver, Some(dgdu_eval_refs.len()));
1907        let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1908        adjoint_state
1909            .as_mut()
1910            .state_mut_adjoint_terminal_root(
1911                &problem.eqn,
1912                terminal_root_idx,
1913                &final_forward_state,
1914                problem.integrate_out,
1915            )
1916            .unwrap();
1917        let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1918        let (adjoint_state, _) = adjoint
1919            .solve_adjoint_backwards_pass(times, dgdu_eval_refs.as_slice())
1920            .unwrap();
1921
1922        let t0 = problem.t0;
1923        let ctx = problem.context().clone();
1924        let nparams = dgdp_check.nrows();
1925        let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1926        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1927        assert!(
1928            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1929            "expected adjoint final time {:?}, got {:?}",
1930            t0,
1931            adjoint_state.as_ref().t,
1932        );
1933        #[allow(clippy::needless_range_loop)]
1934        for j in 0..dgdp_check.ncols() {
1935            adjoint_state.as_ref().sg[j].assert_eq_norm(
1936                &dgdp_check.column(j).into_owned(),
1937                &atol,
1938                Eqn::T::from_f64(1e-6).unwrap(),
1939                Eqn::T::from_f64(260.0).unwrap(),
1940            );
1941        }
1942    }
1943}