use crate::{
error::Error,
interpolate::{Interpolation, cubic_hermite_interpolate},
methods::{Fixed, Ordinary},
ode::{ODE, OrdinaryNumericalMethod},
stats::Evals,
status::Status,
traits::{Real, State},
utils::validate_step_size_parameters,
};
use super::AdamsPredictorCorrector;
impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4> {
pub fn f4(h: T) -> Self {
Self {
h,
..Default::default()
}
}
}
impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
{
fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y>,
{
let mut evals = Evals::new();
match validate_step_size_parameters::<T, Y>(self.h, T::zero(), T::infinity(), t0, tf) {
Ok(h) => self.h = h,
Err(e) => return Err(e),
}
self.t = t0;
self.y = *y0;
self.t_prev[0] = t0;
self.y_prev[0] = *y0;
self.t_old = self.t;
self.y_old = self.y;
let two = T::from_f64(2.0).unwrap();
let six = T::from_f64(6.0).unwrap();
for i in 1..=3 {
ode.diff(self.t, &self.y, &mut self.k[0]);
ode.diff(
self.t + self.h / two,
&(self.y + self.k[0] * (self.h / two)),
&mut self.k[1],
);
ode.diff(
self.t + self.h / two,
&(self.y + self.k[1] * (self.h / two)),
&mut self.k[2],
);
ode.diff(
self.t + self.h,
&(self.y + self.k[2] * self.h),
&mut self.k[3],
);
self.y += (self.k[0] + self.k[1] * two + self.k[2] * two + self.k[3]) * (self.h / six);
self.t += self.h;
self.t_prev[i] = self.t;
self.y_prev[i] = self.y;
evals.function += 4;
if i == 1 {
self.dydt = self.k[0];
self.dydt_old = self.dydt;
}
}
self.status = Status::Initialized;
Ok(evals)
}
fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y>,
{
let mut evals = Evals::new();
self.t_old = self.t;
self.y_old = self.y;
self.dydt_old = self.dydt;
ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
let predictor = self.y_prev[3]
+ (self.k[0] * T::from_f64(55.0).unwrap() - self.k[1] * T::from_f64(59.0).unwrap()
+ self.k[2] * T::from_f64(37.0).unwrap()
- self.k[3] * T::from_f64(9.0).unwrap())
* self.h
/ T::from_f64(24.0).unwrap();
ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
let corrector = self.y_prev[3]
+ (self.k[3] * T::from_f64(9.0).unwrap() + self.k[0] * T::from_f64(19.0).unwrap()
- self.k[1] * T::from_f64(5.0).unwrap()
+ self.k[2] * T::from_f64(1.0).unwrap())
* (self.h / T::from_f64(24.0).unwrap());
self.t += self.h;
self.y = corrector;
ode.diff(self.t, &self.y, &mut self.dydt);
evals.function += 6;
self.t_prev.copy_within(1..4, 0);
self.y_prev.copy_within(1..4, 0);
self.t_prev[3] = self.t;
self.y_prev[3] = self.y;
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_old
}
fn y_prev(&self) -> &Y {
&self.y_old
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = h;
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>> Interpolation<T, Y>
for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
if t_interp < self.t_prev[0] || t_interp > self.t {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev[0],
t_curr: self.t,
});
}
let y_interp = cubic_hermite_interpolate(
self.t_old,
self.t,
&self.y_old,
&self.y,
&self.dydt_old,
&self.dydt,
t_interp,
);
Ok(y_interp)
}
}