use crate::{SolverError, SolverResult};
use nalgebra::{ClosedAddAssign, ClosedMulAssign};
use num_traits::Float;
pub enum ODESolverMethod {
EulerForward,
Heun,
RungeKutta4,
}
pub struct ODESolver<T, V, F> {
f: F,
x0: T,
y0: V,
h: T,
half_h: T,
method: fn(&Self, T, V) -> V,
}
impl<T, V, F> ODESolver<T, V, F>
where
T: Float,
V: Copy + ClosedAddAssign + ClosedMulAssign<T>,
F: Fn(T, V) -> V,
{
pub fn new(f: F, x0: T, y0: V, h: T) -> Self {
Self {
f,
x0,
y0,
h,
half_h: h / T::from(2_f64).unwrap(),
method: Self::rk4_step,
}
}
pub fn solve(&self, x_end: T) -> SolverResult<V> {
let mut x = self.x0;
let mut y = self.y0;
let steps = T::to_usize(&((x_end - self.x0) / self.h)).unwrap_or(0);
if steps == 0 {
return Err(SolverError::IncorrectInput {
details: "the number of steps should be positive",
});
}
for _ in 1..steps {
y = (self.method)(self, x, y);
x = x + self.h;
}
Ok(y)
}
pub fn with_step_size(&mut self, h: T) -> &mut Self {
self.h = h;
self.half_h = h / T::from(2.).unwrap();
self
}
pub fn with_steps(&mut self, x_end: T, steps: usize) -> &mut Self {
self.h = (x_end - self.x0) / T::from(steps).unwrap();
self.half_h = self.h / T::from(2.).unwrap();
self
}
pub fn with_method(&mut self, method: ODESolverMethod) -> &mut Self {
match method {
ODESolverMethod::EulerForward => {
self.method = Self::euler_step;
}
ODESolverMethod::Heun => {
self.method = Self::heun_step;
}
ODESolverMethod::RungeKutta4 => {
self.method = Self::rk4_step;
}
}
self
}
fn euler_step(&self, x: T, y: V) -> V {
y + (self.f)(x, y) * self.h
}
fn heun_step(&self, x: T, y: V) -> V {
let y1 = y + (self.f)(x, y) * self.h;
y + (y1 + (self.f)(x + self.h, y1)) * self.half_h
}
fn rk4_step(&self, x: T, y: V) -> V {
let k1 = (self.f)(x, y);
let k2 = (self.f)(x + self.half_h, y + k1 * self.half_h);
let k3 = (self.f)(x + self.half_h, y + k2 * self.half_h);
let k4 = (self.f)(x + self.h, y + k3 * self.h);
y + (k1 + k2 + k2 + k3 + k3 + k4) * (self.h / T::from(6_f64).unwrap())
}
}