use crate::ode_state::ode_state_trait::{OdeState, StateIndex};
use num_traits::ToPrimitive;
#[macro_export]
macro_rules! impl_ode_state_for_scalar {
($($type:ty),*) => {
$(
impl OdeState for $type {
fn add(&self, other: &Self) -> Self {
self + other
}
fn add_assign(&mut self, other: &Self) {
*self += other;
}
fn sub(&self, other: &Self) -> Self {
self - other
}
fn sub_assign(&mut self, other: &Self) {
*self -= other
}
fn mul(&self, scalar: f64) -> Self {
self * scalar
}
fn mul_assign(&mut self, scalar: f64) {
*self *= scalar
}
fn get_state_variable(&self, index: StateIndex) -> f64 {
match index {
StateIndex::Scalar() => <Self as ToPrimitive>::to_f64(self).unwrap(),
StateIndex::Vector(_) => panic!("Cannot index a scalar ODE state with a StateIndex::Vector."),
StateIndex::Matrix(_, _) => panic!("Cannot index a scalar ODE state with a StateIndex::Matrix.")
}
}
}
)*
};
}
impl_ode_state_for_scalar!(f64);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ode_state_f64() {
let a = 2.0;
let b = 6.0;
let c = 5.0;
let a_plus_b = 8.0;
let a_minus_b = -4.0;
let a_times_c = 10.0;
assert_eq!(OdeState::add(&a, &b), a_plus_b);
let mut d = a;
OdeState::add_assign(&mut d, &b);
assert_eq!(d, a_plus_b);
assert_eq!(OdeState::sub(&a, &b), a_minus_b);
d = a;
OdeState::sub_assign(&mut d, &b);
assert_eq!(d, a_minus_b);
assert_eq!(OdeState::mul(&a, c), a_times_c);
d = a;
OdeState::mul_assign(&mut d, c);
assert_eq!(d, a_times_c);
assert_eq!(a.get_state_variable(StateIndex::Scalar()), a);
}
}