bulirsch/lib.rs
1//! Implementation of the Bulirsch-Stoer method for stepping ordinary differential equations.
2//!
3//! The [(Gragg-)Bulirsch-Stoer](https://en.wikipedia.org/wiki/Bulirsch%E2%80%93Stoer_algorithm)
4//! algorithm combines the (modified) midpoint method with Richardson extrapolation to accelerate
5//! convergence. It is an explicit method that does not require Jacobians.
6//!
7//! This crate's implementation contains simplistic adaptive step size routines with order
8//! estimation. Its API is designed to be useful in situations where an ODE is being integrated step
9//! by step with a prescribed time step, for example in integrated simulations of electromechanical
10//! control systems with a fixed control cycle period. Only time-independent ODEs are supported, but
11//! without loss of generality (since the state vector can be augmented with a time variable if
12//! needed).
13//!
14//! The implementation follows:
15//! * Press, William H. Numerical Recipes 3rd Edition: The Art of Scientific Computing. Cambridge
16//! University Press, 2007. Ch. 17.3.2.
17//! * Deuflhard, Peter. "Order and stepsize control in extrapolation methods." Numerische Mathematik
18//! 41 (1983): 399-422.
19//!
20//! As an example, consider a simple oscillator system:
21//!
22//! ```
23//! // Define ODE.
24//! struct OscillatorSystem {
25//! omega: f64,
26//! }
27//!
28//! impl bulirsch::System for OscillatorSystem {
29//! type Float = f64;
30//!
31//! fn system(
32//! &self,
33//! y: bulirsch::ArrayView1<Self::Float>,
34//! mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
35//! ) {
36//! dydt[[0]] = y[[1]];
37//! dydt[[1]] = -self.omega.powi(2) * y[[0]];
38//! }
39//! }
40//!
41//! let system = OscillatorSystem { omega: 1.2 };
42//!
43//! // Set up the integrator.
44//! let mut integrator = bulirsch::Integrator::default()
45//! .with_abs_tol(1e-8)
46//! .with_rel_tol(1e-8)
47//! .into_adaptive();
48//!
49//! // Define initial conditions and provide solution storage.
50//! let delta_t: f64 = 10.2;
51//! let mut y = ndarray::array![1., 0.];
52//! let mut y_next = ndarray::Array::zeros(y.raw_dim());
53//!
54//! // Integrate for 10 steps.
55//! let num_steps = 10;
56//! for _ in 0..num_steps {
57//! integrator
58//! .step(&system, delta_t, y.view(), y_next.view_mut())
59//! .unwrap();
60//! y.assign(&y_next);
61//! }
62//!
63//! // Ensure result matches analytic solution.
64//! approx::assert_relative_eq!(
65//! (system.omega * delta_t * num_steps as f64).cos(),
66//! y_next[[0]],
67//! epsilon = 5e-7,
68//! max_relative = 5e-7,
69//! );
70//!
71//! // Check integration performance.
72//! assert_eq!(integrator.overall_stats().num_system_evals, 3770);
73//! approx::assert_relative_eq!(integrator.step_size().unwrap(), 2.10, epsilon = 1e-2);
74//! ```
75//!
76//! Note that 3.7k system evaluations have been used. By contrast, the `ode_solvers::Dopri5`
77//! algorithm uses more:
78//!
79//! ```
80//! struct OscillatorSystem {
81//! omega: f64,
82//! }
83//!
84//! impl ode_solvers::System<f64, ode_solvers::SVector<f64, 2>> for OscillatorSystem {
85//! fn system(
86//! &self,
87//! _x: f64,
88//! y: &ode_solvers::SVector<f64, 2>,
89//! dy: &mut ode_solvers::SVector<f64, 2>,
90//! ) {
91//! dy[0] = y[1];
92//! dy[1] = -self.omega.powi(2) * y[0];
93//! }
94//! }
95//!
96//! let omega = 1.2;
97//! let delta_t: f64 = 10.2;
98//! let mut num_system_eval = 0;
99//! let mut y = ode_solvers::Vector2::new(1., 0.);
100//! let num_steps = 10;
101//! for _ in 0..num_steps {
102//! let system = OscillatorSystem { omega };
103//! let mut solver = ode_solvers::Dopri5::new(
104//! system,
105//! 0.,
106//! delta_t,
107//! delta_t,
108//! y,
109//! 1e-8,
110//! 1e-8,
111//! );
112//! num_system_eval += solver.integrate().unwrap().num_eval;
113//! y = *solver.y_out().get(1).unwrap();
114//! }
115//! assert_eq!(num_system_eval, 7476);
116//!
117//! // Ensure result matches analytic solution.
118//! approx::assert_relative_eq!(
119//! (omega * delta_t * num_steps as f64).cos(),
120//! y[0],
121//! epsilon = 5e-7,
122//! max_relative = 5e-7,
123//! );
124//! ```
125//!
126//! As of writing this, the latest version of `ode_solvers`, 0.6.1, also gives a dramatically
127//! incorrect answer likely due to a regression. As a result we use version 0.5 as a dev dependency.
128
129#![expect(
130 non_snake_case,
131 reason = "Used for math symbols to match notation in Numerical Recipes"
132)]
133
134pub use nd::ArrayView1;
135pub use nd::ArrayViewMut1;
136use ndarray as nd;
137
138pub trait Float:
139 num_traits::Float
140 + core::iter::Sum
141 + core::ops::AddAssign
142 + core::ops::MulAssign
143 + core::fmt::Debug
144 + nd::ScalarOperand
145{
146}
147
148impl Float for f32 {}
149impl Float for f64 {}
150
151/// Trait for defining an ordinary differential equation system.
152pub trait System {
153 /// The floating point type.
154 type Float: Float;
155
156 /// Evaluate the ordinary differential equation and store the derivative in `dydt`.
157 fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
158}
159
160/// Error generated when integration produced a step size smaller than the minimum allowed step
161/// size.
162#[derive(Debug)]
163pub struct StepSizeUnderflow<F: Float>(F);
164
165/// Statistics from taking an integration step.
166#[derive(Clone, Debug)]
167pub struct Stats {
168 /// Number of system function evaluations.
169 pub num_system_evals: usize,
170}
171
172/// An ODE integrator using the Bulirsch-Stoer algorithm with an adaptive step size and adaptive
173/// order.
174///
175/// Should be constructed using [`Integrator::into_adaptive`].
176#[derive(Clone)]
177pub struct AdaptiveIntegrator<F: Float> {
178 /// The underlying non-adaptive integrator.
179 integrator: Integrator<F>,
180
181 /// The current step size.
182 step_size: Option<F>,
183 /// The minimum step size to allow before returning [`StepSizeUnderflow`].
184 min_step_size: F,
185 /// The maximum step size to allow.
186 max_step_size: Option<F>,
187
188 /// The current estimated target number of iterations to use.
189 target_order: usize,
190 /// The maximum number of iterations to use.
191 max_order: usize,
192
193 /// Overall stats.
194 overall_stats: Stats,
195}
196
197impl<F: Float> AdaptiveIntegrator<F> {
198 /// Take a step using the Bulirsch-Stoer method.
199 ///
200 /// # Arguments
201 ///
202 /// * `system`: The ODE system.
203 /// * `delta_t`: The size of the prescribed time step to take.
204 /// * `y_init`: The initial state vector at the start of the time step.
205 /// * `y_final`: The vector into which to store the final computed state at the end of the time
206 /// step.
207 ///
208 /// # Result
209 ///
210 /// Stats providing information about integration performance, or an error if integration
211 /// failed.
212 ///
213 /// # Examples
214 ///
215 /// Note that if you're using e.g. `nalgebra` to define your ODE, you can bridge to [`ndarray`]
216 /// vectors using slices, as long as you're using `nalgebra`'s dynamically sized vectors. The
217 /// same applies to using [`Vec`]s, etc. For example, consider a simple oscillator system
218 /// defined using `nalgebra`:
219 ///
220 /// ```
221 /// // Define oscillator ODE.
222 /// #[derive(Clone, Copy)]
223 /// struct OscillatorSystem {
224 /// omega: f32,
225 /// }
226 ///
227 /// fn compute_dydt(
228 /// omega: f32,
229 /// y: nalgebra::DVectorView<f32>,
230 /// mut dydt: nalgebra::DVectorViewMut<f32>,
231 /// ) {
232 /// dydt[0] = y[1];
233 /// dydt[1] = -omega.powi(2) * y[0];
234 /// }
235 ///
236 /// impl bulirsch::System for OscillatorSystem {
237 /// type Float = f32;
238 ///
239 /// fn system(
240 /// &self,
241 /// y: bulirsch::ArrayView1<Self::Float>,
242 /// mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
243 /// ) {
244 /// let y_nalgebra = nalgebra::DVectorView::from_slice(
245 /// y.as_slice().unwrap(),
246 /// y.len(),
247 /// );
248 /// let dydt_nalgebra = nalgebra::DVectorViewMut::from_slice(
249 /// dydt.as_slice_mut().unwrap(),
250 /// y.len(),
251 /// );
252 /// compute_dydt(self.omega, y_nalgebra, dydt_nalgebra);
253 /// }
254 /// }
255 ///
256 /// // Instantiate system and integrator.
257 /// let system = OscillatorSystem { omega: 1.2 };
258 /// let mut integrator =
259 /// bulirsch::Integrator::default()
260 /// .with_abs_tol(1e-6)
261 /// .with_rel_tol(0.)
262 /// .into_adaptive();
263 ///
264 /// // Define initial conditions and integrate.
265 /// let mut y = ndarray::array![1., 0.];
266 /// let mut y_next = ndarray::Array1::zeros(y.raw_dim());
267 /// let delta_t = 0.6;
268 /// let num_steps = 10;
269 /// let mut num_system_evals = 0;
270 /// for _ in 0..num_steps {
271 /// num_system_evals += integrator
272 /// .step(&system, delta_t, y.view(), y_next.view_mut())
273 /// .unwrap()
274 /// .num_system_evals;
275 /// y.assign(&y_next);
276 /// }
277 ///
278 /// // Check against analytic solution.
279 /// let (sin, cos) = (delta_t * num_steps as f32 * system.omega).sin_cos();
280 /// approx::assert_relative_eq!(y_next[0], cos, epsilon = 1e-2);
281 /// approx::assert_relative_eq!(
282 /// y_next[1],
283 /// -system.omega * sin,
284 /// epsilon = 1e-2
285 /// );
286 ///
287 /// // Check integrator performance.
288 /// assert_eq!(num_system_evals, 310);
289 /// ```
290 pub fn step<S: System<Float = F>>(
291 &mut self,
292 system: &S,
293 delta_t: S::Float,
294 y_init: nd::ArrayView1<S::Float>,
295 mut y_final: nd::ArrayViewMut1<S::Float>,
296 ) -> Result<Stats, StepSizeUnderflow<F>> {
297 let mut step_size = *self.step_size.get_or_insert(delta_t);
298
299 let mut system = SystemEvaluationCounter {
300 system,
301 num_system_evals: 0,
302 };
303
304 // Iteratively take steps until taking a step would put us past the input `delta_t`. At that
305 // point, take an exact step to finish `delta_t`. Dynamically adjust the step size to
306 // control truncation error as we go.
307 let mut y_before_step = y_init.to_owned();
308 let mut y_after_step = y_init.to_owned();
309 let mut t = F::zero();
310 loop {
311 if step_size < self.min_step_size || !step_size.is_finite() {
312 return Err(StepSizeUnderflow(step_size));
313 }
314
315 // We set `next_t` to `None` if we're at the tail end of `delta_t` and are taking a
316 // smaller step than is optimal so we don't overshoot.
317 let next_t = if t < delta_t - step_size {
318 Some((t + step_size).min(delta_t))
319 } else {
320 None
321 };
322 step_size = step_size.min(delta_t - t);
323
324 let extrapolation_result = self.integrator.extrapolate(
325 &mut system,
326 step_size,
327 self.target_order,
328 y_before_step.view(),
329 y_after_step.view_mut(),
330 );
331
332 match (extrapolation_result.converged(), next_t) {
333 // The step was successful, and we're at the end of `delta_t`. Done.
334 (true, None) => {
335 // If the local step size is smaller than the internally
336 // tracked step size, then we are taking an intentionally
337 // shorter step to "finish off" integrating the interval and
338 // shouldn't modify step size.
339 if step_size >= cast::<_, F>(self.step_size.unwrap()) {
340 self.perform_step_size_control(&extrapolation_result, &mut step_size);
341 }
342 break;
343 }
344 // The step was successful, and we're not at the end of `delta_t`. Potentially
345 // adjust `target_order`, adjust step size, and continue.
346 (true, Some(next_t)) => {
347 self.perform_order_and_step_size_control(&extrapolation_result, &mut step_size);
348 t = next_t;
349 y_before_step.assign(&y_after_step);
350 }
351 // The step failed. Adjust step size, but for simplicity, unlike Numerical Recipes,
352 // don't try to adjust order. Try again.
353 (false, _) => {
354 self.perform_step_size_control(&extrapolation_result, &mut step_size);
355 }
356 }
357 }
358
359 y_final.assign(&y_after_step);
360 self.overall_stats.num_system_evals += system.num_system_evals;
361
362 Ok(Stats {
363 num_system_evals: system.num_system_evals,
364 })
365 }
366
367 /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
368 pub fn with_min_step_size(self, min_step_size: F) -> Self {
369 Self {
370 min_step_size,
371 ..self
372 }
373 }
374 /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
375 pub fn with_max_step_size(self, max_step_size: Option<F>) -> Self {
376 Self {
377 max_step_size,
378 ..self
379 }
380 }
381 /// Set the maximum "order" to use, i.e. max number of iterations per extrapolation.
382 pub fn with_max_order(self, max_order: usize) -> Self {
383 Self { max_order, ..self }
384 }
385
386 /// Get overall stats across all steps taken so far.
387 pub fn overall_stats(&self) -> &Stats {
388 &self.overall_stats
389 }
390 /// Get the current step size.
391 pub fn step_size(&self) -> Option<F> {
392 self.step_size
393 }
394 /// Get the current target order.
395 pub fn target_order(&self) -> usize {
396 self.target_order
397 }
398
399 fn compute_step_size_adjustment_factor(
400 extrapolation_result: &ExtrapolationResult<F>,
401 target_order: usize,
402 ) -> F {
403 let scaled_truncation_error = *extrapolation_result
404 .scaled_truncation_errors
405 .get(target_order)
406 .unwrap();
407
408 let safety_factor: F = cast(0.9);
409 let min_step_size_decrease_factor: F = cast(0.01);
410 let max_step_size_increase_factor = min_step_size_decrease_factor.recip();
411
412 if scaled_truncation_error > F::zero() {
413 // Eq. 2.14, Deuflhard.
414 (safety_factor / scaled_truncation_error.powf(F::one() / cast(2 * target_order + 1)))
415 .max(min_step_size_decrease_factor)
416 .min(max_step_size_increase_factor)
417 } else if scaled_truncation_error == F::zero() {
418 cast(2)
419 } else {
420 // Handle NaNs.
421 cast(0.5)
422 }
423 }
424
425 fn perform_step_size_control(
426 &mut self,
427 extrapolation_result: &ExtrapolationResult<F>,
428 step_size: &mut F,
429 ) {
430 let adjustment_factor =
431 Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
432 *step_size *= adjustment_factor;
433
434 if let Some(max_step_size) = self.max_step_size {
435 *step_size = step_size.min(max_step_size);
436 }
437 self.step_size = Some(*step_size);
438 }
439
440 fn perform_order_and_step_size_control(
441 &mut self,
442 extrapolation_result: &ExtrapolationResult<F>,
443 step_size: &mut F,
444 ) {
445 let adjustment_factor =
446 Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
447
448 // This follows eqs. 17.3.14 & 17.3.15 in Numerical Recipes.
449 if self.target_order > 0 {
450 let adjustment_factor_lower_order = Self::compute_step_size_adjustment_factor(
451 &extrapolation_result,
452 self.target_order - 1,
453 );
454
455 let work = cast::<_, F>(compute_work(self.target_order));
456 let work_per_step = work / *step_size / adjustment_factor;
457 let work_lower_order = cast::<_, F>(compute_work(self.target_order - 1));
458 let work_per_step_lower_order =
459 work_lower_order / *step_size / adjustment_factor_lower_order;
460
461 self.target_order = if work_per_step_lower_order < cast::<_, F>(0.8) * work_per_step
462 && self.target_order > 1
463 {
464 // Decrease order since a lower order requires less work.
465 *step_size *= adjustment_factor_lower_order;
466 self.target_order - 1
467 } else if work_per_step < cast::<_, F>(0.95) * work_per_step_lower_order
468 && self.target_order + 1 <= self.max_order
469 {
470 // Increase order since a higher order is heuristically indicated to require less
471 // work (even though we didn't extrapolate to this order, so can't tell for sure).
472 // We use 0.95 above instead of 0.9 from Numerical Recipes since it produced better
473 // performance on the tests.
474 let work_higher_order = cast::<_, F>(compute_work(self.target_order + 1));
475 *step_size *= adjustment_factor * work_higher_order / work;
476 self.target_order + 1
477 } else {
478 // Preserve order and only adjust step size.
479 *step_size *= adjustment_factor;
480 self.target_order
481 };
482 } else {
483 *step_size *= adjustment_factor;
484 }
485
486 if let Some(max_step_size) = self.max_step_size {
487 *step_size = step_size.min(max_step_size);
488 }
489 self.step_size = Some(*step_size);
490 }
491}
492
493/// An ODE integrator using the Bulirsch-Stoer algorithm with a fixed step size.
494///
495/// Used to construct an [`AdaptiveIntegrator`].
496#[derive(Clone)]
497pub struct Integrator<F: Float> {
498 /// The absolute tolerance.
499 abs_tol: F,
500 /// The relative tolerance.
501 rel_tol: F,
502}
503
504impl<F: Float> Default for Integrator<F> {
505 fn default() -> Self {
506 Self {
507 abs_tol: cast(1e-6),
508 rel_tol: cast(1e-6),
509 }
510 }
511}
512
513impl<F: Float> Integrator<F> {
514 /// Make an [`AdaptiveIntegrator`].
515 pub fn into_adaptive(self) -> AdaptiveIntegrator<F> {
516 AdaptiveIntegrator {
517 integrator: self,
518 step_size: None,
519 min_step_size: cast(1e-9),
520 max_step_size: None,
521 target_order: 3,
522 max_order: 10,
523 overall_stats: Stats {
524 num_system_evals: 0,
525 },
526 }
527 }
528
529 /// Set the absolute tolerance.
530 pub fn with_abs_tol(self, abs_tol: F) -> Self {
531 Self { abs_tol, ..self }
532 }
533 /// Set the relative tolerance.
534 pub fn with_rel_tol(self, rel_tol: F) -> Self {
535 Self { rel_tol, ..self }
536 }
537
538 /// Take a single extrapolating step, iteratively subdividing `step_size`.
539 fn extrapolate<S: System<Float = F>>(
540 &self,
541 system: &mut SystemEvaluationCounter<S>,
542 step_size: F,
543 order: usize,
544 y_init: nd::ArrayView1<F>,
545 mut y_final: nd::ArrayViewMut1<F>,
546 ) -> ExtrapolationResult<F> {
547 let f_init = {
548 let mut f_init = nd::Array1::zeros(y_init.raw_dim());
549 system.system(y_init, f_init.view_mut());
550 f_init
551 };
552
553 // Build up an extrapolation tableau.
554 let mut tableau = ExtrapolationTableau(Vec::<ExtrapolationTableauRow<_>>::new());
555 for k in 0..=order + 1 {
556 let nk = compute_n(k);
557 let tableau_row = {
558 let mut Tk = Vec::with_capacity(k + 1);
559 Tk.push(self.midpoint_step(system, step_size, nk, &f_init, y_init));
560 for j in 0..k {
561 // There is a mistake in Numerical Recipes eq. 17.3.8. See
562 // https://www.numerical.recipes/forumarchive/index.php/t-2256.html.
563 let denominator = <F as num_traits::Float>::powi(
564 cast::<_, F>(nk) / cast(compute_n(k - j - 1)),
565 2,
566 ) - <F as num_traits::One>::one();
567 Tk.push(&Tk[j] + (&Tk[j] - &tableau.0[k - 1].0[j]) / denominator);
568 }
569 ExtrapolationTableauRow(Tk)
570 };
571 tableau.0.push(tableau_row);
572 }
573
574 y_final.assign(&tableau.0.last().unwrap().estimate());
575 return ExtrapolationResult {
576 scaled_truncation_errors: tableau
577 .compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
578 };
579 }
580
581 fn midpoint_step<S: System<Float = F>>(
582 &self,
583 evaluation_counter: &mut SystemEvaluationCounter<S>,
584 step_size: F,
585 n: usize,
586 f_init: &nd::Array1<F>,
587 y_init: nd::ArrayView1<F>,
588 ) -> nd::Array1<F> {
589 let substep_size = step_size / cast(n);
590 let two_substep_size = cast::<_, F>(2) * substep_size;
591
592 // 0 1 2 3 4 5 6 n
593 // ..
594 // zi zip1
595 // zip1 zi
596 // zi zip1
597 // ..
598 // zi zip1
599 let mut zi = y_init.to_owned();
600 let mut zip1 = &zi + f_init * substep_size;
601 let mut fi = f_init.clone();
602
603 for _i in 1..n {
604 core::mem::swap(&mut zi, &mut zip1);
605 evaluation_counter.system(zi.view(), fi.view_mut());
606 fi *= two_substep_size;
607 zip1 += &fi;
608 }
609
610 evaluation_counter.system(zip1.view(), fi.view_mut());
611 fi *= substep_size;
612 let mut result = zi;
613 result += &zip1;
614 result += &fi;
615 result *= cast::<_, F>(0.5);
616 result
617 }
618}
619
620/// Statistics from taking an integration step.
621#[derive(Debug)]
622struct ExtrapolationResult<F: Float> {
623 /// The scaled (including absolute and relative tolerances) truncation errors for each
624 /// iteration.
625 ///
626 /// Each will be <= 1 if convergence was achieved or > 1 if convergence was not achieved.
627 scaled_truncation_errors: Vec<F>,
628}
629
630impl<F: Float> ExtrapolationResult<F> {
631 fn converged(&self) -> bool {
632 *self.scaled_truncation_errors.last().unwrap() < F::one()
633 }
634}
635
636struct SystemEvaluationCounter<'a, S: System> {
637 system: &'a S,
638 num_system_evals: usize,
639}
640
641impl<'a, S: System> SystemEvaluationCounter<'a, S> {
642 fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
643 self.num_system_evals += 1;
644 <S as System>::system(&self.system, y, dydt);
645 }
646}
647
648struct ExtrapolationTableau<F: Float>(Vec<ExtrapolationTableauRow<F>>);
649
650impl<F: Float> ExtrapolationTableau<F> {
651 fn compute_scaled_truncation_errors(&self, abs_tol: F, rel_tol: F) -> Vec<F> {
652 self.0
653 .iter()
654 .skip(1)
655 .map(|row| row.compute_scaled_truncation_error(abs_tol, rel_tol))
656 .collect()
657 }
658}
659
660struct ExtrapolationTableauRow<F: Float>(Vec<nd::Array1<F>>);
661
662impl<F: Float> ExtrapolationTableauRow<F> {
663 fn compute_scaled_truncation_error(&self, abs_tol: F, rel_tol: F) -> F {
664 let extrap_pair = self.0.last_chunk::<2>().unwrap();
665 let y = &extrap_pair[0];
666 let y_alt = &extrap_pair[1];
667 (y.iter()
668 .zip(y_alt.iter())
669 .map(|(&yi, &yi_alt)| {
670 let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
671 (yi - yi_alt).powi(2) / scale.powi(2)
672 })
673 .sum::<F>()
674 / cast(y.len()))
675 .sqrt()
676 }
677
678 fn estimate(&self) -> &nd::Array1<F> {
679 self.0.last().unwrap()
680 }
681}
682
683/// Step size policy.
684///
685/// We use a simple linear policy based on the results in Deuflhard.
686fn compute_n(iteration: usize) -> usize {
687 2 * (iteration + 1)
688}
689
690/// Cumulative sum of `compute_n`.
691///
692/// The amount of system function evaluations required to extrapolate to a given order.
693fn compute_work(iteration: usize) -> usize {
694 2 * (iteration + 1) + 2 * iteration * (iteration + 1) / 2
695}
696
697fn cast<T: num_traits::NumCast, F: Float>(num: T) -> F {
698 num_traits::cast(num).unwrap()
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704
705 /// Test that the computation of "work" (i.e. number of system evaluations) is correct.
706 #[test]
707 fn test_compute_work() {
708 for iteration in 0..5 {
709 assert_eq!(
710 compute_work(iteration),
711 (0..=iteration).map(compute_n).sum()
712 );
713 }
714 }
715
716 struct ExpSystem {}
717
718 impl System for ExpSystem {
719 type Float = f64;
720
721 fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
722 dydt.assign(&y);
723 }
724 }
725
726 /// Ensure we can solve an exponential system to high precision.
727 #[test]
728 fn test_exp_system_high_precision() {
729 let system = ExpSystem {};
730
731 // Set up integrator with tolerance parameters.
732 let mut integrator = Integrator::default()
733 .with_abs_tol(0.)
734 .with_rel_tol(1e-14)
735 .into_adaptive();
736
737 // Define initial conditions and provide solution storage.
738 let t_final = 3.5;
739 let y = ndarray::array![1.];
740 let mut y_final = ndarray::Array::zeros([1]);
741
742 // Integrate.
743 let stats = integrator
744 .step(&system, t_final, y.view(), y_final.view_mut())
745 .unwrap();
746
747 // Ensure result matches analytic solution to high precision.
748 approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
749
750 // Check integration performance.
751 assert_eq!(stats.num_system_evals, 437);
752 approx::assert_relative_eq!(integrator.step_size().unwrap(), 1.84, epsilon = 1e-2);
753 }
754
755 /// Ensure the algorithm works even when the max order is smaller than optimal.
756 #[test]
757 fn test_exp_system_low_max_order() {
758 let system = ExpSystem {};
759
760 // Set up integrator with tolerance parameters.
761 let mut integrator = Integrator::default()
762 .with_abs_tol(0.)
763 .with_rel_tol(1e-14)
764 .into_adaptive()
765 .with_max_order(1);
766
767 // Define initial conditions and provide solution storage.
768 let t_final = 3.5;
769 let y = ndarray::array![1.];
770 let mut y_final = ndarray::Array::zeros([1]);
771
772 // Integrate.
773 integrator
774 .step(&system, t_final, y.view(), y_final.view_mut())
775 .unwrap();
776
777 // Ensure result matches analytic solution to high precision.
778 approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
779 }
780
781 /// Ensure the algorithm can handle NaNs.
782 #[test]
783 fn test_exp_system_handle_nans() {
784 struct ExpSystemWithNans {
785 hit_a_nan: core::cell::RefCell<bool>,
786 }
787
788 impl System for ExpSystemWithNans {
789 type Float = f64;
790
791 fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
792 if y[0].abs() > 10. {
793 *self.hit_a_nan.borrow_mut() = true;
794 dydt[0] = core::f64::NAN;
795 } else {
796 dydt.assign(&(-&y));
797 }
798 }
799 }
800
801 let system = ExpSystemWithNans {
802 hit_a_nan: false.into(),
803 };
804
805 // Set up integrator with tolerance parameters.
806 let mut integrator = Integrator::default()
807 .with_abs_tol(0.)
808 .with_rel_tol(1e-10)
809 .into_adaptive();
810
811 // Define initial conditions and provide solution storage.
812 let t_final = 20.;
813 let y = ndarray::array![1.];
814 let mut y_final = ndarray::Array::zeros([1]);
815
816 // Integrate.
817 let stats = integrator
818 .step(&system, t_final, y.view(), y_final.view_mut())
819 .unwrap();
820
821 // Ensure result matches analytic solution.
822 approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
823
824 // Ensure we hit at least one NaN.
825 assert!(*system.hit_a_nan.borrow());
826
827 assert_eq!(stats.num_system_evals, 1085);
828 }
829
830 /// This is for interactive debugging as it has no asserts.
831 #[test]
832 fn test_varying_timescale() {
833 struct SharpPendulumSystem {}
834
835 impl System for SharpPendulumSystem {
836 type Float = f64;
837
838 fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
839 dydt[[0]] = y[[1]];
840 dydt[[1]] = -30. * y[[0]].sin().powi(31);
841 }
842 }
843
844 let system = SharpPendulumSystem {};
845
846 let mut integrator = Integrator::default().into_adaptive();
847
848 let delta_t = 10.;
849 let num_steps = 100;
850 let mut y = ndarray::array![1., 0.];
851 let mut y_final = ndarray::Array::zeros(y.raw_dim());
852
853 for _ in 0..num_steps {
854 integrator
855 .step(&system, delta_t, y.view(), y_final.view_mut())
856 .unwrap();
857 y.assign(&y_final);
858 println!(
859 "order: {} step_size: {} y: {y}",
860 integrator.target_order(),
861 integrator.step_size().unwrap()
862 );
863 }
864 }
865
866 /// Ensure we don't adapt timesteps out of the limits.
867 #[test]
868 fn test_step_size_limits() {
869 let system = ExpSystem {};
870
871 // Set up integrator with tolerance parameters.
872 let mut integrator = Integrator::default().into_adaptive();
873
874 // Define initial conditions and provide solution storage.
875 let y = ndarray::array![1.];
876 let mut y_final = ndarray::Array::zeros([1]);
877
878 // Ask the integrator to step forward a tiny fraction above the step size.
879 integrator.step_size = Some(0.02);
880 integrator.max_step_size = Some(0.04);
881 integrator.min_step_size = 1E-3;
882 let t_final = 0.02 + 1E-4;
883 integrator
884 .step(&system, t_final, y.view(), y_final.view_mut())
885 .unwrap();
886
887 // Check that the step size we adapted to is still within the integrator limits.
888 let step_size = integrator.step_size().unwrap();
889 println!("Step size: {step_size}");
890 assert!(integrator.min_step_size <= step_size);
891 assert!(step_size <= integrator.max_step_size.unwrap());
892
893 // Step the integrator again.
894 integrator
895 .step(&system, t_final, y.view(), y_final.view_mut())
896 .unwrap();
897 // Since our first step was tiny, adaptation is allowed to grow our step size.
898 println!("Step size: {}", integrator.step_size().unwrap());
899 assert!(integrator.step_size().unwrap() >= step_size);
900 }
901}