Skip to main content

diffsol/ode_solver/
mod.rs

1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod problem;
11pub mod runge_kutta;
12pub mod sde;
13pub mod sdirk;
14pub mod sdirk_state;
15pub mod sensitivities;
16pub mod solution;
17pub mod state;
18pub mod tableau;
19
20#[cfg(test)]
21mod tests {
22    use std::rc::Rc;
23
24    use self::problem::OdeSolverSolution;
25
26    use super::*;
27    use crate::error::{DiffsolError, OdeSolverError};
28    use crate::matrix::Matrix;
29    use crate::ode_solver::sensitivities::SensitivitiesOdeSolverMethod;
30    use crate::ode_solver::solution::Solution;
31    use crate::op::unit::UnitCallable;
32    use crate::op::ParameterisedOp;
33    use crate::Scalar;
34    use crate::{
35        ode_equations::{OdeEquationsImplicitAdjointWithReset, OdeEquationsImplicitSensWithReset},
36        op::OpStatistics,
37        AdjointEquations, AdjointOdeSolverMethod, Context, DenseMatrix, MatrixCommon, MatrixRef,
38        NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit,
39        OdeEquationsImplicitAdjoint, OdeEquationsRef, OdeSolverConfig, OdeSolverMethod,
40        OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Scale, VectorRef, VectorView,
41        VectorViewMut,
42    };
43    use crate::{
44        ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver,
45        NonLinearOpSens, Op, Vector,
46    };
47    use num_traits::{FromPrimitive, One, Signed, Zero};
48
49    pub fn test_ode_solver<'a, M, Eqn, Method>(
50        method: &mut Method,
51        solution: OdeSolverSolution<M::V>,
52        override_tol: Option<M::T>,
53        use_tstop: bool,
54        solve_for_sensitivities: bool,
55    ) -> Eqn::V
56    where
57        M: Matrix,
58        Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
59        Method: OdeSolverMethod<'a, Eqn>,
60    {
61        let have_root = method.problem().eqn.root().is_some();
62        for (i, point) in solution.solution_points.iter().enumerate() {
63            let (soln, sens_soln) = if use_tstop {
64                match method.set_stop_time(point.t) {
65                    Ok(_) => loop {
66                        match method.step() {
67                            Ok(OdeSolverStopReason::RootFound(_, _)) => {
68                                assert!(have_root);
69                                return method.state().y.clone();
70                            }
71                            Ok(OdeSolverStopReason::TstopReached) => {
72                                break (method.state().y.clone(), method.state().s.to_vec());
73                            }
74                            _ => (),
75                        }
76                    },
77                    Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
78                }
79            } else {
80                while method.state().t.abs() < point.t.abs() {
81                    if let OdeSolverStopReason::RootFound(t, _) = method.step().unwrap() {
82                        assert!(have_root);
83                        return method.interpolate(t).unwrap();
84                    }
85                }
86                let soln = method.interpolate(point.t).unwrap();
87                let sens_soln = method.interpolate_sens(point.t).unwrap();
88                (soln, sens_soln)
89            };
90            let soln = if let Some(out) = method.problem().eqn.out() {
91                out.call(&soln, point.t)
92            } else {
93                soln
94            };
95            assert_eq!(
96                soln.len(),
97                point.state.len(),
98                "soln.len() != point.state.len()"
99            );
100            if let Some(override_tol) = override_tol {
101                soln.assert_eq_st(&point.state, override_tol);
102            } else {
103                let (rtol, atol) = if method.problem().eqn.out().is_some() {
104                    // problem rtol and atol is on the state, so just use solution tolerance here
105                    (solution.rtol, &solution.atol)
106                } else {
107                    (method.problem().rtol, &method.problem().atol)
108                };
109                let error = soln.clone() - &point.state;
110                let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
111                assert!(
112                    error_norm < M::T::from_f64(20.0).unwrap(),
113                    "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
114                    error_norm,
115                    point.t,
116                    soln,
117                    point.state
118                );
119                if solve_for_sensitivities {
120                    if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
121                        for (j, sens_points) in sens_soln_points.iter().enumerate() {
122                            let sens_point = &sens_points[i];
123                            let sens_soln = &sens_soln[j];
124                            let error = sens_soln.clone() - &sens_point.state;
125                            let error_norm =
126                                error.squared_norm(&sens_point.state, atol, rtol).sqrt();
127                            assert!(
128                                error_norm < M::T::from_f64(29.0).unwrap(),
129                                "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
130                                point.t,
131                                sens_point.state
132                            );
133                        }
134                    }
135                }
136            }
137        }
138        method.state().y.clone()
139    }
140
141    pub fn setup_test_adjoint<'a, LS, Eqn>(
142        problem: &'a mut OdeSolverProblem<Eqn>,
143        soln: OdeSolverSolution<Eqn::V>,
144    ) -> <Eqn::V as DefaultDenseMatrix>::M
145    where
146        Eqn: OdeEquationsImplicitAdjoint + 'a,
147        LS: LinearSolver<Eqn::M>,
148        Eqn::V: DefaultDenseMatrix,
149        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
150        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
151    {
152        let nparams = problem.eqn.nparams();
153        let nout = problem.eqn.nout();
154        let ctx = problem.eqn.context();
155        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
156        let final_time = soln.solution_points.last().unwrap().t;
157        let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
158        problem.eqn.get_params(&mut p_0);
159        let h_base = Eqn::T::from_f64(1e-10).unwrap();
160        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
161        h.axpy(h_base, &p_0, Eqn::T::one());
162        let p_base = p_0.clone();
163        for i in 0..nparams {
164            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
165            problem.eqn.set_params(&p_0);
166            let g_pos = {
167                let mut s = problem.bdf::<LS>().unwrap();
168                s.set_stop_time(final_time).unwrap();
169                while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
170                s.state().g.clone()
171            };
172
173            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
174            problem.eqn.set_params(&p_0);
175            let g_neg = {
176                let mut s = problem.bdf::<LS>().unwrap();
177                s.set_stop_time(final_time).unwrap();
178                while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
179                s.state().g.clone()
180            };
181            p_0.set_index(i, p_base.get_index(i));
182
183            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
184            for j in 0..nout {
185                dgdp.set_index(i, j, delta.get_index(j));
186            }
187        }
188        problem.eqn.set_params(&p_base);
189        dgdp
190    }
191
192    /// sum_i^n (soln_i - data_i)^2
193    /// sum_i^n (soln_i - data_i)^4
194    pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
195    where
196        DM: DenseMatrix,
197    {
198        let mut ret = DM::V::zeros(2, soln.context().clone());
199        for j in 0..soln.ncols() {
200            let soln_j = soln.column(j);
201            let data_j = data.column(j);
202            let delta = soln_j - data_j;
203            let norm2 = delta.norm(2);
204            ret.set_index(0, ret.get_index(0) + norm2 * norm2);
205            let norm4 = delta.norm(4);
206            let norm4_sq = norm4 * norm4;
207            ret.set_index(1, ret.get_index(1) + norm4_sq * norm4_sq);
208        }
209        ret
210    }
211
212    /// sum_i^n 2 * (soln_i - data_i)
213    /// sum_i^n 4 * (soln_i - data_i)^3
214    pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
215    where
216        DM: DenseMatrix,
217    {
218        let delta = soln.clone() - data;
219        let mut delta3 = delta.clone();
220        for j in 0..delta3.ncols() {
221            let delta_col = delta.column(j).into_owned();
222
223            let mut delta3_col = delta_col.clone();
224            delta3_col.component_mul_assign(&delta_col);
225            delta3_col.component_mul_assign(&delta_col);
226
227            delta3.column_mut(j).copy_from(&delta3_col);
228        }
229        let ret = vec![
230            delta * Scale(DM::T::from_f64(2.).unwrap()),
231            delta3 * Scale(DM::T::from_f64(4.).unwrap()),
232        ];
233        ret
234    }
235
236    pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
237        problem: &'a mut OdeSolverProblem<Eqn>,
238        times: &[Eqn::T],
239    ) -> (
240        <Eqn::V as DefaultDenseMatrix>::M,
241        <Eqn::V as DefaultDenseMatrix>::M,
242    )
243    where
244        Eqn: OdeEquationsImplicitAdjoint + 'a,
245        LS: LinearSolver<Eqn::M>,
246        Eqn::V: DefaultDenseMatrix,
247        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
248        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
249    {
250        let nparams = problem.eqn.nparams();
251        let nout = 2;
252        let ctx = problem.eqn.context();
253        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
254
255        let mut p_0 = ctx.vector_zeros(nparams);
256        problem.eqn.get_params(&mut p_0);
257        let h_base = Eqn::T::from_f64(1e-10).unwrap();
258        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
259        h.axpy(h_base, &p_0, Eqn::T::one());
260        let mut p_data = p_0.clone();
261        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
262        let p_base = p_0.clone();
263
264        problem.eqn.set_params(&p_data);
265        let data = {
266            let mut s = problem.bdf::<LS>().unwrap();
267            s.solve_dense(times).unwrap().0
268        };
269
270        for i in 0..nparams {
271            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
272            problem.eqn.set_params(&p_0);
273            let g_pos = {
274                let mut s = problem.bdf::<LS>().unwrap();
275                let v = s.solve_dense(times).unwrap().0;
276                sum_squares(&v, &data)
277            };
278
279            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
280            problem.eqn.set_params(&p_0);
281            let g_neg = {
282                let mut s = problem.bdf::<LS>().unwrap();
283                let v = s.solve_dense(times).unwrap().0;
284                sum_squares(&v, &data)
285            };
286
287            p_0.set_index(i, p_base.get_index(i));
288
289            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
290            for j in 0..nout {
291                dgdp.set_index(i, j, delta.get_index(j));
292            }
293        }
294        problem.eqn.set_params(&p_base);
295        (dgdp, data)
296    }
297
298    pub fn single_reset_root_discrete_times<T: Scalar>(t_stop: T) -> Vec<T> {
299        let t_root = t_stop / T::from_f64(2.0).unwrap();
300        [0.25, 0.75, 1.25, 1.75]
301            .into_iter()
302            .map(|factor| t_root * T::from_f64(factor).unwrap())
303            .collect()
304    }
305
306    fn solve_dense_with_single_reset_root<'a, Eqn, Method, BuildForward>(
307        build_forward: BuildForward,
308        times: &[Eqn::T],
309    ) -> <Eqn::V as DefaultDenseMatrix>::M
310    where
311        Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
312        Eqn::V: DefaultDenseMatrix,
313        Method: OdeSolverMethod<'a, Eqn>,
314        BuildForward: Fn(Option<Method::State>) -> Result<Method, DiffsolError>,
315    {
316        let mut soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
317        let first_forward_solver = build_forward(None).unwrap().solve_soln(&mut soln).unwrap();
318        match soln.stop_reason {
319            Some(OdeSolverStopReason::RootFound(_, 0)) => {}
320            Some(OdeSolverStopReason::RootFound(_, idx)) => {
321                panic!("expected first solve_soln() segment to stop on root 0, got root {idx}")
322            }
323            Some(OdeSolverStopReason::TstopReached) => {
324                panic!("expected first solve_soln() segment to stop on the interior root")
325            }
326            Some(OdeSolverStopReason::InternalTimestep) | None => {
327                panic!("first solve_soln() segment did not finish with a terminal stop reason")
328            }
329        }
330
331        let mut state_after_reset = first_forward_solver.state_clone();
332        {
333            let problem = first_forward_solver.problem();
334            let reset_fn = problem.eqn.reset().unwrap();
335            state_after_reset
336                .state_mut_op(&problem.eqn, &reset_fn)
337                .unwrap();
338        }
339
340        build_forward(Some(state_after_reset))
341            .unwrap()
342            .solve_soln(&mut soln)
343            .unwrap();
344        assert!(
345            soln.is_complete(),
346            "expected stitched solve_soln() output to cover all requested observation times",
347        );
348        soln.ys
349    }
350
351    pub fn setup_test_adjoint_sum_squares_with_single_reset_root<'a, LS, Eqn>(
352        problem: &'a mut OdeSolverProblem<Eqn>,
353        times: &[Eqn::T],
354    ) -> (
355        <Eqn::V as DefaultDenseMatrix>::M,
356        <Eqn::V as DefaultDenseMatrix>::M,
357    )
358    where
359        Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
360        LS: LinearSolver<Eqn::M>,
361        Eqn::V: DefaultDenseMatrix,
362        for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
363        for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
364    {
365        let nparams = problem.eqn.nparams();
366        let nout = 2;
367        let ctx = problem.eqn.context();
368        let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
369
370        let mut p_0 = ctx.vector_zeros(nparams);
371        problem.eqn.get_params(&mut p_0);
372        let h_base = Eqn::T::from_f64(1e-10).unwrap();
373        let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
374        h.axpy(h_base, &p_0, Eqn::T::one());
375        let mut p_data = p_0.clone();
376        p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
377        let p_base = p_0.clone();
378
379        problem.eqn.set_params(&p_data);
380        let data = solve_dense_with_single_reset_root::<Eqn, _, _>(
381            |state| match state {
382                Some(state) => problem.bdf_solver(state),
383                None => problem.bdf::<LS>(),
384            },
385            times,
386        );
387
388        for i in 0..nparams {
389            p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
390            problem.eqn.set_params(&p_0);
391            let g_pos = {
392                let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
393                    |state| match state {
394                        Some(state) => problem.bdf_solver(state),
395                        None => problem.bdf::<LS>(),
396                    },
397                    times,
398                );
399                sum_squares(&v, &data)
400            };
401
402            p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
403            problem.eqn.set_params(&p_0);
404            let g_neg = {
405                let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
406                    |state| match state {
407                        Some(state) => problem.bdf_solver(state),
408                        None => problem.bdf::<LS>(),
409                    },
410                    times,
411                );
412                sum_squares(&v, &data)
413            };
414
415            p_0.set_index(i, p_base.get_index(i));
416
417            let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
418            for j in 0..nout {
419                dgdp.set_index(i, j, delta.get_index(j));
420            }
421        }
422        problem.eqn.set_params(&p_base);
423        (dgdp, data)
424    }
425
426    pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
427        backwards_solver: SolverB,
428        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
429        forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
430        data: <Eqn::V as DefaultDenseMatrix>::M,
431        times: &[Eqn::T],
432    ) where
433        SolverF: OdeSolverMethod<'a, Eqn>,
434        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
435        Eqn: OdeEquationsImplicitAdjoint + 'a,
436        Eqn::V: DefaultDenseMatrix,
437        Eqn::M: DefaultSolver,
438    {
439        let nparams = dgdp_check.nrows();
440        let dgdu = dsum_squaresdp(&forwards_soln, &data);
441
442        let atol = Eqn::V::from_element(
443            nparams,
444            Eqn::T::from_f64(1e-6).unwrap(),
445            data.context().clone(),
446        );
447        let rtol = Eqn::T::from_f64(1e-6).unwrap();
448        let state = backwards_solver
449            .solve_adjoint_backwards_pass(None, times, dgdu.iter().collect::<Vec<_>>().as_slice())
450            .unwrap();
451        let gs_adj = state.into_common().sg;
452        #[allow(clippy::needless_range_loop)]
453        for j in 0..dgdp_check.ncols() {
454            gs_adj[j].assert_eq_norm(
455                &dgdp_check.column(j).into_owned(),
456                &atol,
457                rtol,
458                Eqn::T::from_f64(260.).unwrap(),
459            );
460        }
461    }
462
463    pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
464        backwards_solver: SolverB,
465        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
466    ) where
467        SolverF: OdeSolverMethod<'a, Eqn>,
468        SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
469        Eqn: OdeEquationsImplicitAdjoint + 'a,
470        Eqn::V: DefaultDenseMatrix,
471        Eqn::M: DefaultSolver,
472    {
473        let nout = backwards_solver.problem().eqn.nout();
474        let atol = Eqn::V::from_element(
475            nout,
476            Eqn::T::from_f64(1e-6).unwrap(),
477            dgdp_check.context().clone(),
478        );
479        let rtol = Eqn::T::from_f64(1e-6).unwrap();
480        let state = backwards_solver
481            .solve_adjoint_backwards_pass(None, &[], &[])
482            .unwrap();
483        let gs_adj = state.into_common().sg;
484        #[allow(clippy::needless_range_loop)]
485        for j in 0..dgdp_check.ncols() {
486            gs_adj[j].assert_eq_norm(
487                &dgdp_check.column(j).into_owned(),
488                &atol,
489                rtol,
490                Eqn::T::from_f64(40.).unwrap(),
491            );
492        }
493    }
494
495    pub struct TestEqnInit<M: Matrix> {
496        ctx: M::C,
497    }
498
499    impl<M: Matrix> Op for TestEqnInit<M> {
500        type T = M::T;
501        type V = M::V;
502        type M = M;
503        type C = M::C;
504
505        fn nout(&self) -> usize {
506            1
507        }
508        fn nparams(&self) -> usize {
509            1
510        }
511        fn nstates(&self) -> usize {
512            1
513        }
514        fn context(&self) -> &Self::C {
515            &self.ctx
516        }
517    }
518
519    impl<M: Matrix> ConstantOp for TestEqnInit<M> {
520        fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
521            y.fill(M::T::one());
522        }
523    }
524
525    impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
526        fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
527            sens.fill(M::T::zero());
528        }
529    }
530
531    pub struct TestEqnRhs<M: Matrix> {
532        ctx: M::C,
533    }
534
535    impl<M: Matrix> Op for TestEqnRhs<M> {
536        type T = M::T;
537        type V = M::V;
538        type M = M;
539        type C = M::C;
540
541        fn nout(&self) -> usize {
542            1
543        }
544        fn nparams(&self) -> usize {
545            1
546        }
547        fn nstates(&self) -> usize {
548            1
549        }
550        fn context(&self) -> &Self::C {
551            &self.ctx
552        }
553    }
554
555    impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
556        fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
557            y.fill(M::T::zero());
558        }
559    }
560
561    impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
562        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
563            y.fill(M::T::zero());
564        }
565    }
566
567    impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
568        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
569            sens.fill(M::T::zero());
570        }
571    }
572
573    pub struct TestEqnOut<M: Matrix> {
574        ctx: M::C,
575    }
576
577    impl<M: Matrix> Op for TestEqnOut<M> {
578        type T = M::T;
579        type V = M::V;
580        type M = M;
581        type C = M::C;
582
583        fn nout(&self) -> usize {
584            1
585        }
586        fn nparams(&self) -> usize {
587            1
588        }
589        fn nstates(&self) -> usize {
590            1
591        }
592        fn context(&self) -> &Self::C {
593            &self.ctx
594        }
595    }
596
597    impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
598        fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
599            y.copy_from(x);
600        }
601    }
602
603    impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
604        fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
605            y.copy_from(v);
606        }
607    }
608
609    impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
610        fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
611            sens.fill(M::T::zero());
612        }
613    }
614
615    pub struct TestEqn<M: Matrix> {
616        rhs: Rc<TestEqnRhs<M>>,
617        init: Rc<TestEqnInit<M>>,
618        out: Rc<TestEqnOut<M>>,
619        ctx: M::C,
620    }
621
622    impl<M: Matrix> TestEqn<M> {
623        pub fn new() -> Self {
624            let ctx = M::C::default();
625            Self {
626                rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
627                init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
628                out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
629                ctx,
630            }
631        }
632    }
633
634    impl<M: Matrix> Op for TestEqn<M> {
635        type T = M::T;
636        type V = M::V;
637        type M = M;
638        type C = M::C;
639        fn nout(&self) -> usize {
640            1
641        }
642        fn nparams(&self) -> usize {
643            1
644        }
645        fn nstates(&self) -> usize {
646            1
647        }
648        fn statistics(&self) -> crate::op::OpStatistics {
649            OpStatistics::default()
650        }
651        fn context(&self) -> &Self::C {
652            &self.ctx
653        }
654    }
655
656    impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
657        type Rhs = &'a TestEqnRhs<M>;
658        type Mass = ParameterisedOp<'a, UnitCallable<M>>;
659        type Root = ParameterisedOp<'a, UnitCallable<M>>;
660        type Init = &'a TestEqnInit<M>;
661        type Out = &'a TestEqnOut<M>;
662        type Reset = ParameterisedOp<'a, UnitCallable<M>>;
663    }
664
665    impl<M: Matrix> OdeEquations for TestEqn<M> {
666        fn rhs(&self) -> &TestEqnRhs<M> {
667            &self.rhs
668        }
669
670        fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
671            None
672        }
673
674        fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
675            None
676        }
677
678        fn init(&self) -> &TestEqnInit<M> {
679            &self.init
680        }
681
682        fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
683            Some(&self.out)
684        }
685        fn set_params(&mut self, _p: &Self::V) {
686            unimplemented!()
687        }
688        fn get_params(&self, _p: &mut Self::V) {
689            unimplemented!()
690        }
691    }
692
693    pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
694        let eqn = TestEqn::<M>::new();
695        let atol = eqn
696            .context()
697            .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
698        OdeSolverProblem::new(
699            eqn,
700            M::T::from_f64(1e-6).unwrap(),
701            atol,
702            None,
703            None,
704            None,
705            None,
706            None,
707            None,
708            M::T::zero(),
709            M::T::one(),
710            integrate_out,
711            Default::default(),
712            Default::default(),
713        )
714        .unwrap()
715    }
716
717    pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
718        let state = s.checkpoint();
719        let integrating_sens = !s.state().s.is_empty();
720        let integrating_out = s.problem().integrate_out;
721        let t0 = state.as_ref().t;
722        let t1 = t0 + M::T::from_f64(1e6).unwrap();
723        s.interpolate(t0)
724            .unwrap()
725            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
726        assert!(s.interpolate(t1).is_err());
727        assert!(s.interpolate_out(t1).is_err());
728        if integrating_sens {
729            assert!(s.interpolate_sens(t1).is_err());
730        } else {
731            assert!(s.interpolate_sens(t0).is_ok());
732        }
733        s.step().unwrap();
734        let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
735        assert!(s.interpolate(s.state().t).is_ok());
736        assert!(s.interpolate(tmid).is_ok());
737        if integrating_out {
738            assert!(s.interpolate_out(s.state().t).is_ok());
739        } else {
740            assert!(s.interpolate_out(s.state().t).is_err());
741        }
742        assert!(s.interpolate_sens(s.state().t).is_ok());
743        assert!(s.interpolate(s.state().t + t1).is_err());
744        assert!(s.interpolate_out(s.state().t + t1).is_err());
745        if integrating_sens {
746            assert!(s.interpolate_sens(s.state().t + t1).is_err());
747        } else {
748            assert!(s.interpolate_sens(s.state().t + t1).is_ok());
749        }
750
751        let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
752        assert!(s
753            .interpolate_inplace(s.state().t, &mut y_wrong_length)
754            .is_err());
755        let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
756        assert!(s
757            .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
758            .is_err());
759        let mut s_wrong_length = vec![
760            M::V::zeros(1, s.problem().context().clone()),
761            M::V::zeros(1, s.problem().context().clone()),
762        ];
763        assert!(s
764            .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
765            .is_err());
766        let mut s_wrong_vec_length = if integrating_sens {
767            vec![M::V::zeros(2, s.problem().context().clone())]
768        } else {
769            vec![]
770        };
771        if integrating_sens {
772            assert!(s
773                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
774                .is_err());
775        } else {
776            assert!(s
777                .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
778                .is_ok());
779        }
780
781        s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
782        assert!(s.interpolate(s.state().t).is_ok());
783        if integrating_out {
784            assert!(s.interpolate_out(s.state().t).is_ok());
785        }
786        if integrating_sens {
787            assert!(s.interpolate_sens(s.state().t).is_ok());
788        }
789        assert!(s.interpolate(tmid).is_err());
790        assert!(s.interpolate_out(tmid).is_err());
791        if integrating_sens {
792            assert!(s.interpolate_sens(tmid).is_err());
793        } else {
794            assert!(s.interpolate_sens(tmid).is_ok());
795        }
796    }
797
798    pub fn test_interpolate_dy<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(
799        mut s: Method,
800    ) {
801        // Error before first step: t is in the future
802        let t_future = s.state().t + M::T::from_f64(1e6).unwrap();
803        assert!(s.interpolate_dy(t_future).is_err());
804
805        let t0 = s.state().t;
806        s.step().unwrap();
807        let t1 = s.state().t;
808        let dt = t1 - t0;
809        let tmid = t0 + dt / M::T::from_f64(2.0).unwrap();
810
811        // Wrong vector length should return error
812        let mut dy_wrong = M::V::zeros(2, s.problem().context().clone());
813        assert!(s.interpolate_dy_inplace(t1, &mut dy_wrong).is_err());
814
815        // t after current time should return error
816        assert!(s.interpolate_dy(t1 + M::T::from_f64(1e6).unwrap()).is_err());
817
818        // interpolate_dy should be consistent with finite-difference of interpolate (step 1)
819        let eps = dt.abs() * M::T::from_f64(1e-5).unwrap();
820        let y_plus = s.interpolate(tmid + eps).unwrap();
821        let y_minus = s.interpolate(tmid - eps).unwrap();
822        let fd_dy = (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps));
823        let dy = s.interpolate_dy(tmid).unwrap();
824        dy.assert_eq_norm(
825            &fd_dy,
826            &s.problem().atol,
827            s.problem().rtol,
828            M::T::from_f64(1e3).unwrap(),
829        );
830
831        // take a second step and check consistency again
832        let t1 = s.state().t;
833        s.step().unwrap();
834        let t2 = s.state().t;
835        let dt2 = t2 - t1;
836        let tmid2 = t1 + dt2 / M::T::from_f64(2.0).unwrap();
837        let eps2 = dt2.abs() * M::T::from_f64(1e-5).unwrap();
838        let y_plus = s.interpolate(tmid2 + eps2).unwrap();
839        let y_minus = s.interpolate(tmid2 - eps2).unwrap();
840        let fd_dy2 =
841            (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps2));
842        let dy2 = s.interpolate_dy(tmid2).unwrap();
843        dy2.assert_eq_norm(
844            &fd_dy2,
845            &s.problem().atol,
846            s.problem().rtol,
847            M::T::from_f64(1e3).unwrap(),
848        );
849    }
850
851    pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
852        mut s: Method,
853    ) {
854        *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
855        assert_eq!(
856            *s.config().as_base_ref().minimum_timestep,
857            Eqn::T::from_f64(1.0e8).unwrap()
858        );
859        // force a step size reduction
860        *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
861
862        let mut failed = false;
863        for _ in 0..10 {
864            if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
865                s.step()
866            {
867                failed = true;
868                break;
869            }
870        }
871        assert!(failed);
872    }
873
874    pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
875        let state = s.checkpoint();
876        let state2 = s.state();
877        state2
878            .y
879            .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
880        s.state_mut()
881            .y
882            .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
883        assert_eq!(
884            s.state_mut().y.get_index(0),
885            M::T::from_f64(std::f64::consts::PI).unwrap()
886        );
887    }
888
889    #[cfg(feature = "diffsl-cranelift")]
890    pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
891    ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
892        crate::OdeBuilder::<M>::new()
893            .build_from_diffsl(
894                "
895            g { 9.81 } h { 10.0 }
896            u_i {
897                x = h,
898                v = 0,
899            }
900            F_i {
901                v,
902                -g,
903            }
904            stop {
905                x,
906            }
907        ",
908            )
909            .unwrap()
910    }
911
912    #[cfg(feature = "diffsl-cranelift")]
913    pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
914    where
915        M: crate::MatrixHost<T = f64>,
916        M: DefaultSolver<T = f64>,
917        M::V: DefaultDenseMatrix<T = f64>,
918        Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
919    {
920        let e = 0.8;
921
922        let final_time = 2.5;
923
924        // solve and apply the remaining doses
925        solver.set_stop_time(final_time).unwrap();
926        loop {
927            match solver.step() {
928                Ok(OdeSolverStopReason::InternalTimestep) => (),
929                Ok(OdeSolverStopReason::RootFound(t, _)) => {
930                    // get the state when the event occurred
931                    let mut y = solver.interpolate(t).unwrap();
932
933                    // update the velocity of the ball
934                    y.set_index(1, y.get_index(1) * -e);
935
936                    // make sure the ball is above the ground
937                    y.set_index(0, y.get_index(0).max(f64::EPSILON));
938
939                    // set the state to the updated state
940                    solver.state_mut().y.copy_from(&y);
941                    solver.state_mut().dy.set_index(0, y.get_index(1));
942                    *solver.state_mut().t = t;
943
944                    break;
945                }
946                Ok(OdeSolverStopReason::TstopReached) => break,
947                Err(_) => panic!("unexpected solver error"),
948            }
949        }
950        // do three more steps after the 1st bound and many sure they are correct
951        let mut x = vec![];
952        let mut v = vec![];
953        let mut t = vec![];
954        for _ in 0..3 {
955            let ret = solver.step();
956            x.push(solver.state().y.get_index(0));
957            v.push(solver.state().y.get_index(1));
958            t.push(solver.state().t);
959            match ret {
960                Ok(OdeSolverStopReason::InternalTimestep) => (),
961                Ok(OdeSolverStopReason::RootFound(_, _)) => {
962                    panic!("should be an internal timestep but found a root")
963                }
964                Ok(OdeSolverStopReason::TstopReached) => break,
965                _ => panic!("should be an internal timestep"),
966            }
967        }
968        (x, v, t)
969    }
970
971    pub fn test_checkpointing<'a, M, Method, Eqn>(
972        soln: OdeSolverSolution<M::V>,
973        mut solver1: Method,
974        mut solver2: Method,
975    ) where
976        M: Matrix + DefaultSolver,
977        Method: OdeSolverMethod<'a, Eqn>,
978        Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
979    {
980        let half_i = soln.solution_points.len() / 2;
981        let half_t = soln.solution_points[half_i].t;
982        while solver1.state().t <= half_t {
983            solver1.step().unwrap();
984        }
985        let checkpoint = solver1.checkpoint();
986        let checkpoint_t = checkpoint.as_ref().t;
987        solver2.set_state(checkpoint);
988
989        // carry on solving with both solvers, they should produce about the same results (probably might diverge a bit, but should always match the solution)
990        for point in soln.solution_points.iter().skip(half_i + 1) {
991            // point should be past checkpoint
992            if point.t < checkpoint_t {
993                continue;
994            }
995            while solver2.state().t < point.t {
996                solver1.step().unwrap();
997                solver2.step().unwrap();
998                let time_error = (solver1.state().t - solver2.state().t).abs()
999                    / (solver1.state().t.abs() * solver1.problem().rtol
1000                        + solver1.problem().atol.get_index(0));
1001                assert!(
1002                    time_error < M::T::from_f64(20.0).unwrap(),
1003                    "time_error: {} at t = {}",
1004                    time_error,
1005                    solver1.state().t
1006                );
1007                solver1.state().y.assert_eq_norm(
1008                    solver2.state().y,
1009                    &solver1.problem().atol,
1010                    solver1.problem().rtol,
1011                    M::T::from_f64(20.0).unwrap(),
1012                );
1013            }
1014            let soln = solver1.interpolate(point.t).unwrap();
1015            soln.assert_eq_norm(
1016                &point.state,
1017                &solver1.problem().atol,
1018                solver1.problem().rtol,
1019                M::T::from_f64(15.0).unwrap(),
1020            );
1021            let soln = solver2.interpolate(point.t).unwrap();
1022            soln.assert_eq_norm(
1023                &point.state,
1024                &solver1.problem().atol,
1025                solver1.problem().rtol,
1026                M::T::from_f64(15.0).unwrap(),
1027            );
1028        }
1029    }
1030
1031    pub fn test_state_mut_on_problem<'a, Eqn, Method>(
1032        mut s: Method,
1033        soln: OdeSolverSolution<Eqn::V>,
1034    ) where
1035        Eqn: OdeEquationsImplicit + 'a,
1036        Method: OdeSolverMethod<'a, Eqn>,
1037        Eqn::V: DefaultDenseMatrix,
1038    {
1039        // save state and solve for a little bit
1040        let state = s.checkpoint();
1041        s.solve(Eqn::T::one()).unwrap();
1042
1043        // reinit using state_mut
1044        s.state_mut().y.copy_from(state.as_ref().y);
1045        s.state_mut().dy.copy_from(state.as_ref().dy);
1046        *s.state_mut().t = state.as_ref().t;
1047
1048        // solve and check against solution
1049        for point in soln.solution_points.iter() {
1050            while s.state().t < point.t {
1051                s.step().unwrap();
1052            }
1053            let soln = s.interpolate(point.t).unwrap();
1054            let error = soln.clone() - &point.state;
1055            let error_norm = error
1056                .squared_norm(&error, &s.problem().atol, s.problem().rtol)
1057                .sqrt();
1058            assert!(
1059                error_norm < Eqn::T::from_f64(19.0).unwrap(),
1060                "error_norm: {} at t = {}",
1061                error_norm,
1062                point.t
1063            );
1064        }
1065    }
1066
1067    /// Test that `step()` returns `RootFound(t, index)` with the correct root index.
1068    ///
1069    /// The problem must have a root function with **two** outputs and **no** Reset:
1070    ///   - Root 0 fires first (at `t ≈ 5.108`, `y[0] ≈ 0.6` for the exponential-decay test model)
1071    ///   - Root 1 fires second
1072    ///
1073    /// The test asserts that the first `RootFound` reports index 0 and the time
1074    /// matches `t_root_0_expected` within `tol`.
1075    pub fn test_root_found_index<'a, Eqn, Method>(
1076        mut solver: Method,
1077        soln: &OdeSolverSolution<Eqn::V>,
1078        expected_root_index: usize,
1079        tol: Eqn::T,
1080    ) where
1081        Eqn: OdeEquations + 'a,
1082        Method: OdeSolverMethod<'a, Eqn>,
1083    {
1084        let t_root_expected = soln.solution_points[0].t;
1085        solver
1086            .set_stop_time(Eqn::T::from_f64(100.0).unwrap())
1087            .unwrap();
1088        loop {
1089            match solver.step().unwrap() {
1090                // RED: `RootFound` currently has one field; adding `index` makes this fail.
1091                OdeSolverStopReason::RootFound(t, index) => {
1092                    assert_eq!(
1093                        index, expected_root_index,
1094                        "expected root index {expected_root_index} but got {index}",
1095                    );
1096                    assert!(
1097                        (t - t_root_expected).abs() < tol,
1098                        "expected t ≈ {t_root_expected:?}, got {t:?}",
1099                    );
1100                    break;
1101                }
1102                OdeSolverStopReason::TstopReached => {
1103                    panic!("reached tstop without finding a root")
1104                }
1105                OdeSolverStopReason::InternalTimestep => {}
1106            }
1107        }
1108    }
1109
1110    /// Test that `solve()` can be continued manually after a root by applying
1111    /// `reset()` and calling `solve()` again.
1112    ///
1113    /// `soln` must contain one solution point at `t_stop` (the state when the
1114    /// second root fires after the manual reset).
1115    pub fn test_solve_with_reset<'a, Eqn, Method>(
1116        mut solver: Method,
1117        soln: &OdeSolverSolution<Eqn::V>,
1118    ) where
1119        Eqn: OdeEquations + 'a,
1120        Eqn::V: DefaultDenseMatrix,
1121        Method: OdeSolverMethod<'a, Eqn>,
1122    {
1123        let final_time = Eqn::T::from_f64(100.0).unwrap();
1124        let (_ys_first, ts_first, stop_reason_first) = solver.solve(final_time).unwrap();
1125        assert!(matches!(
1126            stop_reason_first,
1127            OdeSolverStopReason::RootFound(_, _)
1128        ));
1129        let t_first_root = *ts_first.last().unwrap();
1130        assert!(
1131            t_first_root < final_time,
1132            "expected first solve() call to stop at a root before final_time"
1133        );
1134
1135        // Manually apply the reset on the solver state and continue to the next root.
1136        let mut state = solver.state_clone();
1137        {
1138            let problem = solver.problem();
1139            if let Some(reset_fn) = problem.eqn.reset() {
1140                state.state_mut_op(&problem.eqn, &reset_fn).unwrap();
1141            }
1142        }
1143        solver.set_state(state);
1144        let (ys_second, ts_second, stop_reason_second) = solver.solve(final_time).unwrap();
1145        assert!(matches!(
1146            stop_reason_second,
1147            OdeSolverStopReason::RootFound(_, _)
1148        ));
1149
1150        let expected = &soln.solution_points[0];
1151        let t_second_root = *ts_second.last().unwrap();
1152        let time_tol = soln.rtol * expected.t.abs() + soln.atol.get_index(0);
1153        assert!(
1154            (t_second_root - expected.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1155            "expected second root time ≈ {:?}, got {:?}",
1156            expected.t,
1157            t_second_root,
1158        );
1159
1160        let last_col = ts_second.len() - 1;
1161        let n = expected.state.len();
1162        let ctx = soln.atol.context().clone();
1163        let mut actual = Eqn::V::zeros(n, ctx);
1164        for j in 0..n {
1165            actual.set_index(j, ys_second.get_index(j, last_col));
1166        }
1167        let error = actual - &expected.state;
1168        let error_norm = error
1169            .squared_norm(&expected.state, &soln.atol, soln.rtol)
1170            .sqrt();
1171        let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1172        assert!(
1173            error_norm < error_threshold,
1174            "second-root state mismatch: WRMS error norm {error_norm:?} ≥ {error_threshold:?}",
1175        );
1176    }
1177
1178    /// Test that `solve_dense()` can be continued manually after a root by
1179    /// applying the reset directly to the solver state and calling `solve_dense()` again.
1180    ///
1181    /// `soln` must contain one solution point at `t_stop`.
1182    ///
1183    /// The test verifies that:
1184    ///  - The output matrix is truncated (fewer columns than t_eval.len())
1185    ///  - The last column matches the stop state (soln[0])
1186    pub fn test_solve_dense_with_reset<'a, Eqn, Method>(
1187        mut solver: Method,
1188        soln: &OdeSolverSolution<Eqn::V>,
1189    ) where
1190        Eqn: OdeEquations + 'a,
1191        Eqn::V: DefaultDenseMatrix,
1192        Method: OdeSolverMethod<'a, Eqn>,
1193    {
1194        let t_stop = soln.solution_points[0].t;
1195
1196        let n_steps = 20usize;
1197        let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1198        let dt = final_time / Eqn::T::from_f64(n_steps as f64).unwrap();
1199        let t_eval: Vec<Eqn::T> = (0..=n_steps)
1200            .map(|i| dt * Eqn::T::from_f64(i as f64).unwrap())
1201            .collect();
1202
1203        let (ret_first, stop_reason_first) = solver.solve_dense(&t_eval).unwrap();
1204        assert!(matches!(
1205            stop_reason_first,
1206            OdeSolverStopReason::RootFound(_, _)
1207        ));
1208        let ncols_first = ret_first.ncols();
1209
1210        // First pass should halt at the first root.
1211        assert!(
1212            ncols_first < t_eval.len(),
1213            "expected first solve_dense() call to stop at a root"
1214        );
1215        let t_first_root = solver.state().t;
1216
1217        // Manually apply the reset directly to the solver state.
1218        let mut state = solver.state_clone();
1219        {
1220            let problem = solver.problem();
1221            if let Some(reset_fn) = problem.eqn.reset() {
1222                state.state_mut_op(&problem.eqn, &reset_fn).unwrap();
1223            }
1224        }
1225        solver.set_state(state);
1226
1227        // Continue from just after the first-root time so t_eval remains valid.
1228        let t_eval_after_reset: Vec<Eqn::T> = t_eval
1229            .iter()
1230            .copied()
1231            .filter(|&t| t > t_first_root)
1232            .collect();
1233        assert!(
1234            !t_eval_after_reset.is_empty(),
1235            "expected at least one evaluation time after first root"
1236        );
1237
1238        let (ret_second, stop_reason_second) = solver.solve_dense(&t_eval_after_reset).unwrap();
1239        assert!(matches!(
1240            stop_reason_second,
1241            OdeSolverStopReason::RootFound(_, _)
1242        ));
1243        let ncols = ret_second.ncols();
1244
1245        // The second root fires before the last t_eval_after_reset, so the matrix should be truncated.
1246        assert!(
1247            ncols < t_eval_after_reset.len(),
1248            "expected early stop after manual reset: ncols ({ncols}) should be < t_eval_after_reset.len() ({})",
1249            t_eval_after_reset.len(),
1250        );
1251
1252        let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1253
1254        // The last column should be the second-root stop state (soln[0]).
1255        let last_col = ncols - 1;
1256        let actual = ret_second.column(last_col).into_owned();
1257        let error = actual - &soln.solution_points[0].state;
1258        let error_norm = error
1259            .squared_norm(&soln.solution_points[0].state, &soln.atol, soln.rtol)
1260            .sqrt();
1261        assert!(
1262            error_norm < error_threshold,
1263            "second-root stop state (soln[0], t ≈ {:?}) not found in last column ({last_col}); \
1264             WRMS norm {error_norm:?} ≥ {error_threshold:?}",
1265            t_stop,
1266        );
1267
1268        let t_second_root = solver.state().t;
1269        let time_tol = soln.rtol * t_stop.abs() + soln.atol.get_index(0);
1270        assert!(
1271            (t_second_root - t_stop).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1272            "expected second root time ≈ {:?}, got {:?}",
1273            t_stop,
1274            t_second_root,
1275        );
1276    }
1277
1278    /// Test that `solve_dense_sensitivities()` can be continued manually after
1279    /// a root by applying the root-aware reset directly to the solver state and
1280    /// calling `solve_dense_sensitivities()` again.
1281    ///
1282    /// `soln` must contain one solution point at `t_stop` with exact `y` and sensitivity vectors.
1283    pub fn test_solve_dense_sensitivities_with_reset<'a, Eqn, Method>(
1284        mut solver: Method,
1285        soln: &OdeSolverSolution<Eqn::V>,
1286    ) where
1287        Eqn: OdeEquationsImplicitSensWithReset + 'a,
1288        Eqn::V: DefaultDenseMatrix,
1289        Method: SensitivitiesOdeSolverMethod<'a, Eqn>,
1290    {
1291        let t_stop = soln.solution_points[0].t;
1292
1293        let n_steps = 20usize;
1294        let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1295        let dt = final_time / Eqn::T::from_f64(n_steps as f64).unwrap();
1296        let t_eval: Vec<Eqn::T> = (0..=n_steps)
1297            .map(|i| dt * Eqn::T::from_f64(i as f64).unwrap())
1298            .collect();
1299
1300        let (ret_first, _ret_sens_first, stop_reason_first) =
1301            solver.solve_dense_sensitivities(&t_eval).unwrap();
1302        assert!(matches!(
1303            stop_reason_first,
1304            OdeSolverStopReason::RootFound(_, _)
1305        ));
1306        let ncols_first = ret_first.ncols();
1307
1308        // First pass should halt at the first root.
1309        assert!(
1310            ncols_first < t_eval.len(),
1311            "expected first solve_dense_sensitivities() call to stop at a root"
1312        );
1313        let t_first_root = solver.state().t;
1314
1315        let first_root_idx = match stop_reason_first {
1316            OdeSolverStopReason::RootFound(_, root_idx) => root_idx,
1317            _ => unreachable!("expected first sensitivity solve to stop on a root"),
1318        };
1319
1320        // Manually apply the root-aware reset directly to the solver state.
1321        let mut state = solver.state_clone();
1322        {
1323            let problem = solver.problem();
1324            let reset_fn = problem.eqn.reset().unwrap();
1325            let root_fn = problem.eqn.root().unwrap();
1326            state
1327                .state_mut_op_with_sens_and_reset(&problem.eqn, &reset_fn, &root_fn, first_root_idx)
1328                .unwrap();
1329        }
1330        solver.set_state(state);
1331
1332        // Continue from just after the first-root time so t_eval remains valid.
1333        let t_eval_after_reset: Vec<Eqn::T> = t_eval
1334            .iter()
1335            .copied()
1336            .filter(|&t| t > t_first_root)
1337            .collect();
1338        assert!(
1339            !t_eval_after_reset.is_empty(),
1340            "expected at least one evaluation time after first root"
1341        );
1342
1343        let (ret_second, ret_sens_second, stop_reason_second) = solver
1344            .solve_dense_sensitivities(&t_eval_after_reset)
1345            .unwrap();
1346        assert!(matches!(
1347            stop_reason_second,
1348            OdeSolverStopReason::RootFound(_, _)
1349        ));
1350        let ncols = ret_second.ncols();
1351
1352        // The second root fires before the final t_eval_after_reset → output must be truncated.
1353        assert!(
1354            ncols < t_eval_after_reset.len(),
1355            "expected early stop after manual reset: ncols ({ncols}) should be < t_eval_after_reset.len() ({})",
1356            t_eval_after_reset.len(),
1357        );
1358
1359        // Check the last column matches the expected second-root solution.
1360        let expected = &soln.solution_points[0];
1361        let error_threshold = Eqn::T::from_f64(100.0).unwrap();
1362        let sens_points = soln.sens_solution_points.as_ref().unwrap();
1363
1364        let last_col = ncols - 1;
1365        let ey = ret_second.column(last_col).into_owned() - &expected.state;
1366        let mut combined = ey.squared_norm(&expected.state, &soln.atol, soln.rtol);
1367        for (param_j, sens_pts_j) in sens_points.iter().enumerate() {
1368            let expected_s = &sens_pts_j[0].state;
1369            let es = ret_sens_second[param_j].column(last_col).into_owned() - expected_s;
1370            combined += es.squared_norm(expected_s, &soln.atol, soln.rtol);
1371        }
1372        let norm = combined.sqrt();
1373        assert!(
1374            norm < error_threshold,
1375            "t_stop solution not found in last column; combined WRMS {norm:?} ≥ {error_threshold:?}",
1376        );
1377
1378        let t_second_root = solver.state().t;
1379        let time_tol = soln.rtol * t_stop.abs() + soln.atol.get_index(0);
1380        assert!(
1381            (t_second_root - t_stop).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1382            "expected second root time ≈ {:?}, got {:?}",
1383            t_stop,
1384            t_second_root,
1385        );
1386    }
1387
1388    pub fn test_solve_adjoint_with_single_reset_root<
1389        'a,
1390        Eqn,
1391        MethodF,
1392        MethodB,
1393        BuildForward,
1394        BuildAdjointState,
1395        BuildAdjointFromState,
1396    >(
1397        build_forward: BuildForward,
1398        soln: &OdeSolverSolution<Eqn::V>,
1399        build_adjoint_state: BuildAdjointState,
1400        build_adjoint_from_state: BuildAdjointFromState,
1401    ) where
1402        Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
1403        Eqn::M: DefaultSolver,
1404        Eqn::V: DefaultDenseMatrix,
1405        MethodF: OdeSolverMethod<'a, Eqn>,
1406        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1407        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1408        BuildAdjointState:
1409            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1410        BuildAdjointFromState:
1411            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1412    {
1413        let expected_out = &soln.solution_points[0];
1414        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1415
1416        let mut first_forward_solver = build_forward(None).unwrap();
1417        let (pre_reset_checkpointer, _, _, _) = first_forward_solver
1418            .solve_with_checkpointing(forward_stop_time, None)
1419            .unwrap();
1420        let fwd_state_minus = first_forward_solver.into_state();
1421        let mut state_after_reset = fwd_state_minus.clone();
1422        let problem = pre_reset_checkpointer.problem();
1423        let reset_fn = problem.eqn.reset().unwrap();
1424        state_after_reset
1425            .state_mut_op(&problem.eqn, &reset_fn)
1426            .unwrap();
1427        let fwd_state_plus = state_after_reset.clone();
1428
1429        let mut second_forward_solver = build_forward(Some(state_after_reset)).unwrap();
1430        let (post_reset_checkpointer, _, _, post_reset_stop_reason) = second_forward_solver
1431            .solve_with_checkpointing(forward_stop_time, None)
1432            .unwrap();
1433        let final_forward_state = second_forward_solver.into_state();
1434        let t_second_root = final_forward_state.as_ref().t;
1435
1436        let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1437        let out_norm = out_error
1438            .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1439            .sqrt();
1440        assert!(
1441            out_norm < Eqn::T::from_f64(50.0).unwrap(),
1442            "forward integrated output mismatch at second root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1443            final_forward_state.as_ref().g,
1444            expected_out.state,
1445        );
1446        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1447        assert!(
1448            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1449            "expected second root time ≈ {:?}, got {:?}",
1450            expected_out.t,
1451            t_second_root,
1452        );
1453
1454        let mut post_reset_adjoint_eqn =
1455            problem.adjoint_equations(post_reset_checkpointer.clone(), None);
1456        let mut post_reset_adjoint_state =
1457            build_adjoint_state(&mut post_reset_adjoint_eqn).unwrap();
1458        let post_reset_root_idx = match post_reset_stop_reason {
1459            OdeSolverStopReason::RootFound(_, idx) => idx,
1460            OdeSolverStopReason::TstopReached => {
1461                panic!("expected second forward segment to stop on a root, got TstopReached")
1462            }
1463            OdeSolverStopReason::InternalTimestep => {
1464                panic!("expected second forward segment to stop on a root, got InternalTimestep")
1465            }
1466        };
1467        post_reset_adjoint_state
1468            .state_mut_adjoint_terminal_root(
1469                &mut post_reset_adjoint_eqn,
1470                post_reset_root_idx,
1471                &final_forward_state,
1472            )
1473            .unwrap();
1474        let post_reset_adjoint =
1475            build_adjoint_from_state(post_reset_adjoint_state, post_reset_adjoint_eqn).unwrap();
1476        let mut adjoint_state = post_reset_adjoint
1477            .solve_adjoint_backwards_pass(Some(fwd_state_minus.as_ref().t), &[], &[])
1478            .unwrap();
1479        let t0 = pre_reset_checkpointer.problem().t0;
1480        let ctx = pre_reset_checkpointer.problem().context().clone();
1481        let reset_problem = pre_reset_checkpointer.problem();
1482        let mut pre_reset_adjoint_eqn = problem.adjoint_equations(pre_reset_checkpointer, None);
1483        {
1484            let reset_fn = reset_problem.eqn.reset().unwrap();
1485            let root_fn = reset_problem.eqn.root().unwrap();
1486            adjoint_state
1487                .state_mut_op_with_adjoint_and_reset(
1488                    &mut pre_reset_adjoint_eqn,
1489                    &reset_fn,
1490                    &root_fn,
1491                    0,
1492                    &fwd_state_minus,
1493                    &fwd_state_plus,
1494                )
1495                .unwrap();
1496        }
1497        let pre_reset_adjoint =
1498            build_adjoint_from_state(adjoint_state, pre_reset_adjoint_eqn).unwrap();
1499        let adjoint_state = pre_reset_adjoint
1500            .solve_adjoint_backwards_pass(None, &[], &[])
1501            .unwrap();
1502
1503        let sens_points = soln.sens_solution_points.as_ref().unwrap();
1504        let expected_grad = Eqn::V::from_vec(
1505            sens_points
1506                .iter()
1507                .map(|pts| pts[0].state.get_index(0))
1508                .collect(),
1509            ctx.clone(),
1510        );
1511        let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1512        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1513        assert!(
1514            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1515            "expected adjoint final time {:?}, got {:?}",
1516            t0,
1517            adjoint_state.as_ref().t,
1518        );
1519        adjoint_state.as_ref().sg[0].assert_eq_norm(
1520            &expected_grad,
1521            &atol,
1522            Eqn::T::from_f64(1e-6).unwrap(),
1523            Eqn::T::from_f64(60.0).unwrap(),
1524        );
1525    }
1526
1527    pub fn test_solve_adjoint_sum_squares_with_single_reset_root<
1528        'a,
1529        Eqn,
1530        MethodF,
1531        MethodB,
1532        BuildForward,
1533        BuildAdjointState,
1534        BuildAdjointFromState,
1535    >(
1536        build_forward: BuildForward,
1537        soln: &OdeSolverSolution<Eqn::V>,
1538        build_adjoint_state: BuildAdjointState,
1539        build_adjoint_from_state: BuildAdjointFromState,
1540        dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1541        data: <Eqn::V as DefaultDenseMatrix>::M,
1542        times: &[Eqn::T],
1543    ) where
1544        Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
1545        Eqn::M: DefaultSolver,
1546        Eqn::V: DefaultDenseMatrix,
1547        MethodF: OdeSolverMethod<'a, Eqn>,
1548        MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1549        BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1550        BuildAdjointState:
1551            Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1552        BuildAdjointFromState:
1553            Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1554    {
1555        let expected_out = &soln.solution_points[0];
1556        let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1557        let forwards_soln =
1558            solve_dense_with_single_reset_root::<Eqn, MethodF, _>(&build_forward, times);
1559        assert_eq!(
1560            forwards_soln.ncols(),
1561            times.len(),
1562            "expected stitched forward samples to cover every requested observation time",
1563        );
1564        let dgdu = dsum_squaresdp(&forwards_soln, &data);
1565        let dgdu_refs = dgdu.iter().collect::<Vec<_>>();
1566
1567        let mut first_forward_solver = build_forward(None).unwrap();
1568        let (pre_reset_checkpointer, _, _, pre_reset_stop_reason) = first_forward_solver
1569            .solve_with_checkpointing(forward_stop_time, None)
1570            .unwrap();
1571        let fwd_state_minus = first_forward_solver.into_state();
1572        match pre_reset_stop_reason {
1573            OdeSolverStopReason::RootFound(_, 0) => {}
1574            OdeSolverStopReason::RootFound(_, idx) => {
1575                panic!("expected first checkpointed segment to stop on root 0, got root {idx}")
1576            }
1577            OdeSolverStopReason::TstopReached => {
1578                panic!("expected first checkpointed segment to stop on the interior root")
1579            }
1580            OdeSolverStopReason::InternalTimestep => {
1581                panic!("first checkpointed segment ended without a terminal stop reason")
1582            }
1583        }
1584
1585        let mut state_after_reset = fwd_state_minus.clone();
1586        let problem = pre_reset_checkpointer.problem();
1587        let reset_fn = problem.eqn.reset().unwrap();
1588        state_after_reset
1589            .state_mut_op(&problem.eqn, &reset_fn)
1590            .unwrap();
1591        let fwd_state_plus = state_after_reset.clone();
1592
1593        let mut second_forward_solver = build_forward(Some(state_after_reset)).unwrap();
1594        let (post_reset_checkpointer, _, _, post_reset_stop_reason) = second_forward_solver
1595            .solve_with_checkpointing(forward_stop_time, None)
1596            .unwrap();
1597        let final_forward_state = second_forward_solver.into_state();
1598        let t_second_root = final_forward_state.as_ref().t;
1599
1600        let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1601        assert!(
1602            (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1603            "expected second root time ≈ {:?}, got {:?}",
1604            expected_out.t,
1605            t_second_root,
1606        );
1607
1608        let mut post_reset_adjoint_eqn =
1609            problem.adjoint_equations(post_reset_checkpointer.clone(), Some(dgdu.len()));
1610        let mut post_reset_adjoint_state =
1611            build_adjoint_state(&mut post_reset_adjoint_eqn).unwrap();
1612        let post_reset_root_idx = match post_reset_stop_reason {
1613            OdeSolverStopReason::RootFound(_, idx) => idx,
1614            OdeSolverStopReason::TstopReached => {
1615                panic!("expected second forward segment to stop on a root, got TstopReached")
1616            }
1617            OdeSolverStopReason::InternalTimestep => {
1618                panic!("expected second forward segment to stop on a root, got InternalTimestep")
1619            }
1620        };
1621        post_reset_adjoint_state
1622            .state_mut_adjoint_terminal_root(
1623                &mut post_reset_adjoint_eqn,
1624                post_reset_root_idx,
1625                &final_forward_state,
1626            )
1627            .unwrap();
1628        let post_reset_adjoint =
1629            build_adjoint_from_state(post_reset_adjoint_state, post_reset_adjoint_eqn).unwrap();
1630        let mut adjoint_state = post_reset_adjoint
1631            .solve_adjoint_backwards_pass(
1632                Some(fwd_state_minus.as_ref().t),
1633                times,
1634                dgdu_refs.as_slice(),
1635            )
1636            .unwrap();
1637
1638        let t0 = pre_reset_checkpointer.problem().t0;
1639        let ctx = pre_reset_checkpointer.problem().context().clone();
1640        let reset_problem = pre_reset_checkpointer.problem();
1641        let mut pre_reset_adjoint_eqn =
1642            problem.adjoint_equations(pre_reset_checkpointer, Some(dgdu.len()));
1643        {
1644            let reset_fn = reset_problem.eqn.reset().unwrap();
1645            let root_fn = reset_problem.eqn.root().unwrap();
1646            adjoint_state
1647                .state_mut_op_with_adjoint_and_reset(
1648                    &mut pre_reset_adjoint_eqn,
1649                    &reset_fn,
1650                    &root_fn,
1651                    0,
1652                    &fwd_state_minus,
1653                    &fwd_state_plus,
1654                )
1655                .unwrap();
1656        }
1657        let pre_reset_adjoint =
1658            build_adjoint_from_state(adjoint_state, pre_reset_adjoint_eqn).unwrap();
1659        let adjoint_state = pre_reset_adjoint
1660            .solve_adjoint_backwards_pass(None, times, dgdu_refs.as_slice())
1661            .unwrap();
1662
1663        let nparams = dgdp_check.nrows();
1664        let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1665        let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1666        assert!(
1667            (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1668            "expected adjoint final time {:?}, got {:?}",
1669            t0,
1670            adjoint_state.as_ref().t,
1671        );
1672        #[allow(clippy::needless_range_loop)]
1673        for j in 0..dgdp_check.ncols() {
1674            adjoint_state.as_ref().sg[j].assert_eq_norm(
1675                &dgdp_check.column(j).into_owned(),
1676                &atol,
1677                Eqn::T::from_f64(1e-6).unwrap(),
1678                Eqn::T::from_f64(260.0).unwrap(),
1679            );
1680        }
1681    }
1682}