differential-equations 0.6.1

A Rust library for solving differential equations.
Documentation
//! Adams-Predictor-Corrector 4th Order Variable Step Size Method

use crate::{
    error::Error,
    interpolate::{Interpolation, cubic_hermite_interpolate},
    methods::{Adaptive, Ordinary, h_init::InitialStepSize},
    ode::{ODE, OrdinaryNumericalMethod},
    stats::Evals,
    status::Status,
    tolerance::Tolerance,
    traits::{Real, State},
    utils::{constrain_step_size, validate_step_size_parameters},
};

use super::AdamsPredictorCorrector;

impl<T: Real, Y: State<T>> AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4> {
    ///// Adams-Predictor-Corrector 4th Order Variable Step Size Method.
    ///
    /// The Adams-Predictor-Corrector method is an explicit method that
    /// uses the previous states to predict the next state. This implementation
    /// uses a variable step size to maintain a desired accuracy.
    /// It is recommended to start with a small step size so that tolerance
    /// can be quickly met and the algorithm can adjust the step size accordingly.
    ///
    /// The First 3 steps are calculated using
    /// the Runge-Kutta method of order 4(5) and then the Adams-Predictor-Corrector
    /// method is used to calculate the remaining steps until the final time./ Create a Adams-Predictor-Corrector 4th Order Variable Step Size Method instance.
    ///
    /// # Example
    ///
    /// ```rust
    /// use differential_equations::prelude::*;
    ///
    /// struct HarmonicOscillator {
    ///     k: f64,
    /// }
    ///
    /// impl ODE<f64, [f64; 2]> for HarmonicOscillator {
    ///     fn diff(&self, _t: f64, y: &[f64; 2], dydt: &mut [f64; 2]) {
    ///         dydt[0] = y[1];
    ///         dydt[1] = -self.k * y[0];
    ///     }
    /// }
    /// let apcv4 = AdamsPredictorCorrector::v4();
    /// let t0 = 0.0;
    /// let tf = 10.0;
    /// let y0 = [1.0, 0.0];
    /// let system = HarmonicOscillator { k: 1.0 };
    /// let results = IVP::ode(&system, t0, tf, y0).method(apcv4).solve().unwrap();
    /// let expected = [-0.83907153, 0.54402111];
    /// assert!((results.y.last().unwrap()[0] - expected[0]).abs() < 1e-6);
    /// assert!((results.y.last().unwrap()[1] - expected[1]).abs() < 1e-6);
    /// ```
    ///
    ///
    /// ## Warning
    ///
    /// This method is not suitable for stiff problems and can results in
    /// extremely small step sizes and long computation times.```
    pub fn v4() -> Self {
        Self::default()
    }

    fn rk4_step<F>(ode: &F, t: &mut T, y: &mut Y, h: T, k: &mut [Y; 4]) -> usize
    where
        F: ODE<T, Y> + ?Sized,
    {
        let two = T::from_f64(2.0).unwrap();
        let six = T::from_f64(6.0).unwrap();

        ode.diff(*t, y, &mut k[0]);
        ode.diff(*t + h / two, &y.plus_scaled(h / two, &k[0]), &mut k[1]);
        ode.diff(*t + h / two, &y.plus_scaled(h / two, &k[1]), &mut k[2]);
        ode.diff(*t + h, &y.plus_scaled(h, &k[2]), &mut k[3]);

        y.add_scaled(h / six, &k[0]);
        y.add_scaled(two * h / six, &k[1]);
        y.add_scaled(two * h / six, &k[2]);
        y.add_scaled(h / six, &k[3]);
        *t += h;
        4
    }
}

// Implement OrdinaryNumericalMethod Trait for APCV4
impl<T: Real, Y: State<T>> OrdinaryNumericalMethod<T, Y>
    for AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4>
{
    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
    where
        F: ODE<T, Y> + ?Sized,
    {
        let mut evals = Evals::new();

        self.tf = tf;

        // If h0 is zero, calculate initial step size
        if self.h0 == T::zero() {
            // Only use adaptive step size calculation if the method supports it
            let tol = Tolerance::Scalar(self.tol);
            self.h0 = InitialStepSize::<Ordinary>::compute(
                ode, t0, tf, y0, 4, &tol, &tol, self.h_min, self.h_max, &mut evals,
            );
            evals.function += 2;
        }

        // Check that the initial step size is set
        match validate_step_size_parameters::<T, Y>(self.h0, T::zero(), T::infinity(), t0, tf) {
            Ok(h0) => self.h = (self.filter)(h0),
            Err(status) => return Err(status),
        }

        // Initialize state
        self.t = t0;
        self.y = y0.clone();
        self.dydt = y0.zeros_like();
        self.dydt_old = y0.zeros_like();
        self.y_old = y0.clone();
        self.y_prev = core::array::from_fn(|_| y0.zeros_like());
        self.k = core::array::from_fn(|_| y0.zeros_like());
        self.t_prev[0] = t0;
        self.y_prev[0] = y0.clone();

        // Previous saved steps
        self.t_old = t0;
        self.y_old = y0.clone();

        // Perform the first 3 steps using Runge-Kutta 4 method
        for i in 1..=3 {
            evals.function += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
            self.t_prev[i] = self.t;
            self.y_prev[i] = self.y.clone();

            if i == 1 {
                self.dydt = self.k[0].clone();
                self.dydt_old = self.k[0].clone();
            }
        }

        self.status = Status::Initialized;
        Ok(evals)
    }

    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
    where
        F: ODE<T, Y> + ?Sized,
    {
        let mut evals = Evals::new();

        // Check if Max Steps Reached
        if self.steps >= self.max_steps {
            self.status = Status::Error(Error::MaxSteps {
                t: self.t,
                y: self.y.clone(),
            });
            return Err(Error::MaxSteps {
                t: self.t,
                y: self.y.clone(),
            });
        }
        self.steps += 1;

        // If Step size changed and it takes us to the final time perform a Runge-Kutta 4 step to finish
        if self.h != self.t_prev[0] - self.t_prev[1] && self.t + self.h == self.tf {
            evals.function += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
            return Ok(evals);
        }

        // Compute derivatives for history
        ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k[0]);
        ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k[1]);
        ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k[2]);
        ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k[3]);

        let predictor = self.y_prev[3].plus_linear_combination(&[
            (
                &self.k[0],
                self.h * T::from_f64(55.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
            (
                &self.k[1],
                -self.h * T::from_f64(59.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
            (
                &self.k[2],
                self.h * T::from_f64(37.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
            (
                &self.k[3],
                -self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
        ]);

        // Corrector step:
        ode.diff(self.t + self.h, &predictor, &mut self.k[3]);
        let corrector = self.y_prev[3].plus_linear_combination(&[
            (
                &self.k[3],
                self.h * T::from_f64(9.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
            (
                &self.k[0],
                self.h * T::from_f64(19.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
            (
                &self.k[1],
                -self.h * T::from_f64(5.0).unwrap() / T::from_f64(24.0).unwrap(),
            ),
            (&self.k[2], self.h / T::from_f64(24.0).unwrap()),
        ]);

        // Track number of evaluations
        evals.function += 5;

        // Calculate sigma for step size adjustment
        let sigma = T::from_f64(19.0).unwrap() * corrector.diff_norm_squared(&predictor).sqrt()
            / (T::from_f64(270.0).unwrap() * self.h.abs());

        // Check if Step meets tolerance
        if sigma <= self.tol {
            // Update Previous step states
            self.t_old = self.t;
            self.y_old = self.y.clone();
            self.dydt_old = self.dydt.clone();

            // Update state
            self.t += self.h;
            self.y = corrector;

            // Check if previous step rejected
            if let Status::RejectedStep = self.status {
                self.status = Status::Solving;
            }

            // Adjust Step Size if needed
            let two = T::from_f64(2.0).unwrap();
            let four = T::from_f64(4.0).unwrap();
            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
            self.h = if q > four { four * self.h } else { q * self.h };

            // Bound Step Size
            let tf_t_abs = (self.tf - self.t).abs();
            let four_div = tf_t_abs / four;
            let h_max_effective = if self.h_max < four_div {
                self.h_max
            } else {
                four_div
            };

            self.h = constrain_step_size(self.h, self.h_min, h_max_effective);
            self.h = (self.filter)(self.h);

            // Calculate Previous Steps with new step size
            self.t_prev[0] = self.t;
            self.y_prev[0] = self.y.clone();
            for i in 1..=3 {
                self.evals += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
                self.t_prev[i] = self.t;
                self.y_prev[i] = self.y.clone();

                if i == 1 {
                    self.dydt = self.k[0].clone();
                }
            }
        } else {
            // Step Rejected
            self.status = Status::RejectedStep;

            // Adjust Step Size
            let two = T::from_f64(2.0).unwrap();
            let tenth = T::from_f64(0.1).unwrap();
            let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
            self.h = if q < tenth {
                tenth * self.h
            } else {
                q * self.h
            };
            self.h = (self.filter)(self.h);

            // Calculate Previous Steps with new step size
            self.t_prev[0] = self.t;
            self.y_prev[0] = self.y.clone();
            for i in 1..=3 {
                self.evals += Self::rk4_step(ode, &mut self.t, &mut self.y, self.h, &mut self.k);
                self.t_prev[i] = self.t;
                self.y_prev[i] = self.y.clone();
            }
        }
        Ok(evals)
    }

    fn t(&self) -> T {
        self.t
    }

    fn y(&self) -> &Y {
        &self.y
    }

    fn t_prev(&self) -> T {
        self.t_old
    }

    fn y_prev(&self) -> &Y {
        &self.y_old
    }

    fn h(&self) -> T {
        // OrdinaryNumericalMethod repeats step size 4 times for each step
        // so the IVP driver is looking for what the next
        // state will be thus the step size is multiplied by 4
        self.h * T::from_f64(4.0).unwrap()
    }

    fn set_h(&mut self, h: T) {
        self.h = (self.filter)(h);
    }

    fn status(&self) -> &Status<T, Y> {
        &self.status
    }

    fn set_status(&mut self, status: Status<T, Y>) {
        self.status = status;
    }
}

// Implement the Interpolation trait for APCV4
impl<T: Real, Y: State<T>> Interpolation<T, Y>
    for AdamsPredictorCorrector<Ordinary, Adaptive, T, Y, 4>
{
    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
        // Check if t is within the range of the solver
        if t_interp < self.t_old || t_interp > self.t {
            return Err(Error::OutOfBounds {
                t_interp,
                t_prev: self.t_old,
                t_curr: self.t,
            });
        }

        // Calculate the interpolated value using cubic Hermite interpolation
        let y_interp = cubic_hermite_interpolate(
            self.t_old,
            self.t,
            &self.y_old,
            &self.y,
            &self.dydt_old,
            &self.dydt,
            t_interp,
        );

        Ok(y_interp)
    }
}