bulirsch/
lib.rs

1//! Implementation of the Bulirsch-Stoer method for stepping ordinary differential equations.
2//!
3//! The [(Gragg-)Bulirsch-Stoer](https://en.wikipedia.org/wiki/Bulirsch%E2%80%93Stoer_algorithm)
4//! algorithm combines the (modified) midpoint method with Richardson extrapolation to accelerate
5//! convergence. It is an explicit method that does not require Jacobians.
6//!
7//! This crate's implementation contains simplistic adaptive step size routines with order
8//! estimation. Its API is designed to be useful in situations where an ODE is being integrated step
9//! by step with a prescribed time step, for example in integrated simulations of electromechanical
10//! control systems with a fixed control cycle period. Only time-independent ODEs are supported, but
11//! without loss of generality (since the state vector can be augmented with a time variable if
12//! needed).
13//!
14//! The implementation follows:
15//! * Press, William H. Numerical Recipes 3rd Edition: The Art of Scientific Computing. Cambridge
16//!   University Press, 2007. Ch. 17.3.2.
17//! * Deuflhard, Peter. "Order and stepsize control in extrapolation methods." Numerische Mathematik
18//!   41 (1983): 399-422.
19//!
20//! As an example, consider a simple oscillator system:
21//!
22//! ```
23//! // Define ODE.
24//! struct OscillatorSystem {
25//!     omega: f64,
26//! }
27//!
28//! impl bulirsch::System for OscillatorSystem {
29//!     type Float = f64;
30//!
31//!     fn system(
32//!         &self,
33//!         y: bulirsch::ArrayView1<Self::Float>,
34//!         mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
35//!     ) {
36//!         dydt[[0]] = y[[1]];
37//!         dydt[[1]] = -self.omega.powi(2) * y[[0]];
38//!     }
39//! }
40//!
41//! let system = OscillatorSystem { omega: 1.2 };
42//!
43//! // Set up the integrator.
44//! let mut integrator = bulirsch::Integrator::default()
45//!     .with_abs_tol(1e-8)
46//!     .with_rel_tol(1e-8)
47//!     .into_adaptive();
48//!
49//! // Define initial conditions and provide solution storage.
50//! let delta_t: f64 = 10.2;
51//! let mut y = ndarray::array![1., 0.];
52//! let mut y_next = ndarray::Array::zeros(y.raw_dim());
53//!
54//! // Integrate for 10 steps.
55//! let num_steps = 10;
56//! for _ in 0..num_steps {
57//!     integrator
58//!         .step(&system, delta_t, y.view(), y_next.view_mut())
59//!         .unwrap();
60//!     y.assign(&y_next);
61//! }
62//!
63//! // Ensure result matches analytic solution.
64//! approx::assert_relative_eq!(
65//!     (system.omega * delta_t * num_steps as f64).cos(),
66//!     y_next[[0]],
67//!     epsilon = 5e-7,
68//!     max_relative = 5e-7,
69//! );
70//!
71//! // Check integration performance.
72//! assert_eq!(integrator.overall_stats().num_system_evals, 3843);
73//! approx::assert_relative_eq!(integrator.step_size().unwrap(), 2.14, epsilon = 1e-2);
74//! ```
75//!
76//! Note that 3.7k system evaluations have been used. By contrast, the `ode_solvers::Dopri5`
77//! algorithm uses more:
78//!
79//! ```
80//! struct OscillatorSystem {
81//!     omega: f64,
82//! }
83//!
84//! impl ode_solvers::System<f64, ode_solvers::SVector<f64, 2>> for OscillatorSystem {
85//!     fn system(
86//!         &self,
87//!         _x: f64,
88//!         y: &ode_solvers::SVector<f64, 2>,
89//!         dy: &mut ode_solvers::SVector<f64, 2>,
90//!     ) {
91//!         dy[0] = y[1];
92//!         dy[1] = -self.omega.powi(2) * y[0];
93//!     }
94//! }
95//!
96//! let omega = 1.2;
97//! let delta_t: f64 = 10.2;
98//! let mut num_system_eval = 0;
99//! let mut y = ode_solvers::Vector2::new(1., 0.);
100//! let num_steps = 10;
101//! for _ in 0..num_steps {
102//!     let system = OscillatorSystem { omega };
103//!     let mut solver = ode_solvers::Dopri5::new(
104//!         system,
105//!         0.,
106//!         delta_t,
107//!         delta_t,
108//!         y,
109//!         1e-8,
110//!         1e-8,
111//!     );
112//!     num_system_eval += solver.integrate().unwrap().num_eval;
113//!     y = *solver.y_out().get(1).unwrap();
114//! }
115//! assert_eq!(num_system_eval, 7476);
116//!
117//! // Ensure result matches analytic solution.
118//! approx::assert_relative_eq!(
119//!     (omega * delta_t * num_steps as f64).cos(),
120//!     y[0],
121//!     epsilon = 5e-7,
122//!     max_relative = 5e-7,
123//! );
124//! ```
125//!
126//! As of writing this, the latest version of `ode_solvers`, 0.6.1, also gives a dramatically
127//! incorrect answer likely due to a regression. As a result we use version 0.5 as a dev dependency.
128
129#![expect(
130    non_snake_case,
131    reason = "Used for math symbols to match notation in Numerical Recipes"
132)]
133
134pub use nd::ArrayView1;
135pub use nd::ArrayViewMut1;
136use ndarray as nd;
137
138pub trait Float:
139    num_traits::Float
140    + core::iter::Sum
141    + core::ops::AddAssign
142    + core::ops::MulAssign
143    + core::fmt::Debug
144    + nd::ScalarOperand
145{
146}
147
148impl Float for f32 {}
149impl Float for f64 {}
150
151/// Trait for defining an ordinary differential equation system.
152pub trait System {
153    /// The floating point type.
154    type Float: Float;
155
156    /// Evaluate the ordinary differential equation and store the derivative in `dydt`.
157    fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
158}
159
160/// Error generated when integration produced a step size smaller than the minimum allowed step
161/// size.
162#[derive(Debug)]
163pub struct StepSizeUnderflow<F: Float>(F);
164
165/// Statistics from taking an integration step.
166#[derive(Clone, Debug)]
167pub struct Stats {
168    /// Number of system function evaluations.
169    pub num_system_evals: usize,
170}
171
172/// An ODE integrator using the Bulirsch-Stoer algorithm with an adaptive step size and adaptive
173/// order.
174///
175/// Should be constructed using [`Integrator::into_adaptive`].
176#[derive(Clone)]
177pub struct AdaptiveIntegrator<F: Float> {
178    /// The underlying non-adaptive integrator.
179    integrator: Integrator<F>,
180
181    /// The current step size.
182    step_size: Option<F>,
183    /// The minimum step size to allow before returning [`StepSizeUnderflow`].
184    min_step_size: F,
185    /// The maximum step size to allow.
186    max_step_size: Option<F>,
187
188    /// The current estimated target number of iterations to use.
189    target_order: usize,
190    /// The maximum number of iterations to use.
191    max_order: usize,
192
193    /// Overall stats.
194    overall_stats: Stats,
195}
196
197impl<F: Float> AdaptiveIntegrator<F> {
198    /// Take a step using the Bulirsch-Stoer method.
199    ///
200    /// # Arguments
201    ///
202    /// * `system`: The ODE system.
203    /// * `delta_t`: The size of the prescribed time step to take.
204    /// * `y_init`: The initial state vector at the start of the time step.
205    /// * `y_final`: The vector into which to store the final computed state at the end of the time
206    ///   step.
207    ///
208    /// # Result
209    ///
210    /// Stats providing information about integration performance, or an error if integration
211    /// failed.
212    ///
213    /// # Examples
214    ///
215    /// Note that if you're using e.g. `nalgebra` to define your ODE, you can bridge to [`ndarray`]
216    /// vectors using slices, as long as you're using `nalgebra`'s dynamically sized vectors. The
217    /// same applies to using [`Vec`]s, etc. For example, consider a simple oscillator system
218    /// defined using `nalgebra`:
219    ///
220    /// ```
221    /// // Define oscillator ODE.
222    /// #[derive(Clone, Copy)]
223    /// struct OscillatorSystem {
224    ///     omega: f32,
225    /// }
226    ///
227    /// fn compute_dydt(
228    ///     omega: f32,
229    ///     y: nalgebra::DVectorView<f32>,
230    ///     mut dydt: nalgebra::DVectorViewMut<f32>,
231    /// ) {
232    ///     dydt[0] = y[1];
233    ///     dydt[1] = -omega.powi(2) * y[0];
234    /// }
235    ///
236    /// impl bulirsch::System for OscillatorSystem {
237    ///     type Float = f32;
238    ///
239    ///     fn system(
240    ///         &self,
241    ///         y: bulirsch::ArrayView1<Self::Float>,
242    ///         mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
243    ///     ) {
244    ///         let y_nalgebra = nalgebra::DVectorView::from_slice(
245    ///             y.as_slice().unwrap(),
246    ///             y.len(),
247    ///         );
248    ///         let dydt_nalgebra = nalgebra::DVectorViewMut::from_slice(
249    ///             dydt.as_slice_mut().unwrap(),
250    ///             y.len(),
251    ///         );
252    ///         compute_dydt(self.omega, y_nalgebra, dydt_nalgebra);
253    ///     }
254    /// }
255    ///
256    /// // Instantiate system and integrator.
257    /// let system = OscillatorSystem { omega: 1.2 };
258    /// let mut integrator =
259    ///     bulirsch::Integrator::default()
260    ///         .with_abs_tol(1e-6)
261    ///         .with_rel_tol(0.)
262    ///         .into_adaptive();
263    ///
264    /// // Define initial conditions and integrate.
265    /// let mut y = ndarray::array![1., 0.];
266    /// let mut y_next = ndarray::Array1::zeros(y.raw_dim());
267    /// let delta_t = 0.6;
268    /// let num_steps = 10;
269    /// let mut num_system_evals = 0;
270    /// for _ in 0..num_steps {
271    ///     num_system_evals += integrator
272    ///         .step(&system, delta_t, y.view(), y_next.view_mut())
273    ///         .unwrap()
274    ///         .num_system_evals;
275    ///     y.assign(&y_next);
276    /// }
277    ///
278    /// // Check against analytic solution.
279    /// let (sin, cos) = (delta_t * num_steps as f32 * system.omega).sin_cos();
280    /// approx::assert_relative_eq!(y_next[0], cos, epsilon = 1e-2);
281    /// approx::assert_relative_eq!(
282    ///     y_next[1],
283    ///     -system.omega * sin,
284    ///     epsilon = 1e-2
285    /// );
286    ///
287    /// // Check integrator performance.
288    /// assert_eq!(num_system_evals, 310);
289    /// ```
290    pub fn step<S: System<Float = F>>(
291        &mut self,
292        system: &S,
293        delta_t: S::Float,
294        y_init: nd::ArrayView1<S::Float>,
295        mut y_final: nd::ArrayViewMut1<S::Float>,
296    ) -> Result<Stats, StepSizeUnderflow<F>> {
297        let mut step_size = if let Some(step_size) = self.step_size {
298            step_size
299        } else {
300            delta_t
301        };
302
303        let mut system = SystemEvaluationCounter {
304            system,
305            num_system_evals: 0,
306        };
307
308        // Iteratively take steps until taking a step would put us past the input `delta_t`. At that
309        // point, take an exact step to finish `delta_t`. Dynamically adjust the step size to
310        // control truncation error as we go.
311        let mut y_before_step = y_init.to_owned();
312        let mut y_after_step = y_init.to_owned();
313        let mut t = F::zero();
314        loop {
315            if step_size < self.min_step_size || !step_size.is_finite() {
316                return Err(StepSizeUnderflow(step_size));
317            }
318
319            // We set `next_t` to `None` if we're at the tail end of `delta_t` and are taking a
320            // smaller step than is optimal so we don't overshoot.
321            let next_t = if t < delta_t - step_size {
322                Some((t + step_size).min(delta_t))
323            } else {
324                None
325            };
326            step_size = step_size.min(delta_t - t);
327
328            let extrapolation_result = self.integrator.extrapolate(
329                &mut system,
330                step_size,
331                self.target_order,
332                y_before_step.view(),
333                y_after_step.view_mut(),
334            );
335
336            match (extrapolation_result.converged(), next_t) {
337                // The step was successful, and we're at the end of `delta_t`. Done.
338                (true, None) => {
339                    self.perform_step_size_control(&extrapolation_result, &mut step_size);
340                    break;
341                }
342                // The step was successful, and we're not at the end of `delta_t`. Potentially
343                // adjust `target_order`, adjust step size, and continue.
344                (true, Some(next_t)) => {
345                    self.perform_order_and_step_size_control(&extrapolation_result, &mut step_size);
346                    t = next_t;
347                    y_before_step.assign(&y_after_step);
348                }
349                // The step failed. Adjust step size, but for simplicity, unlike Numerical Recipes,
350                // don't try to adjust order. Try again.
351                (false, _) => {
352                    self.perform_step_size_control(&extrapolation_result, &mut step_size);
353                }
354            }
355        }
356
357        self.step_size = Some(step_size);
358        y_final.assign(&y_after_step);
359        self.overall_stats.num_system_evals += system.num_system_evals;
360
361        Ok(Stats {
362            num_system_evals: system.num_system_evals,
363        })
364    }
365
366    /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
367    pub fn with_min_step_size(self, min_step_size: F) -> Self {
368        Self {
369            min_step_size,
370            ..self
371        }
372    }
373    /// Set the minimum step size to allow before returning [`StepSizeUnderflow`].
374    pub fn with_max_step_size(self, max_step_size: Option<F>) -> Self {
375        Self {
376            max_step_size,
377            ..self
378        }
379    }
380    /// Set the maximum "order" to use, i.e. max number of iterations per extrapolation.
381    pub fn with_max_order(self, max_order: usize) -> Self {
382        Self { max_order, ..self }
383    }
384
385    /// Get overall stats across all steps taken so far.
386    pub fn overall_stats(&self) -> &Stats {
387        &self.overall_stats
388    }
389    /// Get the current step size.
390    pub fn step_size(&self) -> Option<F> {
391        self.step_size
392    }
393    /// Get the current target order.
394    pub fn target_order(&self) -> usize {
395        self.target_order
396    }
397
398    fn compute_step_size_adjustment_factor(
399        extrapolation_result: &ExtrapolationResult<F>,
400        target_order: usize,
401    ) -> F {
402        let scaled_truncation_error = *extrapolation_result
403            .scaled_truncation_errors
404            .get(target_order)
405            .unwrap();
406
407        let safety_factor: F = cast(0.9);
408        let min_step_size_decrease_factor: F = cast(0.01);
409        let max_step_size_increase_factor = min_step_size_decrease_factor.recip();
410
411        if scaled_truncation_error > F::zero() {
412            // Eq. 2.14, Deuflhard.
413            (safety_factor / scaled_truncation_error.powf(F::one() / cast(2 * target_order + 1)))
414                .max(min_step_size_decrease_factor)
415                .min(max_step_size_increase_factor)
416        } else if scaled_truncation_error == F::zero() {
417            cast(2)
418        } else {
419            // Handle NaNs.
420            cast(0.5)
421        }
422    }
423
424    fn perform_step_size_control(
425        &self,
426        extrapolation_result: &ExtrapolationResult<F>,
427        step_size: &mut F,
428    ) {
429        let adjustment_factor =
430            Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
431        *step_size *= adjustment_factor;
432
433        if let Some(max_step_size) = self.max_step_size {
434            *step_size = step_size.min(max_step_size);
435        }
436    }
437
438    fn perform_order_and_step_size_control(
439        &mut self,
440        extrapolation_result: &ExtrapolationResult<F>,
441        step_size: &mut F,
442    ) {
443        let adjustment_factor =
444            Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
445
446        // This follows eqs. 17.3.14 & 17.3.15 in Numerical Recipes.
447        if self.target_order > 0 {
448            let adjustment_factor_lower_order = Self::compute_step_size_adjustment_factor(
449                &extrapolation_result,
450                self.target_order - 1,
451            );
452
453            let work = cast::<_, F>(compute_work(self.target_order));
454            let work_per_step = work / *step_size / adjustment_factor;
455            let work_lower_order = cast::<_, F>(compute_work(self.target_order - 1));
456            let work_per_step_lower_order =
457                work_lower_order / *step_size / adjustment_factor_lower_order;
458
459            self.target_order = if work_per_step_lower_order < cast::<_, F>(0.8) * work_per_step
460                && self.target_order > 1
461            {
462                // Decrease order since a lower order requires less work.
463                *step_size *= adjustment_factor_lower_order;
464                self.target_order - 1
465            } else if work_per_step < cast::<_, F>(0.95) * work_per_step_lower_order
466                && self.target_order + 1 <= self.max_order
467            {
468                // Increase order since a higher order is heuristically indicated to require less
469                // work (even though we didn't extrapolate to this order, so can't tell for sure).
470                // We use 0.95 above instead of 0.9 from Numerical Recipes since it produced better
471                // performance on the tests.
472                let work_higher_order = cast::<_, F>(compute_work(self.target_order + 1));
473                *step_size *= adjustment_factor * work_higher_order / work;
474                self.target_order + 1
475            } else {
476                // Preserve order and only adjust step size.
477                *step_size *= adjustment_factor;
478                self.target_order
479            };
480        } else {
481            *step_size *= adjustment_factor;
482        }
483
484        if let Some(max_step_size) = self.max_step_size {
485            *step_size = step_size.min(max_step_size);
486        }
487    }
488}
489
490/// An ODE integrator using the Bulirsch-Stoer algorithm with a fixed step size.
491///
492/// Used to construct an [`AdaptiveIntegrator`].
493#[derive(Clone)]
494pub struct Integrator<F: Float> {
495    /// The absolute tolerance.
496    abs_tol: F,
497    /// The relative tolerance.
498    rel_tol: F,
499}
500
501impl<F: Float> Default for Integrator<F> {
502    fn default() -> Self {
503        Self {
504            abs_tol: cast(1e-6),
505            rel_tol: cast(1e-6),
506        }
507    }
508}
509
510impl<F: Float> Integrator<F> {
511    /// Make an [`AdaptiveIntegrator`].
512    pub fn into_adaptive(self) -> AdaptiveIntegrator<F> {
513        AdaptiveIntegrator {
514            integrator: self,
515            step_size: None,
516            min_step_size: cast(1e-9),
517            max_step_size: None,
518            target_order: 3,
519            max_order: 10,
520            overall_stats: Stats {
521                num_system_evals: 0,
522            },
523        }
524    }
525
526    /// Set the absolute tolerance.
527    pub fn with_abs_tol(self, abs_tol: F) -> Self {
528        Self { abs_tol, ..self }
529    }
530    /// Set the relative tolerance.
531    pub fn with_rel_tol(self, rel_tol: F) -> Self {
532        Self { rel_tol, ..self }
533    }
534
535    /// Take a single extrapolating step, iteratively subdividing `step_size`.
536    fn extrapolate<S: System<Float = F>>(
537        &self,
538        system: &mut SystemEvaluationCounter<S>,
539        step_size: F,
540        order: usize,
541        y_init: nd::ArrayView1<F>,
542        mut y_final: nd::ArrayViewMut1<F>,
543    ) -> ExtrapolationResult<F> {
544        let f_init = {
545            let mut f_init = nd::Array1::zeros(y_init.raw_dim());
546            system.system(y_init, f_init.view_mut());
547            f_init
548        };
549
550        // Build up an extrapolation tableau.
551        let mut tableau = ExtrapolationTableau(Vec::<ExtrapolationTableauRow<_>>::new());
552        for k in 0..=order + 1 {
553            let nk = compute_n(k);
554            let tableau_row = {
555                let mut Tk = Vec::with_capacity(k + 1);
556                Tk.push(self.midpoint_step(system, step_size, nk, &f_init, y_init));
557                for j in 0..k {
558                    // There is a mistake in Numerical Recipes eq. 17.3.8. See
559                    // https://www.numerical.recipes/forumarchive/index.php/t-2256.html.
560                    let denominator = <F as num_traits::Float>::powi(
561                        cast::<_, F>(nk) / cast(compute_n(k - j - 1)),
562                        2,
563                    ) - <F as num_traits::One>::one();
564                    Tk.push(&Tk[j] + (&Tk[j] - &tableau.0[k - 1].0[j]) / denominator);
565                }
566                ExtrapolationTableauRow(Tk)
567            };
568            tableau.0.push(tableau_row);
569        }
570
571        y_final.assign(&tableau.0.last().unwrap().estimate());
572        return ExtrapolationResult {
573            scaled_truncation_errors: tableau
574                .compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
575        };
576    }
577
578    fn midpoint_step<S: System<Float = F>>(
579        &self,
580        evaluation_counter: &mut SystemEvaluationCounter<S>,
581        step_size: F,
582        n: usize,
583        f_init: &nd::Array1<F>,
584        y_init: nd::ArrayView1<F>,
585    ) -> nd::Array1<F> {
586        let substep_size = step_size / cast(n);
587        let two_substep_size = cast::<_, F>(2) * substep_size;
588
589        // 0    1    2    3    4    5    6    n
590        //                  ..
591        //           zi  zip1
592        //           zip1 zi
593        //                zi zip1
594        //                  ..
595        //                               zi  zip1
596        let mut zi = y_init.to_owned();
597        let mut zip1 = &zi + f_init * substep_size;
598        let mut fi = f_init.clone();
599
600        for _i in 1..n {
601            core::mem::swap(&mut zi, &mut zip1);
602            evaluation_counter.system(zi.view(), fi.view_mut());
603            fi *= two_substep_size;
604            zip1 += &fi;
605        }
606
607        evaluation_counter.system(zip1.view(), fi.view_mut());
608        fi *= substep_size;
609        let mut result = zi;
610        result += &zip1;
611        result += &fi;
612        result *= cast::<_, F>(0.5);
613        result
614    }
615}
616
617/// Statistics from taking an integration step.
618#[derive(Debug)]
619struct ExtrapolationResult<F: Float> {
620    /// The scaled (including absolute and relative tolerances) truncation errors for each
621    /// iteration.
622    ///
623    /// Each will be <= 1 if convergence was achieved or > 1 if convergence was not achieved.
624    scaled_truncation_errors: Vec<F>,
625}
626
627impl<F: Float> ExtrapolationResult<F> {
628    fn converged(&self) -> bool {
629        *self.scaled_truncation_errors.last().unwrap() < F::one()
630    }
631}
632
633struct SystemEvaluationCounter<'a, S: System> {
634    system: &'a S,
635    num_system_evals: usize,
636}
637
638impl<'a, S: System> SystemEvaluationCounter<'a, S> {
639    fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
640        self.num_system_evals += 1;
641        <S as System>::system(&self.system, y, dydt);
642    }
643}
644
645struct ExtrapolationTableau<F: Float>(Vec<ExtrapolationTableauRow<F>>);
646
647impl<F: Float> ExtrapolationTableau<F> {
648    fn compute_scaled_truncation_errors(&self, abs_tol: F, rel_tol: F) -> Vec<F> {
649        self.0
650            .iter()
651            .skip(1)
652            .map(|row| row.compute_scaled_truncation_error(abs_tol, rel_tol))
653            .collect()
654    }
655}
656
657struct ExtrapolationTableauRow<F: Float>(Vec<nd::Array1<F>>);
658
659impl<F: Float> ExtrapolationTableauRow<F> {
660    fn compute_scaled_truncation_error(&self, abs_tol: F, rel_tol: F) -> F {
661        let extrap_pair = self.0.last_chunk::<2>().unwrap();
662        let y = &extrap_pair[0];
663        let y_alt = &extrap_pair[1];
664        (y.iter()
665            .zip(y_alt.iter())
666            .map(|(&yi, &yi_alt)| {
667                let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
668                (yi - yi_alt).powi(2) / scale.powi(2)
669            })
670            .sum::<F>()
671            / cast(y.len()))
672        .sqrt()
673    }
674
675    fn estimate(&self) -> &nd::Array1<F> {
676        self.0.last().unwrap()
677    }
678}
679
680/// Step size policy.
681///
682/// We use a simple linear policy based on the results in Deuflhard.
683fn compute_n(iteration: usize) -> usize {
684    2 * (iteration + 1)
685}
686
687/// Cumulative sum of `compute_n`.
688///
689/// The amount of system function evaluations required to extrapolate to a given order.
690fn compute_work(iteration: usize) -> usize {
691    2 * (iteration + 1) + 2 * iteration * (iteration + 1) / 2
692}
693
694fn cast<T: num_traits::NumCast, F: Float>(num: T) -> F {
695    num_traits::cast(num).unwrap()
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    /// Test that the computation of "work" (i.e. number of system evaluations) is correct.
703    #[test]
704    fn test_compute_work() {
705        for iteration in 0..5 {
706            assert_eq!(
707                compute_work(iteration),
708                (0..=iteration).map(compute_n).sum()
709            );
710        }
711    }
712
713    struct ExpSystem {}
714
715    impl System for ExpSystem {
716        type Float = f64;
717
718        fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
719            dydt.assign(&y);
720        }
721    }
722
723    /// Ensure we can solve an exponential system to high precision.
724    #[test]
725    fn test_exp_system_high_precision() {
726        let system = ExpSystem {};
727
728        // Set up integrator with tolerance parameters.
729        let mut integrator = Integrator::default()
730            .with_abs_tol(0.)
731            .with_rel_tol(1e-14)
732            .into_adaptive();
733
734        // Define initial conditions and provide solution storage.
735        let t_final = 3.5;
736        let y = ndarray::array![1.];
737        let mut y_final = ndarray::Array::zeros([1]);
738
739        // Integrate.
740        let stats = integrator
741            .step(&system, t_final, y.view(), y_final.view_mut())
742            .unwrap();
743
744        // Ensure result matches analytic solution to high precision.
745        approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
746
747        // Check integration performance.
748        assert_eq!(stats.num_system_evals, 437);
749        approx::assert_relative_eq!(integrator.step_size().unwrap(), 0.28, epsilon = 1e-2);
750    }
751
752    /// Ensure the algorithm works even when the max order is smaller than optimal.
753    #[test]
754    fn test_exp_system_low_max_order() {
755        let system = ExpSystem {};
756
757        // Set up integrator with tolerance parameters.
758        let mut integrator = Integrator::default()
759            .with_abs_tol(0.)
760            .with_rel_tol(1e-14)
761            .into_adaptive()
762            .with_max_order(1);
763
764        // Define initial conditions and provide solution storage.
765        let t_final = 3.5;
766        let y = ndarray::array![1.];
767        let mut y_final = ndarray::Array::zeros([1]);
768
769        // Integrate.
770        integrator
771            .step(&system, t_final, y.view(), y_final.view_mut())
772            .unwrap();
773
774        // Ensure result matches analytic solution to high precision.
775        approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
776    }
777
778    /// Ensure the algorithm can handle NaNs.
779    #[test]
780    fn test_exp_system_handle_nans() {
781        struct ExpSystemWithNans {
782            hit_a_nan: core::cell::RefCell<bool>,
783        }
784
785        impl System for ExpSystemWithNans {
786            type Float = f64;
787
788            fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
789                if y[0].abs() > 10. {
790                    *self.hit_a_nan.borrow_mut() = true;
791                    dydt[0] = core::f64::NAN;
792                } else {
793                    dydt.assign(&(-&y));
794                }
795            }
796        }
797
798        let system = ExpSystemWithNans {
799            hit_a_nan: false.into(),
800        };
801
802        // Set up integrator with tolerance parameters.
803        let mut integrator = Integrator::default()
804            .with_abs_tol(0.)
805            .with_rel_tol(1e-10)
806            .into_adaptive();
807
808        // Define initial conditions and provide solution storage.
809        let t_final = 20.;
810        let y = ndarray::array![1.];
811        let mut y_final = ndarray::Array::zeros([1]);
812
813        // Integrate.
814        let stats = integrator
815            .step(&system, t_final, y.view(), y_final.view_mut())
816            .unwrap();
817
818        // Ensure result matches analytic solution.
819        approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
820
821        // Ensure we hit at least one NaN.
822        assert!(*system.hit_a_nan.borrow());
823
824        assert_eq!(stats.num_system_evals, 1085);
825    }
826
827    /// This is for interactive debugging as it has no asserts.
828    #[test]
829    fn test_varying_timescale() {
830        struct SharpPendulumSystem {}
831
832        impl System for SharpPendulumSystem {
833            type Float = f64;
834
835            fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
836                dydt[[0]] = y[[1]];
837                dydt[[1]] = -30. * y[[0]].sin().powi(31);
838            }
839        }
840
841        let system = SharpPendulumSystem {};
842
843        let mut integrator = Integrator::default().into_adaptive();
844
845        let delta_t = 10.;
846        let num_steps = 100;
847        let mut y = ndarray::array![1., 0.];
848        let mut y_final = ndarray::Array::zeros(y.raw_dim());
849
850        for _ in 0..num_steps {
851            integrator
852                .step(&system, delta_t, y.view(), y_final.view_mut())
853                .unwrap();
854            y.assign(&y_final);
855            println!(
856                "order: {} step_size: {} y: {y}",
857                integrator.target_order(),
858                integrator.step_size().unwrap()
859            );
860        }
861    }
862}