use crate::{
error::Error,
interpolate::{Interpolation, linear_interpolate},
linalg::component_multiply,
methods::{ExplicitRungeKutta, Fixed, Stochastic},
sde::{SDE, StochasticNumericalMethod},
stats::Evals,
status::Status,
traits::{Real, State},
utils::validate_step_size_parameters,
};
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
StochasticNumericalMethod<T, Y> for ExplicitRungeKutta<Stochastic, Fixed, T, Y, O, S, I>
{
fn init<F>(&mut self, sde: &mut F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
where
F: SDE<T, Y>,
{
let mut evals = Evals::new();
if self.h0 == T::zero() {
let duration = (tf - t0).abs();
let default_steps = T::from_usize(100).unwrap();
self.h0 = duration / default_steps;
}
match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
Ok(h0) => self.h = h0,
Err(status) => return Err(status),
}
self.steps = 0;
self.t = t0;
self.y = *y0;
sde.drift(self.t, &self.y, &mut self.dydt);
let mut diffusion = Y::zeros();
sde.diffusion(self.t, &self.y, &mut diffusion);
evals.function += 2;
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
self.status = Status::Initialized;
Ok(evals)
}
fn step<F>(&mut self, sde: &mut F) -> Result<Evals, Error<T, Y>>
where
F: SDE<T, Y>,
{
let mut evals = Evals::new();
if self.steps >= self.max_steps {
self.status = Status::Error(Error::MaxSteps {
t: self.t,
y: self.y,
});
return Err(Error::MaxSteps {
t: self.t,
y: self.y,
});
}
self.steps += 1;
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
self.k[0] = self.dydt;
for i in 1..self.stages {
let mut y_stage = self.y;
for j in 0..i {
y_stage += self.k[j] * (self.a[i][j] * self.h);
}
sde.drift(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
}
evals.function += self.stages - 1;
let mut drift_increment = Y::zeros();
for i in 0..self.stages {
drift_increment += self.k[i] * (self.b[i] * self.h);
}
let mut diffusion = Y::zeros();
sde.diffusion(self.t, &self.y, &mut diffusion);
evals.function += 1;
let mut dw = Y::zeros();
sde.noise(self.h, &mut dw);
let diffusion_increment = component_multiply(&diffusion, &dw);
let y_next = self.y + drift_increment + diffusion_increment;
self.t += self.h;
self.y = y_next;
if self.fsal {
self.dydt = self.k[S - 1];
} else {
sde.drift(self.t, &self.y, &mut self.dydt);
evals.function += 1;
}
self.status = Status::Solving;
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_prev
}
fn y_prev(&self) -> &Y {
&self.y_prev
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = h;
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}
impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
for ExplicitRungeKutta<Stochastic, Fixed, T, Y, O, S, I>
{
fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
if t_interp < self.t_prev || t_interp > self.t {
return Err(Error::OutOfBounds {
t_interp,
t_prev: self.t_prev,
t_curr: self.t,
});
}
let y_interp = linear_interpolate(self.t_prev, self.t, &self.y_prev, &self.y, t_interp);
Ok(y_interp)
}
}