#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "x86_64")]
#[derive(Copy, Clone)]
pub(crate) struct Group(__m128i);
#[cfg(target_arch = "aarch64")]
#[derive(Copy, Clone)]
pub(crate) struct Group(uint8x16_t);
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
#[derive(Copy, Clone)]
pub(crate) struct Group([u8; 16]);
#[derive(Copy, Clone)]
pub(crate) struct BitMask(u64);
impl Group {
#[allow(dead_code)] pub const WIDTH: usize = 16;
#[inline]
pub(crate) unsafe fn load(ptr: *const u8) -> Self {
#[cfg(target_arch = "x86_64")]
unsafe {
Self(_mm_loadu_si128(ptr as *const __m128i))
}
#[cfg(target_arch = "aarch64")]
unsafe {
Self(vld1q_u8(ptr))
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let mut buf = [0u8; 16];
unsafe { core::ptr::copy_nonoverlapping(ptr, buf.as_mut_ptr(), 16) };
Self(buf)
}
}
#[inline]
pub(crate) fn match_byte(self, b: u8) -> BitMask {
#[cfg(target_arch = "x86_64")]
unsafe {
let bcast = _mm_set1_epi8(b as i8);
let eq = _mm_cmpeq_epi8(self.0, bcast);
BitMask(_mm_movemask_epi8(eq) as u32 as u64)
}
#[cfg(target_arch = "aarch64")]
unsafe {
let bcast = vdupq_n_u8(b);
let eq = vceqq_u8(self.0, bcast); let eq16 = vreinterpretq_u16_u8(eq);
let shrn = vshrn_n_u16(eq16, 4);
let raw = vget_lane_u64(vreinterpret_u64_u8(shrn), 0);
BitMask(raw & 0x1111_1111_1111_1111u64)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let mut mask: u64 = 0;
for (i, &m) in self.0.iter().enumerate() {
if m == b {
mask |= 1u64 << i;
}
}
BitMask(mask)
}
}
}
impl BitMask {
#[inline]
pub(crate) fn iter(self) -> BitMaskIter {
BitMaskIter(self.0)
}
#[inline]
pub(crate) fn is_empty(self) -> bool {
self.0 == 0
}
#[inline]
pub(crate) fn lowest_set(self) -> Option<usize> {
if self.0 == 0 {
None
} else {
Some(BIT_TO_SLOT(self.0.trailing_zeros() as usize))
}
}
}
pub(crate) struct BitMaskIter(u64);
impl Iterator for BitMaskIter {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<usize> {
if self.0 == 0 {
return None;
}
let bit = self.0.trailing_zeros() as usize;
self.0 &= self.0 - 1;
Some(BIT_TO_SLOT(bit))
}
}
#[cfg(any(target_arch = "x86_64", not(target_arch = "aarch64")))]
#[allow(non_snake_case)]
#[inline]
fn BIT_TO_SLOT(bit: usize) -> usize {
bit
}
#[cfg(target_arch = "aarch64")]
#[allow(non_snake_case)]
#[inline]
fn BIT_TO_SLOT(bit: usize) -> usize {
bit >> 2 }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn match_finds_all_positions() {
let buf: [u8; 16] = [
0xAB, 0x00, 0xAB, 0x01, 0xAB, 0xAB, 0x02, 0xAB,
0x03, 0x04, 0xAB, 0x05, 0xAB, 0xAB, 0x06, 0xAB,
];
let g = unsafe { Group::load(buf.as_ptr()) };
let hits: Vec<usize> = g.match_byte(0xAB).iter().collect();
let want: Vec<usize> = buf
.iter()
.enumerate()
.filter_map(|(i, &b)| (b == 0xAB).then_some(i))
.collect();
assert_eq!(hits, want);
}
#[test]
fn match_no_hits_is_empty() {
let buf = [0u8; 16];
let g = unsafe { Group::load(buf.as_ptr()) };
let m = g.match_byte(0x42);
assert!(m.is_empty());
assert_eq!(m.iter().count(), 0);
assert_eq!(m.lowest_set(), None);
}
#[test]
fn match_all_hits() {
let buf = [0xFFu8; 16];
let g = unsafe { Group::load(buf.as_ptr()) };
let hits: Vec<usize> = g.match_byte(0xFF).iter().collect();
assert_eq!(hits, (0..16).collect::<Vec<_>>());
assert_eq!(g.match_byte(0xFF).lowest_set(), Some(0));
}
#[test]
fn unaligned_load_works() {
let buf: [u8; 17] = [
0xDE, 0xAB, 0x00, 0xAB, 0x01, 0xAB, 0xAB, 0x02,
0xAB, 0x03, 0x04, 0xAB, 0x05, 0xAB, 0xAB, 0x06, 0xAB,
];
let g = unsafe { Group::load(buf.as_ptr().add(1)) };
let hits: Vec<usize> = g.match_byte(0xAB).iter().collect();
let want: Vec<usize> = buf[1..17]
.iter()
.enumerate()
.filter_map(|(i, &b)| (b == 0xAB).then_some(i))
.collect();
assert_eq!(hits, want);
}
#[test]
fn lowest_set_matches_first_iter() {
let buf: [u8; 16] = [
0, 0, 0, 0, 0xAA, 0, 0, 0,
0xAA, 0, 0, 0, 0, 0, 0, 0,
];
let g = unsafe { Group::load(buf.as_ptr()) };
let m = g.match_byte(0xAA);
assert_eq!(m.lowest_set(), Some(4));
assert_eq!(m.iter().next(), Some(4));
}
}