#![expect(
non_snake_case,
reason = "Used for math symbols to match notation in Numerical Recipes"
)]
pub use nd::ArrayView1;
pub use nd::ArrayViewMut1;
use ndarray as nd;
use num_traits::cast;
pub trait Float:
num_traits::Float
+ core::iter::Sum
+ core::ops::AddAssign
+ core::ops::MulAssign
+ core::fmt::Debug
+ nd::ScalarOperand
{
}
impl Float for f32 {}
impl Float for f64 {}
pub trait System {
type Float: Float;
fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
}
#[derive(Debug)]
pub struct StepSizeUnderflow<F: Float>(F);
#[derive(Debug)]
pub struct Stats {
pub num_system_evals: usize,
}
pub struct AdaptiveStepSizeIntegrator<F: Float> {
integrator: Integrator<F>,
step_size: Option<F>,
minimum_step_size: F,
target_num_iterations: usize,
overall_stats: Stats,
}
impl<F: Float> AdaptiveStepSizeIntegrator<F> {
pub fn step<S: System<Float = F>>(
&mut self,
system: &S,
delta_t: S::Float,
y_init: nd::ArrayView1<S::Float>,
mut y_final: nd::ArrayViewMut1<S::Float>,
) -> Result<Stats, StepSizeUnderflow<F>> {
let mut step_size = if let Some(step_size) = self.step_size {
step_size
} else {
delta_t
};
let mut system = SystemEvaluationCounter {
system,
num_system_evals: 0,
};
let mut y_before_step = y_init.to_owned();
let mut y_after_step = y_init.to_owned();
let mut t = F::zero();
loop {
if step_size < self.minimum_step_size || !step_size.is_finite() {
return Err(StepSizeUnderflow(step_size));
}
let next_t = if t < delta_t - step_size {
Some((t + step_size).min(delta_t))
} else {
None
};
let step_result = self.integrator.step(
&mut system,
step_size.min(delta_t - t),
y_before_step.view(),
y_after_step.view_mut(),
);
let adjustment_factor = match &step_result {
Ok(stats) | Err(stats) => {
let scaled_truncation_error = *stats
.scaled_truncation_errors
.get(self.target_num_iterations)
.unwrap();
self.compute_step_size_adjustment_factor(scaled_truncation_error)
}
};
match (step_result.is_ok(), next_t) {
(true, None) => {
break;
}
(true, Some(next_t)) => {
t = next_t;
step_size *= adjustment_factor;
y_before_step.assign(&y_after_step);
}
(false, _) => {
step_size *= adjustment_factor;
}
}
}
self.step_size = Some(step_size);
y_final.assign(&y_after_step);
self.overall_stats.num_system_evals += system.num_system_evals;
Ok(Stats {
num_system_evals: system.num_system_evals,
})
}
pub fn with_target_num_iterations(self, target_num_iterations: usize) -> Self {
assert!(target_num_iterations < self.integrator.max_iterations);
Self {
target_num_iterations,
integrator: Integrator {
min_iterations: target_num_iterations,
..self.integrator
},
..self
}
}
pub fn with_minimum_step_size(self, minimum_step_size: F) -> Self {
Self {
minimum_step_size,
..self
}
}
pub fn overall_stats(&self) -> &Stats {
&self.overall_stats
}
pub fn step_size(&self) -> Option<F> {
self.step_size
}
fn compute_step_size_adjustment_factor(&self, scaled_truncation_error: F) -> F {
let safety_factor: F = cast(0.95).unwrap();
let min_step_size_decrease_factor: F = cast(0.01).unwrap();
let max_step_size_increase_factor = min_step_size_decrease_factor.recip();
if scaled_truncation_error > F::zero() {
(safety_factor
/ scaled_truncation_error
.powf(F::one() / cast(2 * self.target_num_iterations + 1).unwrap()))
.max(min_step_size_decrease_factor)
.min(max_step_size_increase_factor)
} else if scaled_truncation_error.is_finite() {
F::one()
} else {
cast(0.5).unwrap()
}
}
}
pub struct Integrator<F: Float> {
abs_tol: F,
rel_tol: F,
min_iterations: usize,
max_iterations: usize,
}
impl<F: Float> Default for Integrator<F> {
fn default() -> Self {
Self {
abs_tol: cast(1e-5).unwrap(),
rel_tol: cast(1e-5).unwrap(),
min_iterations: 2,
max_iterations: 10,
}
}
}
impl<F: Float> Integrator<F> {
pub fn into_adaptive(self) -> AdaptiveStepSizeIntegrator<F> {
let target_num_iterations = 3;
AdaptiveStepSizeIntegrator {
integrator: Self {
min_iterations: target_num_iterations,
..self
},
step_size: None,
minimum_step_size: cast(1e-6).unwrap(),
target_num_iterations,
overall_stats: Stats {
num_system_evals: 0,
},
}
}
pub fn with_abs_tol(self, abs_tol: F) -> Self {
Self { abs_tol, ..self }
}
pub fn with_rel_tol(self, rel_tol: F) -> Self {
Self { rel_tol, ..self }
}
pub fn with_max_iterations(self, max_iterations: usize) -> Self {
Self {
max_iterations,
..self
}
}
fn step<S: System<Float = F>>(
&self,
system: &mut SystemEvaluationCounter<S>,
step_size: F,
y_init: nd::ArrayView1<F>,
mut y_final: nd::ArrayViewMut1<F>,
) -> Result<ExtrapolationStats<F>, ExtrapolationStats<F>> {
let f_init = {
let mut f_init = nd::Array1::zeros(y_init.raw_dim());
system.system(y_init, f_init.view_mut());
f_init
};
let compute_n = |k: usize| -> usize { 2 * (k + 1) };
let mut tableau = ExtrapolationTableau(Vec::<ExtrapolationTableauRow<_>>::new());
for k in 0..self.max_iterations {
let nk = compute_n(k);
let tableau_row = {
let mut Tk = Vec::with_capacity(k + 1);
Tk.push(self.midpoint_step(system, step_size, nk, &f_init, y_init));
for j in 0..k {
let denominator = <F as num_traits::Float>::powi(
cast::<_, F>(nk).unwrap() / cast(compute_n(k - j - 1)).unwrap(),
2,
) - <F as num_traits::One>::one();
Tk.push(&Tk[j] + (&Tk[j] - &tableau.0[k - 1].0[j]) / denominator);
}
ExtrapolationTableauRow(Tk)
};
tableau.0.push(tableau_row);
if k > 0 {
let scaled_truncation_error = tableau
.0
.last()
.unwrap()
.compute_scaled_truncation_error(self.abs_tol, self.rel_tol);
if k > self.min_iterations
&& scaled_truncation_error <= <F as num_traits::One>::one()
{
y_final.assign(&tableau.0.last().unwrap().estimate());
return Ok(ExtrapolationStats {
scaled_truncation_errors: tableau
.compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
});
}
}
}
Err(ExtrapolationStats {
scaled_truncation_errors: tableau
.compute_scaled_truncation_errors(self.abs_tol, self.rel_tol),
})
}
fn midpoint_step<S: System<Float = F>>(
&self,
evaluation_counter: &mut SystemEvaluationCounter<S>,
step_size: F,
n: usize,
f_init: &nd::Array1<F>,
y_init: nd::ArrayView1<F>,
) -> nd::Array1<F> {
let substep_size = step_size / cast(n).unwrap();
let two_substep_size = cast::<_, F>(2).unwrap() * substep_size;
let mut zi = y_init.to_owned();
let mut zip1 = &zi + f_init * substep_size;
let mut fi = f_init.clone();
for _i in 1..n {
core::mem::swap(&mut zi, &mut zip1);
evaluation_counter.system(zi.view(), fi.view_mut());
fi *= two_substep_size;
zip1 += &fi;
}
evaluation_counter.system(zip1.view(), fi.view_mut());
fi *= substep_size;
let mut result = zi;
result += &zip1;
result += &fi;
result *= cast::<_, F>(0.5).unwrap();
result
}
}
#[derive(Debug)]
struct ExtrapolationStats<F: Float> {
scaled_truncation_errors: Vec<F>,
}
struct SystemEvaluationCounter<'a, S: System> {
system: &'a S,
num_system_evals: usize,
}
impl<'a, S: System> SystemEvaluationCounter<'a, S> {
fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
self.num_system_evals += 1;
<S as System>::system(&self.system, y, dydt);
}
}
struct ExtrapolationTableau<F: Float>(Vec<ExtrapolationTableauRow<F>>);
impl<F: Float> ExtrapolationTableau<F> {
fn compute_scaled_truncation_errors(&self, abs_tol: F, rel_tol: F) -> Vec<F> {
self.0
.iter()
.skip(1)
.map(|row| row.compute_scaled_truncation_error(abs_tol, rel_tol))
.collect()
}
}
struct ExtrapolationTableauRow<F: Float>(Vec<nd::Array1<F>>);
impl<F: Float> ExtrapolationTableauRow<F> {
fn compute_scaled_truncation_error(&self, abs_tol: F, rel_tol: F) -> F {
let extrap_pair = self.0.last_chunk::<2>().unwrap();
let y = &extrap_pair[0];
let y_alt = &extrap_pair[1];
(y.iter()
.zip(y_alt.iter())
.map(|(&yi, &yi_alt)| {
let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
(yi - yi_alt).powi(2) / scale.powi(2)
})
.sum::<F>()
/ cast(y.len()).unwrap())
.sqrt()
}
fn estimate(&self) -> &nd::Array1<F> {
self.0.last().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exp_system_high_precision() {
struct ExpSystem {}
impl System for ExpSystem {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
dydt.assign(&y);
}
}
let system = ExpSystem {};
let mut integrator = Integrator::default()
.with_abs_tol(0.)
.with_rel_tol(1e-14)
.into_adaptive()
.with_target_num_iterations(4);
let t_final = 3.5;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
let stats = integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 5e-13);
assert_eq!(stats.num_system_evals, 541);
approx::assert_relative_eq!(integrator.step_size().unwrap(), 0.385, epsilon = 1e-3);
}
#[test]
fn exp_system_handle_nans() {
struct ExpSystem {
hit_a_nan: core::cell::RefCell<bool>,
}
impl System for ExpSystem {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
if y[0].abs() > 10. {
*self.hit_a_nan.borrow_mut() = true;
dydt[0] = core::f64::NAN;
} else {
dydt.assign(&(-&y));
}
}
}
let system = ExpSystem {
hit_a_nan: false.into(),
};
let mut integrator = Integrator::default()
.with_abs_tol(0.)
.with_rel_tol(1e-10)
.into_adaptive()
.with_target_num_iterations(4);
let t_final = 20.;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
let stats = integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
assert!(*system.hit_a_nan.borrow());
assert_eq!(stats.num_system_evals, 1404);
}
}