diffsol 0.1.9

A library for solving ordinary differential equations (ODEs) in Rust.
Documentation
use std::rc::Rc;

use crate::{Matrix, Vector};

use super::{ConstantOp, Op};

pub struct ConstantClosureWithSens<M, I, J>
where
    M: Matrix,
    I: Fn(&M::V, M::T) -> M::V,
    J: Fn(&M::V, M::T, &M::V, &mut M::V),
{
    func: I,
    func_sens: J,
    nstates: usize,
    nout: usize,
    nparams: usize,
    p: Rc<M::V>,
}

impl<M, I, J> ConstantClosureWithSens<M, I, J>
where
    M: Matrix,
    I: Fn(&M::V, M::T) -> M::V,
    J: Fn(&M::V, M::T, &M::V, &mut M::V),
{
    pub fn new(func: I, func_sens: J, nstates: usize, nout: usize, p: Rc<M::V>) -> Self {
        let nparams = p.len();
        Self {
            func,
            func_sens,
            nstates,
            nout,
            nparams,
            p,
        }
    }
}

impl<M, I, J> Op for ConstantClosureWithSens<M, I, J>
where
    M: Matrix,
    I: Fn(&M::V, M::T) -> M::V,
    J: Fn(&M::V, M::T, &M::V, &mut M::V),
{
    type V = M::V;
    type T = M::T;
    type M = M;
    fn nstates(&self) -> usize {
        self.nstates
    }
    fn nout(&self) -> usize {
        self.nout
    }
    fn nparams(&self) -> usize {
        self.nparams
    }
    fn set_params(&mut self, p: Rc<M::V>) {
        assert_eq!(p.len(), self.nparams);
        self.p = p;
    }
}

impl<M, I, J> ConstantOp for ConstantClosureWithSens<M, I, J>
where
    M: Matrix,
    I: Fn(&M::V, M::T) -> M::V,
    J: Fn(&M::V, M::T, &M::V, &mut M::V),
{
    fn call_inplace(&self, t: Self::T, y: &mut Self::V) {
        y.copy_from(&(self.func)(self.p.as_ref(), t));
    }
    fn call(&self, t: Self::T) -> Self::V {
        (self.func)(self.p.as_ref(), t)
    }
    fn sens_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) {
        (self.func_sens)(self.p.as_ref(), t, v, y);
    }
    fn has_sens(&self) -> bool {
        true
    }
}