#![expect(
non_snake_case,
reason = "Used for math symbols to match notation in Numerical Recipes"
)]
pub use nd::ArrayView1;
pub use nd::ArrayViewMut1;
use ndarray as nd;
pub trait Float:
num_traits::Float
+ core::iter::Sum
+ core::ops::AddAssign
+ core::ops::MulAssign
+ core::fmt::Debug
+ nd::ScalarOperand
{
}
impl Float for f32 {}
impl Float for f64 {}
pub trait System {
type Float: Float;
fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
}
#[derive(Debug)]
pub struct StepSizeUnderflow<F: Float>(F);
#[derive(Clone, Debug)]
pub struct Stats {
pub num_system_evals: usize,
}
#[derive(Clone)]
pub struct AdaptiveIntegrator<F: Float> {
integrator: Integrator<F>,
step_size: Option<F>,
min_step_size: F,
max_step_size: Option<F>,
target_order: usize,
max_order: usize,
overall_stats: Stats,
}
impl<F: Float> AdaptiveIntegrator<F> {
pub fn step<S: System<Float = F>>(
&mut self,
system: &S,
delta_t: S::Float,
y_init: nd::ArrayView1<S::Float>,
mut y_final: nd::ArrayViewMut1<S::Float>,
) -> Result<Stats, StepSizeUnderflow<F>> {
let mut step_size = *self.step_size.get_or_insert(delta_t);
let mut system = SystemEvaluationCounter {
system,
num_system_evals: 0,
};
let mut y_before_step = y_init.to_owned();
let mut y_after_step = y_init.to_owned();
let mut t = F::zero();
loop {
if step_size < self.min_step_size || !step_size.is_finite() {
return Err(StepSizeUnderflow(step_size));
}
let next_t = if t < delta_t - step_size {
Some((t + step_size).min(delta_t))
} else {
None
};
step_size = step_size.min(delta_t - t);
let extrapolation_result = self.integrator.extrapolate(
&mut system,
step_size,
self.target_order,
y_before_step.view(),
y_after_step.view_mut(),
);
match (extrapolation_result.converged(), next_t) {
(true, None) => {
if step_size >= cast::<_, F>(self.step_size.unwrap()) {
self.perform_step_size_control(&extrapolation_result, &mut step_size);
}
break;
}
(true, Some(next_t)) => {
self.perform_order_and_step_size_control(&extrapolation_result, &mut step_size);
t = next_t;
y_before_step.assign(&y_after_step);
}
(false, _) => {
self.perform_step_size_control(&extrapolation_result, &mut step_size);
}
}
}
y_final.assign(&y_after_step);
self.overall_stats.num_system_evals += system.num_system_evals;
Ok(Stats {
num_system_evals: system.num_system_evals,
})
}
pub fn with_min_step_size(self, min_step_size: F) -> Self {
Self {
min_step_size,
..self
}
}
pub fn with_max_step_size(self, max_step_size: Option<F>) -> Self {
Self {
max_step_size,
..self
}
}
pub fn with_max_order(self, max_order: usize) -> Self {
Self { max_order, ..self }
}
pub fn overall_stats(&self) -> &Stats {
&self.overall_stats
}
pub fn step_size(&self) -> Option<F> {
self.step_size
}
pub fn target_order(&self) -> usize {
self.target_order
}
fn compute_step_size_adjustment_factor(
extrapolation_result: &ExtrapolationResult<F>,
target_order: usize,
) -> F {
let scaled_truncation_error = *extrapolation_result
.scaled_truncation_errors
.get(target_order)
.unwrap();
let safety_factor: F = cast(0.9);
let min_step_size_decrease_factor: F = cast(0.01);
let max_step_size_increase_factor = min_step_size_decrease_factor.recip();
if scaled_truncation_error > F::zero() {
(safety_factor / scaled_truncation_error.powf(F::one() / cast(2 * target_order + 1)))
.max(min_step_size_decrease_factor)
.min(max_step_size_increase_factor)
} else if scaled_truncation_error == F::zero() {
cast(2)
} else {
cast(0.5)
}
}
fn perform_step_size_control(
&mut self,
extrapolation_result: &ExtrapolationResult<F>,
step_size: &mut F,
) {
let adjustment_factor =
Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
*step_size *= adjustment_factor;
if let Some(max_step_size) = self.max_step_size {
*step_size = step_size.min(max_step_size);
}
self.step_size = Some(*step_size);
}
fn perform_order_and_step_size_control(
&mut self,
extrapolation_result: &ExtrapolationResult<F>,
step_size: &mut F,
) {
let adjustment_factor =
Self::compute_step_size_adjustment_factor(&extrapolation_result, self.target_order);
if self.target_order > 0 {
let adjustment_factor_lower_order = Self::compute_step_size_adjustment_factor(
&extrapolation_result,
self.target_order - 1,
);
let work = cast::<_, F>(compute_work(self.target_order));
let work_per_step = work / *step_size / adjustment_factor;
let work_lower_order = cast::<_, F>(compute_work(self.target_order - 1));
let work_per_step_lower_order =
work_lower_order / *step_size / adjustment_factor_lower_order;
self.target_order = if work_per_step_lower_order < cast::<_, F>(0.8) * work_per_step
&& self.target_order > 1
{
*step_size *= adjustment_factor_lower_order;
self.target_order - 1
} else if work_per_step < cast::<_, F>(0.95) * work_per_step_lower_order
&& self.target_order + 1 <= self.max_order
{
let work_higher_order = cast::<_, F>(compute_work(self.target_order + 1));
*step_size *= adjustment_factor * work_higher_order / work;
self.target_order + 1
} else {
*step_size *= adjustment_factor;
self.target_order
};
} else {
*step_size *= adjustment_factor;
}
if let Some(max_step_size) = self.max_step_size {
*step_size = step_size.min(max_step_size);
}
self.step_size = Some(*step_size);
}
}
#[derive(Clone)]
pub struct Integrator<F: Float> {
abs_tol: F,
rel_tol: F,
}
impl<F: Float> Default for Integrator<F> {
fn default() -> Self {
Self {
abs_tol: cast(1e-6),
rel_tol: cast(1e-6),
}
}
}
impl<F: Float> Integrator<F> {
pub fn into_adaptive(self) -> AdaptiveIntegrator<F> {
AdaptiveIntegrator {
integrator: self,
step_size: None,
min_step_size: cast(1e-9),
max_step_size: None,
target_order: 3,
max_order: 10,
overall_stats: Stats {
num_system_evals: 0,
},
}
}
pub fn with_abs_tol(self, abs_tol: F) -> Self {
Self { abs_tol, ..self }
}
pub fn with_rel_tol(self, rel_tol: F) -> Self {
Self { rel_tol, ..self }
}
fn extrapolate<S: System<Float = F>>(
&self,
system: &mut SystemEvaluationCounter<S>,
step_size: F,
order: usize,
y_init: nd::ArrayView1<F>,
mut y_final: nd::ArrayViewMut1<F>,
) -> ExtrapolationResult<F> {
let f_init = {
let mut f_init = nd::Array1::zeros(y_init.raw_dim());
system.system(y_init, f_init.view_mut());
f_init
};
let mut tableau = ExtrapolationTableau(Vec::<ExtrapolationTableauRow<_>>::new());
for k in 0..=order + 1 {
let nk = compute_n(k);
let tableau_row = {
let mut Tk = Vec::with_capacity(k + 1);
Tk.push(self.midpoint_step(system, step_size, nk, &f_init, y_init));
for j in 0..k {
let denominator = <F as num_traits::Float>::powi(
cast::<_, F>(nk) / cast(compute_n(k - j - 1)),
2,
) - <F as num_traits::One>::one();
Tk.push(&Tk[j] + (&Tk[j] - &tableau.0[k - 1].0[j]) / denominator);
}
ExtrapolationTableauRow(Tk)
};
tableau.0.push(tableau_row);
}
y_final.assign(&tableau.0.last().unwrap().estimate());
return ExtrapolationResult {
scaled_truncation_errors: tableau
.compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
};
}
fn midpoint_step<S: System<Float = F>>(
&self,
evaluation_counter: &mut SystemEvaluationCounter<S>,
step_size: F,
n: usize,
f_init: &nd::Array1<F>,
y_init: nd::ArrayView1<F>,
) -> nd::Array1<F> {
let substep_size = step_size / cast(n);
let two_substep_size = cast::<_, F>(2) * substep_size;
let mut zi = y_init.to_owned();
let mut zip1 = &zi + f_init * substep_size;
let mut fi = f_init.clone();
for _i in 1..n {
core::mem::swap(&mut zi, &mut zip1);
evaluation_counter.system(zi.view(), fi.view_mut());
fi *= two_substep_size;
zip1 += &fi;
}
evaluation_counter.system(zip1.view(), fi.view_mut());
fi *= substep_size;
let mut result = zi;
result += &zip1;
result += &fi;
result *= cast::<_, F>(0.5);
result
}
}
#[derive(Debug)]
struct ExtrapolationResult<F: Float> {
scaled_truncation_errors: Vec<F>,
}
impl<F: Float> ExtrapolationResult<F> {
fn converged(&self) -> bool {
*self.scaled_truncation_errors.last().unwrap() < F::one()
}
}
struct SystemEvaluationCounter<'a, S: System> {
system: &'a S,
num_system_evals: usize,
}
impl<'a, S: System> SystemEvaluationCounter<'a, S> {
fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
self.num_system_evals += 1;
<S as System>::system(&self.system, y, dydt);
}
}
struct ExtrapolationTableau<F: Float>(Vec<ExtrapolationTableauRow<F>>);
impl<F: Float> ExtrapolationTableau<F> {
fn compute_scaled_truncation_errors(&self, abs_tol: F, rel_tol: F) -> Vec<F> {
self.0
.iter()
.skip(1)
.map(|row| row.compute_scaled_truncation_error(abs_tol, rel_tol))
.collect()
}
}
struct ExtrapolationTableauRow<F: Float>(Vec<nd::Array1<F>>);
impl<F: Float> ExtrapolationTableauRow<F> {
fn compute_scaled_truncation_error(&self, abs_tol: F, rel_tol: F) -> F {
let extrap_pair = self.0.last_chunk::<2>().unwrap();
let y = &extrap_pair[0];
let y_alt = &extrap_pair[1];
(y.iter()
.zip(y_alt.iter())
.map(|(&yi, &yi_alt)| {
let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
(yi - yi_alt).powi(2) / scale.powi(2)
})
.sum::<F>()
/ cast(y.len()))
.sqrt()
}
fn estimate(&self) -> &nd::Array1<F> {
self.0.last().unwrap()
}
}
fn compute_n(iteration: usize) -> usize {
2 * (iteration + 1)
}
fn compute_work(iteration: usize) -> usize {
2 * (iteration + 1) + 2 * iteration * (iteration + 1) / 2
}
fn cast<T: num_traits::NumCast, F: Float>(num: T) -> F {
num_traits::cast(num).unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_work() {
for iteration in 0..5 {
assert_eq!(
compute_work(iteration),
(0..=iteration).map(compute_n).sum()
);
}
}
struct ExpSystem {}
impl System for ExpSystem {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
dydt.assign(&y);
}
}
#[test]
fn test_exp_system_high_precision() {
let system = ExpSystem {};
let mut integrator = Integrator::default()
.with_abs_tol(0.)
.with_rel_tol(1e-14)
.into_adaptive();
let t_final = 3.5;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
let stats = integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
assert_eq!(stats.num_system_evals, 437);
approx::assert_relative_eq!(integrator.step_size().unwrap(), 1.84, epsilon = 1e-2);
}
#[test]
fn test_exp_system_low_max_order() {
let system = ExpSystem {};
let mut integrator = Integrator::default()
.with_abs_tol(0.)
.with_rel_tol(1e-14)
.into_adaptive()
.with_max_order(1);
let t_final = 3.5;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
}
#[test]
fn test_exp_system_handle_nans() {
struct ExpSystemWithNans {
hit_a_nan: core::cell::RefCell<bool>,
}
impl System for ExpSystemWithNans {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
if y[0].abs() > 10. {
*self.hit_a_nan.borrow_mut() = true;
dydt[0] = core::f64::NAN;
} else {
dydt.assign(&(-&y));
}
}
}
let system = ExpSystemWithNans {
hit_a_nan: false.into(),
};
let mut integrator = Integrator::default()
.with_abs_tol(0.)
.with_rel_tol(1e-10)
.into_adaptive();
let t_final = 20.;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
let stats = integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
assert!(*system.hit_a_nan.borrow());
assert_eq!(stats.num_system_evals, 1085);
}
#[test]
fn test_varying_timescale() {
struct SharpPendulumSystem {}
impl System for SharpPendulumSystem {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
dydt[[0]] = y[[1]];
dydt[[1]] = -30. * y[[0]].sin().powi(31);
}
}
let system = SharpPendulumSystem {};
let mut integrator = Integrator::default().into_adaptive();
let delta_t = 10.;
let num_steps = 100;
let mut y = ndarray::array![1., 0.];
let mut y_final = ndarray::Array::zeros(y.raw_dim());
for _ in 0..num_steps {
integrator
.step(&system, delta_t, y.view(), y_final.view_mut())
.unwrap();
y.assign(&y_final);
println!(
"order: {} step_size: {} y: {y}",
integrator.target_order(),
integrator.step_size().unwrap()
);
}
}
#[test]
fn test_step_size_limits() {
let system = ExpSystem {};
let mut integrator = Integrator::default().into_adaptive();
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
integrator.step_size = Some(0.02);
integrator.max_step_size = Some(0.04);
integrator.min_step_size = 1E-3;
let t_final = 0.02 + 1E-4;
integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
let step_size = integrator.step_size().unwrap();
println!("Step size: {step_size}");
assert!(integrator.min_step_size <= step_size);
assert!(step_size <= integrator.max_step_size.unwrap());
integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
println!("Step size: {}", integrator.step_size().unwrap());
assert!(integrator.step_size().unwrap() >= step_size);
}
}