use crate::traits::FloatScalar;
use crate::Matrix;
pub fn rk4_step<T: FloatScalar, const M: usize, const N: usize>(
t: T,
y: &Matrix<T, M, N>,
h: T,
mut f: impl FnMut(T, &Matrix<T, M, N>) -> Matrix<T, M, N>,
) -> Matrix<T, M, N> {
let half = T::from(0.5).unwrap();
let sixth = T::from(1.0 / 6.0).unwrap();
let third = T::from(1.0 / 3.0).unwrap();
let k1 = f(t, y);
let k2 = f(t + h * half, &(*y + k1 * h * half));
let k3 = f(t + h * half, &(*y + k2 * h * half));
let k4 = f(t + h, &(*y + k3 * h));
*y + (k1 * sixth + k2 * third + k3 * third + k4 * sixth) * h
}
pub fn rk4<T: FloatScalar, const M: usize, const N: usize>(
t0: T,
tf: T,
dt: T,
y0: &Matrix<T, M, N>,
mut f: impl FnMut(T, &Matrix<T, M, N>) -> Matrix<T, M, N>,
) -> Matrix<T, M, N> {
let mut t = t0;
let mut y = *y0;
let tdir = if tf > t0 { T::one() } else { -T::one() };
let mut h = dt.abs() * tdir;
loop {
if (tdir > T::zero() && t + h > tf) || (tdir < T::zero() && t + h < tf) {
h = tf - t;
}
y = rk4_step(t, &y, h, &mut f);
t = t + h;
if (tdir > T::zero() && t >= tf) || (tdir < T::zero() && t <= tf) {
break;
}
}
y
}