#[cfg(test)]
mod test;
use super::{
super::{Tensor, TensorArray, TensorRank0, TensorRank0List},
Explicit, IntegrationError, OdeSolver,
};
use crate::{ABS_TOL, REL_TOL};
use std::ops::{Mul, Sub};
#[derive(Debug)]
pub struct Ode23 {
pub abs_tol: TensorRank0,
pub dec_fac: TensorRank0,
pub inc_fac: TensorRank0,
pub rel_tol: TensorRank0,
}
impl Default for Ode23 {
fn default() -> Self {
Self {
abs_tol: ABS_TOL,
dec_fac: 0.5,
inc_fac: 1.1,
rel_tol: REL_TOL,
}
}
}
impl<Y, U, const W: usize> Explicit<Y, U, W> for Ode23
where
Y: Tensor,
for<'a> &'a Y: Mul<TensorRank0, Output = Y> + Sub<&'a Y, Output = Y>,
U: Tensor<Item = Y> + TensorArray,
{
fn integrate(
&self,
function: impl Fn(&TensorRank0, &Y) -> Y,
initial_time: TensorRank0,
initial_condition: Y,
evaluation_times: &TensorRank0List<W>,
) -> Result<U, IntegrationError<W>> {
let mut e;
let mut k_1 = function(&initial_time, &initial_condition);
let mut k_2;
let mut k_3;
let mut k_4;
let mut solution = U::zero();
let mut y_trial;
{
let (mut eval_times, mut dt, mut t, mut y, mut y_sol) = self.setup(
initial_time,
initial_condition,
evaluation_times,
&mut solution,
)?;
while eval_times.peek().is_some() {
k_2 = function(&(t + 0.5 * dt), &(&k_1 * (0.5 * dt) + &y));
k_3 = function(&(t + 0.75 * dt), &(&k_2 * (0.75 * dt) + &y));
y_trial = (&k_1 * 2.0 + &k_2 * 3.0 + &k_3 * 4.0) * (dt / 9.0) + &y;
k_4 = function(&(t + dt), &y_trial);
e = ((&k_1 * -5.0 + k_2 * 6.0 + k_3 * 8.0 + &k_4 * -9.0) * (dt / 72.0)).norm();
if e < self.abs_tol || e / y_trial.norm() < self.rel_tol {
while let Some(eval_time) = eval_times.next_if(|&eval_time| t > eval_time) {
*y_sol.next().ok_or("not ok")? =
(&y_trial - &y) / dt * (eval_time - t) + &y;
}
k_1 = k_4;
t += dt;
dt *= self.inc_fac;
y = y_trial;
} else {
dt *= self.dec_fac;
}
}
}
Ok(solution)
}
}