use crate::ode_state::ode_state_trait::{OdeState, StateIndex};
use linalg_traits::Vector;
#[cfg(feature = "nalgebra")]
use nalgebra::DVector;
#[cfg(feature = "ndarray")]
use ndarray::Array1;
#[macro_export]
macro_rules! impl_ode_state_for_dvector {
($($type:ident),*) => {
$(
impl OdeState for $type<f64> {
fn add(&self, other: &Self) -> Self {
<Self as Vector<f64>>::add(self, other)
}
fn add_assign(&mut self, other: &Self) {
<Self as Vector<f64>>::add_assign(self, other);
}
fn sub(&self, other: &Self) -> Self {
<Self as Vector<f64>>::sub(self, other)
}
fn sub_assign(&mut self, other: &Self) {
<Self as Vector<f64>>::sub_assign(self, other);
}
fn mul(&self, scalar: f64) -> Self {
<Self as Vector<f64>>::mul(self, scalar)
}
fn mul_assign(&mut self, scalar: f64) {
<Self as Vector<f64>>::mul_assign(self, scalar);
}
fn get_state_variable(&self, index: StateIndex) -> f64 {
match index {
StateIndex::Scalar() => panic!("Cannot index a vector ODE state with a StateIndex::Scalar."),
StateIndex::Vector(i) => self[i],
StateIndex::Matrix(_, _) => panic!("Cannot index a vector ODE state with a StateIndex::Matrix.")
}
}
}
)*
};
}
impl_ode_state_for_dvector!(Vec);
#[cfg(feature = "nalgebra")]
impl_ode_state_for_dvector!(DVector);
#[cfg(feature = "ndarray")]
impl_ode_state_for_dvector!(Array1);
#[cfg(test)]
mod tests {
use super::*;
fn ode_state_vector_test_helper<T: OdeState + Vector<f64>>() {
let a = <T as Vector<f64>>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
let b = <T as Vector<f64>>::from_slice(&[10.0, 8.0, 4.0, 2.0]);
let c = 5.0;
let a_plus_b = <T as Vector<f64>>::from_slice(&[11.0, 10.0, 7.0, 6.0]);
let a_minus_b = <T as Vector<f64>>::from_slice(&[-9.0, -6.0, -1.0, 2.0]);
let a_times_c = <T as Vector<f64>>::from_slice(&[5.0, 10.0, 15.0, 20.0]);
assert_eq!(OdeState::add(&a, &b), a_plus_b);
let mut d = a.clone();
OdeState::add_assign(&mut d, &b);
assert_eq!(d, a_plus_b);
assert_eq!(OdeState::sub(&a, &b), a_minus_b);
d = a.clone();
OdeState::sub_assign(&mut d, &b);
assert_eq!(d, a_minus_b);
assert_eq!(OdeState::mul(&a, c), a_times_c);
d = a.clone();
OdeState::mul_assign(&mut d, c);
assert_eq!(d, a_times_c);
for i in 0..2 {
assert_eq!(a.get_state_variable(StateIndex::Vector(i)), a.vget(i));
}
}
#[test]
fn test_ode_state_linalg_traits_mat() {
ode_state_vector_test_helper::<Vec<f64>>();
}
#[test]
#[cfg(feature = "nalgebra")]
fn test_ode_state_nalgebra_dmatrix() {
ode_state_vector_test_helper::<DVector<f64>>();
}
#[test]
#[cfg(feature = "ndarray")]
fn test_ode_state_ndarray_array2() {
ode_state_vector_test_helper::<Array1<f64>>();
}
}