1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod problem;
11pub mod runge_kutta;
12pub mod sde;
13pub mod sdirk;
14pub mod sdirk_state;
15pub mod sensitivities;
16pub mod solution;
17pub mod state;
18pub mod tableau;
19
20#[cfg(test)]
21mod tests {
22 use std::rc::Rc;
23
24 use self::problem::OdeSolverSolution;
25
26 use super::*;
27 use crate::error::{DiffsolError, OdeSolverError};
28 use crate::matrix::Matrix;
29 use crate::ode_solver::sensitivities::SensitivitiesOdeSolverMethod;
30 use crate::ode_solver::solution::Solution;
31 use crate::op::unit::UnitCallable;
32 use crate::op::ParameterisedOp;
33 use crate::Scalar;
34 use crate::{
35 ode_equations::{OdeEquationsImplicitAdjointWithReset, OdeEquationsImplicitSensWithReset},
36 op::OpStatistics,
37 AdjointEquations, AdjointOdeSolverMethod, Context, DenseMatrix, MatrixCommon, MatrixRef,
38 NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit,
39 OdeEquationsImplicitAdjoint, OdeEquationsRef, OdeSolverConfig, OdeSolverMethod,
40 OdeSolverProblem, OdeSolverState, OdeSolverStopReason, Scale, VectorRef, VectorView,
41 VectorViewMut,
42 };
43 use crate::{
44 ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver,
45 NonLinearOpSens, Op, Vector,
46 };
47 use num_traits::{FromPrimitive, One, Signed, Zero};
48
49 pub fn test_ode_solver<'a, M, Eqn, Method>(
50 method: &mut Method,
51 solution: OdeSolverSolution<M::V>,
52 override_tol: Option<M::T>,
53 use_tstop: bool,
54 solve_for_sensitivities: bool,
55 ) -> Eqn::V
56 where
57 M: Matrix,
58 Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
59 Method: OdeSolverMethod<'a, Eqn>,
60 {
61 let have_root = method.problem().eqn.root().is_some();
62 for (i, point) in solution.solution_points.iter().enumerate() {
63 let (soln, sens_soln) = if use_tstop {
64 match method.set_stop_time(point.t) {
65 Ok(_) => loop {
66 match method.step() {
67 Ok(OdeSolverStopReason::RootFound(_, _)) => {
68 assert!(have_root);
69 return method.state().y.clone();
70 }
71 Ok(OdeSolverStopReason::TstopReached) => {
72 break (method.state().y.clone(), method.state().s.to_vec());
73 }
74 _ => (),
75 }
76 },
77 Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
78 }
79 } else {
80 while method.state().t.abs() < point.t.abs() {
81 if let OdeSolverStopReason::RootFound(t, _) = method.step().unwrap() {
82 assert!(have_root);
83 return method.interpolate(t).unwrap();
84 }
85 }
86 let soln = method.interpolate(point.t).unwrap();
87 let sens_soln = method.interpolate_sens(point.t).unwrap();
88 (soln, sens_soln)
89 };
90 let soln = if let Some(out) = method.problem().eqn.out() {
91 out.call(&soln, point.t)
92 } else {
93 soln
94 };
95 assert_eq!(
96 soln.len(),
97 point.state.len(),
98 "soln.len() != point.state.len()"
99 );
100 if let Some(override_tol) = override_tol {
101 soln.assert_eq_st(&point.state, override_tol);
102 } else {
103 let (rtol, atol) = if method.problem().eqn.out().is_some() {
104 (solution.rtol, &solution.atol)
106 } else {
107 (method.problem().rtol, &method.problem().atol)
108 };
109 let error = soln.clone() - &point.state;
110 let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
111 assert!(
112 error_norm < M::T::from_f64(20.0).unwrap(),
113 "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
114 error_norm,
115 point.t,
116 soln,
117 point.state
118 );
119 if solve_for_sensitivities {
120 if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
121 for (j, sens_points) in sens_soln_points.iter().enumerate() {
122 let sens_point = &sens_points[i];
123 let sens_soln = &sens_soln[j];
124 let error = sens_soln.clone() - &sens_point.state;
125 let error_norm =
126 error.squared_norm(&sens_point.state, atol, rtol).sqrt();
127 assert!(
128 error_norm < M::T::from_f64(29.0).unwrap(),
129 "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
130 point.t,
131 sens_point.state
132 );
133 }
134 }
135 }
136 }
137 }
138 method.state().y.clone()
139 }
140
141 pub fn setup_test_adjoint<'a, LS, Eqn>(
142 problem: &'a mut OdeSolverProblem<Eqn>,
143 soln: OdeSolverSolution<Eqn::V>,
144 ) -> <Eqn::V as DefaultDenseMatrix>::M
145 where
146 Eqn: OdeEquationsImplicitAdjoint + 'a,
147 LS: LinearSolver<Eqn::M>,
148 Eqn::V: DefaultDenseMatrix,
149 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
150 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
151 {
152 let nparams = problem.eqn.nparams();
153 let nout = problem.eqn.nout();
154 let ctx = problem.eqn.context();
155 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
156 let final_time = soln.solution_points.last().unwrap().t;
157 let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
158 problem.eqn.get_params(&mut p_0);
159 let h_base = Eqn::T::from_f64(1e-10).unwrap();
160 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
161 h.axpy(h_base, &p_0, Eqn::T::one());
162 let p_base = p_0.clone();
163 for i in 0..nparams {
164 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
165 problem.eqn.set_params(&p_0);
166 let g_pos = {
167 let mut s = problem.bdf::<LS>().unwrap();
168 s.set_stop_time(final_time).unwrap();
169 while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
170 s.state().g.clone()
171 };
172
173 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
174 problem.eqn.set_params(&p_0);
175 let g_neg = {
176 let mut s = problem.bdf::<LS>().unwrap();
177 s.set_stop_time(final_time).unwrap();
178 while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
179 s.state().g.clone()
180 };
181 p_0.set_index(i, p_base.get_index(i));
182
183 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
184 for j in 0..nout {
185 dgdp.set_index(i, j, delta.get_index(j));
186 }
187 }
188 problem.eqn.set_params(&p_base);
189 dgdp
190 }
191
192 pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
195 where
196 DM: DenseMatrix,
197 {
198 let mut ret = DM::V::zeros(2, soln.context().clone());
199 for j in 0..soln.ncols() {
200 let soln_j = soln.column(j);
201 let data_j = data.column(j);
202 let delta = soln_j - data_j;
203 let norm2 = delta.norm(2);
204 ret.set_index(0, ret.get_index(0) + norm2 * norm2);
205 let norm4 = delta.norm(4);
206 let norm4_sq = norm4 * norm4;
207 ret.set_index(1, ret.get_index(1) + norm4_sq * norm4_sq);
208 }
209 ret
210 }
211
212 pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
215 where
216 DM: DenseMatrix,
217 {
218 let delta = soln.clone() - data;
219 let mut delta3 = delta.clone();
220 for j in 0..delta3.ncols() {
221 let delta_col = delta.column(j).into_owned();
222
223 let mut delta3_col = delta_col.clone();
224 delta3_col.component_mul_assign(&delta_col);
225 delta3_col.component_mul_assign(&delta_col);
226
227 delta3.column_mut(j).copy_from(&delta3_col);
228 }
229 let ret = vec![
230 delta * Scale(DM::T::from_f64(2.).unwrap()),
231 delta3 * Scale(DM::T::from_f64(4.).unwrap()),
232 ];
233 ret
234 }
235
236 pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
237 problem: &'a mut OdeSolverProblem<Eqn>,
238 times: &[Eqn::T],
239 ) -> (
240 <Eqn::V as DefaultDenseMatrix>::M,
241 <Eqn::V as DefaultDenseMatrix>::M,
242 )
243 where
244 Eqn: OdeEquationsImplicitAdjoint + 'a,
245 LS: LinearSolver<Eqn::M>,
246 Eqn::V: DefaultDenseMatrix,
247 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
248 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
249 {
250 let nparams = problem.eqn.nparams();
251 let nout = 2;
252 let ctx = problem.eqn.context();
253 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
254
255 let mut p_0 = ctx.vector_zeros(nparams);
256 problem.eqn.get_params(&mut p_0);
257 let h_base = Eqn::T::from_f64(1e-10).unwrap();
258 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
259 h.axpy(h_base, &p_0, Eqn::T::one());
260 let mut p_data = p_0.clone();
261 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
262 let p_base = p_0.clone();
263
264 problem.eqn.set_params(&p_data);
265 let data = {
266 let mut s = problem.bdf::<LS>().unwrap();
267 s.solve_dense(times).unwrap().0
268 };
269
270 for i in 0..nparams {
271 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
272 problem.eqn.set_params(&p_0);
273 let g_pos = {
274 let mut s = problem.bdf::<LS>().unwrap();
275 let v = s.solve_dense(times).unwrap().0;
276 sum_squares(&v, &data)
277 };
278
279 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
280 problem.eqn.set_params(&p_0);
281 let g_neg = {
282 let mut s = problem.bdf::<LS>().unwrap();
283 let v = s.solve_dense(times).unwrap().0;
284 sum_squares(&v, &data)
285 };
286
287 p_0.set_index(i, p_base.get_index(i));
288
289 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
290 for j in 0..nout {
291 dgdp.set_index(i, j, delta.get_index(j));
292 }
293 }
294 problem.eqn.set_params(&p_base);
295 (dgdp, data)
296 }
297
298 pub fn single_reset_root_discrete_times<T: Scalar>(t_stop: T) -> Vec<T> {
299 let t_root = t_stop / T::from_f64(2.0).unwrap();
300 [0.25, 0.75, 1.25, 1.75]
301 .into_iter()
302 .map(|factor| t_root * T::from_f64(factor).unwrap())
303 .collect()
304 }
305
306 fn solve_dense_with_single_reset_root<'a, Eqn, Method, BuildForward>(
307 build_forward: BuildForward,
308 times: &[Eqn::T],
309 ) -> <Eqn::V as DefaultDenseMatrix>::M
310 where
311 Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
312 Eqn::V: DefaultDenseMatrix,
313 Method: OdeSolverMethod<'a, Eqn>,
314 BuildForward: Fn(Option<Method::State>) -> Result<Method, DiffsolError>,
315 {
316 let mut soln = Solution::<Eqn::V>::new_dense(times.to_vec()).unwrap();
317 let first_forward_solver = build_forward(None).unwrap().solve_soln(&mut soln).unwrap();
318 match soln.stop_reason {
319 Some(OdeSolverStopReason::RootFound(_, 0)) => {}
320 Some(OdeSolverStopReason::RootFound(_, idx)) => {
321 panic!("expected first solve_soln() segment to stop on root 0, got root {idx}")
322 }
323 Some(OdeSolverStopReason::TstopReached) => {
324 panic!("expected first solve_soln() segment to stop on the interior root")
325 }
326 Some(OdeSolverStopReason::InternalTimestep) | None => {
327 panic!("first solve_soln() segment did not finish with a terminal stop reason")
328 }
329 }
330
331 let mut state_after_reset = first_forward_solver.state_clone();
332 {
333 let problem = first_forward_solver.problem();
334 let reset_fn = problem.eqn.reset().unwrap();
335 state_after_reset
336 .state_mut_op(&problem.eqn, &reset_fn)
337 .unwrap();
338 }
339
340 build_forward(Some(state_after_reset))
341 .unwrap()
342 .solve_soln(&mut soln)
343 .unwrap();
344 assert!(
345 soln.is_complete(),
346 "expected stitched solve_soln() output to cover all requested observation times",
347 );
348 soln.ys
349 }
350
351 pub fn setup_test_adjoint_sum_squares_with_single_reset_root<'a, LS, Eqn>(
352 problem: &'a mut OdeSolverProblem<Eqn>,
353 times: &[Eqn::T],
354 ) -> (
355 <Eqn::V as DefaultDenseMatrix>::M,
356 <Eqn::V as DefaultDenseMatrix>::M,
357 )
358 where
359 Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
360 LS: LinearSolver<Eqn::M>,
361 Eqn::V: DefaultDenseMatrix,
362 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
363 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
364 {
365 let nparams = problem.eqn.nparams();
366 let nout = 2;
367 let ctx = problem.eqn.context();
368 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
369
370 let mut p_0 = ctx.vector_zeros(nparams);
371 problem.eqn.get_params(&mut p_0);
372 let h_base = Eqn::T::from_f64(1e-10).unwrap();
373 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
374 h.axpy(h_base, &p_0, Eqn::T::one());
375 let mut p_data = p_0.clone();
376 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
377 let p_base = p_0.clone();
378
379 problem.eqn.set_params(&p_data);
380 let data = solve_dense_with_single_reset_root::<Eqn, _, _>(
381 |state| match state {
382 Some(state) => problem.bdf_solver(state),
383 None => problem.bdf::<LS>(),
384 },
385 times,
386 );
387
388 for i in 0..nparams {
389 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
390 problem.eqn.set_params(&p_0);
391 let g_pos = {
392 let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
393 |state| match state {
394 Some(state) => problem.bdf_solver(state),
395 None => problem.bdf::<LS>(),
396 },
397 times,
398 );
399 sum_squares(&v, &data)
400 };
401
402 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
403 problem.eqn.set_params(&p_0);
404 let g_neg = {
405 let v = solve_dense_with_single_reset_root::<Eqn, _, _>(
406 |state| match state {
407 Some(state) => problem.bdf_solver(state),
408 None => problem.bdf::<LS>(),
409 },
410 times,
411 );
412 sum_squares(&v, &data)
413 };
414
415 p_0.set_index(i, p_base.get_index(i));
416
417 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
418 for j in 0..nout {
419 dgdp.set_index(i, j, delta.get_index(j));
420 }
421 }
422 problem.eqn.set_params(&p_base);
423 (dgdp, data)
424 }
425
426 pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
427 backwards_solver: SolverB,
428 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
429 forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
430 data: <Eqn::V as DefaultDenseMatrix>::M,
431 times: &[Eqn::T],
432 ) where
433 SolverF: OdeSolverMethod<'a, Eqn>,
434 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
435 Eqn: OdeEquationsImplicitAdjoint + 'a,
436 Eqn::V: DefaultDenseMatrix,
437 Eqn::M: DefaultSolver,
438 {
439 let nparams = dgdp_check.nrows();
440 let dgdu = dsum_squaresdp(&forwards_soln, &data);
441
442 let atol = Eqn::V::from_element(
443 nparams,
444 Eqn::T::from_f64(1e-6).unwrap(),
445 data.context().clone(),
446 );
447 let rtol = Eqn::T::from_f64(1e-6).unwrap();
448 let state = backwards_solver
449 .solve_adjoint_backwards_pass(None, times, dgdu.iter().collect::<Vec<_>>().as_slice())
450 .unwrap();
451 let gs_adj = state.into_common().sg;
452 #[allow(clippy::needless_range_loop)]
453 for j in 0..dgdp_check.ncols() {
454 gs_adj[j].assert_eq_norm(
455 &dgdp_check.column(j).into_owned(),
456 &atol,
457 rtol,
458 Eqn::T::from_f64(260.).unwrap(),
459 );
460 }
461 }
462
463 pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
464 backwards_solver: SolverB,
465 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
466 ) where
467 SolverF: OdeSolverMethod<'a, Eqn>,
468 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
469 Eqn: OdeEquationsImplicitAdjoint + 'a,
470 Eqn::V: DefaultDenseMatrix,
471 Eqn::M: DefaultSolver,
472 {
473 let nout = backwards_solver.problem().eqn.nout();
474 let atol = Eqn::V::from_element(
475 nout,
476 Eqn::T::from_f64(1e-6).unwrap(),
477 dgdp_check.context().clone(),
478 );
479 let rtol = Eqn::T::from_f64(1e-6).unwrap();
480 let state = backwards_solver
481 .solve_adjoint_backwards_pass(None, &[], &[])
482 .unwrap();
483 let gs_adj = state.into_common().sg;
484 #[allow(clippy::needless_range_loop)]
485 for j in 0..dgdp_check.ncols() {
486 gs_adj[j].assert_eq_norm(
487 &dgdp_check.column(j).into_owned(),
488 &atol,
489 rtol,
490 Eqn::T::from_f64(40.).unwrap(),
491 );
492 }
493 }
494
495 pub struct TestEqnInit<M: Matrix> {
496 ctx: M::C,
497 }
498
499 impl<M: Matrix> Op for TestEqnInit<M> {
500 type T = M::T;
501 type V = M::V;
502 type M = M;
503 type C = M::C;
504
505 fn nout(&self) -> usize {
506 1
507 }
508 fn nparams(&self) -> usize {
509 1
510 }
511 fn nstates(&self) -> usize {
512 1
513 }
514 fn context(&self) -> &Self::C {
515 &self.ctx
516 }
517 }
518
519 impl<M: Matrix> ConstantOp for TestEqnInit<M> {
520 fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
521 y.fill(M::T::one());
522 }
523 }
524
525 impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
526 fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
527 sens.fill(M::T::zero());
528 }
529 }
530
531 pub struct TestEqnRhs<M: Matrix> {
532 ctx: M::C,
533 }
534
535 impl<M: Matrix> Op for TestEqnRhs<M> {
536 type T = M::T;
537 type V = M::V;
538 type M = M;
539 type C = M::C;
540
541 fn nout(&self) -> usize {
542 1
543 }
544 fn nparams(&self) -> usize {
545 1
546 }
547 fn nstates(&self) -> usize {
548 1
549 }
550 fn context(&self) -> &Self::C {
551 &self.ctx
552 }
553 }
554
555 impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
556 fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
557 y.fill(M::T::zero());
558 }
559 }
560
561 impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
562 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
563 y.fill(M::T::zero());
564 }
565 }
566
567 impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
568 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
569 sens.fill(M::T::zero());
570 }
571 }
572
573 pub struct TestEqnOut<M: Matrix> {
574 ctx: M::C,
575 }
576
577 impl<M: Matrix> Op for TestEqnOut<M> {
578 type T = M::T;
579 type V = M::V;
580 type M = M;
581 type C = M::C;
582
583 fn nout(&self) -> usize {
584 1
585 }
586 fn nparams(&self) -> usize {
587 1
588 }
589 fn nstates(&self) -> usize {
590 1
591 }
592 fn context(&self) -> &Self::C {
593 &self.ctx
594 }
595 }
596
597 impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
598 fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
599 y.copy_from(x);
600 }
601 }
602
603 impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
604 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
605 y.copy_from(v);
606 }
607 }
608
609 impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
610 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
611 sens.fill(M::T::zero());
612 }
613 }
614
615 pub struct TestEqn<M: Matrix> {
616 rhs: Rc<TestEqnRhs<M>>,
617 init: Rc<TestEqnInit<M>>,
618 out: Rc<TestEqnOut<M>>,
619 ctx: M::C,
620 }
621
622 impl<M: Matrix> TestEqn<M> {
623 pub fn new() -> Self {
624 let ctx = M::C::default();
625 Self {
626 rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
627 init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
628 out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
629 ctx,
630 }
631 }
632 }
633
634 impl<M: Matrix> Op for TestEqn<M> {
635 type T = M::T;
636 type V = M::V;
637 type M = M;
638 type C = M::C;
639 fn nout(&self) -> usize {
640 1
641 }
642 fn nparams(&self) -> usize {
643 1
644 }
645 fn nstates(&self) -> usize {
646 1
647 }
648 fn statistics(&self) -> crate::op::OpStatistics {
649 OpStatistics::default()
650 }
651 fn context(&self) -> &Self::C {
652 &self.ctx
653 }
654 }
655
656 impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
657 type Rhs = &'a TestEqnRhs<M>;
658 type Mass = ParameterisedOp<'a, UnitCallable<M>>;
659 type Root = ParameterisedOp<'a, UnitCallable<M>>;
660 type Init = &'a TestEqnInit<M>;
661 type Out = &'a TestEqnOut<M>;
662 type Reset = ParameterisedOp<'a, UnitCallable<M>>;
663 }
664
665 impl<M: Matrix> OdeEquations for TestEqn<M> {
666 fn rhs(&self) -> &TestEqnRhs<M> {
667 &self.rhs
668 }
669
670 fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
671 None
672 }
673
674 fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
675 None
676 }
677
678 fn init(&self) -> &TestEqnInit<M> {
679 &self.init
680 }
681
682 fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
683 Some(&self.out)
684 }
685 fn set_params(&mut self, _p: &Self::V) {
686 unimplemented!()
687 }
688 fn get_params(&self, _p: &mut Self::V) {
689 unimplemented!()
690 }
691 }
692
693 pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
694 let eqn = TestEqn::<M>::new();
695 let atol = eqn
696 .context()
697 .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
698 OdeSolverProblem::new(
699 eqn,
700 M::T::from_f64(1e-6).unwrap(),
701 atol,
702 None,
703 None,
704 None,
705 None,
706 None,
707 None,
708 M::T::zero(),
709 M::T::one(),
710 integrate_out,
711 Default::default(),
712 Default::default(),
713 )
714 .unwrap()
715 }
716
717 pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
718 let state = s.checkpoint();
719 let integrating_sens = !s.state().s.is_empty();
720 let integrating_out = s.problem().integrate_out;
721 let t0 = state.as_ref().t;
722 let t1 = t0 + M::T::from_f64(1e6).unwrap();
723 s.interpolate(t0)
724 .unwrap()
725 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
726 assert!(s.interpolate(t1).is_err());
727 assert!(s.interpolate_out(t1).is_err());
728 if integrating_sens {
729 assert!(s.interpolate_sens(t1).is_err());
730 } else {
731 assert!(s.interpolate_sens(t0).is_ok());
732 }
733 s.step().unwrap();
734 let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
735 assert!(s.interpolate(s.state().t).is_ok());
736 assert!(s.interpolate(tmid).is_ok());
737 if integrating_out {
738 assert!(s.interpolate_out(s.state().t).is_ok());
739 } else {
740 assert!(s.interpolate_out(s.state().t).is_err());
741 }
742 assert!(s.interpolate_sens(s.state().t).is_ok());
743 assert!(s.interpolate(s.state().t + t1).is_err());
744 assert!(s.interpolate_out(s.state().t + t1).is_err());
745 if integrating_sens {
746 assert!(s.interpolate_sens(s.state().t + t1).is_err());
747 } else {
748 assert!(s.interpolate_sens(s.state().t + t1).is_ok());
749 }
750
751 let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
752 assert!(s
753 .interpolate_inplace(s.state().t, &mut y_wrong_length)
754 .is_err());
755 let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
756 assert!(s
757 .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
758 .is_err());
759 let mut s_wrong_length = vec![
760 M::V::zeros(1, s.problem().context().clone()),
761 M::V::zeros(1, s.problem().context().clone()),
762 ];
763 assert!(s
764 .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
765 .is_err());
766 let mut s_wrong_vec_length = if integrating_sens {
767 vec![M::V::zeros(2, s.problem().context().clone())]
768 } else {
769 vec![]
770 };
771 if integrating_sens {
772 assert!(s
773 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
774 .is_err());
775 } else {
776 assert!(s
777 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
778 .is_ok());
779 }
780
781 s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
782 assert!(s.interpolate(s.state().t).is_ok());
783 if integrating_out {
784 assert!(s.interpolate_out(s.state().t).is_ok());
785 }
786 if integrating_sens {
787 assert!(s.interpolate_sens(s.state().t).is_ok());
788 }
789 assert!(s.interpolate(tmid).is_err());
790 assert!(s.interpolate_out(tmid).is_err());
791 if integrating_sens {
792 assert!(s.interpolate_sens(tmid).is_err());
793 } else {
794 assert!(s.interpolate_sens(tmid).is_ok());
795 }
796 }
797
798 pub fn test_interpolate_dy<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(
799 mut s: Method,
800 ) {
801 let t_future = s.state().t + M::T::from_f64(1e6).unwrap();
803 assert!(s.interpolate_dy(t_future).is_err());
804
805 let t0 = s.state().t;
806 s.step().unwrap();
807 let t1 = s.state().t;
808 let dt = t1 - t0;
809 let tmid = t0 + dt / M::T::from_f64(2.0).unwrap();
810
811 let mut dy_wrong = M::V::zeros(2, s.problem().context().clone());
813 assert!(s.interpolate_dy_inplace(t1, &mut dy_wrong).is_err());
814
815 assert!(s.interpolate_dy(t1 + M::T::from_f64(1e6).unwrap()).is_err());
817
818 let eps = dt.abs() * M::T::from_f64(1e-5).unwrap();
820 let y_plus = s.interpolate(tmid + eps).unwrap();
821 let y_minus = s.interpolate(tmid - eps).unwrap();
822 let fd_dy = (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps));
823 let dy = s.interpolate_dy(tmid).unwrap();
824 dy.assert_eq_norm(
825 &fd_dy,
826 &s.problem().atol,
827 s.problem().rtol,
828 M::T::from_f64(1e3).unwrap(),
829 );
830
831 let t1 = s.state().t;
833 s.step().unwrap();
834 let t2 = s.state().t;
835 let dt2 = t2 - t1;
836 let tmid2 = t1 + dt2 / M::T::from_f64(2.0).unwrap();
837 let eps2 = dt2.abs() * M::T::from_f64(1e-5).unwrap();
838 let y_plus = s.interpolate(tmid2 + eps2).unwrap();
839 let y_minus = s.interpolate(tmid2 - eps2).unwrap();
840 let fd_dy2 =
841 (y_plus - y_minus) * Scale(M::T::one() / (M::T::from_f64(2.0).unwrap() * eps2));
842 let dy2 = s.interpolate_dy(tmid2).unwrap();
843 dy2.assert_eq_norm(
844 &fd_dy2,
845 &s.problem().atol,
846 s.problem().rtol,
847 M::T::from_f64(1e3).unwrap(),
848 );
849 }
850
851 pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
852 mut s: Method,
853 ) {
854 *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
855 assert_eq!(
856 *s.config().as_base_ref().minimum_timestep,
857 Eqn::T::from_f64(1.0e8).unwrap()
858 );
859 *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
861
862 let mut failed = false;
863 for _ in 0..10 {
864 if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
865 s.step()
866 {
867 failed = true;
868 break;
869 }
870 }
871 assert!(failed);
872 }
873
874 pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
875 let state = s.checkpoint();
876 let state2 = s.state();
877 state2
878 .y
879 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
880 s.state_mut()
881 .y
882 .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
883 assert_eq!(
884 s.state_mut().y.get_index(0),
885 M::T::from_f64(std::f64::consts::PI).unwrap()
886 );
887 }
888
889 #[cfg(feature = "diffsl-cranelift")]
890 pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
891 ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
892 crate::OdeBuilder::<M>::new()
893 .build_from_diffsl(
894 "
895 g { 9.81 } h { 10.0 }
896 u_i {
897 x = h,
898 v = 0,
899 }
900 F_i {
901 v,
902 -g,
903 }
904 stop {
905 x,
906 }
907 ",
908 )
909 .unwrap()
910 }
911
912 #[cfg(feature = "diffsl-cranelift")]
913 pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
914 where
915 M: crate::MatrixHost<T = f64>,
916 M: DefaultSolver<T = f64>,
917 M::V: DefaultDenseMatrix<T = f64>,
918 Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
919 {
920 let e = 0.8;
921
922 let final_time = 2.5;
923
924 solver.set_stop_time(final_time).unwrap();
926 loop {
927 match solver.step() {
928 Ok(OdeSolverStopReason::InternalTimestep) => (),
929 Ok(OdeSolverStopReason::RootFound(t, _)) => {
930 let mut y = solver.interpolate(t).unwrap();
932
933 y.set_index(1, y.get_index(1) * -e);
935
936 y.set_index(0, y.get_index(0).max(f64::EPSILON));
938
939 solver.state_mut().y.copy_from(&y);
941 solver.state_mut().dy.set_index(0, y.get_index(1));
942 *solver.state_mut().t = t;
943
944 break;
945 }
946 Ok(OdeSolverStopReason::TstopReached) => break,
947 Err(_) => panic!("unexpected solver error"),
948 }
949 }
950 let mut x = vec![];
952 let mut v = vec![];
953 let mut t = vec![];
954 for _ in 0..3 {
955 let ret = solver.step();
956 x.push(solver.state().y.get_index(0));
957 v.push(solver.state().y.get_index(1));
958 t.push(solver.state().t);
959 match ret {
960 Ok(OdeSolverStopReason::InternalTimestep) => (),
961 Ok(OdeSolverStopReason::RootFound(_, _)) => {
962 panic!("should be an internal timestep but found a root")
963 }
964 Ok(OdeSolverStopReason::TstopReached) => break,
965 _ => panic!("should be an internal timestep"),
966 }
967 }
968 (x, v, t)
969 }
970
971 pub fn test_checkpointing<'a, M, Method, Eqn>(
972 soln: OdeSolverSolution<M::V>,
973 mut solver1: Method,
974 mut solver2: Method,
975 ) where
976 M: Matrix + DefaultSolver,
977 Method: OdeSolverMethod<'a, Eqn>,
978 Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
979 {
980 let half_i = soln.solution_points.len() / 2;
981 let half_t = soln.solution_points[half_i].t;
982 while solver1.state().t <= half_t {
983 solver1.step().unwrap();
984 }
985 let checkpoint = solver1.checkpoint();
986 let checkpoint_t = checkpoint.as_ref().t;
987 solver2.set_state(checkpoint);
988
989 for point in soln.solution_points.iter().skip(half_i + 1) {
991 if point.t < checkpoint_t {
993 continue;
994 }
995 while solver2.state().t < point.t {
996 solver1.step().unwrap();
997 solver2.step().unwrap();
998 let time_error = (solver1.state().t - solver2.state().t).abs()
999 / (solver1.state().t.abs() * solver1.problem().rtol
1000 + solver1.problem().atol.get_index(0));
1001 assert!(
1002 time_error < M::T::from_f64(20.0).unwrap(),
1003 "time_error: {} at t = {}",
1004 time_error,
1005 solver1.state().t
1006 );
1007 solver1.state().y.assert_eq_norm(
1008 solver2.state().y,
1009 &solver1.problem().atol,
1010 solver1.problem().rtol,
1011 M::T::from_f64(20.0).unwrap(),
1012 );
1013 }
1014 let soln = solver1.interpolate(point.t).unwrap();
1015 soln.assert_eq_norm(
1016 &point.state,
1017 &solver1.problem().atol,
1018 solver1.problem().rtol,
1019 M::T::from_f64(15.0).unwrap(),
1020 );
1021 let soln = solver2.interpolate(point.t).unwrap();
1022 soln.assert_eq_norm(
1023 &point.state,
1024 &solver1.problem().atol,
1025 solver1.problem().rtol,
1026 M::T::from_f64(15.0).unwrap(),
1027 );
1028 }
1029 }
1030
1031 pub fn test_state_mut_on_problem<'a, Eqn, Method>(
1032 mut s: Method,
1033 soln: OdeSolverSolution<Eqn::V>,
1034 ) where
1035 Eqn: OdeEquationsImplicit + 'a,
1036 Method: OdeSolverMethod<'a, Eqn>,
1037 Eqn::V: DefaultDenseMatrix,
1038 {
1039 let state = s.checkpoint();
1041 s.solve(Eqn::T::one()).unwrap();
1042
1043 s.state_mut().y.copy_from(state.as_ref().y);
1045 s.state_mut().dy.copy_from(state.as_ref().dy);
1046 *s.state_mut().t = state.as_ref().t;
1047
1048 for point in soln.solution_points.iter() {
1050 while s.state().t < point.t {
1051 s.step().unwrap();
1052 }
1053 let soln = s.interpolate(point.t).unwrap();
1054 let error = soln.clone() - &point.state;
1055 let error_norm = error
1056 .squared_norm(&error, &s.problem().atol, s.problem().rtol)
1057 .sqrt();
1058 assert!(
1059 error_norm < Eqn::T::from_f64(19.0).unwrap(),
1060 "error_norm: {} at t = {}",
1061 error_norm,
1062 point.t
1063 );
1064 }
1065 }
1066
1067 pub fn test_root_found_index<'a, Eqn, Method>(
1076 mut solver: Method,
1077 soln: &OdeSolverSolution<Eqn::V>,
1078 expected_root_index: usize,
1079 tol: Eqn::T,
1080 ) where
1081 Eqn: OdeEquations + 'a,
1082 Method: OdeSolverMethod<'a, Eqn>,
1083 {
1084 let t_root_expected = soln.solution_points[0].t;
1085 solver
1086 .set_stop_time(Eqn::T::from_f64(100.0).unwrap())
1087 .unwrap();
1088 loop {
1089 match solver.step().unwrap() {
1090 OdeSolverStopReason::RootFound(t, index) => {
1092 assert_eq!(
1093 index, expected_root_index,
1094 "expected root index {expected_root_index} but got {index}",
1095 );
1096 assert!(
1097 (t - t_root_expected).abs() < tol,
1098 "expected t ≈ {t_root_expected:?}, got {t:?}",
1099 );
1100 break;
1101 }
1102 OdeSolverStopReason::TstopReached => {
1103 panic!("reached tstop without finding a root")
1104 }
1105 OdeSolverStopReason::InternalTimestep => {}
1106 }
1107 }
1108 }
1109
1110 pub fn test_solve_with_reset<'a, Eqn, Method>(
1116 mut solver: Method,
1117 soln: &OdeSolverSolution<Eqn::V>,
1118 ) where
1119 Eqn: OdeEquations + 'a,
1120 Eqn::V: DefaultDenseMatrix,
1121 Method: OdeSolverMethod<'a, Eqn>,
1122 {
1123 let final_time = Eqn::T::from_f64(100.0).unwrap();
1124 let (_ys_first, ts_first, stop_reason_first) = solver.solve(final_time).unwrap();
1125 assert!(matches!(
1126 stop_reason_first,
1127 OdeSolverStopReason::RootFound(_, _)
1128 ));
1129 let t_first_root = *ts_first.last().unwrap();
1130 assert!(
1131 t_first_root < final_time,
1132 "expected first solve() call to stop at a root before final_time"
1133 );
1134
1135 let mut state = solver.state_clone();
1137 {
1138 let problem = solver.problem();
1139 if let Some(reset_fn) = problem.eqn.reset() {
1140 state.state_mut_op(&problem.eqn, &reset_fn).unwrap();
1141 }
1142 }
1143 solver.set_state(state);
1144 let (ys_second, ts_second, stop_reason_second) = solver.solve(final_time).unwrap();
1145 assert!(matches!(
1146 stop_reason_second,
1147 OdeSolverStopReason::RootFound(_, _)
1148 ));
1149
1150 let expected = &soln.solution_points[0];
1151 let t_second_root = *ts_second.last().unwrap();
1152 let time_tol = soln.rtol * expected.t.abs() + soln.atol.get_index(0);
1153 assert!(
1154 (t_second_root - expected.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1155 "expected second root time ≈ {:?}, got {:?}",
1156 expected.t,
1157 t_second_root,
1158 );
1159
1160 let last_col = ts_second.len() - 1;
1161 let n = expected.state.len();
1162 let ctx = soln.atol.context().clone();
1163 let mut actual = Eqn::V::zeros(n, ctx);
1164 for j in 0..n {
1165 actual.set_index(j, ys_second.get_index(j, last_col));
1166 }
1167 let error = actual - &expected.state;
1168 let error_norm = error
1169 .squared_norm(&expected.state, &soln.atol, soln.rtol)
1170 .sqrt();
1171 let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1172 assert!(
1173 error_norm < error_threshold,
1174 "second-root state mismatch: WRMS error norm {error_norm:?} ≥ {error_threshold:?}",
1175 );
1176 }
1177
1178 pub fn test_solve_dense_with_reset<'a, Eqn, Method>(
1187 mut solver: Method,
1188 soln: &OdeSolverSolution<Eqn::V>,
1189 ) where
1190 Eqn: OdeEquations + 'a,
1191 Eqn::V: DefaultDenseMatrix,
1192 Method: OdeSolverMethod<'a, Eqn>,
1193 {
1194 let t_stop = soln.solution_points[0].t;
1195
1196 let n_steps = 20usize;
1197 let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1198 let dt = final_time / Eqn::T::from_f64(n_steps as f64).unwrap();
1199 let t_eval: Vec<Eqn::T> = (0..=n_steps)
1200 .map(|i| dt * Eqn::T::from_f64(i as f64).unwrap())
1201 .collect();
1202
1203 let (ret_first, stop_reason_first) = solver.solve_dense(&t_eval).unwrap();
1204 assert!(matches!(
1205 stop_reason_first,
1206 OdeSolverStopReason::RootFound(_, _)
1207 ));
1208 let ncols_first = ret_first.ncols();
1209
1210 assert!(
1212 ncols_first < t_eval.len(),
1213 "expected first solve_dense() call to stop at a root"
1214 );
1215 let t_first_root = solver.state().t;
1216
1217 let mut state = solver.state_clone();
1219 {
1220 let problem = solver.problem();
1221 if let Some(reset_fn) = problem.eqn.reset() {
1222 state.state_mut_op(&problem.eqn, &reset_fn).unwrap();
1223 }
1224 }
1225 solver.set_state(state);
1226
1227 let t_eval_after_reset: Vec<Eqn::T> = t_eval
1229 .iter()
1230 .copied()
1231 .filter(|&t| t > t_first_root)
1232 .collect();
1233 assert!(
1234 !t_eval_after_reset.is_empty(),
1235 "expected at least one evaluation time after first root"
1236 );
1237
1238 let (ret_second, stop_reason_second) = solver.solve_dense(&t_eval_after_reset).unwrap();
1239 assert!(matches!(
1240 stop_reason_second,
1241 OdeSolverStopReason::RootFound(_, _)
1242 ));
1243 let ncols = ret_second.ncols();
1244
1245 assert!(
1247 ncols < t_eval_after_reset.len(),
1248 "expected early stop after manual reset: ncols ({ncols}) should be < t_eval_after_reset.len() ({})",
1249 t_eval_after_reset.len(),
1250 );
1251
1252 let error_threshold = Eqn::T::from_f64(20.0).unwrap();
1253
1254 let last_col = ncols - 1;
1256 let actual = ret_second.column(last_col).into_owned();
1257 let error = actual - &soln.solution_points[0].state;
1258 let error_norm = error
1259 .squared_norm(&soln.solution_points[0].state, &soln.atol, soln.rtol)
1260 .sqrt();
1261 assert!(
1262 error_norm < error_threshold,
1263 "second-root stop state (soln[0], t ≈ {:?}) not found in last column ({last_col}); \
1264 WRMS norm {error_norm:?} ≥ {error_threshold:?}",
1265 t_stop,
1266 );
1267
1268 let t_second_root = solver.state().t;
1269 let time_tol = soln.rtol * t_stop.abs() + soln.atol.get_index(0);
1270 assert!(
1271 (t_second_root - t_stop).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1272 "expected second root time ≈ {:?}, got {:?}",
1273 t_stop,
1274 t_second_root,
1275 );
1276 }
1277
1278 pub fn test_solve_dense_sensitivities_with_reset<'a, Eqn, Method>(
1284 mut solver: Method,
1285 soln: &OdeSolverSolution<Eqn::V>,
1286 ) where
1287 Eqn: OdeEquationsImplicitSensWithReset + 'a,
1288 Eqn::V: DefaultDenseMatrix,
1289 Method: SensitivitiesOdeSolverMethod<'a, Eqn>,
1290 {
1291 let t_stop = soln.solution_points[0].t;
1292
1293 let n_steps = 20usize;
1294 let final_time = t_stop * Eqn::T::from_f64(2.0).unwrap();
1295 let dt = final_time / Eqn::T::from_f64(n_steps as f64).unwrap();
1296 let t_eval: Vec<Eqn::T> = (0..=n_steps)
1297 .map(|i| dt * Eqn::T::from_f64(i as f64).unwrap())
1298 .collect();
1299
1300 let (ret_first, _ret_sens_first, stop_reason_first) =
1301 solver.solve_dense_sensitivities(&t_eval).unwrap();
1302 assert!(matches!(
1303 stop_reason_first,
1304 OdeSolverStopReason::RootFound(_, _)
1305 ));
1306 let ncols_first = ret_first.ncols();
1307
1308 assert!(
1310 ncols_first < t_eval.len(),
1311 "expected first solve_dense_sensitivities() call to stop at a root"
1312 );
1313 let t_first_root = solver.state().t;
1314
1315 let first_root_idx = match stop_reason_first {
1316 OdeSolverStopReason::RootFound(_, root_idx) => root_idx,
1317 _ => unreachable!("expected first sensitivity solve to stop on a root"),
1318 };
1319
1320 let mut state = solver.state_clone();
1322 {
1323 let problem = solver.problem();
1324 let reset_fn = problem.eqn.reset().unwrap();
1325 let root_fn = problem.eqn.root().unwrap();
1326 state
1327 .state_mut_op_with_sens_and_reset(&problem.eqn, &reset_fn, &root_fn, first_root_idx)
1328 .unwrap();
1329 }
1330 solver.set_state(state);
1331
1332 let t_eval_after_reset: Vec<Eqn::T> = t_eval
1334 .iter()
1335 .copied()
1336 .filter(|&t| t > t_first_root)
1337 .collect();
1338 assert!(
1339 !t_eval_after_reset.is_empty(),
1340 "expected at least one evaluation time after first root"
1341 );
1342
1343 let (ret_second, ret_sens_second, stop_reason_second) = solver
1344 .solve_dense_sensitivities(&t_eval_after_reset)
1345 .unwrap();
1346 assert!(matches!(
1347 stop_reason_second,
1348 OdeSolverStopReason::RootFound(_, _)
1349 ));
1350 let ncols = ret_second.ncols();
1351
1352 assert!(
1354 ncols < t_eval_after_reset.len(),
1355 "expected early stop after manual reset: ncols ({ncols}) should be < t_eval_after_reset.len() ({})",
1356 t_eval_after_reset.len(),
1357 );
1358
1359 let expected = &soln.solution_points[0];
1361 let error_threshold = Eqn::T::from_f64(100.0).unwrap();
1362 let sens_points = soln.sens_solution_points.as_ref().unwrap();
1363
1364 let last_col = ncols - 1;
1365 let ey = ret_second.column(last_col).into_owned() - &expected.state;
1366 let mut combined = ey.squared_norm(&expected.state, &soln.atol, soln.rtol);
1367 for (param_j, sens_pts_j) in sens_points.iter().enumerate() {
1368 let expected_s = &sens_pts_j[0].state;
1369 let es = ret_sens_second[param_j].column(last_col).into_owned() - expected_s;
1370 combined += es.squared_norm(expected_s, &soln.atol, soln.rtol);
1371 }
1372 let norm = combined.sqrt();
1373 assert!(
1374 norm < error_threshold,
1375 "t_stop solution not found in last column; combined WRMS {norm:?} ≥ {error_threshold:?}",
1376 );
1377
1378 let t_second_root = solver.state().t;
1379 let time_tol = soln.rtol * t_stop.abs() + soln.atol.get_index(0);
1380 assert!(
1381 (t_second_root - t_stop).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1382 "expected second root time ≈ {:?}, got {:?}",
1383 t_stop,
1384 t_second_root,
1385 );
1386 }
1387
1388 pub fn test_solve_adjoint_with_single_reset_root<
1389 'a,
1390 Eqn,
1391 MethodF,
1392 MethodB,
1393 BuildForward,
1394 BuildAdjointState,
1395 BuildAdjointFromState,
1396 >(
1397 build_forward: BuildForward,
1398 soln: &OdeSolverSolution<Eqn::V>,
1399 build_adjoint_state: BuildAdjointState,
1400 build_adjoint_from_state: BuildAdjointFromState,
1401 ) where
1402 Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
1403 Eqn::M: DefaultSolver,
1404 Eqn::V: DefaultDenseMatrix,
1405 MethodF: OdeSolverMethod<'a, Eqn>,
1406 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1407 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1408 BuildAdjointState:
1409 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1410 BuildAdjointFromState:
1411 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1412 {
1413 let expected_out = &soln.solution_points[0];
1414 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1415
1416 let mut first_forward_solver = build_forward(None).unwrap();
1417 let (pre_reset_checkpointer, _, _, _) = first_forward_solver
1418 .solve_with_checkpointing(forward_stop_time, None)
1419 .unwrap();
1420 let fwd_state_minus = first_forward_solver.into_state();
1421 let mut state_after_reset = fwd_state_minus.clone();
1422 let problem = pre_reset_checkpointer.problem();
1423 let reset_fn = problem.eqn.reset().unwrap();
1424 state_after_reset
1425 .state_mut_op(&problem.eqn, &reset_fn)
1426 .unwrap();
1427 let fwd_state_plus = state_after_reset.clone();
1428
1429 let mut second_forward_solver = build_forward(Some(state_after_reset)).unwrap();
1430 let (post_reset_checkpointer, _, _, post_reset_stop_reason) = second_forward_solver
1431 .solve_with_checkpointing(forward_stop_time, None)
1432 .unwrap();
1433 let final_forward_state = second_forward_solver.into_state();
1434 let t_second_root = final_forward_state.as_ref().t;
1435
1436 let out_error = final_forward_state.as_ref().g.clone() - &expected_out.state;
1437 let out_norm = out_error
1438 .squared_norm(&expected_out.state, &soln.atol, soln.rtol)
1439 .sqrt();
1440 assert!(
1441 out_norm < Eqn::T::from_f64(50.0).unwrap(),
1442 "forward integrated output mismatch at second root: actual {:?}, expected {:?}, WRMS {out_norm:?}",
1443 final_forward_state.as_ref().g,
1444 expected_out.state,
1445 );
1446 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1447 assert!(
1448 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1449 "expected second root time ≈ {:?}, got {:?}",
1450 expected_out.t,
1451 t_second_root,
1452 );
1453
1454 let mut post_reset_adjoint_eqn =
1455 problem.adjoint_equations(post_reset_checkpointer.clone(), None);
1456 let mut post_reset_adjoint_state =
1457 build_adjoint_state(&mut post_reset_adjoint_eqn).unwrap();
1458 let post_reset_root_idx = match post_reset_stop_reason {
1459 OdeSolverStopReason::RootFound(_, idx) => idx,
1460 OdeSolverStopReason::TstopReached => {
1461 panic!("expected second forward segment to stop on a root, got TstopReached")
1462 }
1463 OdeSolverStopReason::InternalTimestep => {
1464 panic!("expected second forward segment to stop on a root, got InternalTimestep")
1465 }
1466 };
1467 post_reset_adjoint_state
1468 .state_mut_adjoint_terminal_root(
1469 &mut post_reset_adjoint_eqn,
1470 post_reset_root_idx,
1471 &final_forward_state,
1472 )
1473 .unwrap();
1474 let post_reset_adjoint =
1475 build_adjoint_from_state(post_reset_adjoint_state, post_reset_adjoint_eqn).unwrap();
1476 let mut adjoint_state = post_reset_adjoint
1477 .solve_adjoint_backwards_pass(Some(fwd_state_minus.as_ref().t), &[], &[])
1478 .unwrap();
1479 let t0 = pre_reset_checkpointer.problem().t0;
1480 let ctx = pre_reset_checkpointer.problem().context().clone();
1481 let reset_problem = pre_reset_checkpointer.problem();
1482 let mut pre_reset_adjoint_eqn = problem.adjoint_equations(pre_reset_checkpointer, None);
1483 {
1484 let reset_fn = reset_problem.eqn.reset().unwrap();
1485 let root_fn = reset_problem.eqn.root().unwrap();
1486 adjoint_state
1487 .state_mut_op_with_adjoint_and_reset(
1488 &mut pre_reset_adjoint_eqn,
1489 &reset_fn,
1490 &root_fn,
1491 0,
1492 &fwd_state_minus,
1493 &fwd_state_plus,
1494 )
1495 .unwrap();
1496 }
1497 let pre_reset_adjoint =
1498 build_adjoint_from_state(adjoint_state, pre_reset_adjoint_eqn).unwrap();
1499 let adjoint_state = pre_reset_adjoint
1500 .solve_adjoint_backwards_pass(None, &[], &[])
1501 .unwrap();
1502
1503 let sens_points = soln.sens_solution_points.as_ref().unwrap();
1504 let expected_grad = Eqn::V::from_vec(
1505 sens_points
1506 .iter()
1507 .map(|pts| pts[0].state.get_index(0))
1508 .collect(),
1509 ctx.clone(),
1510 );
1511 let atol = Eqn::V::from_element(expected_grad.len(), Eqn::T::from_f64(1e-6).unwrap(), ctx);
1512 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1513 assert!(
1514 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1515 "expected adjoint final time {:?}, got {:?}",
1516 t0,
1517 adjoint_state.as_ref().t,
1518 );
1519 adjoint_state.as_ref().sg[0].assert_eq_norm(
1520 &expected_grad,
1521 &atol,
1522 Eqn::T::from_f64(1e-6).unwrap(),
1523 Eqn::T::from_f64(60.0).unwrap(),
1524 );
1525 }
1526
1527 pub fn test_solve_adjoint_sum_squares_with_single_reset_root<
1528 'a,
1529 Eqn,
1530 MethodF,
1531 MethodB,
1532 BuildForward,
1533 BuildAdjointState,
1534 BuildAdjointFromState,
1535 >(
1536 build_forward: BuildForward,
1537 soln: &OdeSolverSolution<Eqn::V>,
1538 build_adjoint_state: BuildAdjointState,
1539 build_adjoint_from_state: BuildAdjointFromState,
1540 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
1541 data: <Eqn::V as DefaultDenseMatrix>::M,
1542 times: &[Eqn::T],
1543 ) where
1544 Eqn: OdeEquationsImplicitAdjointWithReset + 'a,
1545 Eqn::M: DefaultSolver,
1546 Eqn::V: DefaultDenseMatrix,
1547 MethodF: OdeSolverMethod<'a, Eqn>,
1548 MethodB: AdjointOdeSolverMethod<'a, Eqn, MethodF, State = MethodF::State>,
1549 BuildForward: Fn(Option<MethodF::State>) -> Result<MethodF, DiffsolError>,
1550 BuildAdjointState:
1551 Fn(&mut AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodF::State, DiffsolError>,
1552 BuildAdjointFromState:
1553 Fn(MethodF::State, AdjointEquations<'a, Eqn, MethodF>) -> Result<MethodB, DiffsolError>,
1554 {
1555 let expected_out = &soln.solution_points[0];
1556 let forward_stop_time = expected_out.t + Eqn::T::from_f64(1.0).unwrap();
1557 let forwards_soln =
1558 solve_dense_with_single_reset_root::<Eqn, MethodF, _>(&build_forward, times);
1559 assert_eq!(
1560 forwards_soln.ncols(),
1561 times.len(),
1562 "expected stitched forward samples to cover every requested observation time",
1563 );
1564 let dgdu = dsum_squaresdp(&forwards_soln, &data);
1565 let dgdu_refs = dgdu.iter().collect::<Vec<_>>();
1566
1567 let mut first_forward_solver = build_forward(None).unwrap();
1568 let (pre_reset_checkpointer, _, _, pre_reset_stop_reason) = first_forward_solver
1569 .solve_with_checkpointing(forward_stop_time, None)
1570 .unwrap();
1571 let fwd_state_minus = first_forward_solver.into_state();
1572 match pre_reset_stop_reason {
1573 OdeSolverStopReason::RootFound(_, 0) => {}
1574 OdeSolverStopReason::RootFound(_, idx) => {
1575 panic!("expected first checkpointed segment to stop on root 0, got root {idx}")
1576 }
1577 OdeSolverStopReason::TstopReached => {
1578 panic!("expected first checkpointed segment to stop on the interior root")
1579 }
1580 OdeSolverStopReason::InternalTimestep => {
1581 panic!("first checkpointed segment ended without a terminal stop reason")
1582 }
1583 }
1584
1585 let mut state_after_reset = fwd_state_minus.clone();
1586 let problem = pre_reset_checkpointer.problem();
1587 let reset_fn = problem.eqn.reset().unwrap();
1588 state_after_reset
1589 .state_mut_op(&problem.eqn, &reset_fn)
1590 .unwrap();
1591 let fwd_state_plus = state_after_reset.clone();
1592
1593 let mut second_forward_solver = build_forward(Some(state_after_reset)).unwrap();
1594 let (post_reset_checkpointer, _, _, post_reset_stop_reason) = second_forward_solver
1595 .solve_with_checkpointing(forward_stop_time, None)
1596 .unwrap();
1597 let final_forward_state = second_forward_solver.into_state();
1598 let t_second_root = final_forward_state.as_ref().t;
1599
1600 let time_tol = soln.rtol * expected_out.t.abs() + soln.atol.get_index(0);
1601 assert!(
1602 (t_second_root - expected_out.t).abs() < Eqn::T::from_f64(30.0).unwrap() * time_tol,
1603 "expected second root time ≈ {:?}, got {:?}",
1604 expected_out.t,
1605 t_second_root,
1606 );
1607
1608 let mut post_reset_adjoint_eqn =
1609 problem.adjoint_equations(post_reset_checkpointer.clone(), Some(dgdu.len()));
1610 let mut post_reset_adjoint_state =
1611 build_adjoint_state(&mut post_reset_adjoint_eqn).unwrap();
1612 let post_reset_root_idx = match post_reset_stop_reason {
1613 OdeSolverStopReason::RootFound(_, idx) => idx,
1614 OdeSolverStopReason::TstopReached => {
1615 panic!("expected second forward segment to stop on a root, got TstopReached")
1616 }
1617 OdeSolverStopReason::InternalTimestep => {
1618 panic!("expected second forward segment to stop on a root, got InternalTimestep")
1619 }
1620 };
1621 post_reset_adjoint_state
1622 .state_mut_adjoint_terminal_root(
1623 &mut post_reset_adjoint_eqn,
1624 post_reset_root_idx,
1625 &final_forward_state,
1626 )
1627 .unwrap();
1628 let post_reset_adjoint =
1629 build_adjoint_from_state(post_reset_adjoint_state, post_reset_adjoint_eqn).unwrap();
1630 let mut adjoint_state = post_reset_adjoint
1631 .solve_adjoint_backwards_pass(
1632 Some(fwd_state_minus.as_ref().t),
1633 times,
1634 dgdu_refs.as_slice(),
1635 )
1636 .unwrap();
1637
1638 let t0 = pre_reset_checkpointer.problem().t0;
1639 let ctx = pre_reset_checkpointer.problem().context().clone();
1640 let reset_problem = pre_reset_checkpointer.problem();
1641 let mut pre_reset_adjoint_eqn =
1642 problem.adjoint_equations(pre_reset_checkpointer, Some(dgdu.len()));
1643 {
1644 let reset_fn = reset_problem.eqn.reset().unwrap();
1645 let root_fn = reset_problem.eqn.root().unwrap();
1646 adjoint_state
1647 .state_mut_op_with_adjoint_and_reset(
1648 &mut pre_reset_adjoint_eqn,
1649 &reset_fn,
1650 &root_fn,
1651 0,
1652 &fwd_state_minus,
1653 &fwd_state_plus,
1654 )
1655 .unwrap();
1656 }
1657 let pre_reset_adjoint =
1658 build_adjoint_from_state(adjoint_state, pre_reset_adjoint_eqn).unwrap();
1659 let adjoint_state = pre_reset_adjoint
1660 .solve_adjoint_backwards_pass(None, times, dgdu_refs.as_slice())
1661 .unwrap();
1662
1663 let nparams = dgdp_check.nrows();
1664 let atol = Eqn::V::from_element(nparams, Eqn::T::from_f64(1e-6).unwrap(), ctx);
1665 let t0_tol = Eqn::T::from_f64(10.0).unwrap() * Eqn::T::EPSILON;
1666 assert!(
1667 (adjoint_state.as_ref().t - t0).abs() <= t0_tol,
1668 "expected adjoint final time {:?}, got {:?}",
1669 t0,
1670 adjoint_state.as_ref().t,
1671 );
1672 #[allow(clippy::needless_range_loop)]
1673 for j in 0..dgdp_check.ncols() {
1674 adjoint_state.as_ref().sg[j].assert_eq_norm(
1675 &dgdp_check.column(j).into_owned(),
1676 &atol,
1677 Eqn::T::from_f64(1e-6).unwrap(),
1678 Eqn::T::from_f64(260.0).unwrap(),
1679 );
1680 }
1681 }
1682}