1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod no_checkpointing_solver;
11pub mod problem;
12pub mod runge_kutta;
13pub mod sde;
14pub mod sdirk;
15pub mod sdirk_state;
16pub mod sensitivities;
17pub mod solution;
18pub mod state;
19pub mod tableau;
20
21#[cfg(test)]
22mod tests {
23 use std::rc::Rc;
24
25 use self::problem::OdeSolverSolution;
26
27 use super::*;
28 use crate::error::{DiffsolError, OdeSolverError};
29 use crate::matrix::Matrix;
30 use crate::ode_solver::sensitivities::SensitivitiesOdeSolverMethod;
31 use crate::ode_solver::solution::Solution;
32 use crate::op::unit::UnitCallable;
33 use crate::op::ParameterisedOp;
34 use crate::Scalar;
35 use crate::{
36 op::OpStatistics, AdjointEquations, AdjointOdeSolverMethod, Context, DenseMatrix,
37 MatrixCommon, MatrixRef, NonLinearOp, NonLinearOpJacobian, OdeEquations,
38 OdeEquationsImplicit, OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens,
39 OdeEquationsRef, OdeSolverConfig, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
40 OdeSolverStopReason, Scale, VectorRef, VectorView, VectorViewMut,
41 };
42 use crate::{
43 ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver,
44 NonLinearOpSens, Op, Vector,
45 };
46 use num_traits::{FromPrimitive, One, Signed, ToPrimitive, Zero};
47
48 pub fn test_ode_solver<'a, M, Eqn, Method>(
49 method: &mut Method,
50 solution: OdeSolverSolution<M::V>,
51 override_tol: Option<M::T>,
52 use_tstop: bool,
53 solve_for_sensitivities: bool,
54 ) -> Eqn::V
55 where
56 M: Matrix,
57 Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
58 Method: OdeSolverMethod<'a, Eqn>,
59 {
60 let have_root = method.problem().eqn.root().is_some();
61 for (i, point) in solution.solution_points.iter().enumerate() {
62 let (soln, sens_soln) = if use_tstop {
63 match method.set_stop_time(point.t) {
64 Ok(_) => loop {
65 match method.step() {
66 Ok(OdeSolverStopReason::RootFound(_, _)) => {
67 assert!(have_root);
68 return method.state().y.clone();
69 }
70 Ok(OdeSolverStopReason::TstopReached) => {
71 break (method.state().y.clone(), method.state().s.to_vec());
72 }
73 _ => (),
74 }
75 },
76 Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
77 }
78 } else {
79 while method.state().t.abs() < point.t.abs() {
80 if let OdeSolverStopReason::RootFound(t, _) = method.step().unwrap() {
81 assert!(have_root);
82 return method.interpolate(t).unwrap();
83 }
84 }
85 let soln = method.interpolate(point.t).unwrap();
86 let sens_soln = method.interpolate_sens(point.t).unwrap();
87 (soln, sens_soln)
88 };
89 let soln = if let Some(out) = method.problem().eqn.out() {
90 out.call(&soln, point.t)
91 } else {
92 soln
93 };
94 assert_eq!(
95 soln.len(),
96 point.state.len(),
97 "soln.len() != point.state.len()"
98 );
99 if let Some(override_tol) = override_tol {
100 soln.assert_eq_st(&point.state, override_tol);
101 } else {
102 let (rtol, atol) = if method.problem().eqn.out().is_some() {
103 (solution.rtol, &solution.atol)
105 } else {
106 (method.problem().rtol, &method.problem().atol)
107 };
108 let error = soln.clone() - &point.state;
109 let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
110 assert!(
111 error_norm < M::T::from_f64(20.0).unwrap(),
112 "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
113 error_norm,
114 point.t,
115 soln,
116 point.state
117 );
118 if solve_for_sensitivities {
119 if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
120 for (j, sens_points) in sens_soln_points.iter().enumerate() {
121 let sens_point = &sens_points[i];
122 let sens_soln = &sens_soln[j];
123 let error = sens_soln.clone() - &sens_point.state;
124 let error_norm =
125 error.squared_norm(&sens_point.state, atol, rtol).sqrt();
126 assert!(
127 error_norm < M::T::from_f64(29.0).unwrap(),
128 "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
129 point.t,
130 sens_point.state
131 );
132 }
133 }
134 }
135 }
136 }
137 method.state().y.clone()
138 }
139
140 pub fn setup_test_adjoint<'a, LS, Eqn>(
141 problem: &'a mut OdeSolverProblem<Eqn>,
142 soln: OdeSolverSolution<Eqn::V>,
143 ) -> <Eqn::V as DefaultDenseMatrix>::M
144 where
145 Eqn: OdeEquationsImplicitAdjoint + 'a,
146 LS: LinearSolver<Eqn::M>,
147 Eqn::M: DefaultSolver,
148 Eqn::V: DefaultDenseMatrix,
149 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
150 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
151 {
152 let nparams = problem.eqn.nparams();
153 let nout = problem.eqn.nout();
154 let ctx = problem.eqn.context();
155 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
156 let final_time = soln.solution_points.last().unwrap().t;
157 let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
158 problem.eqn.get_params(&mut p_0);
159 let h_base = Eqn::T::from_f64(1e-10).unwrap();
160 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
161 h.axpy(h_base, &p_0, Eqn::T::one());
162 let p_base = p_0.clone();
163 for i in 0..nparams {
164 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
165 problem.eqn.set_params(&p_0);
166 let g_pos = {
167 let mut s = problem.bdf::<LS>().unwrap();
168 s.set_stop_time(final_time).unwrap();
169 while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
170 s.state().g.clone()
171 };
172
173 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
174 problem.eqn.set_params(&p_0);
175 let g_neg = {
176 let mut s = problem.bdf::<LS>().unwrap();
177 s.set_stop_time(final_time).unwrap();
178 while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
179 s.state().g.clone()
180 };
181 p_0.set_index(i, p_base.get_index(i));
182
183 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
184 for j in 0..nout {
185 dgdp.set_index(i, j, delta.get_index(j));
186 }
187 }
188 problem.eqn.set_params(&p_base);
189 dgdp
190 }
191
192 pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
195 where
196 DM: DenseMatrix,
197 {
198 let mut ret = DM::V::zeros(2, soln.context().clone());
199 for j in 0..soln.ncols() {
200 let soln_j = soln.column(j);
201 let data_j = data.column(j);
202 let delta = soln_j - data_j;
203 let norm2 = delta.norm(2);
204 ret.set_index(0, ret.get_index(0) + norm2 * norm2);
205 let norm4 = delta.norm(4);
206 let norm4_sq = norm4 * norm4;
207 ret.set_index(1, ret.get_index(1) + norm4_sq * norm4_sq);
208 }
209 ret
210 }
211
212 pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
215 where
216 DM: DenseMatrix,
217 {
218 let delta = soln.clone() - data;
219 let mut delta3 = delta.clone();
220 for j in 0..delta3.ncols() {
221 let delta_col = delta.column(j).into_owned();
222
223 let mut delta3_col = delta_col.clone();
224 delta3_col.component_mul_assign(&delta_col);
225 delta3_col.component_mul_assign(&delta_col);
226
227 delta3.column_mut(j).copy_from(&delta3_col);
228 }
229 let ret = vec![
230 delta * Scale(DM::T::from_f64(2.).unwrap()),
231 delta3 * Scale(DM::T::from_f64(4.).unwrap()),
232 ];
233 ret
234 }
235
236 pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
237 problem: &'a mut OdeSolverProblem<Eqn>,
238 times: &[Eqn::T],
239 ) -> (
240 <Eqn::V as DefaultDenseMatrix>::M,
241 <Eqn::V as DefaultDenseMatrix>::M,
242 )
243 where
244 Eqn: OdeEquationsImplicitAdjoint + 'a,
245 LS: LinearSolver<Eqn::M>,
246 Eqn::M: DefaultSolver,
247 Eqn::V: DefaultDenseMatrix,
248 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
249 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
250 {
251 let nparams = problem.eqn.nparams();
252 let nout = 2;
253 let ctx = problem.eqn.context();
254 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
255
256 let mut p_0 = ctx.vector_zeros(nparams);
257 problem.eqn.get_params(&mut p_0);
258 let h_base = Eqn::T::from_f64(1e-10).unwrap();
259 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
260 h.axpy(h_base, &p_0, Eqn::T::one());
261 let mut p_data = p_0.clone();
262 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
263 let p_base = p_0.clone();
264
265 problem.eqn.set_params(&p_data);
266 let data = {
267 let mut s = problem.bdf::<LS>().unwrap();
268 s.solve_dense(times).unwrap().0
269 };
270
271 for i in 0..nparams {
272 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
273 problem.eqn.set_params(&p_0);
274 let g_pos = {
275 let mut s = problem.bdf::<LS>().unwrap();
276 let v = s.solve_dense(times).unwrap().0;
277 sum_squares(&v, &data)
278 };
279
280 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
281 problem.eqn.set_params(&p_0);
282 let g_neg = {
283 let mut s = problem.bdf::<LS>().unwrap();
284 let v = s.solve_dense(times).unwrap().0;
285 sum_squares(&v, &data)
286 };
287
288 p_0.set_index(i, p_base.get_index(i));
289
290 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
291 for j in 0..nout {
292 dgdp.set_index(i, j, delta.get_index(j));
293 }
294 }
295 problem.eqn.set_params(&p_base);
296 (dgdp, data)
297 }
298
299 pub fn single_reset_root_discrete_times<T: Scalar>(t_stop: T) -> Vec<T> {
300 let t_root = t_stop / T::from_f64(2.0).unwrap();
301 [0.25, 0.75, 1.25, 1.75]
302 .into_iter()
303 .map(|factor| t_root * T::from_f64(factor).unwrap())
304 .collect()
305 }
306
307 fn solve_dense_with_single_reset_root<'a, Eqn, Method, BuildForward>(
308 build_forward: BuildForward,
309 times: &[Eqn::T],
310 ) -> <Eqn::V as DefaultDenseMatrix>::M
311 where
312 Eqn: OdeEquationsImplicitAdjoint + 'a,
313 Eqn::M: DefaultSolver,
314 Eqn::V: DefaultDenseMatrix,
315 Method: OdeSolverMethod<'a, Eqn>,
316 BuildForward: Fn(Option<Method::State>) -> Result<Method, DiffsolError>,
317 {
318 let mut soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
319 let first_forward_solver = build_forward(None).unwrap().solve_soln(&mut soln).unwrap();
320 match soln.stop_reason {
321 Some(OdeSolverStopReason::RootFound(_, 0)) => {}
322 Some(OdeSolverStopReason::RootFound(_, idx)) => {
323 panic!("expected first solve_soln() segment to stop on root 0, got root {idx}")
324 }
325 Some(OdeSolverStopReason::TstopReached) => {
326 panic!("expected first solve_soln() segment to stop on the interior root")
327 }
328 Some(OdeSolverStopReason::InternalTimestep) | None => {
329 panic!("first solve_soln() segment did not finish with a terminal stop reason")
330 }
331 }
332
333 let mut state_after_reset = first_forward_solver.state_clone();
334 {
335 let problem = first_forward_solver.problem();
336 state_after_reset
337 .as_mut()
338 .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
339 .unwrap();
340 }
341
342 build_forward(Some(state_after_reset))
343 .unwrap()
344 .solve_soln(&mut soln)
345 .unwrap();
346 assert!(
347 soln.is_complete(),
348 "expected stitched solve_soln() output to cover all requested observation times",
349 );
350 soln.ys
351 }
352
353 fn state_after_manual_reset<'a, Eqn, Method>(solver: &Method) -> Method::State
354 where
355 Eqn: OdeEquationsImplicitAdjoint + 'a,
356 Eqn::M: DefaultSolver,
357 Method: OdeSolverMethod<'a, Eqn>,
358 {
359 let mut state_after_reset = solver.state_clone();
360 {
361 let problem = solver.problem();
362 state_after_reset
363 .as_mut()
364 .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
365 .unwrap();
366 }
367 state_after_reset
368 }
369
370 pub fn setup_test_adjoint_sum_squares_with_single_reset_root<'a, LS, Eqn>(
371 problem: &'a mut OdeSolverProblem<Eqn>,
372 times: &[Eqn::T],
373 ) -> (
374 <Eqn::V as DefaultDenseMatrix>::M,
375 <Eqn::V as DefaultDenseMatrix>::M,
376 )
377 where
378 Eqn: OdeEquationsImplicitAdjoint + 'a,
379 LS: LinearSolver<Eqn::M>,
380 Eqn::M: DefaultSolver,
381 Eqn::V: DefaultDenseMatrix,
382 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
383 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
384 {
385 let nparams = problem.eqn.nparams();
386 let nout = 2;
387 let ctx = problem.eqn.context();
388 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
389
390 let mut p_0 = ctx.vector_zeros(nparams);
391 problem.eqn.get_params(&mut p_0);
392 let h_base = Eqn::T::from_f64(1e-10).unwrap();
393 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
394 h.axpy(h_base, &p_0, Eqn::T::one());
395 let mut p_data = p_0.clone();
396 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
397 let p_base = p_0.clone();
398
399 problem.eqn.set_params(&p_data);
400 let data = solve_dense_with_single_reset_root::<Eqn, _, _>(
401 |state| match state {
402 Some(state) => problem.bdf_solver(state),
403 None => problem.bdf::<LS>(),
404 },
405 times,
406 );
407
408 for i in 0..nparams {
409 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
410 problem.eqn.set_params(&p_0);
411 let g_pos = {
412 let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
413 |state| match state {
414 Some(state) => problem.bdf_solver(state),
415 None => problem.bdf::<LS>(),
416 },
417 times,
418 );
419 sum_squares(&v, &data)
420 };
421
422 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
423 problem.eqn.set_params(&p_0);
424 let g_neg = {
425 let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
426 |state| match state {
427 Some(state) => problem.bdf_solver(state),
428 None => problem.bdf::<LS>(),
429 },
430 times,
431 );
432 sum_squares(&v, &data)
433 };
434
435 p_0.set_index(i, p_base.get_index(i));
436
437 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
438 for j in 0..nout {
439 dgdp.set_index(i, j, delta.get_index(j));
440 }
441 }
442 problem.eqn.set_params(&p_base);
443 (dgdp, data)
444 }
445
446 pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
447 backwards_solver: SolverB,
448 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
449 forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
450 data: <Eqn::V as DefaultDenseMatrix>::M,
451 times: &[Eqn::T],
452 ) where
453 SolverF: OdeSolverMethod<'a, Eqn>,
454 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
455 Eqn: OdeEquationsImplicitAdjoint + 'a,
456 Eqn::V: DefaultDenseMatrix,
457 Eqn::M: DefaultSolver,
458 {
459 let nparams = dgdp_check.nrows();
460 let dgdu = dsum_squaresdp(&forwards_soln, &data);
461
462 let atol = Eqn::V::from_element(
463 nparams,
464 Eqn::T::from_f64(1e-6).unwrap(),
465 data.context().clone(),
466 );
467 let rtol = Eqn::T::from_f64(1e-6).unwrap();
468 let (state, _) = backwards_solver
469 .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
470 .unwrap();
471 let gs_adj = state.into_common().sg;
472 #[allow(clippy::needless_range_loop)]
473 for j in 0..dgdp_check.ncols() {
474 gs_adj[j].assert_eq_norm(
475 &dgdp_check.column(j).into_owned(),
476 &atol,
477 rtol,
478 Eqn::T::from_f64(260.).unwrap(),
479 );
480 }
481 }
482
483 pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
484 backwards_solver: SolverB,
485 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
486 ) where
487 SolverF: OdeSolverMethod<'a, Eqn>,
488 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
489 Eqn: OdeEquationsImplicitAdjoint + 'a,
490 Eqn::V: DefaultDenseMatrix,
491 Eqn::M: DefaultSolver,
492 {
493 let nout = backwards_solver.problem().eqn.nout();
494 let atol = Eqn::V::from_element(
495 nout,
496 Eqn::T::from_f64(1e-6).unwrap(),
497 dgdp_check.context().clone(),
498 );
499 let rtol = Eqn::T::from_f64(1e-6).unwrap();
500 let (state, _) = backwards_solver
501 .solve_adjoint_backwards_pass(&[], &[])
502 .unwrap();
503 let gs_adj = state.into_common().sg;
504 #[allow(clippy::needless_range_loop)]
505 for j in 0..dgdp_check.ncols() {
506 gs_adj[j].assert_eq_norm(
507 &dgdp_check.column(j).into_owned(),
508 &atol,
509 rtol,
510 Eqn::T::from_f64(40.).unwrap(),
511 );
512 }
513 }
514
515 pub struct TestEqnInit<M: Matrix> {
516 ctx: M::C,
517 }
518
519 impl<M: Matrix> Op for TestEqnInit<M> {
520 type T = M::T;
521 type V = M::V;
522 type M = M;
523 type C = M::C;
524
525 fn nout(&self) -> usize {
526 1
527 }
528 fn nparams(&self) -> usize {
529 1
530 }
531 fn nstates(&self) -> usize {
532 1
533 }
534 fn context(&self) -> &Self::C {
535 &self.ctx
536 }
537 }
538
539 impl<M: Matrix> ConstantOp for TestEqnInit<M> {
540 fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
541 y.fill(M::T::one());
542 }
543 }
544
545 impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
546 fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
547 sens.fill(M::T::zero());
548 }
549 }
550
551 pub struct TestEqnRhs<M: Matrix> {
552 ctx: M::C,
553 }
554
555 impl<M: Matrix> Op for TestEqnRhs<M> {
556 type T = M::T;
557 type V = M::V;
558 type M = M;
559 type C = M::C;
560
561 fn nout(&self) -> usize {
562 1
563 }
564 fn nparams(&self) -> usize {
565 1
566 }
567 fn nstates(&self) -> usize {
568 1
569 }
570 fn context(&self) -> &Self::C {
571 &self.ctx
572 }
573 }
574
575 impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
576 fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
577 y.fill(M::T::zero());
578 }
579 }
580
581 impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
582 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
583 y.fill(M::T::zero());
584 }
585 }
586
587 impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
588 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
589 sens.fill(M::T::zero());
590 }
591 }
592
593 pub struct TestEqnOut<M: Matrix> {
594 ctx: M::C,
595 }
596
597 impl<M: Matrix> Op for TestEqnOut<M> {
598 type T = M::T;
599 type V = M::V;
600 type M = M;
601 type C = M::C;
602
603 fn nout(&self) -> usize {
604 1
605 }
606 fn nparams(&self) -> usize {
607 1
608 }
609 fn nstates(&self) -> usize {
610 1
611 }
612 fn context(&self) -> &Self::C {
613 &self.ctx
614 }
615 }
616
617 impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
618 fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
619 y.copy_from(x);
620 }
621 }
622
623 impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
624 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
625 y.copy_from(v);
626 }
627 }
628
629 impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
630 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
631 sens.fill(M::T::zero());
632 }
633 }
634
635 pub struct TestEqn<M: Matrix> {
636 rhs: Rc<TestEqnRhs<M>>,
637 init: Rc<TestEqnInit<M>>,
638 out: Rc<TestEqnOut<M>>,
639 ctx: M::C,
640 }
641
642 impl<M: Matrix> TestEqn<M> {
643 pub fn new() -> Self {
644 let ctx = M::C::default();
645 Self {
646 rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
647 init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
648 out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
649 ctx,
650 }
651 }
652 }
653
654 impl<M: Matrix> Op for TestEqn<M> {
655 type T = M::T;
656 type V = M::V;
657 type M = M;
658 type C = M::C;
659 fn nout(&self) -> usize {
660 1
661 }
662 fn nparams(&self) -> usize {
663 1
664 }
665 fn nstates(&self) -> usize {
666 1
667 }
668 fn statistics(&self) -> crate::op::OpStatistics {
669 OpStatistics::default()
670 }
671 fn context(&self) -> &Self::C {
672 &self.ctx
673 }
674 }
675
676 impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
677 type Rhs = &'a TestEqnRhs<M>;
678 type Mass = ParameterisedOp<'a, UnitCallable<M>>;
679 type Root = ParameterisedOp<'a, UnitCallable<M>>;
680 type Init = &'a TestEqnInit<M>;
681 type Out = &'a TestEqnOut<M>;
682 type Reset = ParameterisedOp<'a, UnitCallable<M>>;
683 }
684
685 impl<M: Matrix> OdeEquations for TestEqn<M> {
686 fn rhs(&self) -> &TestEqnRhs<M> {
687 &self.rhs
688 }
689
690 fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
691 None
692 }
693
694 fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
695 None
696 }
697
698 fn init(&self) -> &TestEqnInit<M> {
699 &self.init
700 }
701
702 fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
703 Some(&self.out)
704 }
705 fn set_params(&mut self, _p: &Self::V) {
706 unimplemented!()
707 }
708 fn get_params(&self, _p: &mut Self::V) {
709 unimplemented!()
710 }
711 }
712
713 pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
714 let eqn = TestEqn::<M>::new();
715 let atol = eqn
716 .context()
717 .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
718 OdeSolverProblem::new(
719 eqn,
720 M::T::from_f64(1e-6).unwrap(),
721 atol,
722 None,
723 None,
724 None,
725 None,
726 None,
727 None,
728 M::T::zero(),
729 M::T::one(),
730 integrate_out,
731 Default::default(),
732 Default::default(),
733 )
734 .unwrap()
735 }
736
737 pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
738 let state = s.checkpoint();
739 let integrating_sens = !s.state().s.is_empty();
740 let integrating_out = s.problem().integrate_out;
741 let t0 = state.as_ref().t;
742 let t1 = t0 + M::T::from_f64(1e6).unwrap();
743 s.interpolate(t0)
744 .unwrap()
745 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
746 assert!(s.interpolate(t1).is_err());
747 assert!(s.interpolate_out(t1).is_err());
748 if integrating_sens {
749 assert!(s.interpolate_sens(t1).is_err());
750 } else {
751 assert!(s.interpolate_sens(t0).is_ok());
752 }
753 s.step().unwrap();
754 let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
755 assert!(s.interpolate(s.state().t).is_ok());
756 assert!(s.interpolate(tmid).is_ok());
757 if integrating_out {
758 assert!(s.interpolate_out(s.state().t).is_ok());
759 } else {
760 assert!(s.interpolate_out(s.state().t).is_err());
761 }
762 assert!(s.interpolate_sens(s.state().t).is_ok());
763 assert!(s.interpolate(s.state().t + t1).is_err());
764 assert!(s.interpolate_out(s.state().t + t1).is_err());
765 if integrating_sens {
766 assert!(s.interpolate_sens(s.state().t + t1).is_err());
767 } else {
768 assert!(s.interpolate_sens(s.state().t + t1).is_ok());
769 }
770
771 let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
772 assert!(s
773 .interpolate_inplace(s.state().t, &mut y_wrong_length)
774 .is_err());
775 let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
776 assert!(s
777 .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
778 .is_err());
779 let mut s_wrong_length = vec![
780 M::V::zeros(1, s.problem().context().clone()),
781 M::V::zeros(1, s.problem().context().clone()),
782 ];
783 assert!(s
784 .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
785 .is_err());
786 let mut s_wrong_vec_length = if integrating_sens {
787 vec![M::V::zeros(2, s.problem().context().clone())]
788 } else {
789 vec![]
790 };
791 if integrating_sens {
792 assert!(s
793 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
794 .is_err());
795 } else {
796 assert!(s
797 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
798 .is_ok());
799 }
800
801 s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
802 assert!(s.interpolate(s.state().t).is_ok());
803 if integrating_out {
804 assert!(s.interpolate_out(s.state().t).is_ok());
805 }
806 if integrating_sens {
807 assert!(s.interpolate_sens(s.state().t).is_ok());
808 }
809 assert!(s.interpolate(tmid).is_err());
810 assert!(s.interpolate_out(tmid).is_err());
811 if integrating_sens {
812 assert!(s.interpolate_sens(tmid).is_err());
813 } else {
814 assert!(s.interpolate_sens(tmid).is_ok());
815 }
816 }
817
818 pub fn test_interpolate_dy<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(
819 mut s: Method,
820 ) {
821 let t_future = s.state().t + M::T::from_f64(1e6).unwrap();
823 assert!(s.interpolate_dy(t_future).is_err());
824
825 let t0 = s.state().t;
826 s.step().unwrap();
827 let t1 = s.state().t;
828 let dt = t1 - t0;
829 let tmid = t0 + dt / M::T::from_f64(2.0).unwrap();
830
831 let mut dy_wrong = M::V::zeros(2, s.problem().context().clone());
833 assert!(s.interpolate_dy_inplace(t1, &mut dy_wrong).is_err());
834
835 assert!(s.interpolate_dy(t1 + M::T::from_f64(1e6).unwrap()).is_err());
837
838 let eps = dt.abs() * M::T::from_f64(1e-5).unwrap();
840 let y_plus = s.interpolate(tmid + eps).unwrap();
841 let y_minus = s.interpolate(tmid - eps).unwrap();
842 let fd_dy = (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps));
843 let dy = s.interpolate_dy(tmid).unwrap();
844 dy.assert_eq_norm(
845 &fd_dy,
846 &s.problem().atol,
847 s.problem().rtol,
848 M::T::from_f64(1e3).unwrap(),
849 );
850
851 let t1 = s.state().t;
853 s.step().unwrap();
854 let t2 = s.state().t;
855 let dt2 = t2 - t1;
856 let tmid2 = t1 + dt2 / M::T::from_f64(2.0).unwrap();
857 let eps2 = dt2.abs() * M::T::from_f64(1e-5).unwrap();
858 let y_plus = s.interpolate(tmid2 + eps2).unwrap();
859 let y_minus = s.interpolate(tmid2 - eps2).unwrap();
860 let fd_dy2 =
861 (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps2));
862 let dy2 = s.interpolate_dy(tmid2).unwrap();
863 dy2.assert_eq_norm(
864 &fd_dy2,
865 &s.problem().atol,
866 s.problem().rtol,
867 M::T::from_f64(1e3).unwrap(),
868 );
869 }
870
871 pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
872 mut s: Method,
873 ) {
874 *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
875 assert_eq!(
876 *s.config().as_base_ref().minimum_timestep,
877 Eqn::T::from_f64(1.0e8).unwrap()
878 );
879 *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
881
882 let mut failed = false;
883 for _ in 0..10 {
884 if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
885 s.step()
886 {
887 failed = true;
888 break;
889 }
890 }
891 assert!(failed);
892 }
893
894 pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
895 let state = s.checkpoint();
896 let state2 = s.state();
897 state2
898 .y
899 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
900 s.state_mut()
901 .y
902 .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
903 assert_eq!(
904 s.state_mut().y.get_index(0),
905 M::T::from_f64(std::f64::consts::PI).unwrap()
906 );
907 }
908
909 #[cfg(feature = "diffsl-cranelift")]
910 pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
911 ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
912 crate::OdeBuilder::<M>::new()
913 .build_from_diffsl(
914 "
915 g { 9.81 } h { 10.0 }
916 u_i {
917 x = h,
918 v = 0,
919 }
920 F_i {
921 v,
922 -g,
923 }
924 stop {
925 x,
926 }
927 ",
928 )
929 .unwrap()
930 }
931
932 #[cfg(feature = "diffsl-cranelift")]
933 pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
934 where
935 M: crate::MatrixHost<T = f64>,
936 M: DefaultSolver<T = f64>,
937 M::V: DefaultDenseMatrix<T = f64>,
938 Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
939 {
940 let e = 0.8;
941
942 let final_time = 2.5;
943
944 solver.set_stop_time(final_time).unwrap();
946 loop {
947 match solver.step() {
948 Ok(OdeSolverStopReason::InternalTimestep) => (),
949 Ok(OdeSolverStopReason::RootFound(t, _)) => {
950 let mut y = solver.interpolate(t).unwrap();
952
953 y.set_index(1, y.get_index(1) * -e);
955
956 y.set_index(0, y.get_index(0).max(f64::EPSILON));
958
959 solver.state_mut().y.copy_from(&y);
961 solver.state_mut().dy.set_index(0, y.get_index(1));
962 *solver.state_mut().t = t;
963
964 break;
965 }
966 Ok(OdeSolverStopReason::TstopReached) => break,
967 Err(_) => panic!("unexpected solver error"),
968 }
969 }
970 let mut x = vec![];
972 let mut v = vec![];
973 let mut t = vec![];
974 for _ in 0..3 {
975 let ret = solver.step();
976 x.push(solver.state().y.get_index(0));
977 v.push(solver.state().y.get_index(1));
978 t.push(solver.state().t);
979 match ret {
980 Ok(OdeSolverStopReason::InternalTimestep) => (),
981 Ok(OdeSolverStopReason::RootFound(_, _)) => {
982 panic!("should be an internal timestep but found a root")
983 }
984 Ok(OdeSolverStopReason::TstopReached) => break,
985 _ => panic!("should be an internal timestep"),
986 }
987 }
988 (x, v, t)
989 }
990
991 pub fn test_checkpointing<'a, M, Method, Eqn>(
992 soln: OdeSolverSolution<M::V>,
993 mut solver1: Method,
994 mut solver2: Method,
995 ) where
996 M: Matrix + DefaultSolver,
997 Method: OdeSolverMethod<'a, Eqn>,
998 Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
999 {
1000 let half_i = soln.solution_points.len() / 2;
1001 let half_t = soln.solution_points[half_i].t;
1002 while solver1.state().t <= half_t {
1003 solver1.step().unwrap();
1004 }
1005 let checkpoint = solver1.checkpoint();
1006 let checkpoint_t = checkpoint.as_ref().t;
1007 solver2.set_state(checkpoint);
1008
1009 for point in soln.solution_points.iter().skip(half_i + 1) {
1011 if point.t < checkpoint_t {
1013 continue;
1014 }
1015 while solver2.state().t < point.t {
1016 solver1.step().unwrap();
1017 solver2.step().unwrap();
1018 let time_error = (solver1.state().t - solver2.state().t).abs()
1019 / (solver1.state().t.abs() * solver1.problem().rtol
1020 + solver1.problem().atol.get_index(0));
1021 assert!(
1022 time_error < M::T::from_f64(20.0).unwrap(),
1023 "time_error: {} at t = {}",
1024 time_error,
1025 solver1.state().t
1026 );
1027 solver1.state().y.assert_eq_norm(
1028 solver2.state().y,
1029 &solver1.problem().atol,
1030 solver1.problem().rtol,
1031 M::T::from_f64(20.0).unwrap(),
1032 );
1033 }
1034 let soln = solver1.interpolate(point.t).unwrap();
1035 soln.assert_eq_norm(
1036 &point.state,
1037 &solver1.problem().atol,
1038 solver1.problem().rtol,
1039 M::T::from_f64(15.0).unwrap(),
1040 );
1041 let soln = solver2.interpolate(point.t).unwrap();
1042 soln.assert_eq_norm(
1043 &point.state,
1044 &solver1.problem().atol,
1045 solver1.problem().rtol,
1046 M::T::from_f64(15.0).unwrap(),
1047 );
1048 }
1049 }
1050
1051 pub fn test_state_mut_on_problem<'a, Eqn, Method>(
1052 mut s: Method,
1053 soln: OdeSolverSolution<Eqn::V>,
1054 ) where
1055 Eqn: OdeEquationsImplicit + 'a,
1056 Eqn::M: DefaultSolver,
1057 Method: OdeSolverMethod<'a, Eqn>,
1058 Eqn::V: DefaultDenseMatrix,
1059 {
1060 let state = s.checkpoint();
1062 s.solve(Eqn::T::one()).unwrap();
1063
1064 s.state_mut().y.copy_from(state.as_ref().y);
1066 s.state_mut().dy.copy_from(state.as_ref().dy);
1067 *s.state_mut().t = state.as_ref().t;
1068
1069 for point in soln.solution_points.iter() {
1071 while s.state().t < point.t {
1072 s.step().unwrap();
1073 }
1074 let soln = s.interpolate(point.t).unwrap();
1075 let error = soln.clone() - &point.state;
1076 let error_norm = error
1077 .squared_norm(&error, &s.problem().atol, s.problem().rtol)
1078 .sqrt();
1079 assert!(
1080 error_norm < Eqn::T::from_f64(19.0).unwrap(),
1081 "error_norm: {} at t = {}",
1082 error_norm,
1083 point.t
1084 );
1085 }
1086 }
1087
1088 pub fn test_root_found_index<'a, Eqn, Method>(
1097 mut solver: Method,
1098 soln: &OdeSolverSolution<Eqn::V>,
1099 expected_root_index: usize,
1100 tol: Eqn::T,
1101 ) where
1102 Eqn: OdeEquations + 'a,
1103 Method: OdeSolverMethod<'a, Eqn>,
1104 {
1105 let t_root_expected = soln.solution_points[0].t;
1106 solver
1107 .set_stop_time(Eqn::T::from_f64(100.0).unwrap())
1108 .unwrap();
1109 loop {
1110 match solver.step().unwrap() {
1111 OdeSolverStopReason::RootFound(t, index) => {
1113 assert_eq!(
1114 index, expected_root_index,
1115 "expected root index {expected_root_index} but got {index}",
1116 );
1117 assert!(
1118 (t - t_root_expected).abs() < tol,
1119 "expected t ≈ {t_root_expected:?}, got {t:?}",
1120 );
1121 break;
1122 }
1123 OdeSolverStopReason::TstopReached => {
1124 panic!("reached tstop without finding a root")
1125 }
1126 OdeSolverStopReason::InternalTimestep => {}
1127 }
1128 }
1129 }
1130
1131 pub fn test_solve_with_reset<'a, Eqn, Method>(
1134 mut solver: Method,
1135 soln: &OdeSolverSolution<Eqn::V>,
1136 ) where
1137 Eqn: OdeEquationsImplicit + 'a,
1138 Eqn::M: DefaultSolver,
1139 Eqn::V: DefaultDenseMatrix,
1140 Method: OdeSolverMethod<'a, Eqn>,
1141 {
1142 let final_time = Eqn::T::from_f64(100.0).unwrap();
1143 let (ys, ts, stop_reason) = solver.solve(final_time).unwrap();
1144 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1145 let t_last = *ts.last().unwrap();
1146 let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1147 assert!(
1148 (t_last - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1149 "expected solve() to reach final_time ≈ {:?}, got {:?}",
1150 final_time,
1151 t_last,
1152 );
1153 assert!(
1154 (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1155 "expected solver state at final_time ≈ {:?}, got {:?}",
1156 final_time,
1157 solver.state().t,
1158 );
1159
1160 let expected = &soln.solution_points[0];
1161 let root_time_tol = soln.rtol * expected.t.abs() + soln.atol.get_index(0);
1162 let root_col = ts
1163 .iter()
1164 .position(|&t| (t - expected.t).abs() < Eqn::T::from_f64(30.0).unwrap() * root_time_tol)
1165 .expect("expected solve() output to include the second-root/reset time");
1166 let root_expected = Eqn::V::from_element(
1167 expected.state.len(),
1168 Eqn::T::from_f64(0.4).unwrap(),
1169 expected.state.context().clone(),
1170 );
1171 let root_state = ys.column(root_col).into_owned();
1172 let root_error = root_state - &root_expected;
1173 let root_error_norm = root_error
1174 .squared_norm(&root_expected, &soln.atol, soln.rtol)
1175 .sqrt();
1176 let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1177 assert!(
1178 root_error_norm < error_threshold,
1179 "expected reset state y=0.4 at second-root time; WRMS error norm {root_error_norm:?} ≥ {error_threshold:?}",
1180 );
1181
1182 let reset_value = Eqn::T::from_f64(0.4).unwrap();
1183 let reset_tol = Eqn::T::from_f64(30.0).unwrap()
1184 * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1185 let last_reset_col = (0..ts.len())
1186 .rev()
1187 .find(|&i| (ys.get_index(0, i) - reset_value).abs() < reset_tol)
1188 .expect("expected solve() output to include at least one reset state");
1189 let final_time_f64 = final_time.to_f64().unwrap();
1190 let last_reset_time_f64 = ts[last_reset_col].to_f64().unwrap();
1191 let expected_final_value =
1192 Eqn::T::from_f64(0.4 * (-0.1 * (final_time_f64 - last_reset_time_f64)).exp()).unwrap();
1193 let expected_final = Eqn::V::from_element(
1194 expected.state.len(),
1195 expected_final_value,
1196 expected.state.context().clone(),
1197 );
1198 let final_state = ys.column(ts.len() - 1).into_owned();
1199 let final_error = final_state - &expected_final;
1200 let final_error_norm = final_error
1201 .squared_norm(&expected_final, &soln.atol, soln.rtol)
1202 .sqrt();
1203 assert!(
1204 final_error_norm < error_threshold,
1205 "final state mismatch after automatic reset continuation: WRMS error norm {final_error_norm:?} ≥ {error_threshold:?}",
1206 );
1207 }
1208
1209 pub fn test_solve_dense_with_reset<'a, Eqn, Method>(
1212 mut solver: Method,
1213 soln: &OdeSolverSolution<Eqn::V>,
1214 ) where
1215 Eqn: OdeEquationsImplicit + 'a,
1216 Eqn::M: DefaultSolver,
1217 Eqn::V: DefaultDenseMatrix,
1218 Method: OdeSolverMethod<'a, Eqn>,
1219 {
1220 let t_stop = soln.solution_points[0].t;
1221 let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1222 let mut probe_solver = solver.clone();
1223 let (probe_ys, probe_ts, probe_stop_reason) = probe_solver.solve(final_time).unwrap();
1224 assert_eq!(probe_stop_reason, OdeSolverStopReason::TstopReached);
1225
1226 let reset_time_tol =
1227 Eqn::T::from_f64(30.0).unwrap() * (soln.rtol * t_stop.abs() + soln.atol.get_index(0));
1228 let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1229 let reset_value = Eqn::T::from_f64(0.4).unwrap();
1230 let reset_value_tol = Eqn::T::from_f64(30.0).unwrap()
1231 * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1232 let reset_col = (0..probe_ts.len())
1233 .find(|&i| {
1234 (probe_ts[i] - t_stop).abs() < reset_time_tol
1235 && (probe_ys.get_index(0, i) - reset_value).abs() < reset_value_tol
1236 })
1237 .expect("expected solve() probe output to contain the second-root reset state");
1238 let t_event = probe_ts[reset_col];
1239 let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, final_time];
1240
1241 let (ret, stop_reason) = solver.solve_dense(&t_eval).unwrap();
1242 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1243 assert!(
1244 ret.ncols() == t_eval.len(),
1245 "expected solve_dense() to fill all requested evaluation times"
1246 );
1247 let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1248 assert!(
1249 (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1250 "expected solver state at final_time ≈ {:?}, got {:?}",
1251 final_time,
1252 solver.state().t,
1253 );
1254
1255 let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1256 let pre_reset_state = ret.column(1).into_owned();
1257 let pre_reset_error = pre_reset_state - &soln.solution_points[0].state;
1258 let pre_reset_error_norm = pre_reset_error
1259 .squared_norm(&soln.solution_points[0].state, &soln.atol, soln.rtol)
1260 .sqrt();
1261 assert!(
1262 pre_reset_error_norm < error_threshold,
1263 "expected pre-reset state at event time; WRMS norm {pre_reset_error_norm:?} >= {error_threshold:?}",
1264 );
1265
1266 let expected_post_reset_value =
1267 reset_value * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1268 let expected_post_reset = Eqn::V::from_element(
1269 soln.solution_points[0].state.len(),
1270 expected_post_reset_value,
1271 soln.solution_points[0].state.context().clone(),
1272 );
1273 let post_reset_state = ret.column(2).into_owned();
1274 let post_reset_error = post_reset_state - &expected_post_reset;
1275 let post_reset_error_norm = post_reset_error
1276 .squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1277 .sqrt();
1278 assert!(
1279 post_reset_error_norm < error_threshold,
1280 "expected reset state just after event time; WRMS norm {post_reset_error_norm:?} >= {error_threshold:?}",
1281 );
1282 }
1283
1284 pub fn test_solve_dense_sensitivities_with_reset<'a, Eqn, Method>(
1287 mut solver: Method,
1288 soln: &OdeSolverSolution<Eqn::V>,
1289 ) where
1290 Eqn: OdeEquationsImplicitSens + 'a,
1291 Eqn::V: DefaultDenseMatrix,
1292 Eqn::M: DefaultSolver,
1293 Method: SensitivitiesOdeSolverMethod<'a, Eqn>,
1294 {
1295 let t_stop = soln.solution_points[0].t;
1296 let t_event = Eqn::T::from_f64(10.0 * (5.0_f64 / 3.0_f64).ln()).unwrap();
1297
1298 let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1299 let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, t_stop];
1300 let (ret, ret_sens, stop_reason) = solver.solve_dense_sensitivities(&t_eval).unwrap();
1301 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1302 assert_eq!(ret.ncols(), t_eval.len());
1303 for ret_sens_j in &ret_sens {
1304 assert_eq!(ret_sens_j.ncols(), t_eval.len());
1305 }
1306
1307 let error_threshold = Eqn::T::from_f64(100.0).unwrap();
1308 let ctx = soln.solution_points[0].state.context().clone();
1309 let nstates = soln.solution_points[0].state.len();
1310
1311 let post_reset_y = Eqn::T::from_f64(2.6).unwrap()
1312 * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1313 let post_reset_t = t_event + post_event_dt;
1314 let expected_post_reset = Eqn::V::from_element(nstates, post_reset_y, ctx.clone());
1315 let expected_post_reset_sk =
1316 Eqn::V::from_element(nstates, -post_reset_y * post_reset_t, ctx.clone());
1317 let expected_post_reset_sy0 = Eqn::V::from_element(nstates, post_reset_y, ctx);
1318
1319 let col = 2;
1320 let ey = ret.column(col).into_owned() - &expected_post_reset;
1321 let esk = ret_sens[0].column(col).into_owned() - &expected_post_reset_sk;
1322 let esy0 = ret_sens[1].column(col).into_owned() - &expected_post_reset_sy0;
1323 let norm = (ey.squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1324 + esk.squared_norm(&expected_post_reset_sk, &soln.atol, soln.rtol)
1325 + esy0.squared_norm(&expected_post_reset_sy0, &soln.atol, soln.rtol))
1326 .sqrt();
1327 assert!(
1328 norm < error_threshold,
1329 "dense sensitivity mismatch just after reset; combined WRMS {norm:?} >= {error_threshold:?}",
1330 );
1331 }
1332
1333 pub fn test_solve_adjoint_with_single_reset_root<
1334 'a,
1335 Eqn,
1336 MethodF,
1337 MethodB,
1338 BuildForward,
1339 BuildAdjointState,
1340 BuildAdjointFromState,
1341 >(
1342 build_forward: BuildForward,
1343 soln: &OdeSolverSolution<Eqn::V>,
1344 build_adjoint_state: BuildAdjointState,
1345 build_adjoint_from_state: BuildAdjointFromState,
1346 use_replay_solver: bool,
1347 ) where
1348 Eqn: OdeEquationsImplicitAdjoint + 'a,
1349 Eqn::M: DefaultSolver,
1350 Eqn::V: DefaultDenseMatrix,
1351 MethodF: OdeSolverMethod<'a, Eqn>,
1352 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1353 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1354 BuildAdjointState:
1355 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1356 BuildAdjointFromState:
1357 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1358 {
1359 let expected_out = &soln.solution_points[0];
1360 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1361
1362 let mut forward_solver = build_forward(None).unwrap();
1363 let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1364 .solve_with_checkpointing(forward_stop_time, None)
1365 .unwrap();
1366 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1367 assert!(
1368 checkpointers.len() >= 3,
1369 "expected checkpointing path to include the two reset events"
1370 );
1371 let problem = forward_solver.problem();
1372 let post_reset_solver = forward_solver.clone();
1373 let post_reset_root_idx = checkpointers[1]
1374 .terminal_reset_root_idx()
1375 .expect("second reset segment should record its terminal root index");
1376 let final_forward_state = checkpointers[1].last_checkpoint().clone();
1377 let t_second_root = final_forward_state.as_ref().t;
1378
1379 let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1380 let out_norm = out_error
1381 .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1382 .sqrt();
1383 assert!(
1384 out_norm < Eqn::T::from_f64(50.0).unwrap(),
1385 "forward integrated output mismatch at second root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1386 final_forward_state.as_ref().g,
1387 expected_out.state,
1388 );
1389 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1390 assert!(
1391 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1392 "expected second root time ≈ {:?}, got {:?}",
1393 expected_out.t,
1394 t_second_root,
1395 );
1396
1397 let adjoint_checkpointers = checkpointers.into_iter().take(2).collect::<Vec<_>>();
1398
1399 let mut missing_metadata_checkpointers = adjoint_checkpointers.clone();
1401 missing_metadata_checkpointers[0].clear_terminal_reset_root_idx();
1402 let missing_metadata_solver = use_replay_solver.then(|| post_reset_solver.clone());
1403 let mut missing_metadata_adjoint_eqn = problem.adjoint_equations(
1404 missing_metadata_checkpointers,
1405 missing_metadata_solver,
1406 None,
1407 );
1408 let mut missing_metadata_adjoint_state =
1409 build_adjoint_state(&mut missing_metadata_adjoint_eqn).unwrap();
1410 missing_metadata_adjoint_state
1411 .as_mut()
1412 .state_mut_adjoint_terminal_root(
1413 &problem.eqn,
1414 post_reset_root_idx,
1415 &final_forward_state,
1416 problem.integrate_out,
1417 )
1418 .unwrap();
1419 let missing_metadata_adjoint =
1420 build_adjoint_from_state(missing_metadata_adjoint_state, missing_metadata_adjoint_eqn)
1421 .unwrap();
1422 let missing_metadata_err =
1423 match missing_metadata_adjoint.solve_adjoint_backwards_pass(&[], &[]) {
1424 Ok(_) => panic!("expected missing reset metadata error"),
1425 Err(err) => err,
1426 };
1427 assert!(
1428 format!("{missing_metadata_err:?}").contains("Missing reset root metadata"),
1429 "expected missing reset metadata error, got {missing_metadata_err:?}",
1430 );
1431
1432 let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1434 let mut adjoint_eqn =
1435 problem.adjoint_equations(adjoint_checkpointers, adjoint_solver, None);
1436 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1437 adjoint_state
1438 .as_mut()
1439 .state_mut_adjoint_terminal_root(
1440 &problem.eqn,
1441 post_reset_root_idx,
1442 &final_forward_state,
1443 problem.integrate_out,
1444 )
1445 .unwrap();
1446 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1447 let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1448
1449 let t0 = problem.t0;
1450 let ctx = problem.context().clone();
1451
1452 let sens_points = soln.sens_solution_points.as_ref().unwrap();
1453 let expected_grad = Eqn::V::from_vec(
1454 sens_points
1455 .iter()
1456 .map(|pts| pts[0].state.get_index(0))
1457 .collect(),
1458 ctx.clone(),
1459 );
1460 let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1461 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1462 assert!(
1463 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1464 "expected adjoint final time {:?}, got {:?}",
1465 t0,
1466 adjoint_state.as_ref().t,
1467 );
1468 adjoint_state.as_ref().sg[0].assert_eq_norm(
1469 &expected_grad,
1470 &atol,
1471 Eqn::T::from_f64(1e-6).unwrap(),
1472 Eqn::T::from_f64(60.0).unwrap(),
1473 );
1474 }
1475
1476 #[allow(clippy::too_many_arguments)]
1477 pub fn test_solve_adjoint_sum_squares_with_single_reset_root<
1478 'a,
1479 Eqn,
1480 MethodF,
1481 MethodB,
1482 BuildForward,
1483 BuildAdjointState,
1484 BuildAdjointFromState,
1485 >(
1486 build_forward: BuildForward,
1487 soln: &OdeSolverSolution<Eqn::V>,
1488 build_adjoint_state: BuildAdjointState,
1489 build_adjoint_from_state: BuildAdjointFromState,
1490 use_replay_solver: bool,
1491 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1492 data: <Eqn::V as DefaultDenseMatrix>::M,
1493 times: &[Eqn::T],
1494 ) where
1495 Eqn: OdeEquationsImplicitAdjoint + 'a,
1496 Eqn::M: DefaultSolver,
1497 Eqn::V: DefaultDenseMatrix,
1498 MethodF: OdeSolverMethod<'a, Eqn>,
1499 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1500 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1501 BuildAdjointState:
1502 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1503 BuildAdjointFromState:
1504 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1505 {
1506 let expected_out = &soln.solution_points[0];
1507 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1508 let forwards_soln =
1509 solve_dense_with_single_reset_root::<Eqn, MethodF, _>(&build_forward, times);
1510 assert_eq!(
1511 forwards_soln.ncols(),
1512 times.len(),
1513 "expected stitched forward samples to cover every requested observation time",
1514 );
1515 let dgdu = dsum_squaresdp(&forwards_soln, &data);
1516 let dgdu_refs = dgdu.iter().collect::<Vec<_>>();
1517
1518 let mut forward_solver = build_forward(None).unwrap();
1519 let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1520 .solve_with_checkpointing(forward_stop_time, None)
1521 .unwrap();
1522 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1523 assert!(
1524 checkpointers.len() >= 3,
1525 "expected checkpointing path to include the two reset events"
1526 );
1527 let problem = forward_solver.problem();
1528 let post_reset_solver = forward_solver.clone();
1529 let post_reset_root_idx = checkpointers[1]
1530 .terminal_reset_root_idx()
1531 .expect("second reset segment should record its terminal root index");
1532 let final_forward_state = checkpointers[1].last_checkpoint().clone();
1533 let t_second_root = final_forward_state.as_ref().t;
1534
1535 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1536 assert!(
1537 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1538 "expected second root time ≈ {:?}, got {:?}",
1539 expected_out.t,
1540 t_second_root,
1541 );
1542
1543 let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1544 let mut adjoint_eqn = problem.adjoint_equations(
1545 checkpointers.into_iter().take(2).collect(),
1546 adjoint_solver,
1547 Some(dgdu.len()),
1548 );
1549 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1550 adjoint_state
1551 .as_mut()
1552 .state_mut_adjoint_terminal_root(
1553 &problem.eqn,
1554 post_reset_root_idx,
1555 &final_forward_state,
1556 problem.integrate_out,
1557 )
1558 .unwrap();
1559 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1560 let (adjoint_state, _) = adjoint
1561 .solve_adjoint_backwards_pass(times, dgdu_refs.as_slice())
1562 .unwrap();
1563
1564 let t0 = problem.t0;
1565 let ctx = problem.context().clone();
1566
1567 let nparams = dgdp_check.nrows();
1568 let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1569 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1570 assert!(
1571 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1572 "expected adjoint final time {:?}, got {:?}",
1573 t0,
1574 adjoint_state.as_ref().t,
1575 );
1576 #[allow(clippy::needless_range_loop)]
1577 for j in 0..dgdp_check.ncols() {
1578 adjoint_state.as_ref().sg[j].assert_eq_norm(
1579 &dgdp_check.column(j).into_owned(),
1580 &atol,
1581 Eqn::T::from_f64(1e-6).unwrap(),
1582 Eqn::T::from_f64(260.0).unwrap(),
1583 );
1584 }
1585 }
1586
1587 pub fn test_solve_soln_adjoint_with_single_reset_root<
1588 'a,
1589 Eqn,
1590 MethodF,
1591 MethodB,
1592 BuildForward,
1593 BuildAdjointState,
1594 BuildAdjointFromState,
1595 >(
1596 build_forward: BuildForward,
1597 soln: &OdeSolverSolution<Eqn::V>,
1598 build_adjoint_state: BuildAdjointState,
1599 build_adjoint_from_state: BuildAdjointFromState,
1600 use_replay_solver: bool,
1601 ) where
1602 Eqn: OdeEquationsImplicitAdjoint + 'a,
1603 Eqn::M: DefaultSolver,
1604 Eqn::V: DefaultDenseMatrix,
1605 MethodF: OdeSolverMethod<'a, Eqn>,
1606 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1607 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1608 BuildAdjointState:
1609 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1610 BuildAdjointFromState:
1611 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1612 {
1613 let expected_out = &soln.solution_points[0];
1614 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1615 let mut forward_soln = Solution::<Eqn::V>::new(forward_stop_time);
1616 let mut checkpointers = Vec::new();
1617
1618 let first_forward_solver = build_forward(None)
1619 .unwrap()
1620 .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1621 .unwrap();
1622 let first_root_idx = match forward_soln.stop_reason {
1623 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1624 Some(reason) => {
1625 panic!("expected first staged solve to stop at reset root, got {reason:?}")
1626 }
1627 None => panic!("first staged solve did not set a stop reason"),
1628 };
1629 assert_eq!(checkpointers.len(), 1);
1630 assert_eq!(
1631 checkpointers[0].terminal_reset_root_idx(),
1632 Some(first_root_idx)
1633 );
1634
1635 let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1636 let terminal_forward_solver = build_forward(Some(state_after_reset))
1637 .unwrap()
1638 .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1639 .unwrap();
1640 let terminal_root_idx = match forward_soln.stop_reason {
1641 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1642 Some(reason) => {
1643 panic!("expected second staged solve to stop at terminal root, got {reason:?}")
1644 }
1645 None => panic!("second staged solve did not set a stop reason"),
1646 };
1647 assert_eq!(checkpointers.len(), 2);
1648 assert_eq!(
1649 checkpointers[1].terminal_reset_root_idx(),
1650 Some(terminal_root_idx)
1651 );
1652
1653 let problem = terminal_forward_solver.problem();
1654 let final_forward_state = terminal_forward_solver.state_clone();
1655 let t_second_root = final_forward_state.as_ref().t;
1656 let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1657 let out_norm = out_error
1658 .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1659 .sqrt();
1660 assert!(
1661 out_norm < Eqn::T::from_f64(50.0).unwrap(),
1662 "forward integrated output mismatch at terminal root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1663 final_forward_state.as_ref().g,
1664 expected_out.state,
1665 );
1666 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1667 assert!(
1668 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1669 "expected terminal root time ≈ {:?}, got {:?}",
1670 expected_out.t,
1671 t_second_root,
1672 );
1673
1674 let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1675 let mut adjoint_eqn = problem.adjoint_equations(checkpointers, adjoint_solver, None);
1676 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1677 adjoint_state
1678 .as_mut()
1679 .state_mut_adjoint_terminal_root(
1680 &problem.eqn,
1681 terminal_root_idx,
1682 &final_forward_state,
1683 problem.integrate_out,
1684 )
1685 .unwrap();
1686 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1687 let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1688
1689 let t0 = problem.t0;
1690 let ctx = problem.context().clone();
1691 let sens_points = soln.sens_solution_points.as_ref().unwrap();
1692 let expected_grad = Eqn::V::from_vec(
1693 sens_points
1694 .iter()
1695 .map(|pts| pts[0].state.get_index(0))
1696 .collect(),
1697 ctx.clone(),
1698 );
1699 let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1700 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1701 assert!(
1702 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1703 "expected adjoint final time {:?}, got {:?}",
1704 t0,
1705 adjoint_state.as_ref().t,
1706 );
1707 adjoint_state.as_ref().sg[0].assert_eq_norm(
1708 &expected_grad,
1709 &atol,
1710 Eqn::T::from_f64(1e-6).unwrap(),
1711 Eqn::T::from_f64(60.0).unwrap(),
1712 );
1713 }
1714
1715 #[allow(clippy::too_many_arguments)]
1716 pub fn test_solve_soln_adjoint_sum_squares_with_single_reset_root<
1717 'a,
1718 Eqn,
1719 MethodF,
1720 MethodB,
1721 BuildForward,
1722 BuildAdjointState,
1723 BuildAdjointFromState,
1724 >(
1725 build_forward: BuildForward,
1726 soln: &OdeSolverSolution<Eqn::V>,
1727 build_adjoint_state: BuildAdjointState,
1728 build_adjoint_from_state: BuildAdjointFromState,
1729 use_replay_solver: bool,
1730 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1731 data: <Eqn::V as DefaultDenseMatrix>::M,
1732 times: &[Eqn::T],
1733 ) where
1734 Eqn: OdeEquationsImplicitAdjoint + 'a,
1735 Eqn::M: DefaultSolver,
1736 Eqn::V: DefaultDenseMatrix,
1737 MethodF: OdeSolverMethod<'a, Eqn>,
1738 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1739 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1740 BuildAdjointState:
1741 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1742 BuildAdjointFromState:
1743 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1744 {
1745 let expected_out = &soln.solution_points[0];
1746 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1747 let mut forward_soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
1748 let mut checkpointers = Vec::new();
1749
1750 let first_forward_solver = build_forward(None)
1751 .unwrap()
1752 .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1753 .unwrap();
1754 let first_root_idx = match forward_soln.stop_reason {
1755 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1756 Some(reason) => {
1757 panic!("expected first staged solve to stop at reset root, got {reason:?}")
1758 }
1759 None => panic!("first staged solve did not set a stop reason"),
1760 };
1761 assert_eq!(checkpointers.len(), 1);
1762 assert_eq!(
1763 checkpointers[0].terminal_reset_root_idx(),
1764 Some(first_root_idx)
1765 );
1766
1767 let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1768 build_forward(Some(state_after_reset.clone()))
1769 .unwrap()
1770 .solve_soln(&mut forward_soln)
1771 .unwrap();
1772 assert!(forward_soln.is_complete());
1773 assert_eq!(
1774 forward_soln.stop_reason,
1775 Some(OdeSolverStopReason::TstopReached)
1776 );
1777
1778 let mut terminal_soln = Solution::<Eqn::V>::new(forward_stop_time);
1779 let terminal_forward_solver = build_forward(Some(state_after_reset))
1780 .unwrap()
1781 .solve_soln_with_checkpointing(&mut terminal_soln, &mut checkpointers, None)
1782 .unwrap();
1783 let terminal_root_idx = match terminal_soln.stop_reason {
1784 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1785 Some(reason) => {
1786 panic!("expected terminal staged solve to stop at root, got {reason:?}")
1787 }
1788 None => panic!("terminal staged solve did not set a stop reason"),
1789 };
1790 assert_eq!(checkpointers.len(), 2);
1791 assert_eq!(
1792 checkpointers.last().unwrap().terminal_reset_root_idx(),
1793 Some(terminal_root_idx)
1794 );
1795
1796 let dgdu_eval = dsum_squaresdp(&forward_soln.ys, &data);
1797 let dgdu_eval_refs = dgdu_eval.iter().collect::<Vec<_>>();
1798 let problem = terminal_forward_solver.problem();
1799 let final_forward_state = terminal_forward_solver.state_clone();
1800 let t_second_root = final_forward_state.as_ref().t;
1801 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1802 assert!(
1803 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1804 "expected terminal root time ≈ {:?}, got {:?}",
1805 expected_out.t,
1806 t_second_root,
1807 );
1808
1809 let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1810 let mut adjoint_eqn =
1811 problem.adjoint_equations(checkpointers, adjoint_solver, Some(dgdu_eval_refs.len()));
1812 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1813 adjoint_state
1814 .as_mut()
1815 .state_mut_adjoint_terminal_root(
1816 &problem.eqn,
1817 terminal_root_idx,
1818 &final_forward_state,
1819 problem.integrate_out,
1820 )
1821 .unwrap();
1822 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1823 let (adjoint_state, _) = adjoint
1824 .solve_adjoint_backwards_pass(times, dgdu_eval_refs.as_slice())
1825 .unwrap();
1826
1827 let t0 = problem.t0;
1828 let ctx = problem.context().clone();
1829 let nparams = dgdp_check.nrows();
1830 let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1831 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1832 assert!(
1833 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1834 "expected adjoint final time {:?}, got {:?}",
1835 t0,
1836 adjoint_state.as_ref().t,
1837 );
1838 #[allow(clippy::needless_range_loop)]
1839 for j in 0..dgdp_check.ncols() {
1840 adjoint_state.as_ref().sg[j].assert_eq_norm(
1841 &dgdp_check.column(j).into_owned(),
1842 &atol,
1843 Eqn::T::from_f64(1e-6).unwrap(),
1844 Eqn::T::from_f64(260.0).unwrap(),
1845 );
1846 }
1847 }
1848}