use crate::{
error::Error,
interpolate::Interpolation,
methods::{DormandPrince, ExplicitRungeKutta, Ordinary, h_init::InitialStepSize},
ode::{ODE, OrdinaryNumericalMethod},
stats::Evals,
status::Status,
traits::{Real, State},
utils::{constrain_step_size, validate_step_size_parameters},
};
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
{
fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y> + ?Sized,
{
let mut evals = Evals::new();
if self.h0 == T::zero() {
self.h0 = InitialStepSize::<Ordinary>::compute(
ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
&mut evals,
);
}
match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
Ok(h0) => self.h = (self.filter)(h0),
Err(status) => return Err(status),
}
self.stiffness_counter = 0;
self.t = t0;
self.y = y0.clone();
self.dydt = y0.zeros_like();
self.y_prev = y0.clone();
self.dydt_prev = y0.zeros_like();
self.k = core::array::from_fn(|_| y0.zeros_like());
self.cont = core::array::from_fn(|_| y0.zeros_like());
ode.diff(self.t, &self.y, &mut self.k[0]);
self.dydt = self.k[0].clone();
evals.function += 1;
self.t_prev = self.t;
self.y_prev = self.y.clone();
self.dydt_prev = self.dydt.clone();
self.status = Status::Initialized;
Ok(evals)
}
fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y> + ?Sized,
{
let mut evals = Evals::new();
if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
self.status = Status::Error(Error::StepSize {
t: self.t,
y: self.y.clone(),
});
return Err(Error::StepSize {
t: self.t,
y: self.y.clone(),
});
}
if self.steps >= self.max_steps {
self.status = Status::Error(Error::MaxSteps {
t: self.t,
y: self.y.clone(),
});
return Err(Error::MaxSteps {
t: self.t,
y: self.y.clone(),
});
}
self.steps += 1;
let mut y_stage = self.y.zeros_like();
for i in 1..self.stages {
y_stage = self.y.clone();
for j in 0..i {
y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
}
ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
}
let ysti = y_stage.clone();
let mut yseg = self.y.zeros_like();
for i in 0..self.stages {
yseg.add_scaled(self.b[i], &self.k[i]);
}
let y_new = self.y.plus_scaled(self.h, &yseg);
let t_new = self.t + self.h;
evals.function += self.stages - 1;
let er = self.er.unwrap();
let n = self.y.len();
let mut err2 = T::zero();
let mut err_state = self.y.zeros_like();
for (j, coefficient) in er.iter().enumerate().take(self.stages) {
err_state.add_scaled(*coefficient, &self.k[j]);
}
let mut err = self
.y
.error_norm(&y_new, &err_state, &self.atol, &self.rtol);
if let Some(bh) = &self.bh {
let mut err2_state = yseg.clone();
for (j, coefficient) in bh.iter().enumerate().take(self.stages) {
err2_state.add_scaled(-*coefficient, &self.k[j]);
}
err2 = self
.y
.error_norm(&y_new, &err2_state, &self.atol, &self.rtol);
}
let mut deno = err + T::from_f64(0.01).unwrap() * err2;
if deno <= T::zero() {
deno = T::one();
}
err = self.h.abs() * err * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
let order = T::from_usize(self.order).unwrap();
let error_exponent = T::one() / order;
let mut scale = self.safety_factor * err.powf(-error_exponent);
scale = scale.max(self.min_scale).min(self.max_scale);
if err <= T::one() {
ode.diff(t_new, &y_new, &mut self.dydt);
evals.function += 1;
let n_stiff_threshold = 100;
if self.steps.is_multiple_of(n_stiff_threshold) {
let stdnum = yseg.diff_norm_squared(&self.k[S - 1]);
let stden = self.dydt.diff_norm_squared(&ysti);
if stden > T::zero() {
let h_lamb = self.h * (stdnum / stden).sqrt();
if h_lamb > T::from_f64(6.1).unwrap() {
self.non_stiffness_counter = 0;
self.stiffness_counter += 1;
if self.stiffness_counter == 15 {
self.status = Status::Error(Error::Stiffness {
t: self.t,
y: self.y.clone(),
});
return Err(Error::Stiffness {
t: self.t,
y: self.y.clone(),
});
}
}
} else {
self.non_stiffness_counter += 1;
if self.non_stiffness_counter == 6 {
self.stiffness_counter = 0;
}
}
}
self.cont[0] = self.y.clone();
let ydiff = y_new.minus(&self.y);
self.cont[1] = ydiff.clone();
let mut bspl = ydiff.zeros_like();
bspl.add_scaled(self.h, &self.k[0]);
bspl.add_scaled(-T::one(), &ydiff);
self.cont[2] = bspl.clone();
let mut cont3 = ydiff;
cont3.add_scaled(-self.h, &self.dydt);
cont3.add_scaled(-T::one(), &bspl);
self.cont[3] = cont3;
if let Some(bi) = &self.bi {
if I > S {
self.k[self.stages] = self.dydt.clone();
for i in S + 1..I {
let mut y_stage = self.y.clone();
for j in 0..i {
y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
}
ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
evals.function += 1;
}
}
for i in 4..self.order {
self.cont[i].fill(T::zero());
for j in 0..self.dense_stages {
self.cont[i].add_scaled(bi[i][j], &self.k[j]);
}
self.cont[i].scale_by(self.h);
}
}
self.t_prev = self.t;
self.y_prev = self.y.clone();
self.dydt_prev = self.k[0].clone();
self.h_prev = self.h;
self.t = t_new;
self.y = y_new;
self.k[0] = self.dydt.clone();
if let Status::RejectedStep = self.status {
self.status = Status::Solving;
scale = scale.min(T::one());
}
} else {
self.status = Status::RejectedStep;
}
self.h *= scale;
self.h = constrain_step_size(self.h, self.h_min, self.h_max);
self.h = (self.filter)(self.h);
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_prev
}
fn y_prev(&self) -> &Y {
&self.y_prev
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = (self.filter)(h);
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
for ExplicitRungeKutta<Ordinary, DormandPrince, T, Y, O, S, I>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
if t_interp < self.t_prev || t_interp > self.t {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev,
t_curr: self.t,
});
}
let s = (t_interp - self.t_prev) / self.h_prev;
let s1 = T::one() - s;
let ilast = self.cont.len() - 1;
let poly = (1..ilast)
.rev()
.fold(self.cont[ilast].clone(), |mut acc, i| {
let factor = if i >= 4 {
if (ilast - i) % 2 == 1 { s1 } else { s }
} else {
if i % 2 == 1 { s1 } else { s }
};
acc.scale_by(factor);
acc.add_scaled(T::one(), &self.cont[i]);
acc
});
let y_interp = self.cont[0].plus_scaled(s, &poly);
Ok(y_interp)
}
}