bulirsch/lib.rs
1//! Implementation of the Bulirsch-Stoer method for stepping ordinary differential equations.
2//!
3//! The [(Gragg-)Bulirsch-Stoer](https://en.wikipedia.org/wiki/Bulirsch%E2%80%93Stoer_algorithm)
4//! algorithm combines the (modified) midpoint method with Richardson extrapolation to accelerate
5//! convergence. It is an explicit method that does not require Jacobians.
6//!
7//! This crate's implementation contains simplistic adaptive step size routines with order
8//! estimation. Its API is designed to be useful in situations where an ODE is being integrated step
9//! by step with a prescribed time step, for example in integrated simulations of electromechanical
10//! control systems with a fixed control cycle period. Only time-independent ODEs are supported, but
11//! without loss of generality (since the state vector can be augmented with a time variable if
12//! needed).
13//!
14//! The implementation follows:
15//! * Press, William H. Numerical Recipes 3rd Edition: The Art of Scientific Computing. Cambridge
16//! University Press, 2007. Ch. 17.3.2.
17//! * Deuflhard, Peter. "Order and stepsize control in extrapolation methods." Numerische Mathematik
18//! 41 (1983): 399-422.
19//!
20//! As an example, consider a simple oscillator system:
21//!
22//! ```
23//! // Define ODE.
24//! struct OscillatorSystem {
25//! omega: f64,
26//! }
27//!
28//! impl bulirsch::System for OscillatorSystem {
29//! type Float = f64;
30//!
31//! fn system(
32//! &self,
33//! y: bulirsch::ArrayView1<Self::Float>,
34//! mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
35//! ) {
36//! dydt[[0]] = y[[1]];
37//! dydt[[1]] = -self.omega.powi(2) * y[[0]];
38//! }
39//! }
40//!
41//! let system = OscillatorSystem { omega: 1.2 };
42//!
43//! // Set up the integrator.
44//! let mut integrator = bulirsch::Integrator::default()
45//! .with_abs_tol(1e-8)
46//! .with_rel_tol(1e-8)
47//! .into_adaptive();
48//!
49//! // Define initial conditions and provide solution storage.
50//! let delta_t: f64 = 10.2;
51//! let mut y = ndarray::array![1., 0.];
52//! let mut y_next = ndarray::Array::zeros(y.raw_dim());
53//!
54//! // Integrate for 10 steps.
55//! let num_steps = 10;
56//! for _ in 0..num_steps {
57//! integrator
58//! .step(&system, delta_t, y.view(), y_next.view_mut())
59//! .unwrap();
60//! y.assign(&y_next);
61//! }
62//!
63//! // Ensure result matches analytic solution.
64//! approx::assert_relative_eq!(
65//! (system.omega * delta_t * num_steps as f64).cos(),
66//! y_next[[0]],
67//! epsilon = 5e-7,
68//! max_relative = 5e-7,
69//! );
70//!
71//! // Check integration performance.
72//! assert_eq!(integrator.overall_stats().num_system_evals, 3843);
73//! approx::assert_relative_eq!(integrator.step_size().unwrap(), 2.14, epsilon = 1e-2);
74//! ```
75//!
76//! Note that 3.7k system evaluations have been used. By contrast, the `ode_solvers::Dopri5`
77//! algorithm uses more:
78//!
79//! ```
80//! struct OscillatorSystem {
81//! omega: f64,
82//! }
83//!
84//! impl ode_solvers::System<f64, ode_solvers::SVector<f64, 2>> for OscillatorSystem {
85//! fn system(
86//! &self,
87//! _x: f64,
88//! y: &ode_solvers::SVector<f64, 2>,
89//! dy: &mut ode_solvers::SVector<f64, 2>,
90//! ) {
91//! dy[0] = y[1];
92//! dy[1] = -self.omega.powi(2) * y[0];
93//! }
94//! }
95//!
96//! let omega = 1.2;
97//! let delta_t: f64 = 10.2;
98//! let mut num_system_eval = 0;
99//! let mut y = ode_solvers::Vector2::new(1., 0.);
100//! let num_steps = 10;
101//! for _ in 0..num_steps {
102//! let system = OscillatorSystem { omega };
103//! let mut solver = ode_solvers::Dopri5::new(
104//! system,
105//! 0.,
106//! delta_t,
107//! delta_t,
108//! y,
109//! 1e-8,
110//! 1e-8,
111//! );
112//! num_system_eval += solver.integrate().unwrap().num_eval;
113//! y = *solver.y_out().get(1).unwrap();
114//! }
115//! assert_eq!(num_system_eval, 7476);
116//!
117//! // Ensure result matches analytic solution.
118//! approx::assert_relative_eq!(
119//! (omega * delta_t * num_steps as f64).cos(),
120//! y[0],
121//! epsilon = 5e-7,
122//! max_relative = 5e-7,
123//! );
124//! ```
125//!
126//! As of writing this, the latest version of `ode_solvers`, 0.6.1, also gives a dramatically
127//! incorrect answer likely due to a regression. As a result we use version 0.5 as a dev dependency.
128
129#![expect(
130 non_snake_case,
131 reason = "Used for math symbols to match notation in Numerical Recipes"
132)]
133
134pub use nd::ArrayView1;
135pub use nd::ArrayViewMut1;
136use ndarray as nd;
137
138pub trait Float:
139 num_traits::Float
140 + core::iter::Sum
141 + core::ops::AddAssign
142 + core::ops::MulAssign
143 + core::fmt::Debug
144 + nd::ScalarOperand
145{
146}
147
148impl Float for f32 {}
149impl Float for f64 {}
150
151/// Trait for defining an ordinary differential equation system.
152pub trait System {
153 /// The floating point type.
154 type Float: Float;
155
156 /// Evaluate the ordinary differential equation and store the derivative in `dydt`.
157 fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
158}
159
160/// Error generated when integration produced a step size smaller than the minimum allowed step
161/// size.
162#[derive(Debug)]
163pub struct StepSizeUnderflow<F: Float>(F);
164
165/// Statistics from taking an integration step.
166#[derive(Clone, Debug)]
167pub struct Stats {
168 /// Number of system function evaluations.
169 pub num_system_evals: usize,
170}
171
172/// An ODE integrator using the Bulirsch-Stoer algorithm with an adaptive step size and adaptive
173/// order.
174///
175/// Should be constructed using [`Integrator::into_adaptive`].
176#[derive(Clone)]
177pub struct AdaptiveIntegrator<F: Float> {
178 /// The underlying non-adaptive integrator.
179 integrator: Integrator<F>,
180
181 /// The current step size.
182 step_size: Option<F>,
183 /// The minimum step size to allow before returning [`StepSizeUnderflow`].
184 min_step_size: F,
185 /// The maximum step size to allow.
186 max_step_size: Option<F>,
187
188 /// The current estimated target number of iterations to use.
189 target_order: usize,
190 /// The maximum number of iterations to use.
191 max_order: usize,
192
193 /// Overall stats.
194 overall_stats: Stats,
195}
196
197impl<F: Float> AdaptiveIntegrator<F> {
198 /// Take a step using the Bulirsch-Stoer method.
199 ///
200 /// # Arguments
201 ///
202 /// * `system`: The ODE system.
203 /// * `delta_t`: The size of the prescribed time step to take.
204 /// * `y_init`: The initial state vector at the start of the time step.
205 /// * `y_final`: The vector into which to store the final computed state at the end of the time
206 /// step.
207 ///
208 /// # Result
209 ///
210 /// Stats providing information about integration performance, or an error if integration
211 /// failed.
212 ///
213 /// # Examples
214 ///
215 /// Note that if you're using e.g. `nalgebra` to define your ODE, you can bridge to [`ndarray`]
216 /// vectors using slices, as long as you're using `nalgebra`'s dynamically sized vectors. The
217 /// same applies to using [`Vec`]s, etc. For example, consider a simple oscillator system
218 /// defined using `nalgebra`:
219 ///
220 /// ```
221 /// // Define oscillator ODE.
222 /// #[derive(Clone, Copy)]
223 /// struct OscillatorSystem {
224 /// omega: f32,
225 /// }
226 ///
227 /// fn compute_dydt(
228 /// omega: f32,
229 /// y: nalgebra::DVectorView<f32>,
230 /// mut dydt: nalgebra::DVectorViewMut<f32>,
231 /// ) {
232 /// dydt[0] = y[1];
233 /// dydt[1] = -omega.powi(2) * y[0];
234 /// }
235 ///
236 /// impl bulirsch::System for OscillatorSystem {
237 /// type Float = f32;
238 ///
239 /// fn system(
240 /// &self,
241 /// y: bulirsch::ArrayView1<Self::Float>,
242 /// mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
243 /// ) {
244 /// let y_nalgebra = nalgebra::DVectorView::from_slice(
245 /// y.as_slice().unwrap(),
246 /// y.len(),
247 /// );
248 /// let dydt_nalgebra = nalgebra::DVectorViewMut::from_slice(
249 /// dydt.as_slice_mut().unwrap(),
250 /// y.len(),
251 /// );
252 /// compute_dydt(self.omega, y_nalgebra, dydt_nalgebra);
253 /// }
254 /// }
255 ///
256 /// // Instantiate system and integrator.
257 /// let system = OscillatorSystem { omega: 1.2 };
258 /// let mut integrator =
259 /// bulirsch::Integrator::default()
260 /// .with_abs_tol(1e-6)
261 /// .with_rel_tol(0.)
262 /// .into_adaptive();
263 ///
264 /// // Define initial conditions and integrate.
265 /// let mut y = ndarray::array![1., 0.];
266 /// let mut y_next = ndarray::Array1::zeros(y.raw_dim());
267 /// let delta_t = 0.6;
268 /// let num_steps = 10;
269 /// let mut num_system_evals = 0;
270 /// for _ in 0..num_steps {
271 /// num_system_evals += integrator
272 /// .step(&system, delta_t, y.view(), y_next.view_mut())
273 /// .unwrap()
274 /// .num_system_evals;
275 /// y.assign(&y_next);
276 /// }
277 ///
278 /// // Check against analytic solution.
279 /// let (sin, cos) = (delta_t * num_steps as f32 * system.omega).sin_cos();
280 /// approx::assert_relative_eq!(y_next[0], cos, epsilon = 1e-2);
281 /// approx::assert_relative_eq!(
282 /// y_next[1],
283 /// -system.omega * sin,
284 /// epsilon = 1e-2
285 /// );
286 ///
287 /// // Check integrator performance.
288 /// assert_eq!(num_system_evals, 310);
289 /// ```
290 pub fn step<S: System<Float = F>>(
291 &mut self,
292 system: &S,
293 delta_t: S::Float,
294 y_init: nd::ArrayView1<S::Float>,
295 mut y_final: nd::ArrayViewMut1<S::Float>,
296 ) -> Result<Stats, StepSizeUnderflow<F>> {
297 let mut step_size = if let Some(step_size) = self.step_size {
298 step_size
299 } else {
300 delta_t
301 };
302
303 let mut system = SystemEvaluationCounter {
304 system,
305 num_system_evals: 0,
306 };
307
308 // Iteratively take steps until taking a step would put us past the input `delta_t`. At that
309 // point, take an exact step to finish `delta_t`. Dynamically adjust the step size to
310 // control truncation error as we go.
311 let mut y_before_step = y_init.to_owned();
312 let mut y_after_step = y_init.to_owned();
313 let mut t = F::zero();
314 loop {
315 if step_size < self.min_step_size || !step_size.is_finite() {
316 return Err(StepSizeUnderflow(step_size));
317 }
318
319 // We set `next_t` to `None` if we're at the tail end of `delta_t` and are taking a
320 // smaller step than is optimal so we don't overshoot.
321 let next_t = if t < delta_t - step_size {
322 Some((t + step_size).min(delta_t))
323 } else {
324 None
325 };
326 step_size = step_size.min(delta_t - t);
327
328 let extrapolation_result = self.integrator.extrapolate(
329 &mut system,
330 step_size,
331 self.target_order,
332 y_before_step.view(),
333 y_after_step.view_mut(),
334 );
335
336 match (extrapolation_result.converged(), next_t) {
337 // The step was successful, and we're at the end of `delta_t`. Done.
338 (true, None) => {
339 self.perform_step_size_control(&extrapolation_result, &mut step_size);
340 break;
341 }
342 // The step was successful, and we're not at the end of `delta_t`. Potentially
343 // adjust `target_order`, adjust step size, and continue.
344 (true, Some(next_t)) => {
345 self.perform_order_and_step_size_control(&extrapolation_result, &mut step_size);
346 t = next_t;
347 y_before_step.assign(&y_after_step);
348 }
349 // The step failed. Adjust step size, but for simplicity, unlike Numerical Recipes,
350 // don't try to adjust order. Try again.
351 (false, _) => {
352 self.perform_step_size_control(&extrapolation_result, &mut step_size);
353 }
354 }
355 }
356
357 self.step_size = Some(step_size);
358 y_final.assign(&y_after_step);
359 self.overall_stats.num_system_evals += system.num_system_evals;
360
361 Ok(Stats {
362 num_system_evals: system.num_system_evals,
363 })
364 }
365
366 /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
367 pub fn with_min_step_size(self, min_step_size: F) -> Self {
368 Self {
369 min_step_size,
370 ..self
371 }
372 }
373 /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
374 pub fn with_max_step_size(self, max_step_size: Option<F>) -> Self {
375 Self {
376 max_step_size,
377 ..self
378 }
379 }
380 /// Set the maximum "order" to use, i.e. max number of iterations per extrapolation.
381 pub fn with_max_order(self, max_order: usize) -> Self {
382 Self { max_order, ..self }
383 }
384
385 /// Get overall stats across all steps taken so far.
386 pub fn overall_stats(&self) -> &Stats {
387 &self.overall_stats
388 }
389 /// Get the current step size.
390 pub fn step_size(&self) -> Option<F> {
391 self.step_size
392 }
393 /// Get the current target order.
394 pub fn target_order(&self) -> usize {
395 self.target_order
396 }
397
398 fn compute_step_size_adjustment_factor(
399 extrapolation_result: &ExtrapolationResult<F>,
400 target_order: usize,
401 ) -> F {
402 let scaled_truncation_error = *extrapolation_result
403 .scaled_truncation_errors
404 .get(target_order)
405 .unwrap();
406
407 let safety_factor: F = cast(0.9);
408 let min_step_size_decrease_factor: F = cast(0.01);
409 let max_step_size_increase_factor = min_step_size_decrease_factor.recip();
410
411 if scaled_truncation_error > F::zero() {
412 // Eq. 2.14, Deuflhard.
413 (safety_factor / scaled_truncation_error.powf(F::one() / cast(2 * target_order + 1)))
414 .max(min_step_size_decrease_factor)
415 .min(max_step_size_increase_factor)
416 } else if scaled_truncation_error == F::zero() {
417 cast(2)
418 } else {
419 // Handle NaNs.
420 cast(0.5)
421 }
422 }
423
424 fn perform_step_size_control(
425 &self,
426 extrapolation_result: &ExtrapolationResult<F>,
427 step_size: &mut F,
428 ) {
429 let adjustment_factor =
430 Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
431 *step_size *= adjustment_factor;
432
433 if let Some(max_step_size) = self.max_step_size {
434 *step_size = step_size.min(max_step_size);
435 }
436 }
437
438 fn perform_order_and_step_size_control(
439 &mut self,
440 extrapolation_result: &ExtrapolationResult<F>,
441 step_size: &mut F,
442 ) {
443 let adjustment_factor =
444 Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
445
446 // This follows eqs. 17.3.14 & 17.3.15 in Numerical Recipes.
447 if self.target_order > 0 {
448 let adjustment_factor_lower_order = Self::compute_step_size_adjustment_factor(
449 &extrapolation_result,
450 self.target_order - 1,
451 );
452
453 let work = cast::<_, F>(compute_work(self.target_order));
454 let work_per_step = work / *step_size / adjustment_factor;
455 let work_lower_order = cast::<_, F>(compute_work(self.target_order - 1));
456 let work_per_step_lower_order =
457 work_lower_order / *step_size / adjustment_factor_lower_order;
458
459 self.target_order = if work_per_step_lower_order < cast::<_, F>(0.8) * work_per_step
460 && self.target_order > 1
461 {
462 // Decrease order since a lower order requires less work.
463 *step_size *= adjustment_factor_lower_order;
464 self.target_order - 1
465 } else if work_per_step < cast::<_, F>(0.95) * work_per_step_lower_order
466 && self.target_order + 1 <= self.max_order
467 {
468 // Increase order since a higher order is heuristically indicated to require less
469 // work (even though we didn't extrapolate to this order, so can't tell for sure).
470 // We use 0.95 above instead of 0.9 from Numerical Recipes since it produced better
471 // performance on the tests.
472 let work_higher_order = cast::<_, F>(compute_work(self.target_order + 1));
473 *step_size *= adjustment_factor * work_higher_order / work;
474 self.target_order + 1
475 } else {
476 // Preserve order and only adjust step size.
477 *step_size *= adjustment_factor;
478 self.target_order
479 };
480 } else {
481 *step_size *= adjustment_factor;
482 }
483
484 if let Some(max_step_size) = self.max_step_size {
485 *step_size = step_size.min(max_step_size);
486 }
487 }
488}
489
490/// An ODE integrator using the Bulirsch-Stoer algorithm with a fixed step size.
491///
492/// Used to construct an [`AdaptiveIntegrator`].
493#[derive(Clone)]
494pub struct Integrator<F: Float> {
495 /// The absolute tolerance.
496 abs_tol: F,
497 /// The relative tolerance.
498 rel_tol: F,
499}
500
501impl<F: Float> Default for Integrator<F> {
502 fn default() -> Self {
503 Self {
504 abs_tol: cast(1e-6),
505 rel_tol: cast(1e-6),
506 }
507 }
508}
509
510impl<F: Float> Integrator<F> {
511 /// Make an [`AdaptiveIntegrator`].
512 pub fn into_adaptive(self) -> AdaptiveIntegrator<F> {
513 AdaptiveIntegrator {
514 integrator: self,
515 step_size: None,
516 min_step_size: cast(1e-9),
517 max_step_size: None,
518 target_order: 3,
519 max_order: 10,
520 overall_stats: Stats {
521 num_system_evals: 0,
522 },
523 }
524 }
525
526 /// Set the absolute tolerance.
527 pub fn with_abs_tol(self, abs_tol: F) -> Self {
528 Self { abs_tol, ..self }
529 }
530 /// Set the relative tolerance.
531 pub fn with_rel_tol(self, rel_tol: F) -> Self {
532 Self { rel_tol, ..self }
533 }
534
535 /// Take a single extrapolating step, iteratively subdividing `step_size`.
536 fn extrapolate<S: System<Float = F>>(
537 &self,
538 system: &mut SystemEvaluationCounter<S>,
539 step_size: F,
540 order: usize,
541 y_init: nd::ArrayView1<F>,
542 mut y_final: nd::ArrayViewMut1<F>,
543 ) -> ExtrapolationResult<F> {
544 let f_init = {
545 let mut f_init = nd::Array1::zeros(y_init.raw_dim());
546 system.system(y_init, f_init.view_mut());
547 f_init
548 };
549
550 // Build up an extrapolation tableau.
551 let mut tableau = ExtrapolationTableau(Vec::<ExtrapolationTableauRow<_>>::new());
552 for k in 0..=order + 1 {
553 let nk = compute_n(k);
554 let tableau_row = {
555 let mut Tk = Vec::with_capacity(k + 1);
556 Tk.push(self.midpoint_step(system, step_size, nk, &f_init, y_init));
557 for j in 0..k {
558 // There is a mistake in Numerical Recipes eq. 17.3.8. See
559 // https://www.numerical.recipes/forumarchive/index.php/t-2256.html.
560 let denominator = <F as num_traits::Float>::powi(
561 cast::<_, F>(nk) / cast(compute_n(k - j - 1)),
562 2,
563 ) - <F as num_traits::One>::one();
564 Tk.push(&Tk[j] + (&Tk[j] - &tableau.0[k - 1].0[j]) / denominator);
565 }
566 ExtrapolationTableauRow(Tk)
567 };
568 tableau.0.push(tableau_row);
569 }
570
571 y_final.assign(&tableau.0.last().unwrap().estimate());
572 return ExtrapolationResult {
573 scaled_truncation_errors: tableau
574 .compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
575 };
576 }
577
578 fn midpoint_step<S: System<Float = F>>(
579 &self,
580 evaluation_counter: &mut SystemEvaluationCounter<S>,
581 step_size: F,
582 n: usize,
583 f_init: &nd::Array1<F>,
584 y_init: nd::ArrayView1<F>,
585 ) -> nd::Array1<F> {
586 let substep_size = step_size / cast(n);
587 let two_substep_size = cast::<_, F>(2) * substep_size;
588
589 // 0 1 2 3 4 5 6 n
590 // ..
591 // zi zip1
592 // zip1 zi
593 // zi zip1
594 // ..
595 // zi zip1
596 let mut zi = y_init.to_owned();
597 let mut zip1 = &zi + f_init * substep_size;
598 let mut fi = f_init.clone();
599
600 for _i in 1..n {
601 core::mem::swap(&mut zi, &mut zip1);
602 evaluation_counter.system(zi.view(), fi.view_mut());
603 fi *= two_substep_size;
604 zip1 += &fi;
605 }
606
607 evaluation_counter.system(zip1.view(), fi.view_mut());
608 fi *= substep_size;
609 let mut result = zi;
610 result += &zip1;
611 result += &fi;
612 result *= cast::<_, F>(0.5);
613 result
614 }
615}
616
617/// Statistics from taking an integration step.
618#[derive(Debug)]
619struct ExtrapolationResult<F: Float> {
620 /// The scaled (including absolute and relative tolerances) truncation errors for each
621 /// iteration.
622 ///
623 /// Each will be <= 1 if convergence was achieved or > 1 if convergence was not achieved.
624 scaled_truncation_errors: Vec<F>,
625}
626
627impl<F: Float> ExtrapolationResult<F> {
628 fn converged(&self) -> bool {
629 *self.scaled_truncation_errors.last().unwrap() < F::one()
630 }
631}
632
633struct SystemEvaluationCounter<'a, S: System> {
634 system: &'a S,
635 num_system_evals: usize,
636}
637
638impl<'a, S: System> SystemEvaluationCounter<'a, S> {
639 fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
640 self.num_system_evals += 1;
641 <S as System>::system(&self.system, y, dydt);
642 }
643}
644
645struct ExtrapolationTableau<F: Float>(Vec<ExtrapolationTableauRow<F>>);
646
647impl<F: Float> ExtrapolationTableau<F> {
648 fn compute_scaled_truncation_errors(&self, abs_tol: F, rel_tol: F) -> Vec<F> {
649 self.0
650 .iter()
651 .skip(1)
652 .map(|row| row.compute_scaled_truncation_error(abs_tol, rel_tol))
653 .collect()
654 }
655}
656
657struct ExtrapolationTableauRow<F: Float>(Vec<nd::Array1<F>>);
658
659impl<F: Float> ExtrapolationTableauRow<F> {
660 fn compute_scaled_truncation_error(&self, abs_tol: F, rel_tol: F) -> F {
661 let extrap_pair = self.0.last_chunk::<2>().unwrap();
662 let y = &extrap_pair[0];
663 let y_alt = &extrap_pair[1];
664 (y.iter()
665 .zip(y_alt.iter())
666 .map(|(&yi, &yi_alt)| {
667 let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
668 (yi - yi_alt).powi(2) / scale.powi(2)
669 })
670 .sum::<F>()
671 / cast(y.len()))
672 .sqrt()
673 }
674
675 fn estimate(&self) -> &nd::Array1<F> {
676 self.0.last().unwrap()
677 }
678}
679
680/// Step size policy.
681///
682/// We use a simple linear policy based on the results in Deuflhard.
683fn compute_n(iteration: usize) -> usize {
684 2 * (iteration + 1)
685}
686
687/// Cumulative sum of `compute_n`.
688///
689/// The amount of system function evaluations required to extrapolate to a given order.
690fn compute_work(iteration: usize) -> usize {
691 2 * (iteration + 1) + 2 * iteration * (iteration + 1) / 2
692}
693
694fn cast<T: num_traits::NumCast, F: Float>(num: T) -> F {
695 num_traits::cast(num).unwrap()
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701
702 /// Test that the computation of "work" (i.e. number of system evaluations) is correct.
703 #[test]
704 fn test_compute_work() {
705 for iteration in 0..5 {
706 assert_eq!(
707 compute_work(iteration),
708 (0..=iteration).map(compute_n).sum()
709 );
710 }
711 }
712
713 struct ExpSystem {}
714
715 impl System for ExpSystem {
716 type Float = f64;
717
718 fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
719 dydt.assign(&y);
720 }
721 }
722
723 /// Ensure we can solve an exponential system to high precision.
724 #[test]
725 fn test_exp_system_high_precision() {
726 let system = ExpSystem {};
727
728 // Set up integrator with tolerance parameters.
729 let mut integrator = Integrator::default()
730 .with_abs_tol(0.)
731 .with_rel_tol(1e-14)
732 .into_adaptive();
733
734 // Define initial conditions and provide solution storage.
735 let t_final = 3.5;
736 let y = ndarray::array![1.];
737 let mut y_final = ndarray::Array::zeros([1]);
738
739 // Integrate.
740 let stats = integrator
741 .step(&system, t_final, y.view(), y_final.view_mut())
742 .unwrap();
743
744 // Ensure result matches analytic solution to high precision.
745 approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
746
747 // Check integration performance.
748 assert_eq!(stats.num_system_evals, 437);
749 approx::assert_relative_eq!(integrator.step_size().unwrap(), 0.28, epsilon = 1e-2);
750 }
751
752 /// Ensure the algorithm works even when the max order is smaller than optimal.
753 #[test]
754 fn test_exp_system_low_max_order() {
755 let system = ExpSystem {};
756
757 // Set up integrator with tolerance parameters.
758 let mut integrator = Integrator::default()
759 .with_abs_tol(0.)
760 .with_rel_tol(1e-14)
761 .into_adaptive()
762 .with_max_order(1);
763
764 // Define initial conditions and provide solution storage.
765 let t_final = 3.5;
766 let y = ndarray::array![1.];
767 let mut y_final = ndarray::Array::zeros([1]);
768
769 // Integrate.
770 integrator
771 .step(&system, t_final, y.view(), y_final.view_mut())
772 .unwrap();
773
774 // Ensure result matches analytic solution to high precision.
775 approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
776 }
777
778 /// Ensure the algorithm can handle NaNs.
779 #[test]
780 fn test_exp_system_handle_nans() {
781 struct ExpSystemWithNans {
782 hit_a_nan: core::cell::RefCell<bool>,
783 }
784
785 impl System for ExpSystemWithNans {
786 type Float = f64;
787
788 fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
789 if y[0].abs() > 10. {
790 *self.hit_a_nan.borrow_mut() = true;
791 dydt[0] = core::f64::NAN;
792 } else {
793 dydt.assign(&(-&y));
794 }
795 }
796 }
797
798 let system = ExpSystemWithNans {
799 hit_a_nan: false.into(),
800 };
801
802 // Set up integrator with tolerance parameters.
803 let mut integrator = Integrator::default()
804 .with_abs_tol(0.)
805 .with_rel_tol(1e-10)
806 .into_adaptive();
807
808 // Define initial conditions and provide solution storage.
809 let t_final = 20.;
810 let y = ndarray::array![1.];
811 let mut y_final = ndarray::Array::zeros([1]);
812
813 // Integrate.
814 let stats = integrator
815 .step(&system, t_final, y.view(), y_final.view_mut())
816 .unwrap();
817
818 // Ensure result matches analytic solution.
819 approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
820
821 // Ensure we hit at least one NaN.
822 assert!(*system.hit_a_nan.borrow());
823
824 assert_eq!(stats.num_system_evals, 1085);
825 }
826
827 /// This is for interactive debugging as it has no asserts.
828 #[test]
829 fn test_varying_timescale() {
830 struct SharpPendulumSystem {}
831
832 impl System for SharpPendulumSystem {
833 type Float = f64;
834
835 fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
836 dydt[[0]] = y[[1]];
837 dydt[[1]] = -30. * y[[0]].sin().powi(31);
838 }
839 }
840
841 let system = SharpPendulumSystem {};
842
843 let mut integrator = Integrator::default().into_adaptive();
844
845 let delta_t = 10.;
846 let num_steps = 100;
847 let mut y = ndarray::array![1., 0.];
848 let mut y_final = ndarray::Array::zeros(y.raw_dim());
849
850 for _ in 0..num_steps {
851 integrator
852 .step(&system, delta_t, y.view(), y_final.view_mut())
853 .unwrap();
854 y.assign(&y_final);
855 println!(
856 "order: {} step_size: {} y: {y}",
857 integrator.target_order(),
858 integrator.step_size().unwrap()
859 );
860 }
861 }
862}