ordinary-diffeq 0.2.3

A library for solving differential equations based on the DifferentialEquations.jl julia library.
Documentation
use nalgebra::SVector;

use super::super::ode::ODE;
use super::Integrator;

/// Integrator Trait
pub trait TsitIntegrator<'a> {
    const A: &'a [f64];
    const B: &'a [f64];
    const C: &'a [f64];
}

#[derive(Debug, Clone, Copy)]
pub struct Tsit5<const D: usize> {
    a_tol: f64,
    r_tol: f64,
}

impl<const D: usize> Tsit5<D> where Tsit5<D>: Integrator<D> {
    pub fn new(a_tol: f64, r_tol: f64) -> Self {
        Self {
            a_tol,
            r_tol,
        }
    }
}

impl<'a, const D: usize> TsitIntegrator<'a> for Tsit5<D> {
    const A: &'a [f64] = &[
        0.161,
        -0.008480655492356989,
        0.335480655492357,
        2.8971530571054935,
        -6.359448489975075,
        4.3622954328695815,
        5.325864828439257,
        -11.748883564062828,
        7.4955393428898365,
        -0.09249506636175525,
        5.86145544294642,
        -12.92096931784711,
        8.159367898576159,
        -0.071584973281401,
        -0.028269050394068383,
        0.09646076681806523,
        0.01,
        0.4798896504144996,
        1.379008574103742,
        -3.290069515436081,
        2.324710524099774,
    ];
    const B: &'a [f64] = &[
        0.09646076681806523,
        0.01,
        0.4798896504144996,
        1.379008574103742,
        -3.290069515436081,
        2.324710524099774,
        0.0,

        -0.001780011052226,
        -0.000816434459657,
        0.007880878010262,
        -0.144711007173263,
        0.582357165452555,
        -0.458082105929187,
        1.0 / 66.0,
    ];
    const C: &'a [f64] = &[
        0.0,
        0.161,
        0.327,
        0.9,
        0.9800255409045097,
        1.0,
        1.0,
    ];
}

impl<const D: usize> Integrator<D> for Tsit5<D>
where
    Tsit5<D>: TsitIntegrator,
{
    const ORDER: usize = 5;
    const STAGES: usize = 7;
    const ADAPTIVE: bool = true;
    const DENSE: bool = true;

    fn step<P>(&self, ode: &ODE<D,P>, h: f64) -> (SVector<f64,D>, Option<f64>, Option<Vec<SVector<f64, D>>>) {
        let mut k: Vec<SVector::<f64,D>> = vec![SVector::<f64,D>::zeros(); Self::STAGES];
        let mut next_y = ode.y.clone();
        let mut err = SVector::<f64, D>::zeros();
        // Do the first of the summations
        k[0] = (ode.f)(ode.t, ode.y, &ode.params);
        next_y += k[0] * Self::B[0] * h;
        err += k[0] * (Self::B[Self::STAGES]) * h;
        // Then the rest
        for i in 1..Self::STAGES {
            // Compute the ks
            let mut y_term = SVector::<f64,D>::zeros();
            for j in 0..i {
                y_term += k[j] * Self::A[( i * (i - 1) ) / 2 + j];
            }
            k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params);

            // Use that and bis to calculate the y and error terms
            next_y += k[i] * h * Self::B[i];
            err += k[i] * (Self::B[i + Self::STAGES]) * h;
        }
        let tol = SVector::<f64,D>::repeat(self.a_tol) + ode.y * self.r_tol;
        (next_y, Some((err.component_div(&tol)).norm()), Some(k))
    }
    fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec<SVector<f64,D>>, t: f64) -> SVector<f64,D> {
        let s = (t - t_start)/(t_end - t_start);
        let hn = t_end - t_start;
        let b = vec![
            -1.0530884977290216 * s * (s - 1.3299890189751412) * (s * s - 1.4364028541716351 * s + 0.7139816917074209),
            0.1017 * s * s * (s * s - 2.1966568338249754 * s + 1.2949852507374631),
            2.490627285651252793 * s * s * (s * s - 2.38535645472061657 * s + 1.57803468208092486),
            -16.54810288924490272 * (s - 1.21712927295533244) * (s - 0.61620406037800089) * s * s,
            47.37952196281928122 * (s - 1.203071208372362603) * (s - 0.658047292653547382) * s * s,
            -34.87065786149660974 * (s - 1.2) * (s - 0.666666666666666667) * s * s,
            2.5 * (s - 1.0) * (s - 0.6) * s * s,
        ];
        let mut sum = SVector::<f64,D>::zeros();
        b.into_iter().zip(dense).for_each(|(bi, fi)| {
            sum += fi * bi;
        });
        dense[0] + sum * hn
    }
}