#![expect(
non_snake_case,
reason = "Used for math symbols to match notation in Numerical Recipes"
)]
pub use nd::ArrayView1;
pub use nd::ArrayViewMut1;
use ndarray as nd;
use num_traits::cast;
pub trait Float:
num_traits::Float
+ core::iter::Sum
+ core::ops::AddAssign
+ core::ops::MulAssign
+ core::fmt::Debug
+ nd::ScalarOperand
{
}
impl Float for f32 {}
impl Float for f64 {}
pub trait System {
type Float: Float;
fn system(&self, y: ArrayView1<Self::Float>, dydt: ArrayViewMut1<Self::Float>);
}
#[must_use]
#[derive(Debug)]
pub struct Stats<F: Float> {
pub num_system_evals: usize,
pub num_iterations: usize,
pub num_substeps: usize,
pub substep_size: F,
pub scaled_truncation_error: F,
}
#[must_use]
#[derive(Debug)]
pub struct FailedToConverge<F: Float> {
pub stats: Stats<F>,
}
#[derive(Default)]
pub enum StepSizePolicy {
#[default]
Linear,
Exponential,
}
#[derive(Default)]
pub enum ConvergencePolicy {
#[default]
AllIterations,
Window {
num_iterations: core::num::NonZero<usize>,
},
}
impl ConvergencePolicy {
fn get_extrap_pair<'a, F: Float>(&self, Tk: &'a [nd::Array1<F>]) -> &'a [nd::Array1<F>] {
match self {
Self::AllIterations => Tk.last_chunk::<2>().unwrap(),
Self::Window { num_iterations } => Tk
.get(num_iterations.get()..num_iterations.get() + 2)
.unwrap_or_else(|| Tk.last_chunk::<2>().unwrap()),
}
}
}
pub struct Integrator<S: System> {
abs_tol: S::Float,
rel_tol: S::Float,
step_size_policy: StepSizePolicy,
convergence_policy: ConvergencePolicy,
max_iterations: usize,
}
impl<S: System> Default for Integrator<S>
where
S::Float: Float,
{
fn default() -> Self {
Self {
abs_tol: cast(1e-5).unwrap(),
rel_tol: cast(1e-5).unwrap(),
step_size_policy: StepSizePolicy::default(),
convergence_policy: ConvergencePolicy::default(),
max_iterations: 20,
}
}
}
impl<S: System> Integrator<S>
where
S::Float: Float,
{
pub fn with_abs_tol(self, abs_tol: S::Float) -> Self {
Self { abs_tol, ..self }
}
pub fn with_rel_tol(self, rel_tol: S::Float) -> Self {
Self { rel_tol, ..self }
}
pub fn with_step_size_policy(self, step_size_policy: StepSizePolicy) -> Self {
Self {
step_size_policy,
..self
}
}
pub fn with_convergence_policy(self, convergence_policy: ConvergencePolicy) -> Self {
Self {
convergence_policy,
..self
}
}
pub fn with_max_iterations(self, max_iterations: usize) -> Self {
Self {
max_iterations,
..self
}
}
pub fn step(
&self,
system: &S,
delta_t: S::Float,
y_init: nd::ArrayView1<S::Float>,
mut y_final: nd::ArrayViewMut1<S::Float>,
) -> Result<Stats<S::Float>, FailedToConverge<S::Float>> {
let mut evaluation_counter = SystemEvaluationCounter {
system,
num_system_evals: 0,
};
let f_init = {
let mut f_init = nd::Array1::zeros(y_init.raw_dim());
evaluation_counter.system(y_init, f_init.view_mut());
f_init
};
let compute_n = match self.step_size_policy {
StepSizePolicy::Linear => |k: usize| -> usize { 2 * (k + 1) },
StepSizePolicy::Exponential => {
|k: usize| -> usize { 2u32.pow((k + 1) as u32) as usize }
}
};
let mut T = Vec::<Vec<nd::Array1<S::Float>>>::new();
for k in 0..self.max_iterations {
let nk = compute_n(k);
let mut Tk = Vec::with_capacity(k + 1);
Tk.push(self.midpoint_step(&mut evaluation_counter, delta_t, nk, &f_init, y_init));
for j in 0..k {
let denominator = <S::Float as num_traits::Float>::powi(
cast::<_, S::Float>(nk).unwrap() / cast(compute_n(k - j - 1)).unwrap(),
2,
) - <S::Float as num_traits::One>::one();
Tk.push(&Tk[j] + (&Tk[j] - &T[k - 1][j]) / denominator);
}
if k > 0 {
let extrap_pair = self.convergence_policy.get_extrap_pair(&Tk);
let scaled_truncation_error = compute_scaled_truncation_error(
extrap_pair[0].view(),
extrap_pair[1].view(),
self.abs_tol,
self.rel_tol,
);
if scaled_truncation_error <= <S::Float as num_traits::One>::one() {
y_final.assign(&extrap_pair[1]);
return Ok(Stats {
num_system_evals: evaluation_counter.num_system_evals,
num_iterations: k,
num_substeps: nk,
substep_size: delta_t / cast::<_, S::Float>(nk).unwrap(),
scaled_truncation_error,
});
}
}
T.push(Tk);
}
let last_Tk = T.last().unwrap();
let extrap_pair = self.convergence_policy.get_extrap_pair(last_Tk);
y_final.assign(&extrap_pair[1]);
let scaled_truncation_error = compute_scaled_truncation_error(
extrap_pair[0].view(),
extrap_pair[1].view(),
self.abs_tol,
self.rel_tol,
);
let n = compute_n(self.max_iterations);
Err(FailedToConverge {
stats: Stats {
num_system_evals: evaluation_counter.num_system_evals,
num_iterations: self.max_iterations,
num_substeps: n,
substep_size: delta_t / cast(n).unwrap(),
scaled_truncation_error,
},
})
}
fn midpoint_step(
&self,
evaluation_counter: &mut SystemEvaluationCounter<S>,
delta_t: S::Float,
n: usize,
f_init: &nd::Array1<S::Float>,
y_init: nd::ArrayView1<S::Float>,
) -> nd::Array1<S::Float> {
let substep_size = delta_t / cast(n).unwrap();
let two_substep_size = cast::<_, S::Float>(2).unwrap() * substep_size;
let mut zi = y_init.to_owned();
let mut zip1 = &zi + f_init * substep_size;
let mut fi = f_init.clone();
for _i in 1..n {
core::mem::swap(&mut zi, &mut zip1);
evaluation_counter.system(zi.view(), fi.view_mut());
fi *= two_substep_size;
zip1 += &fi;
}
evaluation_counter.system(zip1.view(), fi.view_mut());
fi *= substep_size;
let mut result = zi;
result += &zip1;
result += &fi;
result *= cast::<_, S::Float>(0.5).unwrap();
result
}
}
fn compute_scaled_truncation_error<F: Float + core::iter::Sum>(
y: nd::ArrayView1<F>,
y_alt: nd::ArrayView1<F>,
abs_tol: F,
rel_tol: F,
) -> F {
(y.iter()
.zip(y_alt.iter())
.map(|(&yi, &yi_alt)| {
let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
(yi - yi_alt).powi(2) / scale.powi(2)
})
.sum::<F>()
/ cast(y.len()).unwrap())
.sqrt()
}
struct SystemEvaluationCounter<'a, S: System> {
system: &'a S,
num_system_evals: usize,
}
impl<'a, S: System> SystemEvaluationCounter<'a, S> {
fn system(&mut self, y: nd::ArrayView1<S::Float>, dydt: nd::ArrayViewMut1<S::Float>) {
self.num_system_evals += 1;
<S as System>::system(&self.system, y, dydt);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exp_system_high_precision() {
struct ExpSystem {}
impl System for ExpSystem {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
dydt.assign(&y);
}
}
let system = ExpSystem {};
let integrator = Integrator::<ExpSystem>::default()
.with_abs_tol(0.)
.with_rel_tol(1e-14);
let t_final = 0.2;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
let stats = integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!(t_final.exp(), y_final[[0]], max_relative = 1e-14);
assert_eq!(stats.num_system_evals, 43);
assert_eq!(stats.num_iterations, 5);
assert_eq!(stats.num_substeps, 12);
approx::assert_relative_eq!(stats.substep_size, t_final / 12.);
assert!(stats.scaled_truncation_error < 1.);
}
#[test]
fn exp_system_handle_nans() {
struct ExpSystem {
hit_a_nan: core::cell::RefCell<bool>,
}
impl System for ExpSystem {
type Float = f64;
fn system(&self, y: ArrayView1<Self::Float>, mut dydt: ArrayViewMut1<Self::Float>) {
if y[0].abs() > 10. {
*self.hit_a_nan.borrow_mut() = true;
dydt[0] = core::f64::NAN;
} else {
dydt.assign(&(-&y));
}
}
}
let system = ExpSystem {
hit_a_nan: false.into(),
};
let integrator = Integrator::<ExpSystem>::default()
.with_abs_tol(0.)
.with_rel_tol(1e-10)
.with_step_size_policy(StepSizePolicy::Exponential)
.with_convergence_policy(ConvergencePolicy::Window {
num_iterations: core::num::NonZero::new(3).unwrap(),
});
let t_final = 5.;
let y = ndarray::array![1.];
let mut y_final = ndarray::Array::zeros([1]);
let _stats = integrator
.step(&system, t_final, y.view(), y_final.view_mut())
.unwrap();
approx::assert_relative_eq!((-t_final).exp(), y_final[[0]], max_relative = 1e-8);
assert!(*system.hit_a_nan.borrow());
}
}