use crate::ode_state::ode_state_trait::{OdeState, StateIndex};
use linalg_traits::{Mat, Matrix};
#[cfg(feature = "nalgebra")]
use nalgebra::DMatrix;
#[cfg(feature = "ndarray")]
use ndarray::Array2;
#[cfg(feature = "faer")]
use faer::Mat as FMat;
#[macro_export]
macro_rules! impl_ode_state_for_dmatrix {
($($type:ident),*) => {
$(
impl OdeState for $type<f64> {
fn add(&self, other: &Self) -> Self {
<Self as Matrix<f64>>::add(self, other)
}
fn add_assign(&mut self, other: &Self) {
<Self as Matrix<f64>>::add_assign(self, other);
}
fn sub(&self, other: &Self) -> Self {
<Self as Matrix<f64>>::sub(self, other)
}
fn sub_assign(&mut self, other: &Self) {
<Self as Matrix<f64>>::sub_assign(self, other);
}
fn mul(&self, scalar: f64) -> Self {
<Self as Matrix<f64>>::mul(self, scalar)
}
fn mul_assign(&mut self, scalar: f64) {
<Self as Matrix<f64>>::mul_assign(self, scalar);
}
fn get_state_variable(&self, index: StateIndex) -> f64 {
match index {
StateIndex::Scalar() => panic!("Cannot index a matrix ODE state with a StateIndex::Scalar."),
StateIndex::Vector(_) => panic!("Cannot index a matrix ODE state with a StateIndex::Vector."),
StateIndex::Matrix(i, j) => self[(i, j)]
}
}
}
)*
};
}
impl_ode_state_for_dmatrix!(Mat);
#[cfg(feature = "nalgebra")]
impl_ode_state_for_dmatrix!(DMatrix);
#[cfg(feature = "ndarray")]
impl_ode_state_for_dmatrix!(Array2);
#[cfg(feature = "faer")]
impl_ode_state_for_dmatrix!(FMat);
#[cfg(test)]
mod tests {
use super::*;
fn ode_state_matrix_test_helper<T: OdeState + Matrix<f64>>() {
let a = <T as Matrix<f64>>::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = <T as Matrix<f64>>::from_row_slice(2, 2, &[10.0, 8.0, 4.0, 2.0]);
let c = 5.0;
let a_plus_b = <T as Matrix<f64>>::from_row_slice(2, 2, &[11.0, 10.0, 7.0, 6.0]);
let a_minus_b = <T as Matrix<f64>>::from_row_slice(2, 2, &[-9.0, -6.0, -1.0, 2.0]);
let a_times_c = <T as Matrix<f64>>::from_row_slice(2, 2, &[5.0, 10.0, 15.0, 20.0]);
assert_eq!(OdeState::add(&a, &b), a_plus_b);
let mut d = a.clone();
OdeState::add_assign(&mut d, &b);
assert_eq!(d, a_plus_b);
assert_eq!(OdeState::sub(&a, &b), a_minus_b);
d = a.clone();
OdeState::sub_assign(&mut d, &b);
assert_eq!(d, a_minus_b);
assert_eq!(OdeState::mul(&a, c), a_times_c);
d = a.clone();
OdeState::mul_assign(&mut d, c);
assert_eq!(d, a_times_c);
for i in 0..2 {
for j in 0..2 {
assert_eq!(a.get_state_variable(StateIndex::Matrix(i, j)), a[(i, j)]);
}
}
}
#[test]
fn test_ode_state_linalg_traits_mat() {
ode_state_matrix_test_helper::<Mat<f64>>();
}
#[test]
#[cfg(feature = "nalgebra")]
fn test_ode_state_nalgebra_dmatrix() {
ode_state_matrix_test_helper::<DMatrix<f64>>();
}
#[test]
#[cfg(feature = "ndarray")]
fn test_ode_state_ndarray_array2() {
ode_state_matrix_test_helper::<Array2<f64>>();
}
#[test]
#[cfg(feature = "faer")]
fn test_ode_state_faer_mat() {
ode_state_matrix_test_helper::<Mat<f64>>();
}
}