diffsol 0.12.3

A library for solving ordinary differential equations (ODEs) in Rust.
Documentation
use crate::{
    error::{DiffsolError, OdeSolverError},
    ode_solver_error,
    scalar::IndexType,
    scale, AugmentedOdeEquations, DefaultDenseMatrix, DenseMatrix, OdeEquations, OdeSolverProblem,
    OdeSolverState, Op, StateRef, StateRefMut, Vector, VectorViewMut,
};
use num_traits::Zero;
use std::ops::MulAssign;

use super::state::StateCommon;

/// State container for the BDF integrator. For the common state API use `as_ref` and `as_mut` methods.
#[derive(Clone)]
pub struct BdfState<V, M = <V as DefaultDenseMatrix>::M>
where
    V: Vector + DefaultDenseMatrix,
    M: DenseMatrix<T = V::T, V = V>,
{
    pub(crate) order: usize,
    pub(crate) diff: M,
    pub(crate) sdiff: Vec<M>,
    pub(crate) gdiff: M,
    pub(crate) sgdiff: Vec<M>,
    pub(crate) y: V,
    pub(crate) dy: V,
    pub(crate) g: V,
    pub(crate) dg: V,
    pub(crate) s: Vec<V>,
    pub(crate) ds: Vec<V>,
    pub(crate) sg: Vec<V>,
    pub(crate) dsg: Vec<V>,
    pub(crate) t: V::T,
    pub(crate) h: V::T,
    pub(crate) diff_initialised: bool,
    pub(crate) sdiff_initialised: bool,
    pub(crate) gdiff_initialised: bool,
    pub(crate) sgdiff_initialised: bool,
}

impl<V, M> BdfState<V, M>
where
    V: Vector + DefaultDenseMatrix,
    M: DenseMatrix<T = V::T, V = V, C = V::C>,
{
    pub(crate) const MAX_ORDER: IndexType = 5;

    pub(crate) fn new_empty(ctx: V::C) -> Self {
        let default_v = V::zeros(0, ctx.clone());
        let default_m = M::zeros(0, 0, ctx.clone());
        Self {
            order: 1,
            diff: default_m.clone(),
            sdiff: Vec::new(),
            gdiff: default_m.clone(),
            sgdiff: Vec::new(),
            y: default_v.clone(),
            dy: default_v.clone(),
            g: default_v.clone(),
            dg: default_v.clone(),
            s: Vec::new(),
            ds: Vec::new(),
            sg: Vec::new(),
            dsg: Vec::new(),
            t: V::T::zero(),
            h: V::T::zero(),
            diff_initialised: false,
            sdiff_initialised: false,
            gdiff_initialised: false,
            sgdiff_initialised: false,
        }
    }

    pub fn initialise_diff_to_first_order(&mut self) {
        self.order = 1usize;
        self.diff.column_mut(0).copy_from(&self.y);
        self.diff.column_mut(1).copy_from(&self.dy);
        self.diff.column_mut(1).mul_assign(scale(self.h));
        self.diff_initialised = true;
    }

    pub fn initialise_sdiff_to_first_order(&mut self) {
        let naug = self.sdiff.len();
        for i in 0..naug {
            let sdiff = &mut self.sdiff[i];
            let s = &self.s[i];
            let ds = &self.ds[i];
            sdiff.column_mut(0).copy_from(s);
            sdiff.column_mut(1).copy_from(ds);
            sdiff.column_mut(1).mul_assign(scale(self.h));
        }
        self.sdiff_initialised = true;
    }

    pub fn initialise_gdiff_to_first_order(&mut self) {
        self.gdiff.column_mut(0).copy_from(&self.g);
        self.gdiff.column_mut(1).copy_from(&self.dg);
        self.gdiff.column_mut(1).mul_assign(scale(self.h));
        self.gdiff_initialised = true;
    }

    pub fn initialise_sgdiff_to_first_order(&mut self) {
        let naug = self.sgdiff.len();
        for i in 0..naug {
            let sgdiff = &mut self.sgdiff[i];
            let sg = &self.sg[i];
            let dsg = &self.dsg[i];
            sgdiff.column_mut(0).copy_from(sg);
            sgdiff.column_mut(1).copy_from(dsg);
            sgdiff.column_mut(1).mul_assign(scale(self.h));
        }
        self.sgdiff_initialised = true;
    }
}

impl<V, M> OdeSolverState<V> for BdfState<V, M>
where
    V: Vector + DefaultDenseMatrix,
    M: DenseMatrix<T = V::T, V = V, C = V::C>,
{
    fn set_problem<Eqn: OdeEquations>(
        &mut self,
        ode_problem: &OdeSolverProblem<Eqn>,
    ) -> Result<(), DiffsolError> {
        let nstates = ode_problem.eqn.rhs().nstates();
        if self.diff.nrows() != nstates {
            return Err(ode_solver_error!(StateProblemMismatch));
        }
        let expected_gdiff_len = if ode_problem.integrate_out {
            ode_problem.eqn.nout()
        } else {
            0
        };
        if self.gdiff.nrows() != expected_gdiff_len {
            return Err(ode_solver_error!(StateProblemMismatch));
        }
        if !self.diff_initialised {
            self.initialise_diff_to_first_order();
        }
        if !self.gdiff_initialised {
            self.initialise_gdiff_to_first_order();
        }
        Ok(())
    }

    fn set_augmented_problem<Eqn: OdeEquations, AugmentedEqn: AugmentedOdeEquations<Eqn>>(
        &mut self,
        ode_problem: &OdeSolverProblem<Eqn>,
        augmented_eqn: &AugmentedEqn,
    ) -> Result<(), DiffsolError> {
        let naug = augmented_eqn.max_index();
        let nstates = ode_problem.eqn.rhs().nstates();
        if self.sdiff.len() != naug || self.sdiff[0].nrows() != nstates {
            return Err(ode_solver_error!(StateProblemMismatch));
        }
        let (sgdiff_len, sgdiff_size) = if let Some(out) = augmented_eqn.out() {
            (naug, out.nout())
        } else {
            (0, 0)
        };
        if self.sgdiff.len() != sgdiff_len
            || (sgdiff_len > 0 && self.sgdiff[0].nrows() != sgdiff_size)
        {
            return Err(ode_solver_error!(StateProblemMismatch));
        }
        if !self.sdiff_initialised {
            self.initialise_sdiff_to_first_order();
        }
        if !self.sgdiff_initialised {
            self.initialise_sgdiff_to_first_order();
        }
        Ok(())
    }

    fn new_from_common(state: super::state::StateCommon<V>) -> Self {
        let StateCommon {
            y,
            dy,
            g,
            dg,
            s,
            ds,
            sg,
            dsg,
            t,
            h,
        } = state;
        let nstates = y.len();
        let ctx = y.context();
        let diff = M::zeros(nstates, Self::MAX_ORDER + 3, ctx.clone());
        let sdiff = vec![M::zeros(nstates, Self::MAX_ORDER + 3, ctx.clone()); s.len()];
        let gdiff = M::zeros(g.len(), Self::MAX_ORDER + 3, ctx.clone());
        let sgdiff = if !sg.is_empty() {
            vec![M::zeros(sg[0].len(), Self::MAX_ORDER + 3, ctx.clone()); sg.len()]
        } else {
            Vec::new()
        };
        Self {
            order: 1,
            diff,
            sdiff,
            gdiff,
            sgdiff,
            y,
            dy,
            g,
            dg,
            s,
            ds,
            sg,
            dsg,
            t,
            h,
            diff_initialised: false,
            sdiff_initialised: false,
            gdiff_initialised: false,
            sgdiff_initialised: false,
        }
    }

    fn into_common(self) -> StateCommon<V> {
        StateCommon {
            y: self.y,
            dy: self.dy,
            g: self.g,
            dg: self.dg,
            s: self.s,
            ds: self.ds,
            sg: self.sg,
            dsg: self.dsg,
            t: self.t,
            h: self.h,
        }
    }

    fn as_ref(&self) -> StateRef<'_, V> {
        StateRef {
            y: &self.y,
            dy: &self.dy,
            g: &self.g,
            dg: &self.dg,
            s: &self.s,
            ds: &self.ds,
            sg: &self.sg,
            dsg: &self.dsg,
            t: self.t,
            h: self.h,
        }
    }

    fn as_mut(&mut self) -> StateRefMut<'_, V> {
        self.diff_initialised = false;
        self.sdiff_initialised = false;
        self.gdiff_initialised = false;
        self.sgdiff_initialised = false;
        StateRefMut {
            y: &mut self.y,
            dy: &mut self.dy,
            g: &mut self.g,
            dg: &mut self.dg,
            s: &mut self.s,
            ds: &mut self.ds,
            sg: &mut self.sg,
            dsg: &mut self.dsg,
            t: &mut self.t,
            h: &mut self.h,
        }
    }
}