automatica 1.0.0

Automatic control systems library
Documentation
//! # Ordinary differential equations solvers
//!
//! `Rk2` is an explicit Runge-Kutta of order 2 with 2 steps, it is suitable for
//! non stiff systems.
//!
//! `Rk4` is an explicit Runge-Kutta of order 4 with 4 steps, it is suitable for
//! non stiff systems.
//!
//! `Rkf45` is an explicit Runge-Kutta-Fehlberg of order 4 and 5 with 6 steps
//! and adaptive integration step, it is suitable for non stiff systems.
//!
//! `Radau` is an implicit Runge-Kutta-Radau of order 3 with 2 steps, it is
//! suitable for stiff systems.

use std::ops::{Add, AddAssign, Div, Mul, Sub};
use std::{fmt::Debug, marker::Sized};

use ndarray::{s, Array1, Array2};

use crate::{
    linear_system::continuous::Ss,
    matrix::{identity, lup, max_vabs, relative_eq, Lup, MatrixMul, ScalarMul},
    units::Seconds,
    Abs, Max, NumCast, One, Pow, RelativeEq, Zero,
};

/// Define the order of the Runge-Kutta method.
#[derive(Clone, Debug)]
pub(super) enum Order {
    /// Runge-Kutta method of order 2.
    Rk2,
    /// Runge-Kutta method of order 4.
    Rk4,
}

/// Struct for the time evolution of a linear system
#[derive(Clone, Debug)]
pub struct Rk<'a, T, F>
where
    F: Fn(Seconds<T>) -> Vec<T>,
{
    /// Linear system
    sys: &'a Ss<T>,
    /// Input function
    input: F,
    /// State vector.
    state: Array1<T>,
    /// Output vector.
    output: Array1<T>,
    /// Interval.
    h: Seconds<T>,
    /// Number of steps.
    n: usize,
    /// Index.
    index: usize,
    /// Order of the solver.
    order: Order,
}

impl<'a, T, F> Rk<'a, T, F>
where
    T: Add<Output = T>
        + AddAssign
        + Clone
        + Div<Output = T>
        + Mul<Output = T>
        + NumCast
        + RkConst
        + Zero,
    F: Fn(Seconds<T>) -> Vec<T>,
{
    /// Create the solver for a Runge-Kutta method.
    ///
    /// # Arguments
    ///
    /// * `sys` - linear system
    /// * `u` - input function that returns a vector (column vector)
    /// * `x0` - initial state (column vector)
    /// * `h` - integration time interval
    /// * `n` - integration steps
    /// * `order` - order of the solver
    pub(super) fn new(
        sys: &'a Ss<T>,
        u: F,
        x0: &[T],
        h: Seconds<T>,
        n: usize,
        order: Order,
    ) -> Self {
        let start = u(Seconds::zero());
        let state = Array1::from(x0.to_vec());
        let output = sys.c.mmul(&state) + sys.d.mmul(start.as_slice());
        Self {
            sys,
            input: u,
            state,
            output,
            h,
            n,
            index: 0,
            order,
        }
    }

    /// Initial step (time 0) of the Runge-Kutta solver.
    /// It contains the initial state and the calculated initial output
    /// at the constructor.
    fn initial_step(&mut self) -> Step<T> {
        self.index += 1;
        // State and output at time 0.
        Step {
            time: Seconds::zero(),
            state: self.state.to_vec(),
            output: self.output.to_vec(),
        }
    }

    /// Runge-Kutta order 2 method.
    #[allow(clippy::cast_precision_loss)]
    fn main_iteration_rk2(&mut self) -> Option<Step<T>> {
        // y_n+1 = y_n + 1/2(k1 + k2) + O(h^3)
        // k1 = h*f(t_n, y_n)
        // k2 = h*f(t_n + h, y_n + k1)
        // Return None if conversion fails.
        let init_time = &self.h * T::from(self.index - 1)?;
        let end_time = &self.h * T::from(self.index)?;
        let u = (self.input)(init_time);
        let uh = (self.input)(end_time.clone());
        let bu = self.sys.b.mmul(u.as_slice());
        let buh = self.sys.b.mmul(uh.as_slice());
        let k1 = (self.sys.a.mmul(&self.state) + bu).smul(self.h.0.clone());
        let k2 = (self.sys.a.mmul(&self.state + &k1) + buh).smul(self.h.0.clone());
        self.state += &(k1 + k2).smul(T::_05());
        self.output = self.sys.c.mmul(&self.state) + self.sys.d.mmul(uh.as_slice());

        self.index += 1;
        Some(Step {
            time: end_time,
            state: self.state.to_vec(),
            output: self.output.to_vec(),
        })
    }

    /// Runge-Kutta order 4 method.
    #[allow(clippy::cast_precision_loss, clippy::similar_names)]
    fn main_iteration_rk4(&mut self) -> Option<Step<T>> {
        // y_n+1 = y_n + h/6(k1 + 2*k2 + 2*k3 + k4) + O(h^4)
        // k1 = f(t_n, y_n)
        // k2 = f(t_n + h/2, y_n + h/2 * k1)
        // k3 = f(t_n + h/2, y_n + h/2 * k2)
        // k2 = f(t_n + h, y_n + h*k3)
        // Return None if conversion fails
        let init_time = &self.h * T::from(self.index - 1)?;
        let mid_step = T::_05() * self.h.0.clone();
        let mid_time = Seconds(init_time.0.clone() + mid_step.clone());
        let end_time = &self.h * T::from(self.index)?;
        let u = (self.input)(init_time);
        let u_mid = (self.input)(mid_time);
        let u_end = (self.input)(end_time.clone());
        let bu = self.sys.b.mmul(u.as_slice());
        let bu_mid = self.sys.b.mmul(u_mid.as_slice());
        let bu_end = self.sys.b.mmul(u_end.as_slice());
        let k1 = self.sys.a.mmul(&self.state) + bu;
        let k2 = self.sys.a.mmul((&k1).smul(mid_step.clone()) + &self.state) + &bu_mid;
        let k3 = self.sys.a.mmul((&k2).smul(mid_step) + &self.state) + bu_mid;
        let k4 = self.sys.a.mmul((&k3).smul(self.h.0.clone()) + &self.state) + bu_end;
        let [n_2, n_6] = T::A_RK();
        self.state += &(k1 + k2.smul(n_2.clone()) + k3.smul(n_2) + k4).smul(self.h.0.clone() / n_6);
        self.output = self.sys.c.mmul(&self.state) + self.sys.d.mmul(u_end.as_slice());

        self.index += 1;
        Some(Step {
            time: end_time,
            state: self.state.to_vec(),
            output: self.output.to_vec(),
        })
    }
}

// Coefficients of the Butcher table of rk method.
/// Trait that defines the constants used in the Rk solver.
#[allow(non_snake_case)]
pub trait RkConst
where
    Self: Sized,
{
    /// 0.5 constant
    fn _05() -> Self;
    /// A
    fn A_RK() -> [Self; 2];
}

macro_rules! impl_rk_const {
    ($t:ty) => {
        impl RkConst for $t {
            fn _05() -> Self {
                0.5
            }
            fn A_RK() -> [Self; 2] {
                [2., 6.]
            }
        }
    };
}

impl_rk_const!(f32);
impl_rk_const!(f64);
//////

/// Implementation of the Iterator trait for the `Rk` struct
impl<'a, T, F> Iterator for Rk<'a, T, F>
where
    T: AddAssign
        + Add<Output = T>
        + Clone
        + Div<Output = T>
        + Mul<Output = T>
        + NumCast
        + RkConst
        + Zero,
    F: Fn(Seconds<T>) -> Vec<T>,
{
    type Item = Step<T>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.index > self.n {
            None
        } else if self.index == 0 {
            Some(self.initial_step())
        } else {
            match self.order {
                Order::Rk2 => self.main_iteration_rk2(),
                Order::Rk4 => self.main_iteration_rk4(),
            }
        }
    }
}

/// Struct to hold the data of the linear system time evolution
#[derive(Clone, Debug)]
pub struct Step<T> {
    /// Time of the current step
    time: Seconds<T>,
    /// Current state
    state: Vec<T>,
    /// Current output
    output: Vec<T>,
}

impl<T> Step<T>
where
    T: Clone,
{
    /// Get the time of the current step
    pub fn time(&self) -> Seconds<T> {
        self.time.clone()
    }

    /// Get the current state of the system
    pub fn state(&self) -> &[T] {
        &self.state
    }

    /// Get the current output of the system
    pub fn output(&self) -> &[T] {
        &self.output
    }
}

/// Struct for the time evolution of a linear system
#[derive(Clone, Debug)]
pub struct Rkf45<'a, T, F>
where
    F: Fn(Seconds<T>) -> Vec<T>,
{
    /// Linear system
    sys: &'a Ss<T>,
    /// Input function
    input: F,
    /// State vector.
    state: Array1<T>,
    /// Output vector.
    output: Array1<T>,
    /// Interval.
    h: Seconds<T>,
    /// Time limit of the evaluation
    limit: Seconds<T>,
    /// Time
    time: Seconds<T>,
    /// Tolerance
    tol: T,
    /// Is initial step
    initial_step: bool,
}

impl<'a, T, F> Rkf45<'a, T, F>
where
    T: Abs
        + Add<Output = T>
        + AddAssign
        + Clone
        + Div<Output = T>
        + Max
        + Mul<Output = T>
        + PartialOrd
        + Pow<T>
        + Rkf45Const
        + Sub<Output = T>
        + Zero,
    F: Fn(Seconds<T>) -> Vec<T>,
{
    /// Create a solver using Runge-Kutta-Fehlberg method
    ///
    /// # Arguments
    ///
    /// * `sys` - linear system
    /// * `u` - input function (column vector)
    /// * `x0` - initial state (column vector)
    /// * `h` - integration time interval
    /// * `limit` - time limit of the evaluation
    /// * `tol` - error tolerance
    pub(super) fn new(
        sys: &'a Ss<T>,
        u: F,
        x0: &[T],
        h: Seconds<T>,
        limit: Seconds<T>,
        tol: T,
    ) -> Self {
        let start = u(Seconds::zero());
        let state = Array1::from(x0.to_vec());
        // Calculate the output at time 0.
        let output = sys.c.mmul(&state) + sys.d.mmul(start.as_slice());
        Self {
            sys,
            input: u,
            state,
            output,
            h,
            limit,
            time: Seconds::zero(),
            tol,
            initial_step: true,
        }
    }

    /// Initial step (time 0) of the rkf45 solver.
    /// It contains the initial state and the calculated initial output
    /// at the constructor
    fn initial_step(&mut self) -> StepWithError<T> {
        self.initial_step = false;
        StepWithError {
            time: Seconds::zero(),
            state: self.state.to_vec(),
            output: self.output.to_vec(),
            error: T::zero(),
        }
    }

    /// Runge-Kutta-Fehlberg order 4 and 5 method with adaptive step size
    fn main_iteration(&mut self) -> StepWithError<T> {
        let [a0, a1, a2, a3] = T::A();
        let error = loop {
            let [b30, b31] = T::B3();
            let [b40, b41, b42] = T::B4();
            let [b50, b51, b52, b53] = T::B5();
            let [b60, b61, b62, b63, b64] = T::B6();
            let [c0, c1, c2, c3] = T::C();
            let [d0, d1, d2, d3, d4] = T::D();

            let u1 = (self.input)(self.time.clone());
            let u2 = (self.input)(&self.time + &self.h * &a0);
            let u3 = (self.input)(&self.time + &self.h * &a1);
            let u4 = (self.input)(&self.time + &self.h * &a2);
            let u5 = (self.input)(&self.time + &self.h);
            let u6 = (self.input)(&self.time + &self.h * &a3);

            let k1 = (self.sys.a.mmul(&self.state) + self.sys.b.mmul(u1.as_slice()))
                .smul(self.h.0.clone());
            let k2 = (self.sys.a.mmul((&k1).smul(T::B21()) + &self.state)
                + self.sys.b.mmul(u2.as_slice()))
            .smul(self.h.0.clone());
            let k3 = (self
                .sys
                .a
                .mmul((&k1).smul(b30) + (&k2).smul(b31) + &self.state)
                + self.sys.b.mmul(u3.as_slice()))
            .smul(self.h.0.clone());
            let k4 = (self
                .sys
                .a
                .mmul((&k1).smul(b40) + (&k2).smul(b41) + (&k3).smul(b42) + &self.state)
                + self.sys.b.mmul(u4.as_slice()))
            .smul(self.h.0.clone());
            let k5 = (self.sys.a.mmul(
                (&k1).smul(b50) + (&k2).smul(b51) + (&k3).smul(b52) + (&k4).smul(b53) + &self.state,
            ) + self.sys.b.mmul(u5.as_slice()))
            .smul(self.h.0.clone());
            let k6 = (self.sys.a.mmul(
                (&k1).smul(b60)
                    + k2.smul(b61)
                    + (&k3).smul(b62)
                    + (&k4).smul(b63)
                    + (&k5).smul(b64)
                    + &self.state,
            ) + self.sys.b.mmul(u6.as_slice()))
            .smul(self.h.0.clone());

            let xn1 =
                (&k1).smul(c0) + (&k3).smul(c1) + (&k4).smul(c2) + (&k5).smul(c3) + &self.state;
            let xn1_ =
                k1.smul(d0) + k3.smul(d1) + k4.smul(d2) + k5.smul(d3) + k6.smul(d4) + &self.state;

            // Take the maximum absolute error between the states of the system.
            let err = max_vabs(&xn1 - &xn1_);
            let error_ratio = self.tol.clone() / err.clone();
            let [exp0, exp1] = T::EXP();
            if err < self.tol {
                self.h.0 = T::SAFETY_FACTOR() * self.h.0.clone() * error_ratio.powf(exp0);
                self.state = xn1;
                break err;
            }
            self.h.0 = T::SAFETY_FACTOR() * self.h.0.clone() * error_ratio.powf(exp1);
        };

        // Update time before calculate the output.
        self.time.0 += self.h.0.clone();

        let u = (self.input)(self.time.clone());
        self.output = self.sys.c.mmul(&self.state) + &self.sys.d.mmul(u.as_slice());

        StepWithError {
            time: self.time.clone(),
            state: self.state.to_vec(),
            output: self.output.to_vec(),
            error,
        }
    }
}

/// Implementation of the Iterator trait for the `Rkf45` struct
impl<'a, T, F> Iterator for Rkf45<'a, T, F>
where
    T: Abs
        + Add<Output = T>
        + AddAssign
        + Clone
        + Div<Output = T>
        + Max
        + Mul<Output = T>
        + PartialOrd
        + Pow<T>
        + Rkf45Const
        + Sub<Output = T>
        + Zero,
    F: Fn(Seconds<T>) -> Vec<T>,
{
    type Item = StepWithError<T>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.time > self.limit {
            None
        } else if self.initial_step {
            Some(self.initial_step())
        } else {
            Some(self.main_iteration())
        }
    }
}

// Coefficients of the Butcher table of rkf45 method.
/// Trait that defines the constants used in the Rkf45 solver.
#[allow(non_snake_case)]
pub trait Rkf45Const
where
    Self: Sized,
{
    /// A
    fn A() -> [Self; 4];
    /// B21
    fn B21() -> Self;
    /// B3
    fn B3() -> [Self; 2];
    /// B4
    fn B4() -> [Self; 3];
    /// B5
    fn B5() -> [Self; 4];
    /// B6
    fn B6() -> [Self; 5];
    /// C
    fn C() -> [Self; 4];
    /// D
    fn D() -> [Self; 5];
    /// Safety factor to avoid too small step changes.
    fn SAFETY_FACTOR() -> Self;
    /// Error ratio exponents.
    fn EXP() -> [Self; 2];
}

macro_rules! impl_rkf45_const {
    ($t:ty) => {
        impl Rkf45Const for $t {
            fn A() -> [Self; 4] {
                [1. / 4., 3. / 8., 12. / 13., 1. / 2.]
            }
            fn B21() -> Self {
                1. / 4.
            }
            fn B3() -> [Self; 2] {
                [3. / 32., 9. / 32.]
            }
            fn B4() -> [Self; 3] {
                [1932. / 2197., -7200. / 2197., 7296. / 2197.]
            }
            fn B5() -> [Self; 4] {
                [439. / 216., -8., 3680. / 513., -845. / 4104.]
            }
            fn B6() -> [Self; 5] {
                [-8. / 27., 2., -3544. / 2565., 1859. / 4104., -11. / 40.]
            }
            fn C() -> [Self; 4] {
                [25. / 216., 1408. / 2564., 2197. / 4101., -1. / 5.]
            }
            fn D() -> [Self; 5] {
                [
                    16. / 135.,
                    6656. / 12_825.,
                    28_561. / 56_430.,
                    -9. / 50.,
                    2. / 55.,
                ]
            }
            fn SAFETY_FACTOR() -> Self {
                0.95
            }
            fn EXP() -> [Self; 2] {
                [0.25, 0.2]
            }
        }
    };
}

impl_rkf45_const!(f32);
impl_rkf45_const!(f64);
//////

/// Struct to hold the data of the linear system time evolution
#[derive(Clone, Debug)]
pub struct StepWithError<T> {
    /// Current step size
    time: Seconds<T>,
    /// Current state
    state: Vec<T>,
    /// Current output
    output: Vec<T>,
    /// Current maximum absolute error
    error: T,
}

impl<T> StepWithError<T>
where
    T: Clone,
{
    /// Get the time of the current step
    pub fn time(&self) -> Seconds<T> {
        self.time.clone()
    }

    /// Get the current state of the system
    pub fn state(&self) -> &Vec<T> {
        &self.state
    }

    /// Get the current output of the system
    pub fn output(&self) -> &Vec<T> {
        &self.output
    }

    /// Get the current maximum absolute error
    pub fn error(&self) -> T {
        self.error.clone()
    }
}

/// Struct for the time evolution of the linear system using the implicit
/// Radau method of order 3 with 2 steps
#[derive(Clone, Debug)]
pub struct Radau<'a, T, F>
where
    F: Fn(Seconds<T>) -> Vec<T>,
{
    /// Linear system
    sys: &'a Ss<T>,
    /// Input function
    input: F,
    /// State vector
    state: Array1<T>,
    /// Output vector
    output: Array1<T>,
    /// Interval
    h: Seconds<T>,
    /// Number of steps
    n: usize,
    /// Index
    index: usize,
    /// Tolerance
    tol: T,
    /// Store the LU decomposition of the Jacobian matrix
    lu_jacobian: Lup<T>,
}

impl<'a, T, F> Radau<'a, T, F>
where
    T: Abs
        + Add<Output = T>
        + AddAssign
        + Clone
        + Div<Output = T>
        + Mul<Output = T>
        + NumCast
        + One
        + PartialOrd
        + RadauConst
        + RelativeEq
        + Sub<Output = T>
        + Zero,
    F: Fn(Seconds<T>) -> Vec<T>,
{
    /// Create the solver for a Radau order 3 with 2 steps method.
    ///
    /// # Arguments
    ///
    /// * `sys` - linear system
    /// * `u` - input function that returns a vector (column vector)
    /// * `x0` - initial state (column vector)
    /// * `h` - integration time interval
    /// * `n` - integration steps
    /// * `tol` - tolerance of implicit solution finding
    pub(super) fn new(
        sys: &'a Ss<T>,
        u: F,
        x0: &[T],
        h: Seconds<T>,
        n: usize,
        tol: T,
    ) -> Option<Self> {
        let start = u(Seconds::zero());
        let state = Array1::from(x0.to_vec());
        let output = sys.c.mmul(&state) + sys.d.mmul(start.as_slice());
        // Jacobian matrix can be precomputed since it is constant for the
        // given system.
        let g = &(&sys.a).smul(h.0.clone());
        let nr = sys.a.nrows(); // A is a square matrix.
        let n2 = 2 * nr;
        let identity = identity(nr);
        let [ra0, ra1, ra2, ra3] = T::RADAU_A();
        let j11 = g.smul(ra0) - &identity;
        let j12 = g.smul(ra1);
        let j21 = g.smul(ra2);
        let j22 = g.smul(ra3) - &identity;
        let mut jac = Array2::from_elem((n2, n2), T::zero());
        // Copy the sub matrices into the Jacobian.
        jac.slice_mut(s![0..nr, 0..nr]).assign(&j11);
        jac.slice_mut(s![0..nr, nr..n2]).assign(&j12);
        jac.slice_mut(s![nr..n2, 0..nr]).assign(&j21);
        jac.slice_mut(s![nr..n2, nr..n2]).assign(&j22);

        Some(Self {
            sys,
            input: u,
            state,
            output,
            h,
            n,
            index: 0,
            tol,
            lu_jacobian: lup(jac)?,
        })
    }

    /// Initial step (time 0) of the Radau solver.
    /// It contains the initial state and the calculated initial output
    /// at the constructor.
    fn initial_step(&mut self) -> Step<T> {
        self.index += 1;
        Step {
            time: Seconds::zero(),
            state: self.state.to_vec(),
            output: self.output.to_vec(),
        }
    }

    /// Radau order 3 with 2 step implicit method.
    #[allow(clippy::cast_precision_loss, clippy::similar_names)]
    fn main_iteration(&mut self) -> Option<Step<T>> {
        let h = self.h.0.clone();
        // Return None if conversion fails.
        let time = T::from(self.index - 1)? * h.clone();
        let rows = self.sys.a.nrows();
        // k = [k1; k2] (column vector)
        let mut k = Array1::from_elem(2 * rows, T::zero());
        // k sub-vectors (or block vectors) are have size (rows x 1).
        // Use as first guess for k1 and k2 the current state.
        k.slice_mut(s![0..rows]).assign(&self.state);
        k.slice_mut(s![rows..(2 * rows)]).assign(&self.state);

        let [rc0, rc1] = T::RADAU_C();
        let u1 = (self.input)(Seconds(time.clone() + rc0 * h.clone()));
        let bu1 = self.sys.b.mmul(u1.as_slice());
        let u2 = (self.input)(Seconds(time + rc1 * h.clone()));
        let bu2 = self.sys.b.mmul(u2.as_slice());
        let mut f = Array1::from_elem(2 * rows, T::zero());
        // Max 10 iterations.
        for _ in 0..10 {
            let [ra0, ra1, ra2, ra3] = T::RADAU_A();
            let k1 = k.slice(s![0..rows]);
            let k2 = k.slice(s![rows..(2 * rows)]);

            let f1 = self
                .sys
                .a
                .mmul((k1.smul(ra0) + k2.smul(ra1)).smul(h.clone()) + &self.state)
                + &bu1
                - k1;
            let f2 = self
                .sys
                .a
                .mmul((k1.smul(ra2) + k2.smul(ra3)).smul(h.clone()) + &self.state)
                + &bu2
                - k2;
            f.slice_mut(s![0..rows]).assign(&f1);
            f.slice_mut(s![rows..(2 * rows)]).assign(&f2);

            // J * dk = f -> dk = J^-1 * f
            // Override f with dk so there is less allocations of matrices.
            // f = J^-1 * f
            let knew = {
                // k(n+1) = k(n) - dk = k(n) - f
                self.lu_jacobian.solve_mut(&mut f);
                &k - &f
            };

            let equal = relative_eq(&knew, &k, &self.tol);
            k = knew; // Use the latest solution calculated.
            if equal {
                break;
            }
        }
        let [rb0, rb1] = T::RADAU_B();
        self.state += &(k.slice(s![0..rows]).smul(rb0) + k.slice(s![rows..(2 * rows)]).smul(rb1))
            .smul(h.clone());

        // Return None if conversion fails.
        let end_time = Seconds(T::from(self.index)? * h);
        let u = (self.input)(end_time.clone());
        self.output = self.sys.c.mmul(&self.state) + self.sys.d.mmul(u.as_slice());

        self.index += 1;
        Some(Step {
            time: end_time,
            state: self.state.to_vec(),
            output: self.output.to_vec(),
        })
    }
}

// Constants for Radau method.
/// Trait that defines the constants used in the Radau solver.
#[allow(non_snake_case)]
pub trait RadauConst
where
    Self: Sized,
{
    /// A
    fn RADAU_A() -> [Self; 4];
    /// B
    fn RADAU_B() -> [Self; 2];
    /// C
    fn RADAU_C() -> [Self; 2];
}

macro_rules! impl_radau_const {
    ($t:ty) => {
        impl RadauConst for $t {
            fn RADAU_A() -> [Self; 4] {
                [5. / 12., -1. / 12., 3. / 4., 1. / 4.]
            }
            fn RADAU_B() -> [Self; 2] {
                [3. / 4., 1. / 4.]
            }
            fn RADAU_C() -> [Self; 2] {
                [1. / 3., 1.]
            }
        }
    };
}

impl_radau_const!(f32);
impl_radau_const!(f64);
//////

/// Implementation of the Iterator trait for the `Radau` struct.
impl<'a, T, F> Iterator for Radau<'a, T, F>
where
    T: Abs
        + Add<Output = T>
        + AddAssign
        + Clone
        + Div<Output = T>
        + Mul<Output = T>
        + NumCast
        + One
        + PartialOrd
        + RadauConst
        + RelativeEq
        + Sub<Output = T>
        + Zero,
    F: Fn(Seconds<T>) -> Vec<T>,
{
    type Item = Step<T>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.index > self.n {
            None
        } else if self.index == 0 {
            Some(self.initial_step())
        } else {
            self.main_iteration()
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn runge_kutta_struct() {
        let t = Seconds(3.);
        let s = vec![2., 3.];
        let o = vec![-5., -4.];

        let rk = Step {
            time: t,
            state: s.clone(),
            output: o.clone(),
        };
        assert_eq!(t, rk.time());
        assert_eq!(&s, rk.state());
        assert_eq!(&o, rk.output());
    }

    #[test]
    #[allow(clippy::float_cmp)]
    fn runge_kutta_fehlberg_struct() {
        let t = Seconds(3.);
        let s = vec![2., 3.];
        let o = vec![-5., -4.];
        let e = 0.5;

        let rkf = StepWithError {
            time: t,
            state: s.clone(),
            output: o.clone(),
            error: e,
        };
        assert_eq!(t, rkf.time());
        assert_eq!(&s, rkf.state());
        assert_eq!(&o, rkf.output());
        assert_eq!(e, rkf.error());
    }

    #[test]
    fn radau_struct() {
        let t = Seconds(12.);
        let s = vec![2., 2.4];
        let o = vec![-5.33, -4.];

        let rd = Step {
            time: t,
            state: s.clone(),
            output: o.clone(),
        };
        assert_eq!(t, rd.time());
        assert_eq!(&s, rd.state());
        assert_eq!(&o, rd.output());
    }
}