use super::kernels::{self, CompressFn};
use crate::{
backend::cache::OnceCache,
platform::{Caps, caps},
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum KernelId {
Portable,
#[cfg(target_arch = "aarch64")]
Aarch64Neon,
#[cfg(target_arch = "x86_64")]
X86Avx2,
#[cfg(target_arch = "x86_64")]
X86Avx512,
#[cfg(target_arch = "powerpc64")]
PowerVsx,
#[cfg(target_arch = "s390x")]
S390xVector,
#[cfg(target_arch = "riscv64")]
Riscv64V,
#[cfg(target_arch = "wasm32")]
WasmSimd128,
}
impl KernelId {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Portable => "portable",
#[cfg(target_arch = "aarch64")]
Self::Aarch64Neon => "aarch64-neon",
#[cfg(target_arch = "x86_64")]
Self::X86Avx2 => "x86-avx2",
#[cfg(target_arch = "x86_64")]
Self::X86Avx512 => "x86-avx512",
#[cfg(target_arch = "powerpc64")]
Self::PowerVsx => "power-vsx",
#[cfg(target_arch = "s390x")]
Self::S390xVector => "s390x-vector",
#[cfg(target_arch = "riscv64")]
Self::Riscv64V => "riscv64-v",
#[cfg(target_arch = "wasm32")]
Self::WasmSimd128 => "wasm-simd128",
}
}
}
pub const ALL_KERNELS: &[KernelId] = &[
#[cfg(target_arch = "x86_64")]
KernelId::X86Avx512,
#[cfg(target_arch = "x86_64")]
KernelId::X86Avx2,
#[cfg(target_arch = "aarch64")]
KernelId::Aarch64Neon,
#[cfg(target_arch = "powerpc64")]
KernelId::PowerVsx,
#[cfg(target_arch = "s390x")]
KernelId::S390xVector,
#[cfg(target_arch = "riscv64")]
KernelId::Riscv64V,
#[cfg(target_arch = "wasm32")]
KernelId::WasmSimd128,
KernelId::Portable,
];
#[must_use]
pub const fn required_caps(kernel: KernelId) -> Caps {
match kernel {
KernelId::Portable => Caps::from_words([0; 4]),
#[cfg(target_arch = "aarch64")]
KernelId::Aarch64Neon => crate::platform::caps::aarch64::NEON,
#[cfg(target_arch = "x86_64")]
KernelId::X86Avx2 => crate::platform::caps::x86::AVX2,
#[cfg(target_arch = "x86_64")]
KernelId::X86Avx512 => crate::platform::caps::x86::AVX512F.union(crate::platform::caps::x86::AVX512VL),
#[cfg(target_arch = "powerpc64")]
KernelId::PowerVsx => crate::platform::caps::power::VSX,
#[cfg(target_arch = "s390x")]
KernelId::S390xVector => crate::platform::caps::s390x::VECTOR,
#[cfg(target_arch = "riscv64")]
KernelId::Riscv64V => crate::platform::caps::riscv::V,
#[cfg(target_arch = "wasm32")]
KernelId::WasmSimd128 => crate::platform::caps::wasm::SIMD128,
}
}
#[inline]
#[must_use]
pub(super) fn compress_fn_for(kernel: KernelId) -> CompressFn {
match kernel {
KernelId::Portable => kernels::compress_portable,
#[cfg(target_arch = "aarch64")]
KernelId::Aarch64Neon => super::aarch64::compress_neon,
#[cfg(target_arch = "x86_64")]
KernelId::X86Avx2 => super::x86_64::compress_avx2,
#[cfg(target_arch = "x86_64")]
KernelId::X86Avx512 => super::x86_64::compress_avx512,
#[cfg(target_arch = "powerpc64")]
KernelId::PowerVsx => super::power::compress_vsx,
#[cfg(target_arch = "s390x")]
KernelId::S390xVector => super::s390x::compress_vector,
#[cfg(target_arch = "riscv64")]
KernelId::Riscv64V => super::riscv64::compress_rvv,
#[cfg(target_arch = "wasm32")]
KernelId::WasmSimd128 => super::wasm::compress_simd128,
}
}
static ACTIVE_KERNEL: OnceCache<KernelId> = OnceCache::new();
#[must_use]
pub(super) fn active_kernel() -> KernelId {
ACTIVE_KERNEL.get_or_init(|| {
let host = caps();
for &id in ALL_KERNELS {
if host.has(required_caps(id)) {
return id;
}
}
KernelId::Portable
})
}
#[inline]
pub(super) fn active_compress() -> CompressFn {
compress_fn_for(active_kernel())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn portable_has_no_required_caps() {
assert!(required_caps(KernelId::Portable).is_empty());
}
#[test]
fn all_kernels_terminate_on_portable() {
assert_eq!(*ALL_KERNELS.last().unwrap(), KernelId::Portable);
}
#[test]
fn portable_kernel_name() {
assert_eq!(KernelId::Portable.as_str(), "portable");
}
#[cfg(target_arch = "aarch64")]
#[test]
fn aarch64_neon_kernel_name() {
assert_eq!(KernelId::Aarch64Neon.as_str(), "aarch64-neon");
}
#[cfg(target_arch = "aarch64")]
#[test]
fn aarch64_neon_required_caps_are_neon() {
assert_eq!(
required_caps(KernelId::Aarch64Neon),
crate::platform::caps::aarch64::NEON
);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn x86_avx2_kernel_name() {
assert_eq!(KernelId::X86Avx2.as_str(), "x86-avx2");
}
#[cfg(target_arch = "x86_64")]
#[test]
fn x86_avx2_required_caps_are_avx2() {
assert_eq!(required_caps(KernelId::X86Avx2), crate::platform::caps::x86::AVX2);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn x86_avx512_kernel_name() {
assert_eq!(KernelId::X86Avx512.as_str(), "x86-avx512");
}
#[cfg(target_arch = "x86_64")]
#[test]
fn x86_avx512_required_caps_are_f_and_vl() {
let caps = required_caps(KernelId::X86Avx512);
assert!(caps.has(crate::platform::caps::x86::AVX512F));
assert!(caps.has(crate::platform::caps::x86::AVX512VL));
}
#[cfg(target_arch = "x86_64")]
#[test]
fn x86_kernels_ordered_avx512_then_avx2_then_portable() {
let avx512_pos = ALL_KERNELS
.iter()
.position(|&k| k == KernelId::X86Avx512)
.expect("avx-512 in table");
let avx2_pos = ALL_KERNELS
.iter()
.position(|&k| k == KernelId::X86Avx2)
.expect("avx2 in table");
let portable_pos = ALL_KERNELS
.iter()
.position(|&k| k == KernelId::Portable)
.expect("portable in table");
assert!(avx512_pos < avx2_pos);
assert!(avx2_pos < portable_pos);
}
#[cfg(target_arch = "powerpc64")]
#[test]
fn power_vsx_kernel_name_and_caps() {
assert_eq!(KernelId::PowerVsx.as_str(), "power-vsx");
assert_eq!(required_caps(KernelId::PowerVsx), crate::platform::caps::power::VSX);
}
#[cfg(target_arch = "s390x")]
#[test]
fn s390x_vector_kernel_name_and_caps() {
assert_eq!(KernelId::S390xVector.as_str(), "s390x-vector");
assert_eq!(
required_caps(KernelId::S390xVector),
crate::platform::caps::s390x::VECTOR
);
}
#[cfg(target_arch = "riscv64")]
#[test]
fn riscv64_v_kernel_name_and_caps() {
assert_eq!(KernelId::Riscv64V.as_str(), "riscv64-v");
assert_eq!(required_caps(KernelId::Riscv64V), crate::platform::caps::riscv::V);
}
#[cfg(target_arch = "wasm32")]
#[test]
fn wasm_simd128_kernel_name_and_caps() {
assert_eq!(KernelId::WasmSimd128.as_str(), "wasm-simd128");
assert_eq!(
required_caps(KernelId::WasmSimd128),
crate::platform::caps::wasm::SIMD128
);
}
#[test]
fn active_kernel_is_in_all_kernels() {
let id = active_kernel();
assert!(ALL_KERNELS.contains(&id));
}
#[test]
fn forced_kernels_match_portable() {
let x: [u64; super::super::BLOCK_WORDS] = core::array::from_fn(|i| i as u64 * 0x0101_0101_0101_0101);
let y: [u64; super::super::BLOCK_WORDS] = core::array::from_fn(|i| (i as u64).wrapping_mul(0xdead_beef_feed_face));
let mut expected = [0u64; super::super::BLOCK_WORDS];
unsafe { kernels::compress_portable(&mut expected, &x, &y, false) };
for &id in ALL_KERNELS {
let host = crate::platform::caps();
if !host.has(required_caps(id)) {
continue;
}
let kernel = compress_fn_for(id);
let mut got = [0u64; super::super::BLOCK_WORDS];
unsafe { kernel(&mut got, &x, &y, false) };
assert_eq!(got, expected, "kernel {} diverged from portable", id.as_str());
let mut acc = [0xa5a5_a5a5_a5a5_a5a5u64; super::super::BLOCK_WORDS];
let mut expected_xor = acc;
unsafe {
kernels::compress_portable(&mut expected_xor, &x, &y, true);
kernel(&mut acc, &x, &y, true);
}
assert_eq!(
acc,
expected_xor,
"kernel {} xor_into diverged from portable",
id.as_str()
);
}
}
}