use crate::ode::controller::Controller;
use crate::ode::stepper::{Stepper, StepperData};
use crate::vector::Vector;
use num::Zero;
pub struct DormandPrince<const N: usize> {
x: f64,
x_old: f64,
y: Vector<N>,
dydx: Vector<N>,
h_did: f64,
h_next: f64,
y_out: Vector<N>,
y_err: Vector<N>,
k2: Vector<N>,
k3: Vector<N>,
k4: Vector<N>,
k5: Vector<N>,
k6: Vector<N>,
dydxnew: Vector<N>,
atol: f64,
rtol: f64,
controller: Controller,
data: StepperData<N>,
}
impl<const N: usize> DormandPrince<N> {
pub fn new(h: f64, p0: Vector<N>, df: fn(f64, Vector<N>) -> Vector<N>) -> Self {
let x = 0.0;
let x_old = 0.0;
let y = p0;
let dydx = Vector::zero();
let y_out = Vector::zero();
let y_err = Vector::zero();
let h_did = 0.0;
let h_next = 0.0;
let k2 = Vector::zero();
let k3 = Vector::zero();
let k4 = Vector::zero();
let k5 = Vector::zero();
let k6 = Vector::zero();
let dydxnew = Vector::zero();
let atol = 0.01;
let rtol = 0.01;
let controller = Controller::new();
let data = StepperData::new(h, p0, df);
DormandPrince {
x,
x_old,
y,
dydx,
y_out,
y_err,
h_did,
h_next,
k2,
k3,
k4,
k5,
k6,
dydxnew,
atol,
rtol,
controller,
data,
}
}
fn error(&mut self) -> f64 {
let mut err = 0.0;
for i in 0..self.y.dim() {
let sk = self.atol + self.rtol * f64::max(self.y[i].abs(), self.y_out[i].abs());
err += (self.y_err[i] / sk).powi(2);
}
(err / self.y.dim() as f64).sqrt()
}
fn dy(&mut self) {
static C2: f64 = 0.2;
static C3: f64 = 0.3;
static C4: f64 = 0.8;
static C5: f64 = 8.0 / 9.0;
static A21: f64 = 0.2;
static A31: f64 = 3.0 / 40.0;
static A32: f64 = 9.0 / 40.0;
static A41: f64 = 44.0 / 45.0;
static A42: f64 = -56.0 / 15.0;
static A43: f64 = 32.0 / 9.0;
static A51: f64 = 19372.0 / 6561.0;
static A52: f64 = -25360.0 / 2187.0;
static A53: f64 = 64448.0 / 6561.0;
static A54: f64 = -212.0 / 729.0;
static A61: f64 = 9017.0 / 3168.0;
static A62: f64 = -355.0 / 33.0;
static A63: f64 = 46732.0 / 5247.0;
static A64: f64 = 49.0 / 176.0;
static A65: f64 = -5103.0 / 18656.0;
static A71: f64 = 35.0 / 384.0;
static A72: f64 = 0.0;
static A73: f64 = 500.0 / 1113.0;
static A74: f64 = 125.0 / 192.0;
static A75: f64 = -2187.0 / 6784.0;
static A76: f64 = 11.0 / 84.0;
static E1: f64 = 71.0 / 57600.0;
static E3: f64 = -71.0 / 16695.0;
static E4: f64 = 71.0 / 1920.0;
static E5: f64 = -17253.0 / 339200.0;
static E6: f64 = 22.0 / 525.0;
static E7: f64 = -1.0 / 40.0;
let mut y_temp = self.y + self.dydx * A21 * self.data.h_cur;
self.k2 = (self.data.derive)(self.x + C2 * self.data.h_cur, y_temp);
y_temp = self.y + (self.dydx * A31 + self.k2 * A32) * self.data.h_cur;
self.k3 = (self.data.derive)(self.x + C3 * self.data.h_cur, y_temp);
y_temp = self.y + (self.dydx * A41 + self.k2 * A42 + self.k3 * A43) * self.data.h_cur;
self.k4 = (self.data.derive)(self.x + C4 * self.data.h_cur, y_temp);
y_temp = self.y
+ (self.dydx * A51 + self.k2 * A52 + self.k3 * A53 + self.k4 * A54) * self.data.h_cur;
self.k5 = (self.data.derive)(self.x + C5 * self.data.h_cur, y_temp);
y_temp = self.y
+ (self.dydx * A61 + self.k2 * A62 + self.k3 * A63 + self.k4 * A64 + self.k5 * A65)
* self.data.h_cur;
let xph = self.x + self.data.h_cur;
self.k6 = (self.data.derive)(xph, y_temp);
self.y_out = self.y
+ (self.dydx * A71
+ self.k2 * A72
+ self.k3 * A73
+ self.k4 * A74
+ self.k5 * A75
+ self.k6 * A76)
* self.data.h_cur;
self.dydxnew = (self.data.derive)(xph, self.y_out);
self.y_err = (self.dydx * E1
+ self.k3 * E3
+ self.k4 * E4
+ self.k5 * E5
+ self.k6 * E6
+ self.dydxnew * E7)
* self.data.h_cur;
}
}
impl<const N: usize> Stepper<N> for DormandPrince<N> {
fn step(&mut self) -> Vector<N> {
loop {
self.dy();
let err = self.error();
let (success, h_new) = self.controller.success(err, self.data.h_cur);
if success {
break;
}
self.data.h_cur = h_new;
}
self.dydx = self.dydxnew;
self.y = self.y_out;
self.x_old = self.x;
self.h_did = self.data.h_cur;
self.x += self.h_did;
self.h_next = self.controller.h_next;
self.y
}
}