use std::collections::VecDeque;
use crate::{
dde::{DDE, DelayNumericalMethod},
error::Error,
interpolate::{Interpolation, cubic_hermite_interpolate},
methods::{Delay, ExplicitRungeKutta, Fixed},
stats::Evals,
status::Status,
traits::{Real, State},
utils::validate_step_size_parameters,
};
impl<
const L: usize,
T: Real,
Y: State<T>,
H: Fn(T) -> Y,
const O: usize,
const S: usize,
const I: usize,
> DelayNumericalMethod<L, T, Y, H> for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
{
fn init<F>(&mut self, dde: &F, t0: T, tf: T, y0: &Y, phi: &H) -> Result<Evals, Error<T, Y>>
where
F: DDE<L, T, Y>,
{
let mut evals = Evals::new();
if L <= 0 {
return Err(Error::NoLags);
}
self.t0 = t0;
self.t = t0;
self.y = *y0;
self.t_prev = self.t;
self.y_prev = self.y;
self.status = Status::Initialized;
self.steps = 0;
self.history = VecDeque::new();
let mut delays = [T::zero(); L];
let mut y_delayed = [Y::zeros(); L];
dde.lags(self.t, &self.y, &mut delays);
for i in 0..L {
let t_delayed = self.t - delays[i];
if (t_delayed - t0) * (tf - t0).signum() > T::default_epsilon() {
return Err(Error::BadInput {
msg: format!(
"Initial delayed time {} is out of history range (t <= {}).",
t_delayed, t0
),
});
}
y_delayed[i] = phi(t_delayed);
}
dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
evals.function += 1;
self.dydt_prev = self.dydt; self.history.push_back((self.t, self.y, self.dydt));
if self.h0 == T::zero() {
let duration = (tf - t0).abs();
let default_steps = T::from_usize(100).unwrap();
self.h0 = duration / default_steps;
}
match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
Ok(h0) => self.h = h0,
Err(status) => return Err(status),
}
Ok(evals)
}
fn step<F>(&mut self, dde: &F, phi: &H) -> Result<Evals, Error<T, Y>>
where
F: DDE<L, T, Y>,
{
let mut evals = Evals::new();
if self.steps >= self.max_steps {
self.status = Status::Error(Error::MaxSteps {
t: self.t,
y: self.y,
});
return Err(Error::MaxSteps {
t: self.t,
y: self.y,
});
}
self.steps += 1;
let mut delays = [T::zero(); L];
let mut y_delayed = [Y::zeros(); L];
self.k[0] = self.dydt;
let mut min_delay_abs = T::infinity();
let y_pred_for_lags = self.y + self.k[0] * self.h;
dde.lags(self.t + self.h, &y_pred_for_lags, &mut delays);
for i in 0..L {
min_delay_abs = min_delay_abs.min(delays[i].abs());
}
let max_iter: usize = if min_delay_abs < self.h.abs() && min_delay_abs > T::zero() {
5
} else {
1
};
let mut y_next_candidate_iter = self.y; let mut dydt_next_candidate_iter = Y::zeros(); let mut y_prev_candidate_iter = self.y; let mut dde_iteration_failed = false;
for iter_idx in 0..max_iter {
if iter_idx > 0 {
y_prev_candidate_iter = y_next_candidate_iter;
}
for i in 1..self.stages {
let mut y_stage = self.y;
for j in 0..i {
y_stage += self.k[j] * (self.a[i][j] * self.h);
}
dde.lags(self.t + self.c[i] * self.h, &y_stage, &mut delays);
if let Err(e) =
self.lagvals(self.t + self.c[i] * self.h, &delays, &mut y_delayed, phi)
{
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(
self.t + self.c[i] * self.h,
&y_stage,
&y_delayed,
&mut self.k[i],
);
}
evals.function += self.stages - 1;
let mut y_next = self.y;
for i in 0..self.stages {
y_next += self.k[i] * (self.b[i] * self.h);
}
if max_iter > 1 && iter_idx > 0 {
let mut dde_iteration_error = T::zero();
let n_dim = self.y.len();
for i_dim in 0..n_dim {
let scale = T::from_f64(1e-10).unwrap()
+ y_prev_candidate_iter
.get(i_dim)
.abs()
.max(y_next.get(i_dim).abs());
if scale > T::zero() {
let diff_val = y_next.get(i_dim) - y_prev_candidate_iter.get(i_dim);
dde_iteration_error += (diff_val / scale).powi(2);
}
}
if n_dim > 0 {
dde_iteration_error =
(dde_iteration_error / T::from_usize(n_dim).unwrap()).sqrt();
}
if dde_iteration_error <= T::from_f64(1e-6).unwrap() {
break;
}
if iter_idx == max_iter - 1 {
dde_iteration_failed = dde_iteration_error > T::from_f64(1e-6).unwrap();
}
}
y_next_candidate_iter = y_next;
dde.lags(self.t + self.h, &y_next_candidate_iter, &mut delays);
if let Err(e) = self.lagvals(self.t + self.h, &delays, &mut y_delayed, phi) {
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(
self.t + self.h,
&y_next_candidate_iter,
&y_delayed,
&mut dydt_next_candidate_iter,
);
evals.function += 1;
}
if dde_iteration_failed {
let sign = self.h.signum();
self.h = (self.h.abs() * T::from_f64(0.5).unwrap()).max(self.h_min.abs()) * sign;
if L > 0
&& min_delay_abs > T::zero()
&& self.h.abs() < T::from_f64(2.0).unwrap() * min_delay_abs
{
self.h = min_delay_abs * sign;
}
self.status = Status::RejectedStep;
return Ok(evals);
}
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
self.t += self.h;
self.y = y_next_candidate_iter;
if self.fsal {
self.dydt = self.k[S - 1];
} else {
dde.lags(self.t, &self.y, &mut delays);
if let Err(e) = self.lagvals(self.t, &delays, &mut y_delayed, phi) {
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(self.t, &self.y, &y_delayed, &mut self.dydt);
evals.function += 1;
}
if self.bi.is_some() {
for i in 0..(I - S) {
let mut y_stage_dense = self.y_prev;
for j in 0..self.stages + i {
y_stage_dense += self.k[j] * (self.a[self.stages + i][j] * self.h);
}
let t_stage = self.t_prev + self.c[self.stages + i] * self.h;
dde.lags(t_stage, &y_stage_dense, &mut delays);
if let Err(e) = self.lagvals(t_stage, &delays, &mut y_delayed, phi) {
self.status = Status::Error(e.clone());
return Err(e);
}
dde.diff(
self.t_prev + self.c[self.stages + i] * self.h,
&y_stage_dense,
&y_delayed,
&mut self.k[self.stages + i],
);
}
evals.function += I - S;
}
self.history.push_back((self.t, self.y, self.dydt));
if let Some(max_delay) = self.max_delay {
let cutoff_time = self.t - max_delay;
while let Some((t_front, _, _)) = self.history.get(1) {
if *t_front < cutoff_time {
self.history.pop_front();
} else {
break;
}
}
}
self.status = Status::Solving;
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_prev
}
fn y_prev(&self) -> &Y {
&self.y_prev
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = h;
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
{
pub fn lagvals<const L: usize, H>(
&mut self,
t_stage: T,
delays: &[T; L],
y_delayed: &mut [Y; L],
phi: &H,
) -> Result<(), Error<T, Y>>
where
H: Fn(T) -> Y,
{
for i in 0..L {
let t_delayed = t_stage - delays[i];
if (t_delayed - self.t0) * self.h.signum() <= T::default_epsilon() {
y_delayed[i] = phi(t_delayed);
} else if (t_delayed - self.t_prev) * self.h.signum() > T::default_epsilon() {
if self.bi.is_some() {
let s = (t_delayed - self.t_prev) / self.h_prev;
let bi_coeffs = self.bi.as_ref().unwrap();
let mut cont = [T::zero(); I];
for i in 0..I {
if i < cont.len() && i < bi_coeffs.len() {
cont[i] = bi_coeffs[i][self.dense_stages - 1];
for j in (0..self.dense_stages - 1).rev() {
cont[i] = cont[i] * s + bi_coeffs[i][j];
}
cont[i] *= s;
}
}
let mut y_interp = self.y_prev;
for i in 0..I {
if i < self.k.len() && i < cont.len() {
y_interp += self.k[i] * (cont[i] * self.h_prev);
}
}
y_delayed[i] = y_interp;
} else {
y_delayed[i] = cubic_hermite_interpolate(
self.t_prev,
self.t,
&self.y_prev,
&self.y,
&self.dydt_prev,
&self.dydt,
t_delayed,
);
} } else {
let mut found_interpolation = false;
let buffer = &self.history;
let mut buffer_iter = buffer.iter();
if let Some(mut prev_entry) = buffer_iter.next() {
for curr_entry in buffer_iter {
let (t_left, y_left, dydt_left) = prev_entry;
let (t_right, y_right, dydt_right) = curr_entry;
let is_between = if self.h.signum() > T::zero() {
*t_left <= t_delayed && t_delayed <= *t_right
} else {
*t_right <= t_delayed && t_delayed <= *t_left
};
if is_between {
y_delayed[i] = cubic_hermite_interpolate(
*t_left, *t_right, y_left, y_right, dydt_left, dydt_right,
t_delayed,
);
found_interpolation = true;
break;
}
prev_entry = curr_entry;
}
} if !found_interpolation {
return Err(Error::InsufficientHistory {
t_delayed,
t_prev: self.t_prev,
t_curr: self.t,
});
}
}
}
Ok(())
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
for ExplicitRungeKutta<Delay, Fixed, T, Y, O, S, I>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
let dir = self.h.signum();
if (t_interp - self.t_prev) * dir < T::zero() || (t_interp - self.t) * dir > T::zero() {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev,
t_curr: self.t,
});
}
if self.bi.is_some() {
let s = (t_interp - self.t_prev) / self.h_prev;
let bi = self.bi.as_ref().unwrap();
let mut cont = [T::zero(); I];
for i in 0..self.dense_stages {
cont[i] = bi[i][self.order - 1];
for j in (0..self.order - 1).rev() {
cont[i] = cont[i] * s + bi[i][j];
}
cont[i] *= s;
}
let mut y_interp = self.y_prev;
for i in 0..I {
y_interp += self.k[i] * cont[i] * self.h_prev;
}
Ok(y_interp)
} else {
let y_interp = cubic_hermite_interpolate(
self.t_prev,
self.t,
&self.y_prev,
&self.y,
&self.dydt_prev,
&self.dydt,
t_interp,
);
Ok(y_interp)
}
}
}