ordinary-diffeq 0.2.3

A library for solving differential equations based on the DifferentialEquations.jl julia library.
Documentation
use nalgebra::SVector;

use super::super::ode::ODE;
use super::Integrator;

/// Integrator Trait
pub trait DormandPrinceIntegrator<'a> {
    const A: &'a [f64];
    const B: &'a [f64];
    const C: &'a [f64];
    const D: &'a [f64];
}

#[derive(Debug, Clone, Copy)]
pub struct DormandPrince45<const D: usize> {
    a_tol: SVector<f64,D>,
    r_tol: f64,
}

impl<const D: usize> DormandPrince45<D>
where
    DormandPrince45<D>: Integrator<D>,
{
    pub fn new() -> Self {
        Self { a_tol: SVector::<f64,D>::from_element(1e-8), r_tol: 1e-8 }
    }
    pub fn a_tol(&mut self, a_tol: f64) -> Self {
        Self { a_tol: SVector::<f64,D>::from_element(a_tol), r_tol: self.r_tol }
    }
    pub fn a_tol_full(&mut self, a_tol: SVector::<f64,D>) -> Self {
        Self { a_tol, r_tol: self.r_tol }
    }
    pub fn r_tol(&mut self, r_tol: f64) -> Self {
        Self { a_tol: self.a_tol, r_tol }
    }
}

impl<'a, const D: usize> DormandPrinceIntegrator<'a> for DormandPrince45<D> {
    const A: &'a [f64] = &[
        1.0 / 5.0,
        3.0 / 40.0,
        9.0 / 40.0,
        44.0 / 45.0,
        -56.0 / 15.0,
        32.0 / 9.0,
        19_372.0 / 6_561.0,
        -25_360.0 / 2_187.0,
        64_448.0 / 6_561.0,
        -212.0 / 729.0,
        9_017.0 / 3_168.0,
        -355.0 / 33.0,
        46_732.0 / 5247.0,
        49.0 / 176.0,
        -5_103.0 / 18_656.0,
        35.0 / 384.0,
        0.0,
        500.0 / 1_113.0,
        125.0 / 192.0,
        -2_187.0 / 6_784.0,
        11.0 / 84.0,
    ];
    const B: &'a [f64] = &[
        35.0 / 384.0,
        0.0,
        500.0 / 1_113.0,
        125.0 / 192.0,
        -2_187.0 / 6_784.0,
        11.0 / 84.0,
        0.0,
        5_179.0 / 57_600.0,
        0.0,
        7_571.0 / 16_695.0,
        393.0 / 640.0,
        -92_097.0 / 339_200.0,
        187.0 / 2_100.0,
        1.0 / 40.0,
    ];
    const C: &'a [f64] = &[0.0, 1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0];
    const D: &'a [f64] = &[
        -12715105075.0 / 11282082432.0,
        0.0,
        87487479700.0 / 32700410799.0,
        -10690763975.0 / 1880347072.0,
        701980252875.0 / 199316789632.0,
        -1453857185.0 / 822651844.0,
        69997945.0 / 29380423.0,
    ];
}

impl<'a, const D: usize> Integrator<D> for DormandPrince45<D>
where
    DormandPrince45<D>: DormandPrinceIntegrator<'a>,
{
    const ORDER: usize = 5;
    const STAGES: usize = 7;
    const ADAPTIVE: bool = true;
    const DENSE: bool = true;

    fn step<P>(
        &self,
        ode: &ODE<D, P>,
        h: f64,
    ) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>) {
        let mut k: Vec<SVector<f64, D>> = vec![SVector::<f64, D>::zeros(); Self::STAGES];
        let mut next_y = ode.y;
        let mut err = SVector::<f64, D>::zeros();
        let mut rcont5 = SVector::<f64, D>::zeros();
        // Do the first of the summations
        k[0] = (ode.f)(ode.t, ode.y, &ode.params);
        next_y += k[0] * Self::B[0] * h;
        err += k[0] * (Self::B[0] - Self::B[Self::STAGES]) * h;
        let rcont1 = ode.y;
        rcont5 += k[0] * h * Self::D[0];
        // Then the rest
        for i in 1..Self::STAGES {
            // Compute the ks
            let mut y_term = SVector::<f64, D>::zeros();
            for (j, item) in k.iter().enumerate().take(i) {
                y_term += item * Self::A[(i * (i - 1)) / 2 + j];
            }
            k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params);

            // Use that and bis to calculate the y and error terms
            next_y += k[i] * h * Self::B[i];
            err += k[i] * (Self::B[i] - Self::B[i + Self::STAGES]) * h;
            rcont5 += k[i] * h * Self::D[i];
        }
        let rcont2 = next_y - ode.y;
        let rcont3 = h * k[0] - rcont2;
        let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3;
        let tol = self.a_tol + ode.y * self.r_tol;
        let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5];
        (next_y, Some((err.component_div(&tol)).norm()), Some(rcont))
    }
    fn interpolate(
        &self,
        t_start: f64,
        t_end: f64,
        dense: &[SVector<f64, D>],
        t: f64,
    ) -> SVector<f64, D> {
        let s = (t - t_start) / (t_end - t_start);
        let s1 = 1.0 - s;
        dense[0] + (dense[1] + (dense[2] + (dense[3] + dense[4] * s1) * s) * s1) * s
    }
}