use crate::simd_optimize::feature_detect::CpuFeatures;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdImplementation {
Scalar,
SSE,
AVX,
AVX2,
AVX512,
NEON,
SVE,
}
impl SimdImplementation {
pub fn name(&self) -> &'static str {
match self {
SimdImplementation::Scalar => "Scalar",
SimdImplementation::SSE => "SSE",
SimdImplementation::AVX => "AVX",
SimdImplementation::AVX2 => "AVX2",
SimdImplementation::AVX512 => "AVX512",
SimdImplementation::NEON => "NEON",
SimdImplementation::SVE => "SVE",
}
}
pub fn is_avx2_or_better(&self) -> bool {
matches!(self, SimdImplementation::AVX2 | SimdImplementation::AVX512)
}
pub fn supports_fma(&self, features: &CpuFeatures) -> bool {
match self {
SimdImplementation::AVX2 | SimdImplementation::AVX512 => features.fma,
_ => false,
}
}
pub fn is_neon_or_better(&self) -> bool {
matches!(self, SimdImplementation::NEON | SimdImplementation::SVE)
}
pub fn vector_width(&self) -> usize {
match self {
SimdImplementation::Scalar => 0,
SimdImplementation::SSE => 128,
SimdImplementation::AVX => 256,
SimdImplementation::AVX2 => 256,
SimdImplementation::AVX512 => 512,
SimdImplementation::NEON => 128,
SimdImplementation::SVE => 128, }
}
}
pub fn select_simd_implementation(features: &CpuFeatures) -> SimdImplementation {
if features.avx512f {
return SimdImplementation::AVX512;
}
if features.avx2 {
return SimdImplementation::AVX2;
}
if features.avx {
return SimdImplementation::AVX;
}
if features.sse2 {
return SimdImplementation::SSE;
}
if features.sve {
return SimdImplementation::SVE;
}
if features.neon {
return SimdImplementation::NEON;
}
SimdImplementation::Scalar
}
#[allow(clippy::too_many_arguments)]
pub fn apply_simd_strategy<T, S, SSE, AVX, AVX2, AVX512, NEON, SVE>(
features: &CpuFeatures,
scalar: S,
sse: SSE,
avx: AVX,
avx2: AVX2,
avx512: AVX512,
neon: NEON,
sve: SVE,
) -> T
where
S: FnOnce() -> T,
SSE: FnOnce() -> T,
AVX: FnOnce() -> T,
AVX2: FnOnce() -> T,
AVX512: FnOnce() -> T,
NEON: FnOnce() -> T,
SVE: FnOnce() -> T,
{
let implementation = select_simd_implementation(features);
match implementation {
SimdImplementation::AVX512 => avx512(),
SimdImplementation::AVX2 => avx2(),
SimdImplementation::AVX => avx(),
SimdImplementation::SSE => sse(),
SimdImplementation::NEON => neon(),
SimdImplementation::SVE => sve(),
SimdImplementation::Scalar => scalar(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_selection() {
let features = CpuFeatures::default();
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::Scalar
);
let features = CpuFeatures {
sse2: true,
..Default::default()
};
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::SSE
);
let features = CpuFeatures {
sse2: true,
avx: true,
..Default::default()
};
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::AVX
);
let features = CpuFeatures {
sse2: true,
avx: true,
avx2: true,
..Default::default()
};
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::AVX2
);
let features = CpuFeatures {
sse2: true,
avx: true,
avx2: true,
avx512f: true,
..Default::default()
};
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::AVX512
);
let features = CpuFeatures {
neon: true,
..Default::default()
};
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::NEON
);
let features = CpuFeatures {
neon: true,
sve: true,
..Default::default()
};
assert_eq!(
select_simd_implementation(&features),
SimdImplementation::SVE
);
}
#[test]
fn test_simd_strategy() {
let apply_test = |features: &CpuFeatures| {
apply_simd_strategy(
features,
|| "scalar",
|| "sse",
|| "avx",
|| "avx2",
|| "avx512",
|| "neon",
|| "sve",
)
};
let features = CpuFeatures::default();
assert_eq!(apply_test(&features), "scalar");
let features = CpuFeatures {
sse2: true,
..Default::default()
};
assert_eq!(apply_test(&features), "sse");
let features = CpuFeatures {
sse2: true,
avx: true,
..Default::default()
};
assert_eq!(apply_test(&features), "avx");
let features = CpuFeatures {
sse2: true,
avx: true,
avx2: true,
..Default::default()
};
assert_eq!(apply_test(&features), "avx2");
}
}