ordinary-diffeq 0.2.3

A library for solving differential equations based on the DifferentialEquations.jl julia library.
Documentation
use nalgebra::SVector;
use roots::{find_root_brent, SimpleConvergency};

use super::callback::Callback;
use super::controller::{Controller, PIController, TryStep};
use super::integrator::Integrator;
use super::ode::ODE;

#[derive(Clone)]
pub struct Problem<'a, const D: usize, S, P>
where
    S: Integrator<D>,
{
    ode: ODE<'a, D, P>,
    integrator: S,
    controller: PIController,
    callbacks: Vec<Callback<'a, D, P>>,
}

impl<'a, const D: usize, S, P> Problem<'a, D, S, P>
where
    S: Integrator<D> + Copy,
{
    pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self {
        Problem {
            ode,
            integrator,
            controller,
            callbacks: Vec::new(),
        }
    }
    pub fn solve(&mut self) -> Solution<S, D> {
        let mut convergency = SimpleConvergency {
            eps: 1e-12,
            max_iter: 1000,
        };
        let mut times: Vec<f64> = vec![self.ode.t];
        let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
        let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
        while self.ode.t < self.ode.t_end {
            if self.ode.t + self.controller.next_step_guess.extract() > self.ode.t_end {
                // If the next step would go past the end, then just set it to the end
                self.controller.next_step_guess = TryStep::NotYetAccepted(
                    self.ode.t_end - self.ode.t,
                );
            }
            let (mut new_y, mut curr_step, mut dense_option) = if S::ADAPTIVE {
                // First, we try stepping with the "next step guess" to get the error
                let (mut trial_y, mut err_option, mut dense_option) =
                    self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
                let mut err = err_option.unwrap();
                // Then we determine whether we need to reduce the step size or not
                // If successful, we get the next step guess
                let initial_guess = self.controller.next_step_guess.extract();
                let mut next_step_guess = <PIController as Controller<D>>::determine_step(
                    &mut self.controller,
                    initial_guess,
                    err,
                );
                while !next_step_guess.is_accepted() {
                    // If that step isn't acceptable, then change the step until it is
                    (trial_y, err_option, dense_option) =
                        self.integrator.step(&self.ode, next_step_guess.extract());
                    next_step_guess = <PIController as Controller<D>>::determine_step(
                        &mut self.controller,
                        next_step_guess.extract(),
                        err,
                    );
                    err = err_option.unwrap();
                }
                // So at this point we can safely assume we have an accepted step
                self.controller.next_step_guess = next_step_guess.reset().unwrap();
                (trial_y, next_step_guess.extract(), dense_option)
            } else {
                // If fixed time step just step forward one step
                let (trial_y, _, dense_option) = self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
                (trial_y, self.controller.next_step_guess.extract(), dense_option)
            };
            if !self.callbacks.is_empty() {
                // Check for events occurring
                for callback in &self.callbacks {
                    if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
                        * (callback.event)(self.ode.t + curr_step, new_y, &self.ode.params)
                        < 0.0
                    {
                        // If the event crossed zero, then find the root
                        let f = |test_t| {
                            let test_y = self.integrator.step(&self.ode, test_t).0;
                            (callback.event)(self.ode.t + test_t, test_y, &self.ode.params)
                        };
                        let root = find_root_brent(0.0, curr_step, &f, &mut convergency).unwrap();
                        curr_step = root;
                        (new_y, _, dense_option) = self.integrator.step(&self.ode, curr_step);
                        (callback.effect)(&mut self.ode);
                    }
                }
            }
            self.ode.y = new_y;
            self.ode.t += curr_step;
            times.push(self.ode.t);
            states.push(self.ode.y);
            // TODO: Implement third order interpolation for non-dense algorithms
            dense_coefficients.push(dense_option.unwrap());
        }
        Solution {
            integrator: self.integrator,
            times,
            states,
            dense: dense_coefficients,
        }
    }

    pub fn with_callback(mut self, callback: Callback<'a, D, P>) -> Self {
        self.callbacks.push(callback);
        Self {
            ode: self.ode,
            integrator: self.integrator,
            controller: self.controller,
            callbacks: self.callbacks,
        }
    }
}

pub struct Solution<S, const D: usize>
where
    S: Integrator<D>,
{
    pub integrator: S,
    pub times: Vec<f64>,
    pub states: Vec<SVector<f64, D>>,
    pub dense: Vec<Vec<SVector<f64, D>>>,
}

impl<S, const D: usize> Solution<S, D>
where
    S: Integrator<D>,
{
    pub fn interpolate(&self, t: f64) -> SVector<f64, D> {
        // First check that the t is within bounds
        let last = self.times.last().unwrap();
        let first = self.times.first().unwrap();

        // TODO: Improve these errors
        let mut times = self.times.clone();
        if *first > *last {
            times.reverse();
        }
        if t < *first || t > *last {
            panic!();
        }

        // Then find the two t values closest to the desired t
        match times.binary_search_by(|x| x.total_cmp(&t)) {
            Ok(index) => self.states[index],
            Err(end_index) => {
                // Then send that to the integrator
                let t_start = times[end_index - 1];
                let t_end = times[end_index];
                self.integrator
                    .interpolate(t_start, t_end, &self.dense[end_index - 1], t)
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::callback::stop;
    use crate::controller::PIController;
    use crate::integrator::dormand_prince::DormandPrince45;
    use approx::assert_relative_eq;
    use nalgebra::Vector3;

    #[test]
    fn test_problem() {
        type Params = ();
        fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
            y
        }
        let y0 = Vector3::new(1.0, 1.0, 1.0);

        let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
        let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
        let controller = PIController::default();

        let mut problem = Problem::new(ode, dp45, controller);

        let solution = problem.solve();
        solution
            .times
            .iter()
            .zip(solution.states.iter())
            .for_each(|(time, state)| {
                assert_relative_eq!(state[0], time.exp(), max_relative = 1e-2);
            })
    }

    #[test]
    fn test_with_callback() {
        type Params = ();
        fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
            y
        }
        let y0 = Vector3::new(1.0, 1.0, 1.0);

        let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
        let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
        let controller = PIController::default();

        let value_too_high = Callback {
            event: &|_: f64, y: SVector<f64, 3>, _: &Params| 10.0 - y[0],
            effect: &stop,
        };

        let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high);
        let solution = problem.solve();

        assert_relative_eq!(
            solution.states.last().unwrap()[0],
            10.0,
            max_relative = 1e-11
        );
    }

    #[test]
    fn test_with_interpolation() {
        type Params = ();
        fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
            y
        }
        let y0 = Vector3::new(1.0, 1.0, 1.0);

        let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
        let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
        let controller = PIController::default();

        let mut problem = Problem::new(ode, dp45, controller);
        let solution = problem.solve();

        assert_relative_eq!(
            solution.interpolate(8.8)[0],
            8.8_f64.exp(),
            max_relative = 1e-6
        );
    }
}