#![cfg(target_arch = "aarch64")]
const SUB_VEC_BYTES: usize = 16;
const SUB_VECS_PER_CHUNK: usize = 64 / SUB_VEC_BYTES;
#[target_feature(enable = "sve2")]
#[inline]
unsafe fn simd_cmpge_mask_sve(ptr: *const u8, bound: u8) -> u32 {
let lo: u64;
let hi: u64;
unsafe {
core::arch::asm!(
"whilelo p0.b, xzr, {sixteen:x}",
"mov z1.b, {bound:w}",
"ld1b {{z0.b}}, p0/z, [{ptr}]",
"cmphs p1.b, p0/z, z0.b, z1.b",
"mov z2.b, p1/z, #-1",
"ldr q3, [{bm_ptr}]",
"and v2.16b, v2.16b, v3.16b",
"addv b4, v2.8b",
"umov {lo:w}, v4.b[0]",
"ext v5.16b, v2.16b, v2.16b, #8",
"addv b5, v5.8b",
"umov {hi:w}, v5.b[0]",
ptr = in(reg) ptr,
bound = in(reg) bound,
bm_ptr = in(reg) super::MOVEMASK_BIT_MASK.as_ptr(),
sixteen = in(reg) SUB_VEC_BYTES as u64,
lo = lateout(reg) lo,
hi = lateout(reg) hi,
out("z0") _, out("z1") _, out("z2") _,
out("v3") _, out("v4") _, out("v5") _,
out("p0") _, out("p1") _,
options(nostack, preserves_flags, readonly),
);
}
((lo as u32) & 0xFF) | (((hi as u32) & 0xFF) << 8)
}
#[target_feature(enable = "sve2")]
#[inline]
unsafe fn simd_any_ge_sve(ptr: *const u8, bound: u8) -> bool {
let any: u64;
unsafe {
core::arch::asm!(
"whilelo p0.b, xzr, {sixteen:x}",
"mov z1.b, {bound:w}",
"ld1b {{z0.b}}, p0/z, [{ptr}]",
"cmphs p1.b, p0/z, z0.b, z1.b",
"ptest p0, p1.b",
"cset {any:x}, ne",
ptr = in(reg) ptr,
bound = in(reg) bound,
sixteen = in(reg) SUB_VEC_BYTES as u64,
any = lateout(reg) any,
out("z0") _, out("z1") _,
out("p0") _, out("p1") _,
options(nostack, readonly),
);
}
any != 0
}
#[target_feature(enable = "sve2")]
#[inline]
pub(crate) unsafe fn scan_chunk(ptr: *const u8, bound: u8) -> u64 {
let mut mask: u64 = 0;
let mut i = 0;
while i < SUB_VECS_PER_CHUNK {
let p = unsafe { ptr.add(i * SUB_VEC_BYTES) };
if unsafe { simd_any_ge_sve(p, bound) } {
let sub_mask = unsafe { simd_cmpge_mask_sve(p, bound) } as u64;
mask |= sub_mask << (i * SUB_VEC_BYTES);
}
i += 1;
}
mask
}
#[allow(dead_code)]
#[target_feature(enable = "sve2")]
#[inline(never)]
pub(crate) unsafe fn scan_and_prefetch(
ptr: *const u8,
prefetch_l1: *const u8,
prefetch_l2: *const u8,
bound: u8,
) -> u64 {
use crate::simd::prefetch::{prefetch_l1_stream, prefetch_l2_stream};
unsafe {
prefetch_l1_stream(prefetch_l1);
prefetch_l2_stream(prefetch_l2);
scan_chunk(ptr, bound)
}
}
#[cfg(all(test, target_arch = "aarch64"))]
mod compile_smoke {
#[test]
fn sve2_symbols_compile() {
let _: unsafe fn(*const u8, u8) -> u64 = super::scan_chunk;
let _: unsafe fn(*const u8, *const u8, *const u8, u8) -> u64 = super::scan_and_prefetch;
}
}
#[cfg(all(test, target_arch = "aarch64", target_feature = "sve2"))]
mod tests {
use super::*;
#[test]
fn sve2_scan_all_below() {
let data = [0x41u8; 64];
let mask = unsafe { scan_chunk(data.as_ptr(), 0xC0) };
assert_eq!(mask, 0);
}
#[test]
fn sve2_scan_all_above() {
let data = [0xFFu8; 64];
let mask = unsafe { scan_chunk(data.as_ptr(), 0xC0) };
assert_eq!(mask, u64::MAX);
}
#[test]
fn sve2_scan_at_bound() {
let data = [0xC0u8; 64];
let mask = unsafe { scan_chunk(data.as_ptr(), 0xC0) };
assert_eq!(mask, u64::MAX);
}
#[test]
fn sve2_scan_mixed() {
let mut data = [0x41u8; 64];
data[0] = 0xC0;
data[15] = 0xC0;
data[16] = 0xC0;
data[63] = 0xFF;
let mask = unsafe { scan_chunk(data.as_ptr(), 0xC0) };
let expected = (1u64 << 0) | (1u64 << 15) | (1u64 << 16) | (1u64 << 63);
assert_eq!(mask, expected);
}
#[test]
fn sve2_scan_every_position() {
for pos in 0..64 {
let mut chunk = [0u8; 64];
chunk[pos] = 0xC0;
let mask = unsafe { scan_chunk(chunk.as_ptr(), 0xC0) };
assert_eq!(mask, 1u64 << pos, "SVE2: Expected only bit {pos} set");
}
}
#[test]
fn sve2_scan_bound_zero() {
let data = [0x00u8; 64];
let mask = unsafe { scan_chunk(data.as_ptr(), 0x00) };
assert_eq!(mask, u64::MAX);
}
#[test]
fn sve2_scan_and_prefetch_matches() {
let mut data = [0x30u8; 64];
data[7] = 0xD0;
data[31] = 0xE5;
let dummy = data.as_ptr();
let m1 = unsafe { scan_chunk(data.as_ptr(), 0xC0) };
let m2 = unsafe { scan_and_prefetch(data.as_ptr(), dummy, dummy, 0xC0) };
assert_eq!(m1, m2, "Prefetch variant must produce identical bitmask");
}
#[test]
fn sve2_matches_scalar() {
let mut chunk = [0u8; 64];
for (i, byte) in chunk.iter_mut().enumerate() {
*byte = (i as u8).wrapping_mul(7);
}
let sve_mask = unsafe { scan_chunk(chunk.as_ptr(), 0xC0) };
let scalar_mask = unsafe { crate::simd::scalar::scan_chunk(chunk.as_ptr(), 0xC0) };
assert_eq!(sve_mask, scalar_mask, "SVE2 must match scalar");
}
#[test]
fn sve2_matches_neon() {
let mut chunk = [0u8; 64];
for (i, byte) in chunk.iter_mut().enumerate() {
*byte = (i as u8).wrapping_mul(13).wrapping_add(0x80);
}
let sve_mask = unsafe { scan_chunk(chunk.as_ptr(), 0xC0) };
let neon_mask = unsafe { crate::simd::aarch64::neon::scan_chunk(chunk.as_ptr(), 0xC0) };
assert_eq!(sve_mask, neon_mask, "SVE2 and NEON must agree");
}
#[test]
fn sve2_any_ge_helper() {
unsafe {
let zeros = [0x00u8; 16];
let ones = [0xFFu8; 16];
let bound = [0xC0u8; 16];
assert!(!simd_any_ge_sve(zeros.as_ptr(), 0xC0));
assert!(simd_any_ge_sve(ones.as_ptr(), 0xC0));
assert!(simd_any_ge_sve(bound.as_ptr(), 0xC0));
}
}
}