diffsol 0.1.9

A library for solving ordinary differential equations (ODEs) in Rust.
Documentation
use std::ops::{Div, Mul, MulAssign};

use nalgebra::{DMatrix, DVector, DVectorView, DVectorViewMut};

use crate::{IndexType, Scalar, Scale};

use super::{DefaultDenseMatrix, Vector, VectorCommon, VectorIndex, VectorView, VectorViewMut};

impl<T: Scalar> DefaultDenseMatrix for DVector<T> {
    type M = DMatrix<T>;
}

macro_rules! impl_op_for_dvector_struct {
    ($struct:ident, $trait_name:ident, $func_name:ident) => {
        impl<'a, T: Scalar> $trait_name<Scale<T>> for $struct<'a, T> {
            type Output = DVector<T>;
            fn $func_name(self, rhs: Scale<T>) -> Self::Output {
                self * rhs.value()
            }
        }
    };
}

impl_op_for_dvector_struct!(DVectorView, Mul, mul);
impl_op_for_dvector_struct!(DVectorViewMut, Mul, mul);

impl VectorIndex for DVector<IndexType> {
    fn zeros(len: IndexType) -> Self {
        DVector::from_element(len, 0)
    }
    fn len(&self) -> crate::IndexType {
        self.len()
    }
    fn from_slice(slice: &[IndexType]) -> Self {
        DVector::from_iterator(slice.len(), slice.iter().copied())
    }
    fn clone_as_vec(&self) -> Vec<IndexType> {
        self.iter().copied().collect()
    }
}

macro_rules! impl_vector_common {
    ($vector_type:ty) => {
        impl<'a, T: Scalar> VectorCommon for $vector_type {
            type T = T;
        }
    };
}

impl_vector_common!(DVector<T>);
impl_vector_common!(DVectorView<'a, T>);
impl_vector_common!(DVectorViewMut<'a, T>);

impl<'a, T: Scalar> VectorView<'a> for DVectorView<'a, T> {
    type Owned = DVector<T>;
    fn abs_to(&self, y: &mut Self::Owned) {
        y.zip_apply(self, |y, x| *y = x.abs());
    }
    fn into_owned(self) -> Self::Owned {
        self.into_owned()
    }
    fn norm(&self) -> T {
        self.norm()
    }
    fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T {
        let mut acc = T::zero();
        if y.len() != self.len() || y.len() != atol.len() {
            panic!("Vector lengths do not match");
        }
        for i in 0..self.len() {
            let yi = unsafe { y.get_unchecked(i) };
            let ai = unsafe { atol.get_unchecked(i) };
            let xi = unsafe { self.get_unchecked(i) };
            acc += (*xi / (yi.abs() * rtol + *ai)).powi(2);
        }
        acc / Self::T::from(self.len() as f64)
    }
}

macro_rules! impl_mul_scale_vector {
    ($vector_type:ty) => {
        impl<T: Scalar> Mul<Scale<T>> for $vector_type {
            type Output = DVector<T>;
            fn mul(self, rhs: Scale<T>) -> Self::Output {
                self * rhs.value()
            }
        }

        impl<T: Scalar> Mul<Scale<T>> for &$vector_type {
            type Output = DVector<T>;
            fn mul(self, rhs: Scale<T>) -> Self::Output {
                self * rhs.value()
            }
        }
    };
}

impl_mul_scale_vector!(DVector<T>);
macro_rules! impl_mul_assign_scale_vector {
    ($vector_type:ty) => {
        impl<'a, T: Scalar> MulAssign<Scale<T>> for $vector_type {
            fn mul_assign(&mut self, rhs: Scale<T>) {
                *self *= rhs.value();
            }
        }
    };
}

impl_mul_assign_scale_vector!(DVector<T>);
impl_mul_assign_scale_vector!(DVectorViewMut<'a, T>);

impl<'a, T: Scalar> VectorViewMut<'a> for DVectorViewMut<'a, T> {
    type Owned = DVector<T>;
    type View = DVectorView<'a, T>;
    fn abs_to(&self, y: &mut Self::Owned) {
        y.zip_apply(self, |y, x| *y = x.abs());
    }
    fn copy_from(&mut self, other: &Self::Owned) {
        self.copy_from(other);
    }
    fn copy_from_view(&mut self, other: &Self::View) {
        self.copy_from(other);
    }
}

impl<T: Scalar> Div<Scale<T>> for DVector<T> {
    type Output = DVector<T>;
    fn div(self, rhs: Scale<T>) -> Self::Output {
        self / rhs.value()
    }
}

impl<T: Scalar> Vector for DVector<T> {
    type View<'a> = DVectorView<'a, T>;
    type ViewMut<'a> = DVectorViewMut<'a, T>;
    type Index = DVector<IndexType>;
    fn len(&self) -> IndexType {
        self.len()
    }
    fn norm(&self) -> Self::T {
        self.norm()
    }
    fn squared_norm(&self, y: &Self, atol: &Self, rtol: Self::T) -> Self::T {
        let mut acc = T::zero();
        if y.len() != self.len() || y.len() != atol.len() {
            panic!("Vector lengths do not match");
        }
        for i in 0..self.len() {
            let yi = unsafe { y.get_unchecked(i) };
            let ai = unsafe { atol.get_unchecked(i) };
            let xi = unsafe { self.get_unchecked(i) };
            acc += (*xi / (yi.abs() * rtol + *ai)).powi(2);
        }
        acc / Self::T::from(self.len() as f64)
    }
    fn abs_to(&self, y: &mut Self) {
        y.zip_apply(self, |y, x| *y = x.abs());
    }
    fn fill(&mut self, value: T) {
        self.fill(value);
    }
    fn as_view(&self) -> Self::View<'_> {
        self.as_view()
    }
    fn as_view_mut(&mut self) -> Self::ViewMut<'_> {
        self.as_view_mut()
    }
    fn copy_from(&mut self, other: &Self) {
        self.copy_from(other);
    }
    fn exp(&self) -> Self {
        self.map(|x| x.exp())
    }
    fn copy_from_view(&mut self, other: &Self::View<'_>) {
        self.copy_from(other);
    }
    fn from_element(nstates: usize, value: T) -> Self {
        Self::from_element(nstates, value)
    }
    fn from_vec(vec: Vec<T>) -> Self {
        Self::from_vec(vec)
    }
    fn zeros(nstates: usize) -> Self {
        Self::zeros(nstates)
    }
    fn add_scalar_mut(&mut self, scalar: T) {
        self.add_scalar_mut(scalar);
    }
    fn axpy(&mut self, alpha: T, x: &Self, beta: T) {
        self.axpy(alpha, x, beta);
    }
    fn axpy_v(&mut self, alpha: Self::T, x: &Self::View<'_>, beta: Self::T) {
        self.axpy(alpha, x, beta);
    }
    fn component_div_assign(&mut self, other: &Self) {
        self.component_div_assign(other);
    }
    fn component_mul_assign(&mut self, other: &Self) {
        self.component_mul_assign(other);
    }
    fn filter_indices<F: Fn(T) -> bool>(&self, f: F) -> Self::Index {
        let mut indices = vec![];
        for (i, &x) in self.iter().enumerate() {
            if f(x) {
                indices.push(i as IndexType);
            }
        }
        Self::Index::from_vec(indices)
    }
    fn gather_from(&mut self, other: &Self, indices: &Self::Index) {
        for (i, &index) in indices.iter().enumerate() {
            self[i] = other[index];
        }
    }
    fn scatter_from(&mut self, other: &Self, indices: &Self::Index) {
        for (i, &index) in indices.iter().enumerate() {
            self[index] = other[i];
        }
    }
    fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
        for &index in indices.iter() {
            self[index] = value;
        }
    }
    fn binary_fold<B, F>(&self, other: &Self, init: B, f: F) -> B
    where
        F: Fn(B, Self::T, Self::T, IndexType) -> B,
    {
        let mut acc = init;
        for (i, (x, y)) in self.iter().zip(other.iter()).enumerate() {
            acc = f(acc, *x, *y, i);
        }
        acc
    }
}

// tests
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_abs() {
        let v = DVector::from_vec(vec![1.0, -2.0, 3.0]);
        let v_abs = v.abs();
        assert_eq!(v_abs, DVector::from_vec(vec![1.0, 2.0, 3.0]));
    }

    #[test]
    fn test_error_norm() {
        let v = DVector::from_vec(vec![1.0, -2.0, 3.0]);
        let y = DVector::from_vec(vec![1.0, 2.0, 3.0]);
        let atol = DVector::from_vec(vec![0.1, 0.2, 0.3]);
        let rtol = 0.1;
        let mut tmp = y.clone() * rtol;
        tmp += &atol;
        let mut r = v.clone();
        r.component_div_assign(&tmp);
        let errorn_check = r.norm_squared() / 3.0;
        assert_eq!(v.squared_norm(&y, &atol, rtol), errorn_check);
        let vview = v.as_view();
        assert_eq!(
            VectorView::squared_norm(&vview, &y, &atol, rtol),
            errorn_check
        );
    }
}