#![expect(
non_snake_case,
reason = "Used for math symbols to match notation in Numerical Recipes"
)]
pub use nd::ArrayView1;
pub use nd::ArrayViewMut1;
use ndarray as nd;
use num_traits::cast;
pub trait Float:
num_traits::Float + core::iter::Sum + core::ops::AddAssign
{
}
impl Float for f32 {}
impl Float for f64 {}
pub trait System {
type Float: Float;
fn system(
&self,
y: ArrayView1<Self::Float>,
dydt: ArrayViewMut1<Self::Float>,
);
}
#[must_use]
#[derive(Debug)]
pub struct Stats<F: Float> {
pub num_system_evals: usize,
pub num_iterations: usize,
pub num_midpoint_substeps: usize,
pub midpoint_substep_size: F,
pub scaled_truncation_error: F,
}
#[must_use]
#[derive(Debug)]
pub struct FailedToConverge<F: Float> {
pub stats: Stats<F>,
}
pub struct Integrator<S: System> {
abs_tol: S::Float,
rel_tol: S::Float,
max_iterations: usize,
}
impl<S: System> Default for Integrator<S>
where
S::Float: Float + nd::ScalarOperand,
{
fn default() -> Self {
Self {
abs_tol: cast(1e-5).unwrap(),
rel_tol: cast(1e-5).unwrap(),
max_iterations: 20,
}
}
}
impl<S: System> Integrator<S>
where
S::Float: Float + nd::ScalarOperand,
{
pub fn with_abs_tol(self, abs_tol: S::Float) -> Self {
Self { abs_tol, ..self }
}
pub fn with_rel_tol(self, rel_tol: S::Float) -> Self {
Self { rel_tol, ..self }
}
pub fn with_max_iterations(self, max_iterations: usize) -> Self {
Self {
max_iterations,
..self
}
}
pub fn step(
&self,
system: &S,
delta_t: S::Float,
y_init: nd::ArrayView1<S::Float>,
mut y_final: nd::ArrayViewMut1<S::Float>,
) -> Result<Stats<S::Float>, FailedToConverge<S::Float>> {
let mut evaluation_counter = EvaluationCounter {
system,
num_system_evals: 0,
};
let f_init = {
let mut f_init = nd::Array1::zeros(y_init.raw_dim());
evaluation_counter.system(y_init, f_init.view_mut());
f_init
};
let compute_n = |k: usize| -> usize { 2 * (k + 1) };
let mut T = Vec::<Vec<nd::Array1<S::Float>>>::new();
for k in 0..self.max_iterations {
let n = compute_n(k);
let mut Tk = Vec::with_capacity(k + 1);
Tk.push(self.midpoint_step(
&mut evaluation_counter,
delta_t,
n,
&f_init,
y_init,
));
for j in 0..k {
let denominator = <S::Float as num_traits::Float>::powi(
cast::<_, S::Float>(n).unwrap()
/ cast(compute_n(k - j - 1)).unwrap(),
2,
) - <S::Float as num_traits::One>::one();
Tk.push(&Tk[j] + (&Tk[j] - &T[k - 1][j]) / denominator);
}
if k > 0 {
let last_two = Tk.last_chunk::<2>().unwrap();
let scaled_truncation_error = compute_scaled_truncation_error(
last_two[0].view(),
last_two[1].view(),
self.abs_tol,
self.rel_tol,
);
if scaled_truncation_error
<= <S::Float as num_traits::One>::one()
{
y_final.assign(&last_two[1]);
return Ok(Stats {
num_system_evals: evaluation_counter.num_system_evals,
num_iterations: k,
num_midpoint_substeps: n,
midpoint_substep_size: delta_t
/ cast::<_, S::Float>(n).unwrap(),
scaled_truncation_error,
});
}
}
T.push(Tk);
}
let last_two = T.last().unwrap().last_chunk::<2>().unwrap();
let scaled_truncation_error = compute_scaled_truncation_error(
last_two[0].view(),
last_two[1].view(),
self.abs_tol,
self.rel_tol,
);
let n = compute_n(self.max_iterations);
Err(FailedToConverge {
stats: Stats {
num_system_evals: evaluation_counter.num_system_evals,
num_iterations: self.max_iterations,
num_midpoint_substeps: n,
midpoint_substep_size: delta_t / cast(n).unwrap(),
scaled_truncation_error,
},
})
}
fn midpoint_step(
&self,
evaluation_counter: &mut EvaluationCounter<S>,
delta_t: S::Float,
n: usize,
f_init: &nd::Array1<S::Float>,
y_init: nd::ArrayView1<S::Float>,
) -> nd::Array1<S::Float> {
let substep_size = delta_t / cast(n).unwrap();
let mut zi = y_init.to_owned();
let mut zip1 = &zi + f_init * substep_size;
let mut fi = f_init.clone();
for _i in 1..n {
std::mem::swap(&mut zi, &mut zip1);
evaluation_counter.system(zi.view(), fi.view_mut());
zip1 += &(&fi * cast::<_, S::Float>(2.).unwrap() * substep_size);
}
evaluation_counter.system(zip1.view(), fi.view_mut());
(&zi + &zip1 + fi * S::Float::from(substep_size))
* cast::<_, S::Float>(0.5).unwrap()
}
}
fn compute_scaled_truncation_error<F: Float + core::iter::Sum>(
y: nd::ArrayView1<F>,
y_alt: nd::ArrayView1<F>,
abs_tol: F,
rel_tol: F,
) -> F {
(y.iter()
.zip(y_alt.iter())
.map(|(&yi, &yi_alt)| {
let scale = abs_tol + rel_tol * yi_alt.abs().max(yi.abs());
(yi - yi_alt).powi(2) / scale.powi(2)
})
.sum::<F>()
/ cast(y.len()).unwrap())
.sqrt()
}
struct EvaluationCounter<'a, S: System> {
system: &'a S,
num_system_evals: usize,
}
impl<'a, S: System> EvaluationCounter<'a, S> {
fn system(
&mut self,
y: nd::ArrayView1<S::Float>,
dydt: nd::ArrayViewMut1<S::Float>,
) {
self.num_system_evals += 1;
<S as System>::system(&self.system, y, dydt);
}
}