use crate::{
control::ControlFlag,
error::Error,
interpolate::Interpolation,
sde::{SDE, StochasticNumericalMethod},
solout::*,
solution::Solution,
status::Status,
traits::{Real, State},
};
pub fn solve_sde<T, Y, S, F, O>(
solver: &mut S,
sde: &mut F,
t0: T,
tf: T,
y0: &Y,
solout: &mut O,
) -> Result<Solution<T, Y>, Error<T, Y>>
where
T: Real,
Y: State<T>,
F: SDE<T, Y> + ?Sized,
S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y> + ?Sized,
O: Solout<T, Y> + ?Sized,
{
let mut solution = Solution::new();
#[cfg(not(target_arch = "wasm32"))]
solution.timer.start();
let integration_direction = match (tf - t0).signum() {
x if x == T::one() => T::one(),
x if x == T::from_f64(-1.0).unwrap() => T::from_f64(-1.0).unwrap(),
_ => {
return Err(Error::BadInput {
msg: "Final time tf must be different from initial time t0.".to_string(),
});
}
};
match solver.init(sde, t0, tf, y0) {
Ok(evals) => {
solution.evals += evals;
}
Err(e) => return Err(e),
}
let mut y_curr = solver.y().clone();
let mut y_prev = solver.y_prev().clone();
match solout.solout(
solver.t(),
solver.t_prev(),
&y_curr,
&y_prev,
solver,
&mut solution,
) {
ControlFlag::Continue => {}
ControlFlag::ModifyState(tm, ym) => {
match solver.init(sde, tm, tf, &ym) {
Ok(evals) => {
solution.evals += evals;
}
Err(e) => return Err(e),
}
}
ControlFlag::Terminate => {
solution.status = Status::Interrupted;
#[cfg(not(target_arch = "wasm32"))]
solution.timer.complete();
return Ok(solution);
}
}
solver.set_status(Status::Solving);
solution.status = Status::Solving;
loop {
if (solver.t() + solver.h() - tf) * integration_direction > T::zero() {
let h_new = tf - solver.t();
if h_new.abs() < T::default_epsilon() * T::from_f64(10.0).unwrap() {
solver.set_status(Status::Complete);
solution.status = Status::Complete;
#[cfg(not(target_arch = "wasm32"))]
solution.timer.complete();
return Ok(solution);
}
solver.set_h(h_new);
}
match solver.step(sde) {
Ok(evals) => {
solution.evals += evals;
solution.steps.accepted += 1;
}
Err(e) => {
return Err(e);
}
}
y_curr = solver.y().clone();
y_prev = solver.y_prev().clone();
match solout.solout(
solver.t(),
solver.t_prev(),
&y_curr,
&y_prev,
solver,
&mut solution,
) {
ControlFlag::Continue => {}
ControlFlag::ModifyState(tm, ym) => {
match solver.init(sde, tm, tf, &ym) {
Ok(evals) => {
solution.evals += evals;
}
Err(e) => return Err(e),
}
}
ControlFlag::Terminate => {
solution.status = Status::Interrupted;
#[cfg(not(target_arch = "wasm32"))]
solution.timer.complete();
return Ok(solution);
}
}
if (tf - solver.t()).abs() <= T::default_epsilon() * T::from_f64(10.0).unwrap() {
break;
}
}
solver.set_status(Status::Complete);
solution.status = Status::Complete;
#[cfg(not(target_arch = "wasm32"))]
solution.timer.complete();
Ok(solution)
}