1pub mod adjoint;
2pub mod bdf;
3pub mod bdf_state;
4pub mod builder;
5pub mod checkpointing;
6pub mod config;
7pub mod explicit_rk;
8pub mod jacobian_update;
9pub mod method;
10pub mod problem;
11pub mod runge_kutta;
12pub mod sde;
13pub mod sdirk;
14pub mod sdirk_state;
15pub mod sensitivities;
16pub mod state;
17pub mod tableau;
18
19#[cfg(test)]
20mod tests {
21 use std::rc::Rc;
22
23 use self::problem::OdeSolverSolution;
24 use nalgebra::ComplexField;
25
26 use super::*;
27 use crate::error::{DiffsolError, OdeSolverError};
28 use crate::matrix::Matrix;
29 use crate::op::unit::UnitCallable;
30 use crate::op::ParameterisedOp;
31 use crate::{
32 op::OpStatistics, AdjointOdeSolverMethod, Context, DenseMatrix, MatrixCommon, MatrixRef,
33 NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeEquationsImplicitAdjoint,
34 OdeEquationsRef, OdeSolverConfig, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
35 OdeSolverStopReason, Scale, VectorRef, VectorView, VectorViewMut,
36 };
37 use crate::{
38 ConstantOp, ConstantOpSens, DefaultDenseMatrix, DefaultSolver, LinearSolver, NonLinearOp,
39 NonLinearOpSens, Op, Vector,
40 };
41 use num_traits::{FromPrimitive, One, Zero};
42
43 pub fn test_ode_solver<'a, M, Eqn, Method>(
44 method: &mut Method,
45 solution: OdeSolverSolution<M::V>,
46 override_tol: Option<M::T>,
47 use_tstop: bool,
48 solve_for_sensitivities: bool,
49 ) -> Eqn::V
50 where
51 M: Matrix,
52 Eqn: OdeEquations<M = M, T = M::T, V = M::V> + 'a,
53 Method: OdeSolverMethod<'a, Eqn>,
54 {
55 let have_root = method.problem().eqn.root().is_some();
56 for (i, point) in solution.solution_points.iter().enumerate() {
57 let (soln, sens_soln) = if use_tstop {
58 match method.set_stop_time(point.t) {
59 Ok(_) => loop {
60 match method.step() {
61 Ok(OdeSolverStopReason::RootFound(_)) => {
62 assert!(have_root);
63 return method.state().y.clone();
64 }
65 Ok(OdeSolverStopReason::TstopReached) => {
66 break (method.state().y.clone(), method.state().s.to_vec());
67 }
68 _ => (),
69 }
70 },
71 Err(_) => (method.state().y.clone(), method.state().s.to_vec()),
72 }
73 } else {
74 while method.state().t.abs() < point.t.abs() {
75 if let OdeSolverStopReason::RootFound(t) = method.step().unwrap() {
76 assert!(have_root);
77 return method.interpolate(t).unwrap();
78 }
79 }
80 let soln = method.interpolate(point.t).unwrap();
81 let sens_soln = method.interpolate_sens(point.t).unwrap();
82 (soln, sens_soln)
83 };
84 let soln = if let Some(out) = method.problem().eqn.out() {
85 out.call(&soln, point.t)
86 } else {
87 soln
88 };
89 assert_eq!(
90 soln.len(),
91 point.state.len(),
92 "soln.len() != point.state.len()"
93 );
94 if let Some(override_tol) = override_tol {
95 soln.assert_eq_st(&point.state, override_tol);
96 } else {
97 let (rtol, atol) = if method.problem().eqn.out().is_some() {
98 (solution.rtol, &solution.atol)
100 } else {
101 (method.problem().rtol, &method.problem().atol)
102 };
103 let error = soln.clone() - &point.state;
104 let error_norm = error.squared_norm(&point.state, atol, rtol).sqrt();
105 assert!(
106 error_norm < M::T::from_f64(20.0).unwrap(),
107 "error_norm: {} at t = {}. soln: {:?}, expected: {:?}",
108 error_norm,
109 point.t,
110 soln,
111 point.state
112 );
113 if solve_for_sensitivities {
114 if let Some(sens_soln_points) = solution.sens_solution_points.as_ref() {
115 for (j, sens_points) in sens_soln_points.iter().enumerate() {
116 let sens_point = &sens_points[i];
117 let sens_soln = &sens_soln[j];
118 let error = sens_soln.clone() - &sens_point.state;
119 let error_norm =
120 error.squared_norm(&sens_point.state, atol, rtol).sqrt();
121 assert!(
122 error_norm < M::T::from_f64(29.0).unwrap(),
123 "error_norm: {error_norm} at t = {}, sens index: {j}. soln: {sens_soln:?}, expected: {:?}",
124 point.t,
125 sens_point.state
126 );
127 }
128 }
129 }
130 }
131 }
132 method.state().y.clone()
133 }
134
135 pub fn setup_test_adjoint<'a, LS, Eqn>(
136 problem: &'a mut OdeSolverProblem<Eqn>,
137 soln: OdeSolverSolution<Eqn::V>,
138 ) -> <Eqn::V as DefaultDenseMatrix>::M
139 where
140 Eqn: OdeEquationsImplicitAdjoint + 'a,
141 LS: LinearSolver<Eqn::M>,
142 Eqn::V: DefaultDenseMatrix,
143 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
144 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
145 {
146 let nparams = problem.eqn.nparams();
147 let nout = problem.eqn.nout();
148 let ctx = problem.eqn.context();
149 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
150 let final_time = soln.solution_points.last().unwrap().t;
151 let mut p_0 = Eqn::V::zeros(nparams, ctx.clone());
152 problem.eqn.get_params(&mut p_0);
153 let h_base = Eqn::T::from_f64(1e-10).unwrap();
154 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
155 h.axpy(h_base, &p_0, Eqn::T::one());
156 let p_base = p_0.clone();
157 for i in 0..nparams {
158 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
159 problem.eqn.set_params(&p_0);
160 let g_pos = {
161 let mut s = problem.bdf::<LS>().unwrap();
162 s.set_stop_time(final_time).unwrap();
163 while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
164 s.state().g.clone()
165 };
166
167 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
168 problem.eqn.set_params(&p_0);
169 let g_neg = {
170 let mut s = problem.bdf::<LS>().unwrap();
171 s.set_stop_time(final_time).unwrap();
172 while s.step().unwrap() != OdeSolverStopReason::TstopReached {}
173 s.state().g.clone()
174 };
175 p_0.set_index(i, p_base.get_index(i));
176
177 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
178 for j in 0..nout {
179 dgdp.set_index(i, j, delta.get_index(j));
180 }
181 }
182 problem.eqn.set_params(&p_base);
183 dgdp
184 }
185
186 pub(crate) fn sum_squares<DM>(soln: &DM, data: &DM) -> DM::V
189 where
190 DM: DenseMatrix,
191 {
192 let mut ret = DM::V::zeros(2, soln.context().clone());
193 for j in 0..soln.ncols() {
194 let soln_j = soln.column(j);
195 let data_j = data.column(j);
196 let delta = soln_j - data_j;
197 ret.set_index(0, ret.get_index(0) + delta.norm(2).powi(2));
198 ret.set_index(1, ret.get_index(1) + delta.norm(4).powi(4));
199 }
200 ret
201 }
202
203 pub(crate) fn dsum_squaresdp<DM>(soln: &DM, data: &DM) -> Vec<DM>
206 where
207 DM: DenseMatrix,
208 {
209 let delta = soln.clone() - data;
210 let mut delta3 = delta.clone();
211 for j in 0..delta3.ncols() {
212 let delta_col = delta.column(j).into_owned();
213
214 let mut delta3_col = delta_col.clone();
215 delta3_col.component_mul_assign(&delta_col);
216 delta3_col.component_mul_assign(&delta_col);
217
218 delta3.column_mut(j).copy_from(&delta3_col);
219 }
220 let ret = vec![
221 delta * Scale(DM::T::from_f64(2.).unwrap()),
222 delta3 * Scale(DM::T::from_f64(4.).unwrap()),
223 ];
224 ret
225 }
226
227 pub fn setup_test_adjoint_sum_squares<'a, LS, Eqn>(
228 problem: &'a mut OdeSolverProblem<Eqn>,
229 times: &[Eqn::T],
230 ) -> (
231 <Eqn::V as DefaultDenseMatrix>::M,
232 <Eqn::V as DefaultDenseMatrix>::M,
233 )
234 where
235 Eqn: OdeEquationsImplicitAdjoint + 'a,
236 LS: LinearSolver<Eqn::M>,
237 Eqn::V: DefaultDenseMatrix,
238 for<'b> &'b Eqn::V: VectorRef<Eqn::V>,
239 for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
240 {
241 let nparams = problem.eqn.nparams();
242 let nout = 2;
243 let ctx = problem.eqn.context();
244 let mut dgdp = <Eqn::V as DefaultDenseMatrix>::M::zeros(nparams, nout, ctx.clone());
245
246 let mut p_0 = ctx.vector_zeros(nparams);
247 problem.eqn.get_params(&mut p_0);
248 let h_base = Eqn::T::from_f64(1e-10).unwrap();
249 let mut h = Eqn::V::from_element(nparams, h_base, ctx.clone());
250 h.axpy(h_base, &p_0, Eqn::T::one());
251 let mut p_data = p_0.clone();
252 p_data.axpy(Eqn::T::from_f64(0.1).unwrap(), &p_0, Eqn::T::one());
253 let p_base = p_0.clone();
254
255 problem.eqn.set_params(&p_data);
256 let data = {
257 let mut s = problem.bdf::<LS>().unwrap();
258 s.solve_dense(times).unwrap()
259 };
260
261 for i in 0..nparams {
262 p_0.set_index(i, p_base.get_index(i) + h.get_index(i));
263 problem.eqn.set_params(&p_0);
264 let g_pos = {
265 let mut s = problem.bdf::<LS>().unwrap();
266 let v = s.solve_dense(times).unwrap();
267 sum_squares(&v, &data)
268 };
269
270 p_0.set_index(i, p_base.get_index(i) - h.get_index(i));
271 problem.eqn.set_params(&p_0);
272 let g_neg = {
273 let mut s = problem.bdf::<LS>().unwrap();
274 let v = s.solve_dense(times).unwrap();
275 sum_squares(&v, &data)
276 };
277
278 p_0.set_index(i, p_base.get_index(i));
279
280 let delta = (g_pos - g_neg) / Scale(Eqn::T::from_f64(2.).unwrap() * h.get_index(i));
281 for j in 0..nout {
282 dgdp.set_index(i, j, delta.get_index(j));
283 }
284 }
285 problem.eqn.set_params(&p_base);
286 (dgdp, data)
287 }
288
289 pub fn test_adjoint_sum_squares<'a, Eqn, SolverF, SolverB>(
290 backwards_solver: SolverB,
291 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
292 forwards_soln: <Eqn::V as DefaultDenseMatrix>::M,
293 data: <Eqn::V as DefaultDenseMatrix>::M,
294 times: &[Eqn::T],
295 ) where
296 SolverF: OdeSolverMethod<'a, Eqn>,
297 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
298 Eqn: OdeEquationsImplicitAdjoint + 'a,
299 Eqn::V: DefaultDenseMatrix,
300 Eqn::M: DefaultSolver,
301 {
302 let nparams = dgdp_check.nrows();
303 let dgdu = dsum_squaresdp(&forwards_soln, &data);
304
305 let atol = Eqn::V::from_element(
306 nparams,
307 Eqn::T::from_f64(1e-6).unwrap(),
308 data.context().clone(),
309 );
310 let rtol = Eqn::T::from_f64(1e-6).unwrap();
311 let state = backwards_solver
312 .solve_adjoint_backwards_pass(times, dgdu.iter().collect::<Vec<_>>().as_slice())
313 .unwrap();
314 let gs_adj = state.into_common().sg;
315 #[allow(clippy::needless_range_loop)]
316 for j in 0..dgdp_check.ncols() {
317 gs_adj[j].assert_eq_norm(
318 &dgdp_check.column(j).into_owned(),
319 &atol,
320 rtol,
321 Eqn::T::from_f64(260.).unwrap(),
322 );
323 }
324 }
325
326 pub fn test_adjoint<'a, Eqn, SolverF, SolverB>(
327 backwards_solver: SolverB,
328 dgdp_check: <Eqn::V as DefaultDenseMatrix>::M,
329 ) where
330 SolverF: OdeSolverMethod<'a, Eqn>,
331 SolverB: AdjointOdeSolverMethod<'a, Eqn, SolverF>,
332 Eqn: OdeEquationsImplicitAdjoint + 'a,
333 Eqn::V: DefaultDenseMatrix,
334 Eqn::M: DefaultSolver,
335 {
336 let nout = backwards_solver.problem().eqn.nout();
337 let atol = Eqn::V::from_element(
338 nout,
339 Eqn::T::from_f64(1e-6).unwrap(),
340 dgdp_check.context().clone(),
341 );
342 let rtol = Eqn::T::from_f64(1e-6).unwrap();
343 let state = backwards_solver
344 .solve_adjoint_backwards_pass(&[], &[])
345 .unwrap();
346 let gs_adj = state.into_common().sg;
347 #[allow(clippy::needless_range_loop)]
348 for j in 0..dgdp_check.ncols() {
349 gs_adj[j].assert_eq_norm(
350 &dgdp_check.column(j).into_owned(),
351 &atol,
352 rtol,
353 Eqn::T::from_f64(40.).unwrap(),
354 );
355 }
356 }
357
358 pub struct TestEqnInit<M: Matrix> {
359 ctx: M::C,
360 }
361
362 impl<M: Matrix> Op for TestEqnInit<M> {
363 type T = M::T;
364 type V = M::V;
365 type M = M;
366 type C = M::C;
367
368 fn nout(&self) -> usize {
369 1
370 }
371 fn nparams(&self) -> usize {
372 1
373 }
374 fn nstates(&self) -> usize {
375 1
376 }
377 fn context(&self) -> &Self::C {
378 &self.ctx
379 }
380 }
381
382 impl<M: Matrix> ConstantOp for TestEqnInit<M> {
383 fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
384 y.fill(M::T::one());
385 }
386 }
387
388 impl<M: Matrix> ConstantOpSens for TestEqnInit<M> {
389 fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
390 sens.fill(M::T::zero());
391 }
392 }
393
394 pub struct TestEqnRhs<M: Matrix> {
395 ctx: M::C,
396 }
397
398 impl<M: Matrix> Op for TestEqnRhs<M> {
399 type T = M::T;
400 type V = M::V;
401 type M = M;
402 type C = M::C;
403
404 fn nout(&self) -> usize {
405 1
406 }
407 fn nparams(&self) -> usize {
408 1
409 }
410 fn nstates(&self) -> usize {
411 1
412 }
413 fn context(&self) -> &Self::C {
414 &self.ctx
415 }
416 }
417
418 impl<M: Matrix> NonLinearOp for TestEqnRhs<M> {
419 fn call_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::V) {
420 y.fill(M::T::zero());
421 }
422 }
423
424 impl<M: Matrix> NonLinearOpJacobian for TestEqnRhs<M> {
425 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, y: &mut Self::V) {
426 y.fill(M::T::zero());
427 }
428 }
429
430 impl<M: Matrix> NonLinearOpSens for TestEqnRhs<M> {
431 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
432 sens.fill(M::T::zero());
433 }
434 }
435
436 pub struct TestEqnOut<M: Matrix> {
437 ctx: M::C,
438 }
439
440 impl<M: Matrix> Op for TestEqnOut<M> {
441 type T = M::T;
442 type V = M::V;
443 type M = M;
444 type C = M::C;
445
446 fn nout(&self) -> usize {
447 1
448 }
449 fn nparams(&self) -> usize {
450 1
451 }
452 fn nstates(&self) -> usize {
453 1
454 }
455 fn context(&self) -> &Self::C {
456 &self.ctx
457 }
458 }
459
460 impl<M: Matrix> NonLinearOp for TestEqnOut<M> {
461 fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
462 y.copy_from(x);
463 }
464 }
465
466 impl<M: Matrix> NonLinearOpJacobian for TestEqnOut<M> {
467 fn jac_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
468 y.copy_from(v);
469 }
470 }
471
472 impl<M: Matrix> NonLinearOpSens for TestEqnOut<M> {
473 fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, sens: &mut Self::V) {
474 sens.fill(M::T::zero());
475 }
476 }
477
478 pub struct TestEqn<M: Matrix> {
479 rhs: Rc<TestEqnRhs<M>>,
480 init: Rc<TestEqnInit<M>>,
481 out: Rc<TestEqnOut<M>>,
482 ctx: M::C,
483 }
484
485 impl<M: Matrix> TestEqn<M> {
486 pub fn new() -> Self {
487 let ctx = M::C::default();
488 Self {
489 rhs: Rc::new(TestEqnRhs { ctx: ctx.clone() }),
490 init: Rc::new(TestEqnInit { ctx: ctx.clone() }),
491 out: Rc::new(TestEqnOut { ctx: ctx.clone() }),
492 ctx,
493 }
494 }
495 }
496
497 impl<M: Matrix> Op for TestEqn<M> {
498 type T = M::T;
499 type V = M::V;
500 type M = M;
501 type C = M::C;
502 fn nout(&self) -> usize {
503 1
504 }
505 fn nparams(&self) -> usize {
506 1
507 }
508 fn nstates(&self) -> usize {
509 1
510 }
511 fn statistics(&self) -> crate::op::OpStatistics {
512 OpStatistics::default()
513 }
514 fn context(&self) -> &Self::C {
515 &self.ctx
516 }
517 }
518
519 impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn<M> {
520 type Rhs = &'a TestEqnRhs<M>;
521 type Mass = ParameterisedOp<'a, UnitCallable<M>>;
522 type Root = ParameterisedOp<'a, UnitCallable<M>>;
523 type Init = &'a TestEqnInit<M>;
524 type Out = &'a TestEqnOut<M>;
525 }
526
527 impl<M: Matrix> OdeEquations for TestEqn<M> {
528 fn rhs(&self) -> &TestEqnRhs<M> {
529 &self.rhs
530 }
531
532 fn mass(&self) -> Option<<Self as OdeEquationsRef<'_>>::Mass> {
533 None
534 }
535
536 fn root(&self) -> Option<<Self as OdeEquationsRef<'_>>::Root> {
537 None
538 }
539
540 fn init(&self) -> &TestEqnInit<M> {
541 &self.init
542 }
543
544 fn out(&self) -> Option<<Self as OdeEquationsRef<'_>>::Out> {
545 Some(&self.out)
546 }
547 fn set_params(&mut self, _p: &Self::V) {
548 unimplemented!()
549 }
550 fn get_params(&self, _p: &mut Self::V) {
551 unimplemented!()
552 }
553 }
554
555 pub fn test_problem<M: Matrix>(integrate_out: bool) -> OdeSolverProblem<TestEqn<M>> {
556 let eqn = TestEqn::<M>::new();
557 let atol = eqn
558 .context()
559 .vector_from_element(1, M::T::from_f64(1e-6).unwrap());
560 OdeSolverProblem::new(
561 eqn,
562 M::T::from_f64(1e-6).unwrap(),
563 atol,
564 None,
565 None,
566 None,
567 None,
568 None,
569 None,
570 M::T::zero(),
571 M::T::one(),
572 integrate_out,
573 Default::default(),
574 Default::default(),
575 )
576 .unwrap()
577 }
578
579 pub fn test_interpolate<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
580 let state = s.checkpoint();
581 let integrating_sens = !s.state().s.is_empty();
582 let integrating_out = s.problem().integrate_out;
583 let t0 = state.as_ref().t;
584 let t1 = t0 + M::T::from_f64(1e6).unwrap();
585 s.interpolate(t0)
586 .unwrap()
587 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
588 assert!(s.interpolate(t1).is_err());
589 assert!(s.interpolate_out(t1).is_err());
590 if integrating_sens {
591 assert!(s.interpolate_sens(t1).is_err());
592 } else {
593 assert!(s.interpolate_sens(t0).is_ok());
594 }
595 s.step().unwrap();
596 let tmid = t0 + (s.state().t - t0) / M::T::from_f64(2.0).unwrap();
597 assert!(s.interpolate(s.state().t).is_ok());
598 assert!(s.interpolate(tmid).is_ok());
599 if integrating_out {
600 assert!(s.interpolate_out(s.state().t).is_ok());
601 } else {
602 assert!(s.interpolate_out(s.state().t).is_err());
603 }
604 assert!(s.interpolate_sens(s.state().t).is_ok());
605 assert!(s.interpolate(s.state().t + t1).is_err());
606 assert!(s.interpolate_out(s.state().t + t1).is_err());
607 if integrating_sens {
608 assert!(s.interpolate_sens(s.state().t + t1).is_err());
609 } else {
610 assert!(s.interpolate_sens(s.state().t + t1).is_ok());
611 }
612
613 let mut y_wrong_length = M::V::zeros(2, s.problem().context().clone());
614 assert!(s
615 .interpolate_inplace(s.state().t, &mut y_wrong_length)
616 .is_err());
617 let mut g_wrong_length = M::V::zeros(2, s.problem().context().clone());
618 assert!(s
619 .interpolate_out_inplace(s.state().t, &mut g_wrong_length)
620 .is_err());
621 let mut s_wrong_length = vec![
622 M::V::zeros(1, s.problem().context().clone()),
623 M::V::zeros(1, s.problem().context().clone()),
624 ];
625 assert!(s
626 .interpolate_sens_inplace(s.state().t, &mut s_wrong_length)
627 .is_err());
628 let mut s_wrong_vec_length = if integrating_sens {
629 vec![M::V::zeros(2, s.problem().context().clone())]
630 } else {
631 vec![]
632 };
633 if integrating_sens {
634 assert!(s
635 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
636 .is_err());
637 } else {
638 assert!(s
639 .interpolate_sens_inplace(s.state().t, &mut s_wrong_vec_length)
640 .is_ok());
641 }
642
643 s.state_mut().y.fill(M::T::from_f64(3.0).unwrap());
644 assert!(s.interpolate(s.state().t).is_ok());
645 if integrating_out {
646 assert!(s.interpolate_out(s.state().t).is_ok());
647 }
648 if integrating_sens {
649 assert!(s.interpolate_sens(s.state().t).is_ok());
650 }
651 assert!(s.interpolate(tmid).is_err());
652 assert!(s.interpolate_out(tmid).is_err());
653 if integrating_sens {
654 assert!(s.interpolate_sens(tmid).is_err());
655 } else {
656 assert!(s.interpolate_sens(tmid).is_ok());
657 }
658 }
659
660 pub fn test_config<'a, Eqn: OdeEquations + 'a, Method: OdeSolverMethod<'a, Eqn>>(
661 mut s: Method,
662 ) {
663 *s.config_mut().as_base_mut().minimum_timestep = Eqn::T::from_f64(1.0e8).unwrap();
664 assert_eq!(
665 *s.config().as_base_ref().minimum_timestep,
666 Eqn::T::from_f64(1.0e8).unwrap()
667 );
668 *s.state_mut().h = Eqn::T::from_f64(0.1).unwrap();
670
671 let mut failed = false;
672 for _ in 0..10 {
673 if let Err(DiffsolError::OdeSolverError(OdeSolverError::StepSizeTooSmall { time: _ })) =
674 s.step()
675 {
676 failed = true;
677 break;
678 }
679 }
680 assert!(failed);
681 }
682
683 pub fn test_state_mut<'a, M: Matrix, Method: OdeSolverMethod<'a, TestEqn<M>>>(mut s: Method) {
684 let state = s.checkpoint();
685 let state2 = s.state();
686 state2
687 .y
688 .assert_eq_st(state.as_ref().y, M::T::from_f64(1e-9).unwrap());
689 s.state_mut()
690 .y
691 .set_index(0, M::T::from_f64(std::f64::consts::PI).unwrap());
692 assert_eq!(
693 s.state_mut().y.get_index(0),
694 M::T::from_f64(std::f64::consts::PI).unwrap()
695 );
696 }
697
698 #[cfg(feature = "diffsl-cranelift")]
699 pub fn test_ball_bounce_problem<M: crate::MatrixHost<T = f64>>(
700 ) -> OdeSolverProblem<crate::DiffSl<M, crate::CraneliftJitModule>> {
701 crate::OdeBuilder::<M>::new()
702 .build_from_diffsl(
703 "
704 g { 9.81 } h { 10.0 }
705 u_i {
706 x = h,
707 v = 0,
708 }
709 F_i {
710 v,
711 -g,
712 }
713 stop {
714 x,
715 }
716 ",
717 )
718 .unwrap()
719 }
720
721 #[cfg(feature = "diffsl-cranelift")]
722 pub fn test_ball_bounce<'a, M, Method>(mut solver: Method) -> (Vec<f64>, Vec<f64>, Vec<f64>)
723 where
724 M: crate::MatrixHost<T = f64>,
725 M: DefaultSolver<T = f64>,
726 M::V: DefaultDenseMatrix<T = f64>,
727 Method: OdeSolverMethod<'a, crate::DiffSl<M, crate::CraneliftJitModule>>,
728 {
729 let e = 0.8;
730
731 let final_time = 2.5;
732
733 solver.set_stop_time(final_time).unwrap();
735 loop {
736 match solver.step() {
737 Ok(OdeSolverStopReason::InternalTimestep) => (),
738 Ok(OdeSolverStopReason::RootFound(t)) => {
739 let mut y = solver.interpolate(t).unwrap();
741
742 y.set_index(1, y.get_index(1) * -e);
744
745 y.set_index(0, y.get_index(0).max(f64::EPSILON));
747
748 solver.state_mut().y.copy_from(&y);
750 solver.state_mut().dy.set_index(0, y.get_index(1));
751 *solver.state_mut().t = t;
752
753 break;
754 }
755 Ok(OdeSolverStopReason::TstopReached) => break,
756 Err(_) => panic!("unexpected solver error"),
757 }
758 }
759 let mut x = vec![];
761 let mut v = vec![];
762 let mut t = vec![];
763 for _ in 0..3 {
764 let ret = solver.step();
765 x.push(solver.state().y.get_index(0));
766 v.push(solver.state().y.get_index(1));
767 t.push(solver.state().t);
768 match ret {
769 Ok(OdeSolverStopReason::InternalTimestep) => (),
770 Ok(OdeSolverStopReason::RootFound(_)) => {
771 panic!("should be an internal timestep but found a root")
772 }
773 Ok(OdeSolverStopReason::TstopReached) => break,
774 _ => panic!("should be an internal timestep"),
775 }
776 }
777 (x, v, t)
778 }
779
780 pub fn test_checkpointing<'a, M, Method, Eqn>(
781 soln: OdeSolverSolution<M::V>,
782 mut solver1: Method,
783 mut solver2: Method,
784 ) where
785 M: Matrix + DefaultSolver,
786 Method: OdeSolverMethod<'a, Eqn>,
787 Eqn: OdeEquationsImplicit<M = M, T = M::T, V = M::V> + 'a,
788 {
789 let half_i = soln.solution_points.len() / 2;
790 let half_t = soln.solution_points[half_i].t;
791 while solver1.state().t <= half_t {
792 solver1.step().unwrap();
793 }
794 let checkpoint = solver1.checkpoint();
795 let checkpoint_t = checkpoint.as_ref().t;
796 solver2.set_state(checkpoint);
797
798 for point in soln.solution_points.iter().skip(half_i + 1) {
800 if point.t < checkpoint_t {
802 continue;
803 }
804 while solver2.state().t < point.t {
805 solver1.step().unwrap();
806 solver2.step().unwrap();
807 let time_error = (solver1.state().t - solver2.state().t).abs()
808 / (solver1.state().t.abs() * solver1.problem().rtol
809 + solver1.problem().atol.get_index(0));
810 assert!(
811 time_error < M::T::from_f64(20.0).unwrap(),
812 "time_error: {} at t = {}",
813 time_error,
814 solver1.state().t
815 );
816 solver1.state().y.assert_eq_norm(
817 solver2.state().y,
818 &solver1.problem().atol,
819 solver1.problem().rtol,
820 M::T::from_f64(20.0).unwrap(),
821 );
822 }
823 let soln = solver1.interpolate(point.t).unwrap();
824 soln.assert_eq_norm(
825 &point.state,
826 &solver1.problem().atol,
827 solver1.problem().rtol,
828 M::T::from_f64(15.0).unwrap(),
829 );
830 let soln = solver2.interpolate(point.t).unwrap();
831 soln.assert_eq_norm(
832 &point.state,
833 &solver1.problem().atol,
834 solver1.problem().rtol,
835 M::T::from_f64(15.0).unwrap(),
836 );
837 }
838 }
839
840 pub fn test_state_mut_on_problem<'a, Eqn, Method>(
841 mut s: Method,
842 soln: OdeSolverSolution<Eqn::V>,
843 ) where
844 Eqn: OdeEquationsImplicit + 'a,
845 Method: OdeSolverMethod<'a, Eqn>,
846 Eqn::V: DefaultDenseMatrix,
847 {
848 let state = s.checkpoint();
850 s.solve(Eqn::T::one()).unwrap();
851
852 s.state_mut().y.copy_from(state.as_ref().y);
854 s.state_mut().dy.copy_from(state.as_ref().dy);
855 *s.state_mut().t = state.as_ref().t;
856
857 for point in soln.solution_points.iter() {
859 while s.state().t < point.t {
860 s.step().unwrap();
861 }
862 let soln = s.interpolate(point.t).unwrap();
863 let error = soln.clone() - &point.state;
864 let error_norm = error
865 .squared_norm(&error, &s.problem().atol, s.problem().rtol)
866 .sqrt();
867 assert!(
868 error_norm < Eqn::T::from_f64(19.0).unwrap(),
869 "error_norm: {} at t = {}",
870 error_norm,
871 point.t
872 );
873 }
874 }
875}