bulirsch/
lib.rs

1//! Implementation of the Bulirsch-Stoer method for stepping ordinary differential equations.
2//!
3//! The [(Gragg-)Bulirsch-Stoer](https://en.wikipedia.org/wiki/Bulirsch%E2%80%93Stoer_algorithm)
4//! algorithm combines the (modified) midpoint method with Richardson extrapolation to accelerate
5//! convergence. It is an explicit method that does not require Jacobians.
6//!
7//! This crate's implementation contains simplistic adaptive step size routines with order
8//! estimation. Its API is designed to be useful in situations where an ODE is being integrated step
9//! by step with a prescribed time step, for example in integrated simulations of electromechanical
10//! control systems with a fixed control cycle period. Only time-independent ODEs are supported, but
11//! without loss of generality (since the state vector can be augmented with a time variable if
12//! needed).
13//!
14//! The implementation follows:
15//! * Press, William H. Numerical Recipes 3rd Edition: The Art of Scientific Computing. Cambridge
16//!   University Press, 2007. Ch. 17.3.2.
17//! * Deuflhard, Peter. "Order and stepsize control in extrapolation methods." Numerische Mathematik
18//!   41 (1983): 399-422.
19//!
20//! As an example, consider a simple oscillator system:
21//!
22//! ```
23//! // Define ODE.
24//! struct OscillatorSystem {
25//!     omega: f64,
26//! }
27//!
28//! impl bulirsch::System for OscillatorSystem {
29//!     type Float = f64;
30//!
31//!     fn system(
32//!         &self,
33//!         y: bulirsch::ArrayView1<Self::Float>,
34//!         mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
35//!     ) {
36//!         dydt[[0]] = y[[1]];
37//!         dydt[[1]] = -self.omega.powi(2) * y[[0]];
38//!     }
39//! }
40//!
41//! let system = OscillatorSystem { omega: 1.2 };
42//!
43//! // Set up the integrator.
44//! let mut integrator = bulirsch::Integrator::default()
45//!     .with_abs_tol(1e-8)
46//!     .with_rel_tol(1e-8)
47//!     .into_adaptive();
48//!
49//! // Define initial conditions and provide solution storage.
50//! let delta_t: f64 = 10.2;
51//! let mut y = ndarray::array![1., 0.];
52//! let mut y_next = ndarray::Array::zeros(y.raw_dim());
53//!
54//! // Integrate for 10 steps.
55//! let num_steps = 10;
56//! for _ in 0..num_steps {
57//!     integrator
58//!         .step(&system, delta_t, y.view(), y_next.view_mut())
59//!         .unwrap();
60//!     y.assign(&y_next);
61//! }
62//!
63//! // Ensure result matches analytic solution.
64//! approx::assert_relative_eq!(
65//!     (system.omega * delta_t * num_steps as f64).cos(),
66//!     y_next[[0]],
67//!     epsilon = 5e-7,
68//!     max_relative = 5e-7,
69//! );
70//!
71//! // Check integration performance.
72//! assert_eq!(integrator.overall_stats().num_system_evals, 3770);
73//! approx::assert_relative_eq!(integrator.step_size().unwrap(), 2.10, epsilon = 1e-2);
74//! ```
75//!
76//! Note that 3.7k system evaluations have been used. By contrast, the `ode_solvers::Dopri5`
77//! algorithm uses more:
78//!
79//! ```
80//! struct OscillatorSystem {
81//!     omega: f64,
82//! }
83//!
84//! impl ode_solvers::System<f64, ode_solvers::SVector<f64, 2>> for OscillatorSystem {
85//!     fn system(
86//!         &self,
87//!         _x: f64,
88//!         y: &ode_solvers::SVector<f64, 2>,
89//!         dy: &mut ode_solvers::SVector<f64, 2>,
90//!     ) {
91//!         dy[0] = y[1];
92//!         dy[1] = -self.omega.powi(2) * y[0];
93//!     }
94//! }
95//!
96//! let omega = 1.2;
97//! let delta_t: f64 = 10.2;
98//! let mut num_system_eval = 0;
99//! let mut y = ode_solvers::Vector2::new(1., 0.);
100//! let num_steps = 10;
101//! for _ in 0..num_steps {
102//!     let system = OscillatorSystem { omega };
103//!     let mut solver = ode_solvers::Dopri5::new(
104//!         system,
105//!         0.,
106//!         delta_t,
107//!         delta_t,
108//!         y,
109//!         1e-8,
110//!         1e-8,
111//!     );
112//!     num_system_eval += solver.integrate().unwrap().num_eval;
113//!     y = *solver.y_out().get(1).unwrap();
114//! }
115//! assert_eq!(num_system_eval, 7476);
116//!
117//! // Ensure result matches analytic solution.
118//! approx::assert_relative_eq!(
119//!     (omega * delta_t * num_steps as f64).cos(),
120//!     y[0],
121//!     epsilon = 5e-7,
122//!     max_relative = 5e-7,
123//! );
124//! ```
125//!
126//! As of writing this, the latest version of `ode_solvers`, 0.6.1, also gives a dramatically
127//! incorrect answer likely due to a regression. As a result we use version 0.5 as a dev dependency.
128
129#![expect(
130    non_snake_case,
131    reason = "Used for math symbols to match notation in Numerical Recipes"
132)]
133
134pub use nd::ArrayView1;
135pub use nd::ArrayViewMut1;
136use ndarray as nd;
137
138pub trait Float:
139    num_traits::Float
140    + core::iter::Sum
141    + core::ops::AddAssign
142    + core::ops::MulAssign
143    + core::fmt::Debug
144    + nd::ScalarOperand
145{
146}
147
148impl Float for f32 {}
149impl Float for f64 {}
150
151/// Trait for defining an ordinary differential equation system.
152pub trait System {
153    /// The floating point type.
154    type Float: Float;
155
156    /// Evaluate the ordinary differential equation and store the derivative in `dydt`.
157    fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
158}
159
160/// Error generated when integration produced a step size smaller than the minimum allowed step
161/// size.
162#[derive(Debug)]
163pub struct StepSizeUnderflow<F: Float>(F);
164
165/// Statistics from taking an integration step.
166#[derive(Clone, Debug)]
167pub struct Stats {
168    /// Number of system function evaluations.
169    pub num_system_evals: usize,
170}
171
172/// An ODE integrator using the Bulirsch-Stoer algorithm with an adaptive step size and adaptive
173/// order.
174///
175/// Should be constructed using [`Integrator::into_adaptive`].
176#[derive(Clone)]
177pub struct AdaptiveIntegrator<F: Float> {
178    /// The underlying non-adaptive integrator.
179    integrator: Integrator<F>,
180
181    /// The current step size.
182    step_size: Option<F>,
183    /// The minimum step size to allow before returning [`StepSizeUnderflow`].
184    min_step_size: F,
185    /// The maximum step size to allow.
186    max_step_size: Option<F>,
187
188    /// The current estimated target number of iterations to use.
189    target_order: usize,
190    /// The maximum number of iterations to use.
191    max_order: usize,
192
193    /// Overall stats.
194    overall_stats: Stats,
195}
196
197impl<F: Float> AdaptiveIntegrator<F> {
198    /// Take a step using the Bulirsch-Stoer method.
199    ///
200    /// # Arguments
201    ///
202    /// * `system`: The ODE system.
203    /// * `delta_t`: The size of the prescribed time step to take.
204    /// * `y_init`: The initial state vector at the start of the time step.
205    /// * `y_final`: The vector into which to store the final computed state at the end of the time
206    ///   step.
207    ///
208    /// # Result
209    ///
210    /// Stats providing information about integration performance, or an error if integration
211    /// failed.
212    ///
213    /// # Examples
214    ///
215    /// Note that if you're using e.g. `nalgebra` to define your ODE, you can bridge to [`ndarray`]
216    /// vectors using slices, as long as you're using `nalgebra`'s dynamically sized vectors. The
217    /// same applies to using [`Vec`]s, etc. For example, consider a simple oscillator system
218    /// defined using `nalgebra`:
219    ///
220    /// ```
221    /// // Define oscillator ODE.
222    /// #[derive(Clone, Copy)]
223    /// struct OscillatorSystem {
224    ///     omega: f32,
225    /// }
226    ///
227    /// fn compute_dydt(
228    ///     omega: f32,
229    ///     y: nalgebra::DVectorView<f32>,
230    ///     mut dydt: nalgebra::DVectorViewMut<f32>,
231    /// ) {
232    ///     dydt[0] = y[1];
233    ///     dydt[1] = -omega.powi(2) * y[0];
234    /// }
235    ///
236    /// impl bulirsch::System for OscillatorSystem {
237    ///     type Float = f32;
238    ///
239    ///     fn system(
240    ///         &self,
241    ///         y: bulirsch::ArrayView1<Self::Float>,
242    ///         mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
243    ///     ) {
244    ///         let y_nalgebra = nalgebra::DVectorView::from_slice(
245    ///             y.as_slice().unwrap(),
246    ///             y.len(),
247    ///         );
248    ///         let dydt_nalgebra = nalgebra::DVectorViewMut::from_slice(
249    ///             dydt.as_slice_mut().unwrap(),
250    ///             y.len(),
251    ///         );
252    ///         compute_dydt(self.omega, y_nalgebra, dydt_nalgebra);
253    ///     }
254    /// }
255    ///
256    /// // Instantiate system and integrator.
257    /// let system = OscillatorSystem { omega: 1.2 };
258    /// let mut integrator =
259    ///     bulirsch::Integrator::default()
260    ///         .with_abs_tol(1e-6)
261    ///         .with_rel_tol(0.)
262    ///         .into_adaptive();
263    ///
264    /// // Define initial conditions and integrate.
265    /// let mut y = ndarray::array![1., 0.];
266    /// let mut y_next = ndarray::Array1::zeros(y.raw_dim());
267    /// let delta_t = 0.6;
268    /// let num_steps = 10;
269    /// let mut num_system_evals = 0;
270    /// for _ in 0..num_steps {
271    ///     num_system_evals += integrator
272    ///         .step(&system, delta_t, y.view(), y_next.view_mut())
273    ///         .unwrap()
274    ///         .num_system_evals;
275    ///     y.assign(&y_next);
276    /// }
277    ///
278    /// // Check against analytic solution.
279    /// let (sin, cos) = (delta_t * num_steps as f32 * system.omega).sin_cos();
280    /// approx::assert_relative_eq!(y_next[0], cos, epsilon = 1e-2);
281    /// approx::assert_relative_eq!(
282    ///     y_next[1],
283    ///     -system.omega * sin,
284    ///     epsilon = 1e-2
285    /// );
286    ///
287    /// // Check integrator performance.
288    /// assert_eq!(num_system_evals, 310);
289    /// ```
290    pub fn step<S: System<Float = F>>(
291        &mut self,
292        system: &S,
293        delta_t: S::Float,
294        y_init: nd::ArrayView1<S::Float>,
295        mut y_final: nd::ArrayViewMut1<S::Float>,
296    ) -> Result<Stats, StepSizeUnderflow<F>> {
297        let mut step_size = *self.step_size.get_or_insert(delta_t);
298
299        let mut system = SystemEvaluationCounter {
300            system,
301            num_system_evals: 0,
302        };
303
304        // Iteratively take steps until taking a step would put us past the input `delta_t`. At that
305        // point, take an exact step to finish `delta_t`. Dynamically adjust the step size to
306        // control truncation error as we go.
307        let mut y_before_step = y_init.to_owned();
308        let mut y_after_step = y_init.to_owned();
309        let mut t = F::zero();
310        loop {
311            if step_size < self.min_step_size || !step_size.is_finite() {
312                return Err(StepSizeUnderflow(step_size));
313            }
314
315            // We set `next_t` to `None` if we're at the tail end of `delta_t` and are taking a
316            // smaller step than is optimal so we don't overshoot.
317            let next_t = if t < delta_t - step_size {
318                Some((t + step_size).min(delta_t))
319            } else {
320                None
321            };
322            step_size = step_size.min(delta_t - t);
323
324            let extrapolation_result = self.integrator.extrapolate(
325                &mut system,
326                step_size,
327                self.target_order,
328                y_before_step.view(),
329                y_after_step.view_mut(),
330            );
331
332            match (extrapolation_result.converged(), next_t) {
333                // The step was successful, and we're at the end of `delta_t`. Done.
334                (true, None) => {
335                    // If the local step size is smaller than the internally
336                    // tracked step size, then we are taking an intentionally
337                    // shorter step to "finish off" integrating the interval and
338                    // shouldn't modify step size.
339                    if step_size >= cast::<_, F>(self.step_size.unwrap()) {
340                        self.perform_step_size_control(&extrapolation_result, &mut step_size);
341                    }
342                    break;
343                }
344                // The step was successful, and we're not at the end of `delta_t`. Potentially
345                // adjust `target_order`, adjust step size, and continue.
346                (true, Some(next_t)) => {
347                    self.perform_order_and_step_size_control(&extrapolation_result, &mut step_size);
348                    t = next_t;
349                    y_before_step.assign(&y_after_step);
350                }
351                // The step failed. Adjust step size, but for simplicity, unlike Numerical Recipes,
352                // don't try to adjust order. Try again.
353                (false, _) => {
354                    self.perform_step_size_control(&extrapolation_result, &mut step_size);
355                }
356            }
357        }
358
359        y_final.assign(&y_after_step);
360        self.overall_stats.num_system_evals += system.num_system_evals;
361
362        Ok(Stats {
363            num_system_evals: system.num_system_evals,
364        })
365    }
366
367    /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
368    pub fn with_min_step_size(self, min_step_size: F) -> Self {
369        Self {
370            min_step_size,
371            ..self
372        }
373    }
374    /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
375    pub fn with_max_step_size(self, max_step_size: Option<F>) -> Self {
376        Self {
377            max_step_size,
378            ..self
379        }
380    }
381    /// Set the maximum "order" to use, i.e. max number of iterations per extrapolation.
382    pub fn with_max_order(self, max_order: usize) -> Self {
383        Self { max_order, ..self }
384    }
385
386    /// Get overall stats across all steps taken so far.
387    pub fn overall_stats(&self) -> &Stats {
388        &self.overall_stats
389    }
390    /// Get the current step size.
391    pub fn step_size(&self) -> Option<F> {
392        self.step_size
393    }
394    /// Get the current target order.
395    pub fn target_order(&self) -> usize {
396        self.target_order
397    }
398
399    fn compute_step_size_adjustment_factor(
400        extrapolation_result: &ExtrapolationResult<F>,
401        target_order: usize,
402    ) -> F {
403        let scaled_truncation_error = *extrapolation_result
404            .scaled_truncation_errors
405            .get(target_order)
406            .unwrap();
407
408        let safety_factor: F = cast(0.9);
409        let min_step_size_decrease_factor: F = cast(0.01);
410        let max_step_size_increase_factor = min_step_size_decrease_factor.recip();
411
412        if scaled_truncation_error > F::zero() {
413            // Eq. 2.14, Deuflhard.
414            (safety_factor / scaled_truncation_error.powf(F::one() / cast(2 * target_order + 1)))
415                .max(min_step_size_decrease_factor)
416                .min(max_step_size_increase_factor)
417        } else if scaled_truncation_error == F::zero() {
418            cast(2)
419        } else {
420            // Handle NaNs.
421            cast(0.5)
422        }
423    }
424
425    fn perform_step_size_control(
426        &mut self,
427        extrapolation_result: &ExtrapolationResult<F>,
428        step_size: &mut F,
429    ) {
430        let adjustment_factor =
431            Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
432        *step_size *= adjustment_factor;
433
434        if let Some(max_step_size) = self.max_step_size {
435            *step_size = step_size.min(max_step_size);
436        }
437        self.step_size = Some(*step_size);
438    }
439
440    fn perform_order_and_step_size_control(
441        &mut self,
442        extrapolation_result: &ExtrapolationResult<F>,
443        step_size: &mut F,
444    ) {
445        let adjustment_factor =
446            Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
447
448        // This follows eqs. 17.3.14 & 17.3.15 in Numerical Recipes.
449        if self.target_order > 0 {
450            let adjustment_factor_lower_order = Self::compute_step_size_adjustment_factor(
451                &extrapolation_result,
452                self.target_order - 1,
453            );
454
455            let work = cast::<_, F>(compute_work(self.target_order));
456            let work_per_step = work / *step_size / adjustment_factor;
457            let work_lower_order = cast::<_, F>(compute_work(self.target_order - 1));
458            let work_per_step_lower_order =
459                work_lower_order / *step_size / adjustment_factor_lower_order;
460
461            self.target_order = if work_per_step_lower_order < cast::<_, F>(0.8) * work_per_step
462                && self.target_order > 1
463            {
464                // Decrease order since a lower order requires less work.
465                *step_size *= adjustment_factor_lower_order;
466                self.target_order - 1
467            } else if work_per_step < cast::<_, F>(0.95) * work_per_step_lower_order
468                && self.target_order + 1 <= self.max_order
469            {
470                // Increase order since a higher order is heuristically indicated to require less
471                // work (even though we didn't extrapolate to this order, so can't tell for sure).
472                // We use 0.95 above instead of 0.9 from Numerical Recipes since it produced better
473                // performance on the tests.
474                let work_higher_order = cast::<_, F>(compute_work(self.target_order + 1));
475                *step_size *= adjustment_factor * work_higher_order / work;
476                self.target_order + 1
477            } else {
478                // Preserve order and only adjust step size.
479                *step_size *= adjustment_factor;
480                self.target_order
481            };
482        } else {
483            *step_size *= adjustment_factor;
484        }
485
486        if let Some(max_step_size) = self.max_step_size {
487            *step_size = step_size.min(max_step_size);
488        }
489        self.step_size = Some(*step_size);
490    }
491}
492
493/// An ODE integrator using the Bulirsch-Stoer algorithm with a fixed step size.
494///
495/// Used to construct an [`AdaptiveIntegrator`].
496#[derive(Clone)]
497pub struct Integrator<F: Float> {
498    /// The absolute tolerance.
499    abs_tol: F,
500    /// The relative tolerance.
501    rel_tol: F,
502}
503
504impl<F: Float> Default for Integrator<F> {
505    fn default() -> Self {
506        Self {
507            abs_tol: cast(1e-6),
508            rel_tol: cast(1e-6),
509        }
510    }
511}
512
513impl<F: Float> Integrator<F> {
514    /// Make an [`AdaptiveIntegrator`].
515    pub fn into_adaptive(self) -> AdaptiveIntegrator<F> {
516        AdaptiveIntegrator {
517            integrator: self,
518            step_size: None,
519            min_step_size: cast(1e-9),
520            max_step_size: None,
521            target_order: 3,
522            max_order: 10,
523            overall_stats: Stats {
524                num_system_evals: 0,
525            },
526        }
527    }
528
529    /// Set the absolute tolerance.
530    pub fn with_abs_tol(self, abs_tol: F) -> Self {
531        Self { abs_tol, ..self }
532    }
533    /// Set the relative tolerance.
534    pub fn with_rel_tol(self, rel_tol: F) -> Self {
535        Self { rel_tol, ..self }
536    }
537
538    /// Take a single extrapolating step, iteratively subdividing `step_size`.
539    fn extrapolate<S: System<Float = F>>(
540        &self,
541        system: &mut SystemEvaluationCounter<S>,
542        step_size: F,
543        order: usize,
544        y_init: nd::ArrayView1<F>,
545        mut y_final: nd::ArrayViewMut1<F>,
546    ) -> ExtrapolationResult<F> {
547        let f_init = {
548            let mut f_init = nd::Array1::zeros(y_init.raw_dim());
549            system.system(y_init, f_init.view_mut());
550            f_init
551        };
552
553        // Build up an extrapolation tableau.
554        let mut tableau = ExtrapolationTableau(Vec::<ExtrapolationTableauRow<_>>::new());
555        for k in 0..=order + 1 {
556            let nk = compute_n(k);
557            let tableau_row = {
558                let mut Tk = Vec::with_capacity(k + 1);
559                Tk.push(self.midpoint_step(system, step_size, nk, &f_init, y_init));
560                for j in 0..k {
561                    // There is a mistake in Numerical Recipes eq. 17.3.8. See
562                    // https://www.numerical.recipes/forumarchive/index.php/t-2256.html.
563                    let denominator = <F as num_traits::Float>::powi(
564                        cast::<_, F>(nk) / cast(compute_n(k - j - 1)),
565                        2,
566                    ) - <F as num_traits::One>::one();
567                    Tk.push(&Tk[j] + (&Tk[j] - &tableau.0[k - 1].0[j]) / denominator);
568                }
569                ExtrapolationTableauRow(Tk)
570            };
571            tableau.0.push(tableau_row);
572        }
573
574        y_final.assign(&tableau.0.last().unwrap().estimate());
575        return ExtrapolationResult {
576            scaled_truncation_errors: tableau
577                .compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
578        };
579    }
580
581    fn midpoint_step<S: System<Float = F>>(
582        &self,
583        evaluation_counter: &mut SystemEvaluationCounter<S>,
584        step_size: F,
585        n: usize,
586        f_init: &nd::Array1<F>,
587        y_init: nd::ArrayView1<F>,
588    ) -> nd::Array1<F> {
589        let substep_size = step_size / cast(n);
590        let two_substep_size = cast::<_, F>(2) * substep_size;
591
592        // 0    1    2    3    4    5    6    n
593        //                  ..
594        //           zi  zip1
595        //           zip1 zi
596        //                zi zip1
597        //                  ..
598        //                               zi  zip1
599        let mut zi = y_init.to_owned();
600        let mut zip1 = &zi + f_init * substep_size;
601        let mut fi = f_init.clone();
602
603        for _i in 1..n {
604            core::mem::swap(&mut zi, &mut zip1);
605            evaluation_counter.system(zi.view(), fi.view_mut());
606            fi *= two_substep_size;
607            zip1 += &fi;
608        }
609
610        evaluation_counter.system(zip1.view(), fi.view_mut());
611        fi *= substep_size;
612        let mut result = zi;
613        result += &zip1;
614        result += &fi;
615        result *= cast::<_, F>(0.5);
616        result
617    }
618}
619
620/// Statistics from taking an integration step.
621#[derive(Debug)]
622struct ExtrapolationResult<F: Float> {
623    /// The scaled (including absolute and relative tolerances) truncation errors for each
624    /// iteration.
625    ///
626    /// Each will be <= 1 if convergence was achieved or > 1 if convergence was not achieved.
627    scaled_truncation_errors: Vec<F>,
628}
629
630impl<F: Float> ExtrapolationResult<F> {
631    fn converged(&self) -> bool {
632        *self.scaled_truncation_errors.last().unwrap() < F::one()
633    }
634}
635
636struct SystemEvaluationCounter<'a, S: System> {
637    system: &'a S,
638    num_system_evals: usize,
639}
640
641impl<'a, S: System> SystemEvaluationCounter<'a, S> {
642    fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
643        self.num_system_evals += 1;
644        <S as System>::system(&self.system, y, dydt);
645    }
646}
647
648struct ExtrapolationTableau<F: Float>(Vec<ExtrapolationTableauRow<F>>);
649
650impl<F: Float> ExtrapolationTableau<F> {
651    fn compute_scaled_truncation_errors(&self, abs_tol: F, rel_tol: F) -> Vec<F> {
652        self.0
653            .iter()
654            .skip(1)
655            .map(|row| row.compute_scaled_truncation_error(abs_tol, rel_tol))
656            .collect()
657    }
658}
659
660struct ExtrapolationTableauRow<F: Float>(Vec<nd::Array1<F>>);
661
662impl<F: Float> ExtrapolationTableauRow<F> {
663    fn compute_scaled_truncation_error(&self, abs_tol: F, rel_tol: F) -> F {
664        let extrap_pair = self.0.last_chunk::<2>().unwrap();
665        let y = &extrap_pair[0];
666        let y_alt = &extrap_pair[1];
667        (y.iter()
668            .zip(y_alt.iter())
669            .map(|(&yi, &yi_alt)| {
670                let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
671                (yi - yi_alt).powi(2) / scale.powi(2)
672            })
673            .sum::<F>()
674            / cast(y.len()))
675        .sqrt()
676    }
677
678    fn estimate(&self) -> &nd::Array1<F> {
679        self.0.last().unwrap()
680    }
681}
682
683/// Step size policy.
684///
685/// We use a simple linear policy based on the results in Deuflhard.
686fn compute_n(iteration: usize) -> usize {
687    2 * (iteration + 1)
688}
689
690/// Cumulative sum of `compute_n`.
691///
692/// The amount of system function evaluations required to extrapolate to a given order.
693fn compute_work(iteration: usize) -> usize {
694    2 * (iteration + 1) + 2 * iteration * (iteration + 1) / 2
695}
696
697fn cast<T: num_traits::NumCast, F: Float>(num: T) -> F {
698    num_traits::cast(num).unwrap()
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704
705    /// Test that the computation of "work" (i.e. number of system evaluations) is correct.
706    #[test]
707    fn test_compute_work() {
708        for iteration in 0..5 {
709            assert_eq!(
710                compute_work(iteration),
711                (0..=iteration).map(compute_n).sum()
712            );
713        }
714    }
715
716    struct ExpSystem {}
717
718    impl System for ExpSystem {
719        type Float = f64;
720
721        fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
722            dydt.assign(&y);
723        }
724    }
725
726    /// Ensure we can solve an exponential system to high precision.
727    #[test]
728    fn test_exp_system_high_precision() {
729        let system = ExpSystem {};
730
731        // Set up integrator with tolerance parameters.
732        let mut integrator = Integrator::default()
733            .with_abs_tol(0.)
734            .with_rel_tol(1e-14)
735            .into_adaptive();
736
737        // Define initial conditions and provide solution storage.
738        let t_final = 3.5;
739        let y = ndarray::array![1.];
740        let mut y_final = ndarray::Array::zeros([1]);
741
742        // Integrate.
743        let stats = integrator
744            .step(&system, t_final, y.view(), y_final.view_mut())
745            .unwrap();
746
747        // Ensure result matches analytic solution to high precision.
748        approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
749
750        // Check integration performance.
751        assert_eq!(stats.num_system_evals, 437);
752        approx::assert_relative_eq!(integrator.step_size().unwrap(), 1.84, epsilon = 1e-2);
753    }
754
755    /// Ensure the algorithm works even when the max order is smaller than optimal.
756    #[test]
757    fn test_exp_system_low_max_order() {
758        let system = ExpSystem {};
759
760        // Set up integrator with tolerance parameters.
761        let mut integrator = Integrator::default()
762            .with_abs_tol(0.)
763            .with_rel_tol(1e-14)
764            .into_adaptive()
765            .with_max_order(1);
766
767        // Define initial conditions and provide solution storage.
768        let t_final = 3.5;
769        let y = ndarray::array![1.];
770        let mut y_final = ndarray::Array::zeros([1]);
771
772        // Integrate.
773        integrator
774            .step(&system, t_final, y.view(), y_final.view_mut())
775            .unwrap();
776
777        // Ensure result matches analytic solution to high precision.
778        approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
779    }
780
781    /// Ensure the algorithm can handle NaNs.
782    #[test]
783    fn test_exp_system_handle_nans() {
784        struct ExpSystemWithNans {
785            hit_a_nan: core::cell::RefCell<bool>,
786        }
787
788        impl System for ExpSystemWithNans {
789            type Float = f64;
790
791            fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
792                if y[0].abs() > 10. {
793                    *self.hit_a_nan.borrow_mut() = true;
794                    dydt[0] = core::f64::NAN;
795                } else {
796                    dydt.assign(&(-&y));
797                }
798            }
799        }
800
801        let system = ExpSystemWithNans {
802            hit_a_nan: false.into(),
803        };
804
805        // Set up integrator with tolerance parameters.
806        let mut integrator = Integrator::default()
807            .with_abs_tol(0.)
808            .with_rel_tol(1e-10)
809            .into_adaptive();
810
811        // Define initial conditions and provide solution storage.
812        let t_final = 20.;
813        let y = ndarray::array![1.];
814        let mut y_final = ndarray::Array::zeros([1]);
815
816        // Integrate.
817        let stats = integrator
818            .step(&system, t_final, y.view(), y_final.view_mut())
819            .unwrap();
820
821        // Ensure result matches analytic solution.
822        approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
823
824        // Ensure we hit at least one NaN.
825        assert!(*system.hit_a_nan.borrow());
826
827        assert_eq!(stats.num_system_evals, 1085);
828    }
829
830    /// This is for interactive debugging as it has no asserts.
831    #[test]
832    fn test_varying_timescale() {
833        struct SharpPendulumSystem {}
834
835        impl System for SharpPendulumSystem {
836            type Float = f64;
837
838            fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
839                dydt[[0]] = y[[1]];
840                dydt[[1]] = -30. * y[[0]].sin().powi(31);
841            }
842        }
843
844        let system = SharpPendulumSystem {};
845
846        let mut integrator = Integrator::default().into_adaptive();
847
848        let delta_t = 10.;
849        let num_steps = 100;
850        let mut y = ndarray::array![1., 0.];
851        let mut y_final = ndarray::Array::zeros(y.raw_dim());
852
853        for _ in 0..num_steps {
854            integrator
855                .step(&system, delta_t, y.view(), y_final.view_mut())
856                .unwrap();
857            y.assign(&y_final);
858            println!(
859                "order: {} step_size: {} y: {y}",
860                integrator.target_order(),
861                integrator.step_size().unwrap()
862            );
863        }
864    }
865
866    /// Ensure we don't adapt timesteps out of the limits.
867    #[test]
868    fn test_step_size_limits() {
869        let system = ExpSystem {};
870
871        // Set up integrator with tolerance parameters.
872        let mut integrator = Integrator::default().into_adaptive();
873
874        // Define initial conditions and provide solution storage.
875        let y = ndarray::array![1.];
876        let mut y_final = ndarray::Array::zeros([1]);
877
878        // Ask the integrator to step forward a tiny fraction above the step size.
879        integrator.step_size = Some(0.02);
880        integrator.max_step_size = Some(0.04);
881        integrator.min_step_size = 1E-3;
882        let t_final = 0.02 + 1E-4;
883        integrator
884            .step(&system, t_final, y.view(), y_final.view_mut())
885            .unwrap();
886
887        // Check that the step size we adapted to is still within the integrator limits.
888        let step_size = integrator.step_size().unwrap();
889        println!("Step size: {step_size}");
890        assert!(integrator.min_step_size <= step_size);
891        assert!(step_size <= integrator.max_step_size.unwrap());
892
893        // Step the integrator again.
894        integrator
895            .step(&system, t_final, y.view(), y_final.view_mut())
896            .unwrap();
897        // Since our first step was tiny, adaptation is allowed to grow our step size.
898        println!("Step size: {}", integrator.step_size().unwrap());
899        assert!(integrator.step_size().unwrap() >= step_size);
900    }
901}