ordinary-diffeq 0.2.3

A library for solving differential equations based on the DifferentialEquations.jl julia library.
Documentation
use nalgebra::SVector;

type ProblemFunction<'a, const D: usize, P> = &'a dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>;

/// The basic ODE object that will be passed around. The type (T) and the size (D) will be
/// determined upon creation of the object
#[derive(Clone, Copy)]
pub struct ODE<'a, const D: usize, P> {
    pub f: ProblemFunction<'a, D, P>,
    pub y: SVector<f64, D>,
    pub t: f64,
    pub params: P,
    pub t0: f64,
    pub t_end: f64,
    pub h: f64,
    pub finished: bool,
}

impl<'a, const D: usize, P> ODE<'a, D, P> {
    pub fn new(
        f: ProblemFunction<'a, D, P>,
        t0: f64,
        t_end: f64,
        y0: SVector<f64, D>,
        params: P,
    ) -> Self {
        Self {
            f,
            y: y0,
            t: t0,
            params,
            t0,
            t_end,
            h: 0.001,
            finished: false,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use nalgebra::Vector3;

    #[test]
    fn test_ode_creation() {
        type Params = ();
        fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
            -y
        }

        let y0 = Vector3::new(1.0, 0.0, 0.0);
        let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());

        assert!((ode.f)(0.0, y0, &()) == Vector3::new(-1.0, 0.0, 0.0));
        assert!(ode.y == Vector3::new(1.0, 0.0, 0.0));
        assert!(ode.t == 0.0);
        assert!(!ode.finished);
        assert!(ode.t_end == 10.0);
    }

    #[test]
    fn test_ode_with_params() {
        type Params = (f64, bool);
        let params = (34.0, true);

        fn derivative(t: f64, y: Vector3<f64>, p: &Params) -> Vector3<f64> {
            if p.1 {
                -y
            } else {
                y * t
            }
        }

        let y0 = Vector3::new(1.0, 0.0, 0.0);
        let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);

        assert!((ode.f)(0.0, y0, &params) == Vector3::new(-1.0, 0.0, 0.0));
        assert!(ode.y == Vector3::new(1.0, 0.0, 0.0));
        assert!(ode.t == 0.0);
        assert!(!ode.finished);
        assert!(ode.t_end == 10.0);
    }
}