use std::sync::OnceLock;
use crate::cpu::features::CpuFeatures;
use crate::optimizer::scalar;
use crate::optimizer::simd::{sse2, avx, avx2};
type AddFn = fn(&[f32], &[f32], &mut [f32]);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DispatchPath {
Scalar,
SSE2,
AVX,
AVX2,
AVX512,
Neon, }
pub struct Selector;
static CACHED_ADD_FN: OnceLock<AddFn> = OnceLock::new();
impl Selector {
pub fn best_path(features: &CpuFeatures) -> DispatchPath {
if features.avx512f {
DispatchPath::AVX512
} else if features.avx2 {
DispatchPath::AVX2
} else if features.avx {
DispatchPath::AVX
} else if features.sse2 {
DispatchPath::SSE2
} else if features.neon {
DispatchPath::Neon
} else {
DispatchPath::Scalar
}
}
pub fn get_add_fn() -> AddFn {
*CACHED_ADD_FN.get_or_init(|| {
let features = CpuFeatures::detect();
match Self::best_path(&features) {
DispatchPath::AVX512 => {
#[cfg(target_arch = "x86_64")]
{ crate::optimizer::simd::avx512::add_avx512_impl }
#[cfg(not(target_arch = "x86_64"))]
{ scalar::add_impl }
}
DispatchPath::AVX2 => {
#[cfg(target_arch = "x86_64")]
{ avx2::add_avx2_impl }
#[cfg(not(target_arch = "x86_64"))]
{ scalar::add_impl }
}
DispatchPath::AVX => {
#[cfg(target_arch = "x86_64")]
{ avx::add_avx_impl }
#[cfg(not(target_arch = "x86_64"))]
{ scalar::add_impl }
}
DispatchPath::SSE2 => {
#[cfg(target_arch = "x86_64")]
{ sse2::add_sse2_impl }
#[cfg(not(target_arch = "x86_64"))]
{ scalar::add_impl }
}
DispatchPath::Neon => {
#[cfg(target_arch = "aarch64")]
{ crate::optimizer::simd::neon::add_neon_impl }
#[cfg(not(target_arch = "aarch64"))]
{ scalar::add_impl }
}
DispatchPath::Scalar => scalar::add_impl,
}
})
}
pub fn dispatch_add(a: &[f32], b: &[f32], out: &mut [f32]) {
let func = Self::get_add_fn();
func(a, b, out);
}
}