use crate::{
error::Error,
interpolate::{Interpolation, cubic_hermite_interpolate},
methods::{Fixed, Ordinary},
ode::{ODE, OrdinaryNumericalMethod},
stats::Evals,
status::Status,
traits::{Real, State},
utils::validate_step_size_parameters,
};
use super::AdamsPredictorCorrector;
impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4> {
pub fn f4(h: T) -> Self {
Self {
h,
..Default::default()
}
}
}
impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
{
fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y> + ?Sized,
{
let mut evals = Evals::new();
match validate_step_size_parameters::<T, Y>(self.h, T::zero(), T::infinity(), t0, tf) {
Ok(h) => self.h = h,
Err(e) => return Err(e),
}
self.t = t0;
self.y = y0.clone();
self.dydt = y0.zeros_like();
self.dydt_old = y0.zeros_like();
self.y_old = y0.clone();
self.y_prev = core::array::from_fn(|_| y0.zeros_like());
self.k = core::array::from_fn(|_| y0.zeros_like());
self.t_prev[0] = t0;
self.y_prev[0] = y0.clone();
self.t_old = self.t;
self.y_old = self.y.clone();
let two = T::from_f64(2.0).unwrap();
let six = T::from_f64(6.0).unwrap();
for i in 1..=3 {
ode.diff(self.t, &self.y, &mut self.k[0]);
ode.diff(
self.t + self.h / two,
&self.y.plus_scaled(self.h / two, &self.k[0]),
&mut self.k[1],
);
ode.diff(
self.t + self.h / two,
&self.y.plus_scaled(self.h / two, &self.k[1]),
&mut self.k[2],
);
ode.diff(
self.t + self.h,
&self.y.plus_scaled(self.h, &self.k[2]),
&mut self.k[3],
);
self.y.add_scaled(self.h / six, &self.k[0]);
self.y.add_scaled(two * self.h / six, &self.k[1]);
self.y.add_scaled(two * self.h / six, &self.k[2]);
self.y.add_scaled(self.h / six, &self.k[3]);
self.t += self.h;
self.t_prev[i] = self.t;
self.y_prev[i] = self.y.clone();
evals.function += 4;
if i == 1 {
self.dydt = self.k[0].clone();
self.dydt_old = self.dydt.clone();
}
}
self.status = Status::Initialized;
Ok(evals)
}
fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
where
F: ODE<T, Y> + ?Sized,
{
let mut evals = Evals::new();
self.t_old = self.t;
self.y_old = self.y.clone();
self.dydt_old = self.dydt.clone();
ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);
let predictor = self.y_prev[3].plus_linear_combination(&[
(
&self.k[0],
self.h * T::from_f64(55.0).unwrap() / T::from_f64(24.0).unwrap(),
),
(
&self.k[1],
-self.h * T::from_f64(59.0).unwrap() / T::from_f64(24.0).unwrap(),
),
(
&self.k[2],
self.h * T::from_f64(37.0).unwrap() / T::from_f64(24.0).unwrap(),
),
(
&self.k[3],
-self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
),
]);
ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
let corrector = self.y_prev[3].plus_linear_combination(&[
(
&self.k[3],
self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
),
(
&self.k[0],
self.h * T::from_f64(19.0).unwrap() / T::from_f64(24.0).unwrap(),
),
(
&self.k[1],
-self.h * T::from_f64(5.0).unwrap() / T::from_f64(24.0).unwrap(),
),
(&self.k[2], self.h / T::from_f64(24.0).unwrap()),
]);
self.t += self.h;
self.y = corrector;
ode.diff(self.t, &self.y, &mut self.dydt);
evals.function += 6;
self.t_prev.copy_within(1..4, 0);
self.y_prev.rotate_left(1);
self.t_prev[3] = self.t;
self.y_prev[3] = self.y.clone();
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_old
}
fn y_prev(&self) -> &Y {
&self.y_old
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = h;
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>> Interpolation<T, Y>
for AdamsPredictorCorrector<Ordinary, Fixed, T, Y, 4>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
if t_interp < self.t_prev[0] || t_interp > self.t {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev[0],
t_curr: self.t,
});
}
let y_interp = cubic_hermite_interpolate(
self.t_old,
self.t,
&self.y_old,
&self.y,
&self.dydt_old,
&self.dydt,
t_interp,
);
Ok(y_interp)
}
}