use std::collections::VecDeque;
use crate::{
dde::{DDE, DelayNumericalMethod},
error::Error,
interpolate::{Interpolation, cubic_hermite_interpolate},
methods::{Delay, DormandPrince, ExplicitRungeKutta, h_init::InitialStepSize},
stats::Evals,
status::Status,
traits::{Real, State},
utils::{constrain_step_size, validate_step_size_parameters},
};
impl<
const L: usize,
T: Real,
Y: State<T>,
H: Fn(T) -> Y,
const O: usize,
const S: usize,
const I: usize,
> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
{
fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
where
F: DDE<L, T, Y> + ?Sized,
{
let mut evals = Evals::new();
if L == 0 {
return Err(Error::NoLags);
}
self.t0 = t0;
self.t = t0;
self.y = y0.clone();
self.dydt = y0.zeros_like();
self.y_prev = y0.clone();
self.dydt_prev = y0.zeros_like();
self.k = core::array::from_fn(|_| y0.zeros_like());
self.cont = core::array::from_fn(|_| y0.zeros_like());
self.t_prev = self.t;
self.y_prev = self.y.clone();
self.status = Status::Initialized;
self.steps = 0;
self.stiffness_counter = 0;
self.non_stiffness_counter = 0;
self.history = VecDeque::new();
let mut delays = [T::zero(); L];
let mut y_delayed = core::array::from_fn(|_| y0.zeros_like());
dde.lags(self.t, &self.y, &mut delays);
for i in 0..L {
let t_delayed = self.t - delays[i];
if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
return Err(Error::BadInput {
msg: format!("Delayed time {} is beyond initial time {}", t_delayed, t0),
});
}
y_delayed[i] = phi(t_delayed);
}
dde.diff(self.t, &self.y, &y_delayed, &mut self.k[0]);
self.dydt = self.k[0].clone();
evals.function += 1;
self.dydt_prev = self.dydt.clone();
self.history
.push_back((self.t, self.y.clone(), self.dydt.clone()));
if self.h0 == T::zero() {
self.h0 = InitialStepSize::<Delay>::compute(
dde, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max, phi,
&self.k[0], &mut evals,
);
}
match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
Ok(h0) => self.h = (self.filter)(h0),
Err(status) => return Err(status),
}
Ok(evals)
}
fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
where
F: DDE<L, T, Y> + ?Sized,
{
let mut evals = Evals::new();
if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
self.status = Status::Error(Error::StepSize {
t: self.t,
y: self.y.clone(),
});
return Err(Error::StepSize {
t: self.t,
y: self.y.clone(),
});
}
if self.steps >= self.max_steps {
self.status = Status::Error(Error::MaxSteps {
t: self.t,
y: self.y.clone(),
});
return Err(Error::MaxSteps {
t: self.t,
y: self.y.clone(),
});
}
self.steps += 1;
let mut delays = [T::zero(); L];
let mut y_delayed = core::array::from_fn(|_| self.y.zeros_like());
let mut min_delay_abs = T::infinity();
let y_pred_for_lags = self.y.plus_scaled(self.h, &self.k[0]);
dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
for i in 0..L {
min_delay_abs = min_delay_abs.min(delays[i].abs());
}
let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
5
} else {
1
};
let mut y_next_est = self.y.clone();
let mut y_next_est_prev = self.y.clone();
let mut dde_iter_failed = false;
let mut err_norm: T = T::zero();
let mut y_last_stage = self.y.zeros_like();
for it in 0..max_iter {
if it > 0 {
y_next_est_prev = y_next_est.clone();
}
let mut y_stage = self.y.zeros_like();
for i in 1..self.stages {
y_stage = self.y.clone();
for j in 0..i {
y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
}
dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
if let Err(e) =
self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
{
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(
self.t + self.c[i] * self.h,
&y_stage,
&y_delayed,
&mut self.k[i],
);
}
evals.function += self.stages - 1;
y_last_stage = y_stage.clone();
let mut yseg = self.y.zeros_like();
for i in 0..self.stages {
yseg.add_scaled(self.b[i], &self.k[i]);
}
let y_new = self.y.plus_scaled(self.h, &yseg);
let er = self.er.unwrap();
let n = self.y.len();
let mut err2 = T::zero();
let mut err_state = self.y.zeros_like();
for (j, coefficient) in er.iter().enumerate().take(self.stages) {
err_state.add_scaled(*coefficient, &self.k[j]);
}
let err_val = self
.y
.error_norm(&y_new, &err_state, &self.atol, &self.rtol);
if let Some(bh) = &self.bh {
let mut err2_state = yseg.clone();
for (j, coefficient) in bh.iter().enumerate().take(self.stages) {
err2_state.add_scaled(-*coefficient, &self.k[j]);
}
err2 = self
.y
.error_norm(&y_new, &err2_state, &self.atol, &self.rtol);
}
let mut deno = err_val + T::from_f64(0.01).unwrap() * err2;
if deno <= T::zero() {
deno = T::one();
}
err_norm =
self.h.abs() * err_val * (T::one() / (deno * T::from_usize(n).unwrap())).sqrt();
if max_iter > 1 && it > 0 {
let n_dim = self.y.len();
let iter_diff = y_new.minus(&y_next_est_prev);
let mut dde_iteration_error =
y_next_est_prev.error_norm(&y_new, &iter_diff, &self.atol, &self.rtol);
if n_dim > 0 {
dde_iteration_error =
(dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
}
if dde_iteration_error <= self.rtol.average() * T::from_f64(0.1).unwrap() {
break;
}
if it == max_iter - 1 {
dde_iter_failed =
dde_iteration_error > self.rtol.average() * T::from_f64(0.1).unwrap();
}
}
y_next_est = y_new.clone();
}
if dde_iter_failed {
let sign = self.h.signum();
self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
if L > 0
&& min_delay_abs > T::zero()
&& self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
{
self.h = min_delay_abs * sign;
}
self.h = constrain_step_size(self.h, self.h_min, self.h_max);
self.h = (self.filter)(self.h);
self.status = Status::RejectedStep;
return Ok(evals);
}
let order = T::from_usize(self.order).unwrap();
let error_exponent = T::one() / order;
let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
scale = scale.max(self.min_scale).min(self.max_scale);
if err_norm <= T::one() {
let y_new = y_next_est.clone();
let t_new = self.t + self.h;
dde.lags(t_new, &y_new, &mut delays);
if let Err(e) = self.lagvals(t_new, &delays, &mut y_delayed, phi) {
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(t_new, &y_new, &y_delayed, &mut self.dydt);
evals.function += 1;
let n_stiff_threshold = 100;
if self.steps.is_multiple_of(n_stiff_threshold) {
let mut yseg = self.y.zeros_like();
for i in 0..self.stages {
yseg.add_scaled(self.b[i], &self.k[i]);
}
let stdnum = yseg.diff_norm_squared(&self.k[S - 1]);
let stden = self.dydt.diff_norm_squared(&y_last_stage);
if stden > T::zero() {
let h_lamb = self.h * (stdnum / stden).sqrt();
if h_lamb > T::from_f64(6.1).unwrap() {
self.non_stiffness_counter = 0;
self.stiffness_counter += 1;
if self.stiffness_counter == 15 {
self.status = Status::Error(Error::Stiffness {
t: self.t,
y: self.y.clone(),
});
return Err(Error::Stiffness {
t: self.t,
y: self.y.clone(),
});
}
}
} else {
self.non_stiffness_counter += 1;
if self.non_stiffness_counter == 6 {
self.stiffness_counter = 0;
}
}
}
self.cont[0] = self.y.clone();
let ydiff = y_new.minus(&self.y);
self.cont[1] = ydiff.clone();
let mut bspl = ydiff.zeros_like();
bspl.add_scaled(self.h, &self.k[0]);
bspl.add_scaled(-T::one(), &ydiff);
self.cont[2] = bspl.clone();
let mut cont3 = ydiff;
cont3.add_scaled(-self.h, &self.dydt);
cont3.add_scaled(-T::one(), &bspl);
self.cont[3] = cont3;
if self.bi.is_some() {
if I > S {
self.k[self.stages] = self.dydt.clone();
for i in S + 1..I {
let mut y_stage = self.y.clone();
for j in 0..i {
y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
}
let t_stage = self.t + self.c[i] * self.h;
dde.lags(t_stage, &y_stage, &mut delays);
if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(t_stage, &y_stage, &y_delayed, &mut self.k[i]);
evals.function += 1;
}
}
for i in 4..self.order {
self.cont[i].fill(T::zero());
for j in 0..self.dense_stages {
let bi = self.bi.as_ref().expect("dense output coefficients checked");
self.cont[i].add_scaled(bi[i][j], &self.k[j]);
}
self.cont[i].scale_by(self.h);
}
}
self.t_prev = self.t;
self.y_prev = self.y.clone();
self.dydt_prev = self.k[0].clone();
self.h_prev = self.h;
self.t = t_new;
self.y = y_new;
self.k[0] = self.dydt.clone();
self.history
.push_back((self.t, self.y.clone(), self.dydt.clone()));
if let Some(max_delay) = self.max_delay {
let cutoff_time = self.t - max_delay;
while let Some((t_front, _, _)) = self.history.get(1) {
if *t_front < cutoff_time {
self.history.pop_front();
} else {
break;
}
}
}
if let Status::RejectedStep = self.status {
self.status = Status::Solving;
scale = scale.min(T::one());
}
} else {
self.status = Status::RejectedStep;
}
self.h *= scale;
self.h = constrain_step_size(self.h, self.h_min, self.h_max);
self.h = (self.filter)(self.h);
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_prev
}
fn y_prev(&self) -> &Y {
&self.y_prev
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = (self.filter)(h);
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
{
fn lagvals<const L: usize, H>(
&mut self,
t_stage: T,
lags: &[T; L],
yd: &mut [Y; L],
phi: &H,
) -> Result<(), Error<T, Y>>
where
H: Fn(T) -> Y,
{
for i in 0..L {
let t_delayed = t_stage - lags[i];
if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
yd[i] = phi(t_delayed);
} else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
if self.bi.is_some() {
let theta = (t_delayed - self.t_prev) / self.h_prev;
let one_minus_theta = T::one() - theta;
let ilast = self.cont.len() - 1;
let poly = (1..ilast)
.rev()
.fold(self.cont[ilast].clone(), |mut acc, i| {
let factor = if i >= 4 {
if (ilast - i) % 2 == 1 {
one_minus_theta
} else {
theta
}
} else if i % 2 == 1 {
one_minus_theta
} else {
theta
};
acc.scale_by(factor);
acc.add_scaled(T::one(), &self.cont[i]);
acc
});
let y_interp = self.cont[0].plus_scaled(theta, &poly);
yd[i] = y_interp;
} else {
yd[i] = cubic_hermite_interpolate(
self.t_prev,
self.t,
&self.y_prev,
&self.y,
&self.dydt_prev,
&self.dydt,
t_delayed,
);
}
} else {
let mut found_interpolation = false;
let buffer = &self.history;
let mut buffer_iter = buffer.iter();
if let Some(mut prev_entry) = buffer_iter.next() {
for curr_entry in buffer_iter {
let (t_left, y_left, dydt_left) = prev_entry;
let (t_right, y_right, dydt_right) = curr_entry;
let is_between = if self.h.signum() > T::zero() {
*t_left <= t_delayed && t_delayed <= *t_right
} else {
*t_right <= t_delayed && t_delayed <= *t_left
};
if is_between {
yd[i] = cubic_hermite_interpolate(
*t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
t_delayed,
);
found_interpolation = true;
break;
}
prev_entry = curr_entry;
}
}
if !found_interpolation {
return Err(Error::InsufficientHistory {
t_delayed,
t_prev: self.t_prev,
t_curr: self.t,
});
}
}
}
Ok(())
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
for ExplicitRungeKutta<Delay, DormandPrince, T, Y, O, S, I>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
let dir = (self.t - self.t_prev).signum();
if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev,
t_curr: self.t,
});
}
let theta = (t_interp - self.t_prev) / self.h_prev;
let one_minus_theta = T::one() - theta;
let ilast = self.cont.len() - 1;
let poly = (1..ilast)
.rev()
.fold(self.cont[ilast].clone(), |mut acc, i| {
let factor = if i >= 4 {
if (ilast - i) % 2 == 1 {
one_minus_theta
} else {
theta
}
} else if i % 2 == 1 {
one_minus_theta
} else {
theta
};
acc.scale_by(factor);
acc.add_scaled(T::one(), &self.cont[i]);
acc
});
let y_interp = self.cont[0].plus_scaled(theta, &poly);
Ok(y_interp)
}
}