1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod no_checkpointing_solver;
11pub mod problem;
12pub mod runge_kutta;
13pub mod sde;
14pub mod sdirk;
15pub mod sdirk_state;
16pub mod sensitivities;
17pub mod solution;
18pub mod state;
19pub mod tableau;
20
21use serde::Serialize;
22use std::fmt::Display;
23
24use crate::ode_solver::jacobian_update::SolverState;
25
26#[derive(Clone, Debug, Serialize, Default)]
28pub struct OdeSolverStatistics {
29 pub number_of_linear_solver_setups: usize,
31 pub number_of_steps: usize,
33 pub number_of_error_test_failures: usize,
35 pub number_of_nonlinear_solver_iterations: usize,
37 pub number_of_nonlinear_solver_fails: usize,
39 pub number_of_linear_solver_setups_from_checkpoint: usize,
41 pub number_of_linear_solver_setups_from_first_convergence_fail: usize,
43 pub number_of_linear_solver_setups_from_second_convergence_fail: usize,
45 pub number_of_linear_solver_setups_from_error_test_fail: usize,
47 pub number_of_linear_solver_setups_from_step_success: usize,
49}
50
51impl OdeSolverStatistics {
52 pub(crate) fn record_linear_solver_setup(&mut self, cause: SolverState) {
54 self.number_of_linear_solver_setups += 1;
55 match cause {
56 SolverState::Checkpoint => self.number_of_linear_solver_setups_from_checkpoint += 1,
57 SolverState::FirstConvergenceFail => {
58 self.number_of_linear_solver_setups_from_first_convergence_fail += 1
59 }
60 SolverState::SecondConvergenceFail => {
61 self.number_of_linear_solver_setups_from_second_convergence_fail += 1
62 }
63 SolverState::ErrorTestFail => {
64 self.number_of_linear_solver_setups_from_error_test_fail += 1
65 }
66 SolverState::StepSuccess => self.number_of_linear_solver_setups_from_step_success += 1,
67 }
68 }
69}
70
71impl Display for OdeSolverStatistics {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "{}", serde_json::to_string_pretty(self).unwrap())
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::rc::Rc;
80
81 use self::problem::OdeSolverSolution;
82
83 use super::*;
84 use crate::error::{DiffsolError, OdeSolverError};
85 use crate::matrix::Matrix;
86 use crate::ode_solver::sensitivities::SensitivitiesOdeSolverMethod;
87 use crate::ode_solver::solution::Solution;
88 use crate::op::unit::UnitCallable;
89 use crate::op::ParameterisedOp;
90 use crate::Scalar;
91 use crate::{
92 op::OpStatistics, AdjointEquations, AdjointOdeSolverMethod, Context, DenseMatrix,
93 MatrixCommon, MatrixRef, NonLinearOp, NonLinearOpJacobian, OdeEquations,
94 OdeEquationsImplicit, OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens,
95 OdeEquationsRef, OdeSolverConfig, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
96 OdeSolverStopReason, Scale, VectorRef, VectorView, VectorViewMut,
97 };
98 use crate::{
99 ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver,
100 NonLinearOpSens, Op, Vector,
101 };
102 use num_traits::{FromPrimitive, One, Signed, ToPrimitive, Zero};
103
104 pub fn test_ode_solver<'a, M, Eqn, Method>(
105 method: &mut Method,
106 solution: OdeSolverSolution<M::V>,
107 override_tol: Option<M::T>,
108 use_tstop: bool,
109 solve_for_sensitivities: bool,
110 ) -> Eqn::V
111 where
112 M: Matrix,
113 Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
114 Method: OdeSolverMethod<'a, Eqn>,
115 {
116 let have_root = method.problem().eqn.root().is_some();
117 for (i, point) in solution.solution_points.iter().enumerate() {
118 let (soln, sens_soln) = if use_tstop {
119 match method.set_stop_time(point.t) {
120 Ok(_) => loop {
121 match method.step() {
122 Ok(OdeSolverStopReason::RootFound(_, _)) => {
123 assert!(have_root);
124 return method.state().y.clone();
125 }
126 Ok(OdeSolverStopReason::TstopReached) => {
127 break (method.state().y.clone(), method.state().s.to_vec());
128 }
129 _ => (),
130 }
131 },
132 Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
133 }
134 } else {
135 while method.state().t.abs() < point.t.abs() {
136 if let OdeSolverStopReason::RootFound(t, _) = method.step().unwrap() {
137 assert!(have_root);
138 return method.interpolate(t).unwrap();
139 }
140 }
141 let soln = method.interpolate(point.t).unwrap();
142 let sens_soln = method.interpolate_sens(point.t).unwrap();
143 (soln, sens_soln)
144 };
145 let soln = if let Some(out) = method.problem().eqn.out() {
146 out.call(&soln, point.t)
147 } else {
148 soln
149 };
150 assert_eq!(
151 soln.len(),
152 point.state.len(),
153 "soln.len() != point.state.len()"
154 );
155 if let Some(override_tol) = override_tol {
156 soln.assert_eq_st(&point.state, override_tol);
157 } else {
158 let (rtol, atol) = if method.problem().eqn.out().is_some() {
159 (solution.rtol, &solution.atol)
161 } else {
162 (method.problem().rtol, &method.problem().atol)
163 };
164 let error = soln.clone() - &point.state;
165 let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
166 assert!(
167 error_norm < M::T::from_f64(20.0).unwrap(),
168 "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
169 error_norm,
170 point.t,
171 soln,
172 point.state
173 );
174 if solve_for_sensitivities {
175 if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
176 for (j, sens_points) in sens_soln_points.iter().enumerate() {
177 let sens_point = &sens_points[i];
178 let sens_soln = &sens_soln[j];
179 let error = sens_soln.clone() - &sens_point.state;
180 let error_norm =
181 error.squared_norm(&sens_point.state, atol, rtol).sqrt();
182 assert!(
183 error_norm < M::T::from_f64(29.0).unwrap(),
184 "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
185 point.t,
186 sens_point.state
187 );
188 }
189 }
190 }
191 }
192 }
193 method.state().y.clone()
194 }
195
196 pub fn setup_test_adjoint<'a, LS, Eqn>(
197 problem: &'a mut OdeSolverProblem<Eqn>,
198 soln: OdeSolverSolution<Eqn::V>,
199 ) -> <Eqn::V as DefaultDenseMatrix>::M
200 where
201 Eqn: OdeEquationsImplicitAdjoint + 'a,
202 LS: LinearSolver<Eqn::M>,
203 Eqn::M: DefaultSolver,
204 Eqn::V: DefaultDenseMatrix,
205 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
206 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
207 {
208 let nparams = problem.eqn.nparams();
209 let nout = problem.eqn.nout();
210 let ctx = problem.eqn.context();
211 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
212 let final_time = soln.solution_points.last().unwrap().t;
213 let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
214 problem.eqn.get_params(&mut p_0);
215 let nbatch = p_0.context().nbatch();
216 let h_base = Eqn::T::from_f64(1e-6).unwrap();
217 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
218 h.axpy(h_base, &p_0, Eqn::T::one());
219 let p_base = p_0.clone();
220 for i in 0..nparams {
221 for b in 0..nbatch {
222 let base = p_base.get_batch(b).get_index(i);
223 let hb = h.get_batch(b).get_index(i);
224 p_0.get_batch_mut(b).set_index(i, base + hb);
225 }
226 problem.eqn.set_params(&p_0);
227 let g_pos = {
228 let mut s = problem.bdf::<LS>().unwrap();
229 s.solve(final_time).unwrap();
230 s.state().g.clone()
231 };
232
233 for b in 0..nbatch {
234 let base = p_base.get_batch(b).get_index(i);
235 let hb = h.get_batch(b).get_index(i);
236 p_0.get_batch_mut(b).set_index(i, base - hb);
237 }
238 problem.eqn.set_params(&p_0);
239 let g_neg = {
240 let mut s = problem.bdf::<LS>().unwrap();
241 s.solve(final_time).unwrap();
242 s.state().g.clone()
243 };
244 for b in 0..nbatch {
245 let base = p_base.get_batch(b).get_index(i);
246 p_0.get_batch_mut(b).set_index(i, base);
247 }
248
249 let delta_full = g_pos - g_neg;
250 for b in 0..nbatch {
251 let hb = h.get_batch(b).get_index(i);
252 let denom = Eqn::T::from_f64(2.0).unwrap() * hb;
253 for j in 0..nout {
254 let delta_val = delta_full.get_batch(b).get_index(j) / denom;
255 dgdp.set_index(i, b * nout + j, delta_val);
256 }
257 }
258 }
259 problem.eqn.set_params(&p_base);
260 dgdp
261 }
262
263 pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
266 where
267 DM: DenseMatrix,
268 {
269 let nbatch = soln.context().nbatch();
270 let mut ret = DM::V::zeros(2, soln.context().clone());
271 for j in 0..soln.ncols() {
272 let soln_j = soln.column(j);
273 let data_j = data.column(j);
274 let delta = soln_j - data_j;
275 for b in 0..nbatch {
276 let delta_b = delta.get_batch(b).into_owned();
277 let norm2 = delta_b.norm(2);
278 let norm4 = delta_b.norm(4);
279 let cur0 = ret.get_batch(b).get_index(0);
280 let cur1 = ret.get_batch(b).get_index(1);
281 ret.get_batch_mut(b).set_index(0, cur0 + norm2 * norm2);
282 let norm4_sq = norm4 * norm4;
283 ret.get_batch_mut(b)
284 .set_index(1, cur1 + norm4_sq * norm4_sq);
285 }
286 }
287 ret
288 }
289
290 pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
293 where
294 DM: DenseMatrix,
295 {
296 let delta = soln.clone() - data;
297 let mut delta3 = delta.clone();
298 for j in 0..delta3.ncols() {
299 let delta_col = delta.column(j).into_owned();
300
301 let mut delta3_col = delta_col.clone();
302 delta3_col.component_mul_assign(&delta_col);
303 delta3_col.component_mul_assign(&delta_col);
304
305 delta3.column_mut(j).copy_from(&delta3_col);
306 }
307 let ret = vec![
308 delta * Scale(DM::T::from_f64(2.).unwrap()),
309 delta3 * Scale(DM::T::from_f64(4.).unwrap()),
310 ];
311 ret
312 }
313
314 pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
315 problem: &'a mut OdeSolverProblem<Eqn>,
316 times: &[Eqn::T],
317 ) -> (
318 <Eqn::V as DefaultDenseMatrix>::M,
319 <Eqn::V as DefaultDenseMatrix>::M,
320 )
321 where
322 Eqn: OdeEquationsImplicitAdjoint + 'a,
323 LS: LinearSolver<Eqn::M>,
324 Eqn::M: DefaultSolver,
325 Eqn::V: DefaultDenseMatrix,
326 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
327 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
328 {
329 let nparams = problem.eqn.nparams();
330 let nout = 2;
331 let ctx = problem.eqn.context();
332 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
333
334 let mut p_0 = ctx.vector_zeros(nparams);
335 problem.eqn.get_params(&mut p_0);
336 let nbatch = p_0.context().nbatch();
337 let h_base = Eqn::T::from_f64(1e-6).unwrap();
338 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
339 h.axpy(h_base, &p_0, Eqn::T::one());
340 let mut p_data = p_0.clone();
341 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
342 let p_base = p_0.clone();
343
344 problem.eqn.set_params(&p_data);
345 let data = {
346 let mut s = problem.bdf::<LS>().unwrap();
347 s.solve_dense(times).unwrap().0
348 };
349
350 for i in 0..nparams {
351 for b in 0..nbatch {
352 let base = p_base.get_batch(b).get_index(i);
353 let hb = h.get_batch(b).get_index(i);
354 p_0.get_batch_mut(b).set_index(i, base + hb);
355 }
356 problem.eqn.set_params(&p_0);
357 let g_pos = {
358 let mut s = problem.bdf::<LS>().unwrap();
359 let v = s.solve_dense(times).unwrap().0;
360 sum_squares(&v, &data)
361 };
362
363 for b in 0..nbatch {
364 let base = p_base.get_batch(b).get_index(i);
365 let hb = h.get_batch(b).get_index(i);
366 p_0.get_batch_mut(b).set_index(i, base - hb);
367 }
368 problem.eqn.set_params(&p_0);
369 let g_neg = {
370 let mut s = problem.bdf::<LS>().unwrap();
371 let v = s.solve_dense(times).unwrap().0;
372 sum_squares(&v, &data)
373 };
374
375 for b in 0..nbatch {
376 let base = p_base.get_batch(b).get_index(i);
377 p_0.get_batch_mut(b).set_index(i, base);
378 }
379
380 let delta_full = g_pos - g_neg;
381 for b in 0..nbatch {
382 let hb = h.get_batch(b).get_index(i);
383 let denom = Eqn::T::from_f64(2.0).unwrap() * hb;
384 for j in 0..nout {
385 let delta_val = delta_full.get_batch(b).get_index(j) / denom;
386 dgdp.set_index(i, b * nout + j, delta_val);
387 }
388 }
389 }
390 problem.eqn.set_params(&p_base);
391 (dgdp, data)
392 }
393
394 pub fn single_reset_root_discrete_times<T: Scalar>(t_stop: T) -> Vec<T> {
395 let t_root = t_stop / T::from_f64(2.0).unwrap();
396 [0.25, 0.75, 1.25, 1.75]
397 .into_iter()
398 .map(|factor| t_root * T::from_f64(factor).unwrap())
399 .collect()
400 }
401
402 fn solve_dense_with_single_reset_root<'a, Eqn, Method, BuildForward>(
403 build_forward: BuildForward,
404 times: &[Eqn::T],
405 ) -> <Eqn::V as DefaultDenseMatrix>::M
406 where
407 Eqn: OdeEquationsImplicitAdjoint + 'a,
408 Eqn::M: DefaultSolver,
409 Eqn::V: DefaultDenseMatrix,
410 Method: OdeSolverMethod<'a, Eqn>,
411 BuildForward: Fn(Option<Method::State>) -> Result<Method, DiffsolError>,
412 {
413 let mut soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
414 let first_forward_solver = build_forward(None).unwrap().solve_soln(&mut soln).unwrap();
415 match soln.stop_reason {
416 Some(OdeSolverStopReason::RootFound(_, 0)) => {}
417 Some(OdeSolverStopReason::RootFound(_, idx)) => {
418 panic!("expected first solve_soln() segment to stop on root 0, got root {idx}")
419 }
420 Some(OdeSolverStopReason::TstopReached) => {
421 panic!("expected first solve_soln() segment to stop on the interior root")
422 }
423 Some(OdeSolverStopReason::InternalTimestep) | None => {
424 panic!("first solve_soln() segment did not finish with a terminal stop reason")
425 }
426 }
427
428 let mut state_after_reset = first_forward_solver.state_clone();
429 {
430 let problem = first_forward_solver.problem();
431 state_after_reset
432 .as_mut()
433 .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
434 .unwrap();
435 }
436
437 build_forward(Some(state_after_reset))
438 .unwrap()
439 .solve_soln(&mut soln)
440 .unwrap();
441 assert!(
442 soln.is_complete(),
443 "expected stitched solve_soln() output to cover all requested observation times",
444 );
445 soln.ys
446 }
447
448 fn state_after_manual_reset<'a, Eqn, Method>(solver: &Method) -> Method::State
449 where
450 Eqn: OdeEquationsImplicitAdjoint + 'a,
451 Eqn::M: DefaultSolver,
452 Method: OdeSolverMethod<'a, Eqn>,
453 {
454 let mut state_after_reset = solver.state_clone();
455 {
456 let problem = solver.problem();
457 state_after_reset
458 .as_mut()
459 .apply_reset_with_mass::<<Eqn::M as DefaultSolver>::LS, _>(problem)
460 .unwrap();
461 }
462 state_after_reset
463 }
464
465 pub fn setup_test_adjoint_sum_squares_with_single_reset_root<'a, LS, Eqn>(
466 problem: &'a mut OdeSolverProblem<Eqn>,
467 times: &[Eqn::T],
468 ) -> (
469 <Eqn::V as DefaultDenseMatrix>::M,
470 <Eqn::V as DefaultDenseMatrix>::M,
471 )
472 where
473 Eqn: OdeEquationsImplicitAdjoint + 'a,
474 LS: LinearSolver<Eqn::M>,
475 Eqn::M: DefaultSolver,
476 Eqn::V: DefaultDenseMatrix,
477 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
478 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
479 {
480 let nparams = problem.eqn.nparams();
481 let nout = 2;
482 let ctx = problem.eqn.context();
483 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
484
485 let mut p_0 = ctx.vector_zeros(nparams);
486 problem.eqn.get_params(&mut p_0);
487 let h_base = Eqn::T::from_f64(1e-10).unwrap();
488 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
489 h.axpy(h_base, &p_0, Eqn::T::one());
490 let mut p_data = p_0.clone();
491 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
492 let p_base = p_0.clone();
493
494 problem.eqn.set_params(&p_data);
495 let data = solve_dense_with_single_reset_root::<Eqn, _, _>(
496 |state| match state {
497 Some(state) => problem.bdf_solver(state),
498 None => problem.bdf::<LS>(),
499 },
500 times,
501 );
502
503 for i in 0..nparams {
504 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
505 problem.eqn.set_params(&p_0);
506 let g_pos = {
507 let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
508 |state| match state {
509 Some(state) => problem.bdf_solver(state),
510 None => problem.bdf::<LS>(),
511 },
512 times,
513 );
514 sum_squares(&v, &data)
515 };
516
517 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
518 problem.eqn.set_params(&p_0);
519 let g_neg = {
520 let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
521 |state| match state {
522 Some(state) => problem.bdf_solver(state),
523 None => problem.bdf::<LS>(),
524 },
525 times,
526 );
527 sum_squares(&v, &data)
528 };
529
530 p_0.set_index(i, p_base.get_index(i));
531
532 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
533 for j in 0..nout {
534 dgdp.set_index(i, j, delta.get_index(j));
535 }
536 }
537 problem.eqn.set_params(&p_base);
538 (dgdp, data)
539 }
540
541 pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
542 backwards_solver: SolverB,
543 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
544 forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
545 data: <Eqn::V as DefaultDenseMatrix>::M,
546 times: &[Eqn::T],
547 ) where
548 SolverF: OdeSolverMethod<'a, Eqn>,
549 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
550 Eqn: OdeEquationsImplicitAdjoint + 'a,
551 Eqn::V: DefaultDenseMatrix,
552 Eqn::M: DefaultSolver,
553 {
554 let nparams = dgdp_check.nrows();
555 let dgdu = dsum_squaresdp(&forwards_soln, &data);
556
557 let atol = Eqn::V::from_element(
558 nparams,
559 Eqn::T::from_f64(1e-6).unwrap(),
560 data.context().clone(),
561 );
562 let rtol = Eqn::T::from_f64(1e-6).unwrap();
563 let (state, _) = backwards_solver
564 .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
565 .unwrap();
566 let gs_adj = state.into_common().sg;
567 #[allow(clippy::needless_range_loop)]
568 for j in 0..dgdp_check.ncols() {
569 gs_adj[j].assert_eq_norm(
570 &dgdp_check.column(j).into_owned(),
571 &atol,
572 rtol,
573 Eqn::T::from_f64(260.).unwrap(),
574 );
575 }
576 }
577
578 pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
579 backwards_solver: SolverB,
580 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
581 ) where
582 SolverF: OdeSolverMethod<'a, Eqn>,
583 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
584 Eqn: OdeEquationsImplicitAdjoint + 'a,
585 Eqn::V: DefaultDenseMatrix,
586 Eqn::M: DefaultSolver,
587 {
588 let nout = backwards_solver.problem().eqn.nout();
589 let atol = Eqn::V::from_element(
590 nout,
591 Eqn::T::from_f64(1e-6).unwrap(),
592 dgdp_check.context().clone(),
593 );
594 let rtol = Eqn::T::from_f64(1e-6).unwrap();
595 let (state, _) = backwards_solver
596 .solve_adjoint_backwards_pass(&[], &[])
597 .unwrap();
598 let gs_adj = state.into_common().sg;
599 #[allow(clippy::needless_range_loop)]
600 for j in 0..dgdp_check.ncols() {
601 gs_adj[j].assert_eq_norm(
602 &dgdp_check.column(j).into_owned(),
603 &atol,
604 rtol,
605 Eqn::T::from_f64(40.).unwrap(),
606 );
607 }
608 }
609
610 pub struct TestEqnInit<M: Matrix> {
611 ctx: M::C,
612 }
613
614 impl<M: Matrix> Op for TestEqnInit<M> {
615 type T = M::T;
616 type V = M::V;
617 type M = M;
618 type C = M::C;
619
620 fn nout(&self) -> usize {
621 1
622 }
623 fn nparams(&self) -> usize {
624 1
625 }
626 fn nstates(&self) -> usize {
627 1
628 }
629 fn context(&self) -> &Self::C {
630 &self.ctx
631 }
632 }
633
634 impl<M: Matrix> ConstantOp for TestEqnInit<M> {
635 fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
636 y.fill(M::T::one());
637 }
638 }
639
640 impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
641 fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
642 sens.fill(M::T::zero());
643 }
644 }
645
646 pub struct TestEqnRhs<M: Matrix> {
647 ctx: M::C,
648 }
649
650 impl<M: Matrix> Op for TestEqnRhs<M> {
651 type T = M::T;
652 type V = M::V;
653 type M = M;
654 type C = M::C;
655
656 fn nout(&self) -> usize {
657 1
658 }
659 fn nparams(&self) -> usize {
660 1
661 }
662 fn nstates(&self) -> usize {
663 1
664 }
665 fn context(&self) -> &Self::C {
666 &self.ctx
667 }
668 }
669
670 impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
671 fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
672 y.fill(M::T::zero());
673 }
674 }
675
676 impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
677 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
678 y.fill(M::T::zero());
679 }
680 }
681
682 impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
683 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
684 sens.fill(M::T::zero());
685 }
686 }
687
688 pub struct TestEqnOut<M: Matrix> {
689 ctx: M::C,
690 }
691
692 impl<M: Matrix> Op for TestEqnOut<M> {
693 type T = M::T;
694 type V = M::V;
695 type M = M;
696 type C = M::C;
697
698 fn nout(&self) -> usize {
699 1
700 }
701 fn nparams(&self) -> usize {
702 1
703 }
704 fn nstates(&self) -> usize {
705 1
706 }
707 fn context(&self) -> &Self::C {
708 &self.ctx
709 }
710 }
711
712 impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
713 fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
714 y.copy_from(x);
715 }
716 }
717
718 impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
719 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
720 y.copy_from(v);
721 }
722 }
723
724 impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
725 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
726 sens.fill(M::T::zero());
727 }
728 }
729
730 pub struct TestEqn<M: Matrix> {
731 rhs: Rc<TestEqnRhs<M>>,
732 init: Rc<TestEqnInit<M>>,
733 out: Rc<TestEqnOut<M>>,
734 ctx: M::C,
735 }
736
737 impl<M: Matrix> TestEqn<M> {
738 pub fn new() -> Self {
739 let ctx = M::C::default();
740 Self {
741 rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
742 init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
743 out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
744 ctx,
745 }
746 }
747 }
748
749 impl<M: Matrix> Op for TestEqn<M> {
750 type T = M::T;
751 type V = M::V;
752 type M = M;
753 type C = M::C;
754 fn nout(&self) -> usize {
755 1
756 }
757 fn nparams(&self) -> usize {
758 1
759 }
760 fn nstates(&self) -> usize {
761 1
762 }
763 fn statistics(&self) -> crate::op::OpStatistics {
764 OpStatistics::default()
765 }
766 fn context(&self) -> &Self::C {
767 &self.ctx
768 }
769 }
770
771 impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
772 type Rhs = &'a TestEqnRhs<M>;
773 type Mass = ParameterisedOp<'a, UnitCallable<M>>;
774 type Root = ParameterisedOp<'a, UnitCallable<M>>;
775 type Init = &'a TestEqnInit<M>;
776 type Out = &'a TestEqnOut<M>;
777 type Reset = ParameterisedOp<'a, UnitCallable<M>>;
778 }
779
780 impl<M: Matrix> OdeEquations for TestEqn<M> {
781 fn rhs(&self) -> &TestEqnRhs<M> {
782 &self.rhs
783 }
784
785 fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
786 None
787 }
788
789 fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
790 None
791 }
792
793 fn init(&self) -> &TestEqnInit<M> {
794 &self.init
795 }
796
797 fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
798 Some(&self.out)
799 }
800 fn set_params(&mut self, _p: &Self::V) {
801 unimplemented!()
802 }
803 fn get_params(&self, _p: &mut Self::V) {
804 unimplemented!()
805 }
806 }
807
808 pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
809 let eqn = TestEqn::<M>::new();
810 let atol = eqn
811 .context()
812 .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
813 OdeSolverProblem::new(
814 eqn,
815 M::T::from_f64(1e-6).unwrap(),
816 atol,
817 None,
818 None,
819 None,
820 None,
821 None,
822 None,
823 M::T::zero(),
824 M::T::one(),
825 integrate_out,
826 Default::default(),
827 Default::default(),
828 )
829 .unwrap()
830 }
831
832 pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
833 let state = s.checkpoint();
834 let integrating_sens = !s.state().s.is_empty();
835 let integrating_out = s.problem().integrate_out;
836 let t0 = state.as_ref().t;
837 let t1 = t0 + M::T::from_f64(1e6).unwrap();
838 s.interpolate(t0)
839 .unwrap()
840 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
841 assert!(s.interpolate(t1).is_err());
842 assert!(s.interpolate_out(t1).is_err());
843 if integrating_sens {
844 assert!(s.interpolate_sens(t1).is_err());
845 } else {
846 assert!(s.interpolate_sens(t0).is_ok());
847 }
848 s.step().unwrap();
849 let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
850 assert!(s.interpolate(s.state().t).is_ok());
851 assert!(s.interpolate(tmid).is_ok());
852 if integrating_out {
853 assert!(s.interpolate_out(s.state().t).is_ok());
854 } else {
855 assert!(s.interpolate_out(s.state().t).is_err());
856 }
857 assert!(s.interpolate_sens(s.state().t).is_ok());
858 assert!(s.interpolate(s.state().t + t1).is_err());
859 assert!(s.interpolate_out(s.state().t + t1).is_err());
860 if integrating_sens {
861 assert!(s.interpolate_sens(s.state().t + t1).is_err());
862 } else {
863 assert!(s.interpolate_sens(s.state().t + t1).is_ok());
864 }
865
866 let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
867 assert!(s
868 .interpolate_inplace(s.state().t, &mut y_wrong_length)
869 .is_err());
870 let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
871 assert!(s
872 .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
873 .is_err());
874 let mut s_wrong_length = vec![
875 M::V::zeros(1, s.problem().context().clone()),
876 M::V::zeros(1, s.problem().context().clone()),
877 ];
878 assert!(s
879 .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
880 .is_err());
881 let mut s_wrong_vec_length = if integrating_sens {
882 vec![M::V::zeros(2, s.problem().context().clone())]
883 } else {
884 vec![]
885 };
886 if integrating_sens {
887 assert!(s
888 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
889 .is_err());
890 } else {
891 assert!(s
892 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
893 .is_ok());
894 }
895
896 s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
897 assert!(s.interpolate(s.state().t).is_ok());
898 if integrating_out {
899 assert!(s.interpolate_out(s.state().t).is_ok());
900 }
901 if integrating_sens {
902 assert!(s.interpolate_sens(s.state().t).is_ok());
903 }
904 assert!(s.interpolate(tmid).is_err());
905 assert!(s.interpolate_out(tmid).is_err());
906 if integrating_sens {
907 assert!(s.interpolate_sens(tmid).is_err());
908 } else {
909 assert!(s.interpolate_sens(tmid).is_ok());
910 }
911 }
912
913 pub fn test_interpolate_dy<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(
914 mut s: Method,
915 ) {
916 let t_future = s.state().t + M::T::from_f64(1e6).unwrap();
918 assert!(s.interpolate_dy(t_future).is_err());
919
920 let t0 = s.state().t;
921 s.step().unwrap();
922 let t1 = s.state().t;
923 let dt = t1 - t0;
924 let tmid = t0 + dt / M::T::from_f64(2.0).unwrap();
925
926 let mut dy_wrong = M::V::zeros(2, s.problem().context().clone());
928 assert!(s.interpolate_dy_inplace(t1, &mut dy_wrong).is_err());
929
930 assert!(s.interpolate_dy(t1 + M::T::from_f64(1e6).unwrap()).is_err());
932
933 let eps = dt.abs() * M::T::from_f64(1e-5).unwrap();
935 let y_plus = s.interpolate(tmid + eps).unwrap();
936 let y_minus = s.interpolate(tmid - eps).unwrap();
937 let fd_dy = (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps));
938 let dy = s.interpolate_dy(tmid).unwrap();
939 dy.assert_eq_norm(
940 &fd_dy,
941 &s.problem().atol,
942 s.problem().rtol,
943 M::T::from_f64(1e3).unwrap(),
944 );
945
946 let t1 = s.state().t;
948 s.step().unwrap();
949 let t2 = s.state().t;
950 let dt2 = t2 - t1;
951 let tmid2 = t1 + dt2 / M::T::from_f64(2.0).unwrap();
952 let eps2 = dt2.abs() * M::T::from_f64(1e-5).unwrap();
953 let y_plus = s.interpolate(tmid2 + eps2).unwrap();
954 let y_minus = s.interpolate(tmid2 - eps2).unwrap();
955 let fd_dy2 =
956 (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps2));
957 let dy2 = s.interpolate_dy(tmid2).unwrap();
958 dy2.assert_eq_norm(
959 &fd_dy2,
960 &s.problem().atol,
961 s.problem().rtol,
962 M::T::from_f64(1e3).unwrap(),
963 );
964 }
965
966 pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
967 mut s: Method,
968 ) {
969 *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
970 assert_eq!(
971 *s.config().as_base_ref().minimum_timestep,
972 Eqn::T::from_f64(1.0e8).unwrap()
973 );
974 *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
976
977 let mut failed = false;
978 for _ in 0..10 {
979 if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
980 s.step()
981 {
982 failed = true;
983 break;
984 }
985 }
986 assert!(failed);
987 }
988
989 pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
990 let state = s.checkpoint();
991 let state2 = s.state();
992 state2
993 .y
994 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
995 s.state_mut()
996 .y
997 .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
998 assert_eq!(
999 s.state_mut().y.get_index(0),
1000 M::T::from_f64(std::f64::consts::PI).unwrap()
1001 );
1002 }
1003
1004 #[cfg(feature = "diffsl-cranelift")]
1005 pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
1006 ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
1007 crate::OdeBuilder::<M>::new()
1008 .build_from_diffsl(
1009 "
1010 g { 9.81 } h { 10.0 }
1011 u_i {
1012 x = h,
1013 v = 0,
1014 }
1015 F_i {
1016 v,
1017 -g,
1018 }
1019 stop {
1020 x,
1021 }
1022 ",
1023 )
1024 .unwrap()
1025 }
1026
1027 #[cfg(feature = "diffsl-cranelift")]
1028 pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
1029 where
1030 M: crate::MatrixHost<T = f64>,
1031 M: DefaultSolver<T = f64>,
1032 M::V: DefaultDenseMatrix<T = f64>,
1033 Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
1034 {
1035 let e = 0.8;
1036
1037 let final_time = 2.5;
1038
1039 solver.set_stop_time(final_time).unwrap();
1041 loop {
1042 match solver.step() {
1043 Ok(OdeSolverStopReason::InternalTimestep) => (),
1044 Ok(OdeSolverStopReason::RootFound(t, _)) => {
1045 let mut y = solver.interpolate(t).unwrap();
1047
1048 y.set_index(1, y.get_index(1) * -e);
1050
1051 y.set_index(0, y.get_index(0).max(f64::EPSILON));
1053
1054 solver.state_mut().y.copy_from(&y);
1056 solver.state_mut().dy.set_index(0, y.get_index(1));
1057 *solver.state_mut().t = t;
1058
1059 break;
1060 }
1061 Ok(OdeSolverStopReason::TstopReached) => break,
1062 Err(_) => panic!("unexpected solver error"),
1063 }
1064 }
1065 let mut x = vec![];
1067 let mut v = vec![];
1068 let mut t = vec![];
1069 for _ in 0..3 {
1070 let ret = solver.step();
1071 x.push(solver.state().y.get_index(0));
1072 v.push(solver.state().y.get_index(1));
1073 t.push(solver.state().t);
1074 match ret {
1075 Ok(OdeSolverStopReason::InternalTimestep) => (),
1076 Ok(OdeSolverStopReason::RootFound(_, _)) => {
1077 panic!("should be an internal timestep but found a root")
1078 }
1079 Ok(OdeSolverStopReason::TstopReached) => break,
1080 _ => panic!("should be an internal timestep"),
1081 }
1082 }
1083 (x, v, t)
1084 }
1085
1086 pub fn test_checkpointing<'a, M, Method, Eqn>(
1087 soln: OdeSolverSolution<M::V>,
1088 mut solver1: Method,
1089 mut solver2: Method,
1090 ) where
1091 M: Matrix + DefaultSolver,
1092 Method: OdeSolverMethod<'a, Eqn>,
1093 Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
1094 {
1095 let half_i = soln.solution_points.len() / 2;
1096 let half_t = soln.solution_points[half_i].t;
1097 while solver1.state().t <= half_t {
1098 solver1.step().unwrap();
1099 }
1100 let checkpoint = solver1.checkpoint();
1101 let checkpoint_t = checkpoint.as_ref().t;
1102 solver2.set_state(checkpoint);
1103
1104 for point in soln.solution_points.iter().skip(half_i + 1) {
1106 if point.t < checkpoint_t {
1108 continue;
1109 }
1110 while solver2.state().t < point.t {
1111 solver1.step().unwrap();
1112 solver2.step().unwrap();
1113 let time_error = (solver1.state().t - solver2.state().t).abs()
1114 / (solver1.state().t.abs() * solver1.problem().rtol
1115 + solver1.problem().atol.get_index(0));
1116 assert!(
1117 time_error < M::T::from_f64(20.0).unwrap(),
1118 "time_error: {} at t = {}",
1119 time_error,
1120 solver1.state().t
1121 );
1122 solver1.state().y.assert_eq_norm(
1123 solver2.state().y,
1124 &solver1.problem().atol,
1125 solver1.problem().rtol,
1126 M::T::from_f64(20.0).unwrap(),
1127 );
1128 }
1129 let soln = solver1.interpolate(point.t).unwrap();
1130 soln.assert_eq_norm(
1131 &point.state,
1132 &solver1.problem().atol,
1133 solver1.problem().rtol,
1134 M::T::from_f64(15.0).unwrap(),
1135 );
1136 let soln = solver2.interpolate(point.t).unwrap();
1137 soln.assert_eq_norm(
1138 &point.state,
1139 &solver1.problem().atol,
1140 solver1.problem().rtol,
1141 M::T::from_f64(15.0).unwrap(),
1142 );
1143 }
1144 }
1145
1146 pub fn test_state_mut_on_problem<'a, Eqn, Method>(
1147 mut s: Method,
1148 soln: OdeSolverSolution<Eqn::V>,
1149 ) where
1150 Eqn: OdeEquationsImplicit + 'a,
1151 Eqn::M: DefaultSolver,
1152 Method: OdeSolverMethod<'a, Eqn>,
1153 Eqn::V: DefaultDenseMatrix,
1154 {
1155 let state = s.checkpoint();
1157 s.solve(Eqn::T::one()).unwrap();
1158
1159 s.state_mut().y.copy_from(state.as_ref().y);
1161 s.state_mut().dy.copy_from(state.as_ref().dy);
1162 *s.state_mut().t = state.as_ref().t;
1163
1164 for point in soln.solution_points.iter() {
1166 while s.state().t < point.t {
1167 s.step().unwrap();
1168 }
1169 let soln = s.interpolate(point.t).unwrap();
1170 let error = soln.clone() - &point.state;
1171 let error_norm = error
1172 .squared_norm(&error, &s.problem().atol, s.problem().rtol)
1173 .sqrt();
1174 assert!(
1175 error_norm < Eqn::T::from_f64(19.0).unwrap(),
1176 "error_norm: {} at t = {}",
1177 error_norm,
1178 point.t
1179 );
1180 }
1181 }
1182
1183 pub fn test_root_found_index<'a, Eqn, Method>(
1192 mut solver: Method,
1193 soln: &OdeSolverSolution<Eqn::V>,
1194 expected_root_index: usize,
1195 tol: Eqn::T,
1196 ) where
1197 Eqn: OdeEquations + 'a,
1198 Method: OdeSolverMethod<'a, Eqn>,
1199 {
1200 let t_root_expected = soln.solution_points[0].t;
1201 solver
1202 .set_stop_time(Eqn::T::from_f64(100.0).unwrap())
1203 .unwrap();
1204 loop {
1205 match solver.step().unwrap() {
1206 OdeSolverStopReason::RootFound(t, index) => {
1208 assert_eq!(
1209 index, expected_root_index,
1210 "expected root index {expected_root_index} but got {index}",
1211 );
1212 assert!(
1213 (t - t_root_expected).abs() < tol,
1214 "expected t ≈ {t_root_expected:?}, got {t:?}",
1215 );
1216 break;
1217 }
1218 OdeSolverStopReason::TstopReached => {
1219 panic!("reached tstop without finding a root")
1220 }
1221 OdeSolverStopReason::InternalTimestep => {}
1222 }
1223 }
1224 }
1225
1226 pub fn test_solve_with_reset<'a, Eqn, Method>(
1229 mut solver: Method,
1230 soln: &OdeSolverSolution<Eqn::V>,
1231 final_time: Eqn::T,
1232 ) where
1233 Eqn: OdeEquationsImplicit + 'a,
1234 Eqn::M: DefaultSolver,
1235 Eqn::V: DefaultDenseMatrix,
1236 Method: OdeSolverMethod<'a, Eqn>,
1237 {
1238 let (ys, ts, stop_reason) = solver.solve(final_time).unwrap();
1239 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1240 let t_last = *ts.last().unwrap();
1241 let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1242 assert!(
1243 (t_last - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1244 "expected solve() to reach final_time ≈ {:?}, got {:?}",
1245 final_time,
1246 t_last,
1247 );
1248 assert!(
1249 (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1250 "expected solver state at final_time ≈ {:?}, got {:?}",
1251 final_time,
1252 solver.state().t,
1253 );
1254
1255 let expected = &soln.solution_points[0];
1256 let root_time_tol = soln.rtol * expected.t.abs() + soln.atol.get_index(0);
1257 let root_col = ts
1258 .iter()
1259 .position(|&t| (t - expected.t).abs() < Eqn::T::from_f64(30.0).unwrap() * root_time_tol)
1260 .expect("expected solve() output to include the second-root/reset time");
1261 let root_expected = Eqn::V::from_element(
1262 expected.state.len(),
1263 Eqn::T::from_f64(0.4).unwrap(),
1264 expected.state.context().clone(),
1265 );
1266 let root_state = ys.column(root_col).into_owned();
1267 let root_error = root_state - &root_expected;
1268 let root_error_norm = root_error
1269 .squared_norm(&root_expected, &soln.atol, soln.rtol)
1270 .sqrt();
1271 let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1272 assert!(
1273 root_error_norm < error_threshold,
1274 "expected reset state y=0.4 at second-root time; WRMS error norm {root_error_norm:?} ≥ {error_threshold:?}",
1275 );
1276
1277 let reset_value = Eqn::T::from_f64(0.4).unwrap();
1278 let reset_tol = Eqn::T::from_f64(30.0).unwrap()
1279 * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1280 let last_reset_col = (0..ts.len())
1281 .rev()
1282 .find(|&i| (ys.get_index(0, i) - reset_value).abs() < reset_tol)
1283 .expect("expected solve() output to include at least one reset state");
1284 let final_time_f64 = final_time.to_f64().unwrap();
1285 let last_reset_time_f64 = ts[last_reset_col].to_f64().unwrap();
1286 let expected_final_value =
1287 Eqn::T::from_f64(0.4 * (-0.1 * (final_time_f64 - last_reset_time_f64)).exp()).unwrap();
1288 let expected_final = Eqn::V::from_element(
1289 expected.state.len(),
1290 expected_final_value,
1291 expected.state.context().clone(),
1292 );
1293 let final_state = ys.column(ts.len() - 1).into_owned();
1294 let final_error = final_state - &expected_final;
1295 let final_error_norm = final_error
1296 .squared_norm(&expected_final, &soln.atol, soln.rtol)
1297 .sqrt();
1298 assert!(
1299 final_error_norm < error_threshold,
1300 "final state mismatch after automatic reset continuation: WRMS error norm {final_error_norm:?} ≥ {error_threshold:?}",
1301 );
1302 }
1303
1304 pub fn test_solve_dense_with_reset<'a, Eqn, Method>(
1307 mut solver: Method,
1308 soln: &OdeSolverSolution<Eqn::V>,
1309 ) where
1310 Eqn: OdeEquationsImplicit + 'a,
1311 Eqn::M: DefaultSolver,
1312 Eqn::V: DefaultDenseMatrix,
1313 Method: OdeSolverMethod<'a, Eqn>,
1314 {
1315 let t_stop = soln.solution_points[0].t;
1316 let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1317 let mut probe_solver = solver.clone();
1318 let (probe_ys, probe_ts, probe_stop_reason) = probe_solver.solve(final_time).unwrap();
1319 assert_eq!(probe_stop_reason, OdeSolverStopReason::TstopReached);
1320
1321 let reset_time_tol =
1322 Eqn::T::from_f64(30.0).unwrap() * (soln.rtol * t_stop.abs() + soln.atol.get_index(0));
1323 let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1324 let reset_value = Eqn::T::from_f64(0.4).unwrap();
1325 let reset_value_tol = Eqn::T::from_f64(30.0).unwrap()
1326 * (soln.rtol * reset_value.abs() + soln.atol.get_index(0));
1327 let reset_col = (0..probe_ts.len())
1328 .find(|&i| {
1329 (probe_ts[i] - t_stop).abs() < reset_time_tol
1330 && (probe_ys.get_index(0, i) - reset_value).abs() < reset_value_tol
1331 })
1332 .expect("expected solve() probe output to contain the second-root reset state");
1333 let t_event = probe_ts[reset_col];
1334 let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, final_time];
1335
1336 let (ret, stop_reason) = solver.solve_dense(&t_eval).unwrap();
1337 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1338 assert!(
1339 ret.ncols() == t_eval.len(),
1340 "expected solve_dense() to fill all requested evaluation times"
1341 );
1342 let time_tol = soln.rtol * final_time.abs() + soln.atol.get_index(0);
1343 assert!(
1344 (solver.state().t - final_time).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1345 "expected solver state at final_time ≈ {:?}, got {:?}",
1346 final_time,
1347 solver.state().t,
1348 );
1349
1350 let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1351 let pre_reset_state = ret.column(1).into_owned();
1352 let pre_reset_error = pre_reset_state - &soln.solution_points[0].state;
1353 let pre_reset_error_norm = pre_reset_error
1354 .squared_norm(&soln.solution_points[0].state, &soln.atol, soln.rtol)
1355 .sqrt();
1356 assert!(
1357 pre_reset_error_norm < error_threshold,
1358 "expected pre-reset state at event time; WRMS norm {pre_reset_error_norm:?} >= {error_threshold:?}",
1359 );
1360
1361 let expected_post_reset_value =
1362 reset_value * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1363 let expected_post_reset = Eqn::V::from_element(
1364 soln.solution_points[0].state.len(),
1365 expected_post_reset_value,
1366 soln.solution_points[0].state.context().clone(),
1367 );
1368 let post_reset_state = ret.column(2).into_owned();
1369 let post_reset_error = post_reset_state - &expected_post_reset;
1370 let post_reset_error_norm = post_reset_error
1371 .squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1372 .sqrt();
1373 assert!(
1374 post_reset_error_norm < error_threshold,
1375 "expected reset state just after event time; WRMS norm {post_reset_error_norm:?} >= {error_threshold:?}",
1376 );
1377 }
1378
1379 pub fn test_solve_dense_sensitivities_with_reset<'a, Eqn, Method>(
1382 mut solver: Method,
1383 soln: &OdeSolverSolution<Eqn::V>,
1384 ) where
1385 Eqn: OdeEquationsImplicitSens + 'a,
1386 Eqn::V: DefaultDenseMatrix,
1387 Eqn::M: DefaultSolver,
1388 Method: SensitivitiesOdeSolverMethod<'a, Eqn>,
1389 {
1390 let t_stop = soln.solution_points[0].t;
1391 let t_event = Eqn::T::from_f64(10.0 * (5.0_f64 / 3.0_f64).ln()).unwrap();
1392
1393 let post_event_dt = Eqn::T::from_f64(1e-6).unwrap();
1394 let t_eval = vec![Eqn::T::zero(), t_event, t_event + post_event_dt, t_stop];
1395 let (ret, ret_sens, stop_reason) = solver.solve_dense_sensitivities(&t_eval).unwrap();
1396 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1397 assert_eq!(ret.ncols(), t_eval.len());
1398 for ret_sens_j in &ret_sens {
1399 assert_eq!(ret_sens_j.ncols(), t_eval.len());
1400 }
1401
1402 let error_threshold = Eqn::T::from_f64(100.0).unwrap();
1403 let ctx = soln.solution_points[0].state.context().clone();
1404 let nstates = soln.solution_points[0].state.len();
1405
1406 let post_reset_y = Eqn::T::from_f64(2.6).unwrap()
1407 * (-Eqn::T::from_f64(0.1).unwrap() * post_event_dt).exp();
1408 let post_reset_t = t_event + post_event_dt;
1409 let expected_post_reset = Eqn::V::from_element(nstates, post_reset_y, ctx.clone());
1410 let expected_post_reset_sk =
1411 Eqn::V::from_element(nstates, -post_reset_y * post_reset_t, ctx.clone());
1412 let expected_post_reset_sy0 = Eqn::V::from_element(nstates, post_reset_y, ctx);
1413
1414 let col = 2;
1415 let ey = ret.column(col).into_owned() - &expected_post_reset;
1416 let esk = ret_sens[0].column(col).into_owned() - &expected_post_reset_sk;
1417 let esy0 = ret_sens[1].column(col).into_owned() - &expected_post_reset_sy0;
1418 let norm = (ey.squared_norm(&expected_post_reset, &soln.atol, soln.rtol)
1419 + esk.squared_norm(&expected_post_reset_sk, &soln.atol, soln.rtol)
1420 + esy0.squared_norm(&expected_post_reset_sy0, &soln.atol, soln.rtol))
1421 .sqrt();
1422 assert!(
1423 norm < error_threshold,
1424 "dense sensitivity mismatch just after reset; combined WRMS {norm:?} >= {error_threshold:?}",
1425 );
1426 }
1427
1428 pub fn test_solve_adjoint_with_single_reset_root<
1429 'a,
1430 Eqn,
1431 MethodF,
1432 MethodB,
1433 BuildForward,
1434 BuildAdjointState,
1435 BuildAdjointFromState,
1436 >(
1437 build_forward: BuildForward,
1438 soln: &OdeSolverSolution<Eqn::V>,
1439 build_adjoint_state: BuildAdjointState,
1440 build_adjoint_from_state: BuildAdjointFromState,
1441 use_replay_solver: bool,
1442 ) where
1443 Eqn: OdeEquationsImplicitAdjoint + 'a,
1444 Eqn::M: DefaultSolver,
1445 Eqn::V: DefaultDenseMatrix,
1446 MethodF: OdeSolverMethod<'a, Eqn>,
1447 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1448 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1449 BuildAdjointState:
1450 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1451 BuildAdjointFromState:
1452 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1453 {
1454 let expected_out = &soln.solution_points[0];
1455 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1456
1457 let mut forward_solver = build_forward(None).unwrap();
1458 let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1459 .solve_with_checkpointing(forward_stop_time, None)
1460 .unwrap();
1461 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1462 assert!(
1463 checkpointers.len() >= 3,
1464 "expected checkpointing path to include the two reset events"
1465 );
1466 let problem = forward_solver.problem();
1467 let post_reset_solver = forward_solver.clone();
1468 let post_reset_root_idx = checkpointers[1]
1469 .terminal_reset_root_idx()
1470 .expect("second reset segment should record its terminal root index");
1471 let final_forward_state = checkpointers[1].last_checkpoint().clone();
1472 let t_second_root = final_forward_state.as_ref().t;
1473
1474 let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1475 let out_norm = out_error
1476 .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1477 .sqrt();
1478 assert!(
1479 out_norm < Eqn::T::from_f64(50.0).unwrap(),
1480 "forward integrated output mismatch at second root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1481 final_forward_state.as_ref().g,
1482 expected_out.state,
1483 );
1484 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1485 assert!(
1486 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1487 "expected second root time ≈ {:?}, got {:?}",
1488 expected_out.t,
1489 t_second_root,
1490 );
1491
1492 let adjoint_checkpointers = checkpointers.into_iter().take(2).collect::<Vec<_>>();
1493
1494 let mut missing_metadata_checkpointers = adjoint_checkpointers.clone();
1496 missing_metadata_checkpointers[0].clear_terminal_reset_root_idx();
1497 let missing_metadata_solver = use_replay_solver.then(|| post_reset_solver.clone());
1498 let mut missing_metadata_adjoint_eqn = problem.adjoint_equations(
1499 missing_metadata_checkpointers,
1500 missing_metadata_solver,
1501 None,
1502 );
1503 let mut missing_metadata_adjoint_state =
1504 build_adjoint_state(&mut missing_metadata_adjoint_eqn).unwrap();
1505 missing_metadata_adjoint_state
1506 .as_mut()
1507 .state_mut_adjoint_terminal_root(
1508 &problem.eqn,
1509 post_reset_root_idx,
1510 &final_forward_state,
1511 problem.integrate_out,
1512 )
1513 .unwrap();
1514 let missing_metadata_adjoint =
1515 build_adjoint_from_state(missing_metadata_adjoint_state, missing_metadata_adjoint_eqn)
1516 .unwrap();
1517 let missing_metadata_err =
1518 match missing_metadata_adjoint.solve_adjoint_backwards_pass(&[], &[]) {
1519 Ok(_) => panic!("expected missing reset metadata error"),
1520 Err(err) => err,
1521 };
1522 assert!(
1523 format!("{missing_metadata_err:?}").contains("Missing reset root metadata"),
1524 "expected missing reset metadata error, got {missing_metadata_err:?}",
1525 );
1526
1527 let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1529 let mut adjoint_eqn =
1530 problem.adjoint_equations(adjoint_checkpointers, adjoint_solver, None);
1531 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1532 adjoint_state
1533 .as_mut()
1534 .state_mut_adjoint_terminal_root(
1535 &problem.eqn,
1536 post_reset_root_idx,
1537 &final_forward_state,
1538 problem.integrate_out,
1539 )
1540 .unwrap();
1541 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1542 let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1543
1544 let t0 = problem.t0;
1545 let ctx = problem.context().clone();
1546
1547 let sens_points = soln.sens_solution_points.as_ref().unwrap();
1548 let expected_grad = Eqn::V::from_vec(
1549 sens_points
1550 .iter()
1551 .map(|pts| pts[0].state.get_index(0))
1552 .collect(),
1553 ctx.clone(),
1554 );
1555 let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1556 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1557 assert!(
1558 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1559 "expected adjoint final time {:?}, got {:?}",
1560 t0,
1561 adjoint_state.as_ref().t,
1562 );
1563 adjoint_state.as_ref().sg[0].assert_eq_norm(
1564 &expected_grad,
1565 &atol,
1566 Eqn::T::from_f64(1e-6).unwrap(),
1567 Eqn::T::from_f64(60.0).unwrap(),
1568 );
1569 }
1570
1571 #[allow(clippy::too_many_arguments)]
1572 pub fn test_solve_adjoint_sum_squares_with_single_reset_root<
1573 'a,
1574 Eqn,
1575 MethodF,
1576 MethodB,
1577 BuildForward,
1578 BuildAdjointState,
1579 BuildAdjointFromState,
1580 >(
1581 build_forward: BuildForward,
1582 soln: &OdeSolverSolution<Eqn::V>,
1583 build_adjoint_state: BuildAdjointState,
1584 build_adjoint_from_state: BuildAdjointFromState,
1585 use_replay_solver: bool,
1586 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1587 data: <Eqn::V as DefaultDenseMatrix>::M,
1588 times: &[Eqn::T],
1589 ) where
1590 Eqn: OdeEquationsImplicitAdjoint + 'a,
1591 Eqn::M: DefaultSolver,
1592 Eqn::V: DefaultDenseMatrix,
1593 MethodF: OdeSolverMethod<'a, Eqn>,
1594 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1595 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1596 BuildAdjointState:
1597 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1598 BuildAdjointFromState:
1599 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1600 {
1601 let expected_out = &soln.solution_points[0];
1602 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1603 let forwards_soln =
1604 solve_dense_with_single_reset_root::<Eqn, MethodF, _>(&build_forward, times);
1605 assert_eq!(
1606 forwards_soln.ncols(),
1607 times.len(),
1608 "expected stitched forward samples to cover every requested observation time",
1609 );
1610 let dgdu = dsum_squaresdp(&forwards_soln, &data);
1611 let dgdu_refs = dgdu.iter().collect::<Vec<_>>();
1612
1613 let mut forward_solver = build_forward(None).unwrap();
1614 let (checkpointers, _forward_y, _forward_t, stop_reason) = forward_solver
1615 .solve_with_checkpointing(forward_stop_time, None)
1616 .unwrap();
1617 assert_eq!(stop_reason, OdeSolverStopReason::TstopReached);
1618 assert!(
1619 checkpointers.len() >= 3,
1620 "expected checkpointing path to include the two reset events"
1621 );
1622 let problem = forward_solver.problem();
1623 let post_reset_solver = forward_solver.clone();
1624 let post_reset_root_idx = checkpointers[1]
1625 .terminal_reset_root_idx()
1626 .expect("second reset segment should record its terminal root index");
1627 let final_forward_state = checkpointers[1].last_checkpoint().clone();
1628 let t_second_root = final_forward_state.as_ref().t;
1629
1630 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1631 assert!(
1632 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1633 "expected second root time ≈ {:?}, got {:?}",
1634 expected_out.t,
1635 t_second_root,
1636 );
1637
1638 let adjoint_solver = use_replay_solver.then_some(post_reset_solver);
1639 let mut adjoint_eqn = problem.adjoint_equations(
1640 checkpointers.into_iter().take(2).collect(),
1641 adjoint_solver,
1642 Some(dgdu.len()),
1643 );
1644 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1645 adjoint_state
1646 .as_mut()
1647 .state_mut_adjoint_terminal_root(
1648 &problem.eqn,
1649 post_reset_root_idx,
1650 &final_forward_state,
1651 problem.integrate_out,
1652 )
1653 .unwrap();
1654 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1655 let (adjoint_state, _) = adjoint
1656 .solve_adjoint_backwards_pass(times, dgdu_refs.as_slice())
1657 .unwrap();
1658
1659 let t0 = problem.t0;
1660 let ctx = problem.context().clone();
1661
1662 let nparams = dgdp_check.nrows();
1663 let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1664 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1665 assert!(
1666 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1667 "expected adjoint final time {:?}, got {:?}",
1668 t0,
1669 adjoint_state.as_ref().t,
1670 );
1671 #[allow(clippy::needless_range_loop)]
1672 for j in 0..dgdp_check.ncols() {
1673 adjoint_state.as_ref().sg[j].assert_eq_norm(
1674 &dgdp_check.column(j).into_owned(),
1675 &atol,
1676 Eqn::T::from_f64(1e-6).unwrap(),
1677 Eqn::T::from_f64(260.0).unwrap(),
1678 );
1679 }
1680 }
1681
1682 pub fn test_solve_soln_adjoint_with_single_reset_root<
1683 'a,
1684 Eqn,
1685 MethodF,
1686 MethodB,
1687 BuildForward,
1688 BuildAdjointState,
1689 BuildAdjointFromState,
1690 >(
1691 build_forward: BuildForward,
1692 soln: &OdeSolverSolution<Eqn::V>,
1693 build_adjoint_state: BuildAdjointState,
1694 build_adjoint_from_state: BuildAdjointFromState,
1695 use_replay_solver: bool,
1696 ) where
1697 Eqn: OdeEquationsImplicitAdjoint + 'a,
1698 Eqn::M: DefaultSolver,
1699 Eqn::V: DefaultDenseMatrix,
1700 MethodF: OdeSolverMethod<'a, Eqn>,
1701 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1702 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1703 BuildAdjointState:
1704 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1705 BuildAdjointFromState:
1706 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1707 {
1708 let expected_out = &soln.solution_points[0];
1709 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1710 let mut forward_soln = Solution::<Eqn::V>::new(forward_stop_time);
1711 let mut checkpointers = Vec::new();
1712
1713 let first_forward_solver = build_forward(None)
1714 .unwrap()
1715 .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1716 .unwrap();
1717 let first_root_idx = match forward_soln.stop_reason {
1718 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1719 Some(reason) => {
1720 panic!("expected first staged solve to stop at reset root, got {reason:?}")
1721 }
1722 None => panic!("first staged solve did not set a stop reason"),
1723 };
1724 assert_eq!(checkpointers.len(), 1);
1725 assert_eq!(
1726 checkpointers[0].terminal_reset_root_idx(),
1727 Some(first_root_idx)
1728 );
1729
1730 let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1731 let terminal_forward_solver = build_forward(Some(state_after_reset))
1732 .unwrap()
1733 .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1734 .unwrap();
1735 let terminal_root_idx = match forward_soln.stop_reason {
1736 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1737 Some(reason) => {
1738 panic!("expected second staged solve to stop at terminal root, got {reason:?}")
1739 }
1740 None => panic!("second staged solve did not set a stop reason"),
1741 };
1742 assert_eq!(checkpointers.len(), 2);
1743 assert_eq!(
1744 checkpointers[1].terminal_reset_root_idx(),
1745 Some(terminal_root_idx)
1746 );
1747
1748 let problem = terminal_forward_solver.problem();
1749 let final_forward_state = terminal_forward_solver.state_clone();
1750 let t_second_root = final_forward_state.as_ref().t;
1751 let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1752 let out_norm = out_error
1753 .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1754 .sqrt();
1755 assert!(
1756 out_norm < Eqn::T::from_f64(50.0).unwrap(),
1757 "forward integrated output mismatch at terminal root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1758 final_forward_state.as_ref().g,
1759 expected_out.state,
1760 );
1761 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1762 assert!(
1763 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1764 "expected terminal root time ≈ {:?}, got {:?}",
1765 expected_out.t,
1766 t_second_root,
1767 );
1768
1769 let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1770 let mut adjoint_eqn = problem.adjoint_equations(checkpointers, adjoint_solver, None);
1771 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1772 adjoint_state
1773 .as_mut()
1774 .state_mut_adjoint_terminal_root(
1775 &problem.eqn,
1776 terminal_root_idx,
1777 &final_forward_state,
1778 problem.integrate_out,
1779 )
1780 .unwrap();
1781 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1782 let (adjoint_state, _) = adjoint.solve_adjoint_backwards_pass(&[], &[]).unwrap();
1783
1784 let t0 = problem.t0;
1785 let ctx = problem.context().clone();
1786 let sens_points = soln.sens_solution_points.as_ref().unwrap();
1787 let expected_grad = Eqn::V::from_vec(
1788 sens_points
1789 .iter()
1790 .map(|pts| pts[0].state.get_index(0))
1791 .collect(),
1792 ctx.clone(),
1793 );
1794 let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1795 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1796 assert!(
1797 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1798 "expected adjoint final time {:?}, got {:?}",
1799 t0,
1800 adjoint_state.as_ref().t,
1801 );
1802 adjoint_state.as_ref().sg[0].assert_eq_norm(
1803 &expected_grad,
1804 &atol,
1805 Eqn::T::from_f64(1e-6).unwrap(),
1806 Eqn::T::from_f64(60.0).unwrap(),
1807 );
1808 }
1809
1810 #[allow(clippy::too_many_arguments)]
1811 pub fn test_solve_soln_adjoint_sum_squares_with_single_reset_root<
1812 'a,
1813 Eqn,
1814 MethodF,
1815 MethodB,
1816 BuildForward,
1817 BuildAdjointState,
1818 BuildAdjointFromState,
1819 >(
1820 build_forward: BuildForward,
1821 soln: &OdeSolverSolution<Eqn::V>,
1822 build_adjoint_state: BuildAdjointState,
1823 build_adjoint_from_state: BuildAdjointFromState,
1824 use_replay_solver: bool,
1825 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1826 data: <Eqn::V as DefaultDenseMatrix>::M,
1827 times: &[Eqn::T],
1828 ) where
1829 Eqn: OdeEquationsImplicitAdjoint + 'a,
1830 Eqn::M: DefaultSolver,
1831 Eqn::V: DefaultDenseMatrix,
1832 MethodF: OdeSolverMethod<'a, Eqn>,
1833 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1834 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1835 BuildAdjointState:
1836 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1837 BuildAdjointFromState:
1838 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1839 {
1840 let expected_out = &soln.solution_points[0];
1841 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1842 let mut forward_soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
1843 let mut checkpointers = Vec::new();
1844
1845 let first_forward_solver = build_forward(None)
1846 .unwrap()
1847 .solve_soln_with_checkpointing(&mut forward_soln, &mut checkpointers, None)
1848 .unwrap();
1849 let first_root_idx = match forward_soln.stop_reason {
1850 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1851 Some(reason) => {
1852 panic!("expected first staged solve to stop at reset root, got {reason:?}")
1853 }
1854 None => panic!("first staged solve did not set a stop reason"),
1855 };
1856 assert_eq!(checkpointers.len(), 1);
1857 assert_eq!(
1858 checkpointers[0].terminal_reset_root_idx(),
1859 Some(first_root_idx)
1860 );
1861
1862 let state_after_reset = state_after_manual_reset::<Eqn, MethodF>(&first_forward_solver);
1863 build_forward(Some(state_after_reset.clone()))
1864 .unwrap()
1865 .solve_soln(&mut forward_soln)
1866 .unwrap();
1867 assert!(forward_soln.is_complete());
1868 assert_eq!(
1869 forward_soln.stop_reason,
1870 Some(OdeSolverStopReason::TstopReached)
1871 );
1872
1873 let mut terminal_soln = Solution::<Eqn::V>::new(forward_stop_time);
1874 let terminal_forward_solver = build_forward(Some(state_after_reset))
1875 .unwrap()
1876 .solve_soln_with_checkpointing(&mut terminal_soln, &mut checkpointers, None)
1877 .unwrap();
1878 let terminal_root_idx = match terminal_soln.stop_reason {
1879 Some(OdeSolverStopReason::RootFound(_, idx)) => idx,
1880 Some(reason) => {
1881 panic!("expected terminal staged solve to stop at root, got {reason:?}")
1882 }
1883 None => panic!("terminal staged solve did not set a stop reason"),
1884 };
1885 assert_eq!(checkpointers.len(), 2);
1886 assert_eq!(
1887 checkpointers.last().unwrap().terminal_reset_root_idx(),
1888 Some(terminal_root_idx)
1889 );
1890
1891 let dgdu_eval = dsum_squaresdp(&forward_soln.ys, &data);
1892 let dgdu_eval_refs = dgdu_eval.iter().collect::<Vec<_>>();
1893 let problem = terminal_forward_solver.problem();
1894 let final_forward_state = terminal_forward_solver.state_clone();
1895 let t_second_root = final_forward_state.as_ref().t;
1896 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1897 assert!(
1898 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1899 "expected terminal root time ≈ {:?}, got {:?}",
1900 expected_out.t,
1901 t_second_root,
1902 );
1903
1904 let adjoint_solver = use_replay_solver.then_some(terminal_forward_solver.clone());
1905 let mut adjoint_eqn =
1906 problem.adjoint_equations(checkpointers, adjoint_solver, Some(dgdu_eval_refs.len()));
1907 let mut adjoint_state = build_adjoint_state(&mut adjoint_eqn).unwrap();
1908 adjoint_state
1909 .as_mut()
1910 .state_mut_adjoint_terminal_root(
1911 &problem.eqn,
1912 terminal_root_idx,
1913 &final_forward_state,
1914 problem.integrate_out,
1915 )
1916 .unwrap();
1917 let adjoint = build_adjoint_from_state(adjoint_state, adjoint_eqn).unwrap();
1918 let (adjoint_state, _) = adjoint
1919 .solve_adjoint_backwards_pass(times, dgdu_eval_refs.as_slice())
1920 .unwrap();
1921
1922 let t0 = problem.t0;
1923 let ctx = problem.context().clone();
1924 let nparams = dgdp_check.nrows();
1925 let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1926 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1927 assert!(
1928 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1929 "expected adjoint final time {:?}, got {:?}",
1930 t0,
1931 adjoint_state.as_ref().t,
1932 );
1933 #[allow(clippy::needless_range_loop)]
1934 for j in 0..dgdp_check.ncols() {
1935 adjoint_state.as_ref().sg[j].assert_eq_norm(
1936 &dgdp_check.column(j).into_owned(),
1937 &atol,
1938 Eqn::T::from_f64(1e-6).unwrap(),
1939 Eqn::T::from_f64(260.0).unwrap(),
1940 );
1941 }
1942 }
1943}