use std::marker::PhantomData;
pub(crate) trait SIMDInstructionSet {
const REGISTER_SIZE: usize;
const LANE_SIZE_8: usize = Self::REGISTER_SIZE / (std::mem::size_of::<u8>() * 8);
const LANE_SIZE_16: usize = Self::REGISTER_SIZE / (std::mem::size_of::<u16>() * 8);
const LANE_SIZE_32: usize = Self::REGISTER_SIZE / (std::mem::size_of::<u32>() * 8);
const LANE_SIZE_64: usize = Self::REGISTER_SIZE / (std::mem::size_of::<u64>() * 8);
}
pub struct SSE<DTypeStrategy> {
pub(crate) _dtype_strategy: PhantomData<DTypeStrategy>,
}
impl<DTypeStrategy> SIMDInstructionSet for SSE<DTypeStrategy> {
const REGISTER_SIZE: usize = 128;
}
pub struct AVX2<DTypeStrategy> {
pub(crate) _dtype_strategy: PhantomData<DTypeStrategy>,
}
impl<DTypeStrategy> SIMDInstructionSet for AVX2<DTypeStrategy> {
const REGISTER_SIZE: usize = 256;
}
#[cfg(feature = "nightly_simd")]
pub struct AVX512<DTypeStrategy> {
pub(crate) _dtype_strategy: PhantomData<DTypeStrategy>,
}
#[cfg(feature = "nightly_simd")]
impl<DTypeStrategy> SIMDInstructionSet for AVX512<DTypeStrategy> {
const REGISTER_SIZE: usize = 512;
}
pub struct NEON<DTypeStrategy> {
pub(crate) _dtype_strategy: PhantomData<DTypeStrategy>,
}
impl<DTypeStrategy> SIMDInstructionSet for NEON<DTypeStrategy> {
const REGISTER_SIZE: usize = 128;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype_strategy::*;
use rstest::rstest;
use rstest_reuse::{self, *};
#[cfg(feature = "half")]
use half::f16;
#[cfg(any(feature = "float", feature = "half"))]
#[template]
#[rstest]
#[case::int(Int)]
#[case::float_return_nan(FloatIgnoreNaN)]
#[case::float_ignore_nan(FloatReturnNaN)]
fn dtype_strategies<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {}
#[cfg(not(any(feature = "float", feature = "half")))]
#[template]
#[rstest]
#[case::int(Int)]
fn dtype_strategies<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {}
#[apply(dtype_strategies)]
fn test_lane_size_8bit_dtype<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::LANE_SIZE_8, 16);
assert_eq!(AVX2::<DTypeStrategy>::LANE_SIZE_8, 32);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::LANE_SIZE_8, 64);
assert_eq!(NEON::<DTypeStrategy>::LANE_SIZE_8, 16);
}
#[apply(dtype_strategies)]
fn test_lane_size_16bit_dtype<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::LANE_SIZE_16, 8);
assert_eq!(AVX2::<DTypeStrategy>::LANE_SIZE_16, 16);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::LANE_SIZE_16, 32);
assert_eq!(NEON::<DTypeStrategy>::LANE_SIZE_16, 8);
}
#[apply(dtype_strategies)]
fn test_lane_size_32bit_dtype<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::LANE_SIZE_32, 4);
assert_eq!(AVX2::<DTypeStrategy>::LANE_SIZE_32, 8);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::LANE_SIZE_32, 16);
assert_eq!(NEON::<DTypeStrategy>::LANE_SIZE_32, 4);
}
#[apply(dtype_strategies)]
fn test_lane_size_64bit_dtype<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::LANE_SIZE_64, 2);
assert_eq!(AVX2::<DTypeStrategy>::LANE_SIZE_64, 4);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::LANE_SIZE_64, 8);
assert_eq!(NEON::<DTypeStrategy>::LANE_SIZE_64, 2);
}
}