use core::arch::aarch64::*;
use crate::Needles;
const NEON_CHUNK_SIZE: usize = 16;
#[doc(hidden)]
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn nibble_mask(cmp: uint8x16_t) -> u64 {
let narrowed = unsafe { vshrn_n_u16::<4>(vreinterpretq_u16_u8(cmp)) };
unsafe { vget_lane_u64::<0>(vreinterpret_u64_u8(narrowed)) }
}
#[doc(hidden)]
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn range_mask(chunk: uint8x16_t, lo: u8, hi: u8) -> uint8x16_t {
let width = hi.wrapping_sub(lo);
if width == 0xFF {
return unsafe { vdupq_n_u8(0xFF) };
}
let shifted = unsafe { vsubq_u8(chunk, vdupq_n_u8(lo)) };
let bound = unsafe { vdupq_n_u8(width.wrapping_add(1)) };
unsafe { vcltq_u8(shifted, bound) }
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn binary_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, b'0', b'1')
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn octal_digit_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, b'0', b'7')
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn digit_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, b'0', b'9')
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn hex_digit_mask(chunk: uint8x16_t) -> uint8x16_t {
let digit = digit_mask(chunk);
let lower = unsafe { vorrq_u8(chunk, vdupq_n_u8(0x20)) };
let alpha = range_mask(lower, b'a', b'f');
unsafe { vorrq_u8(digit, alpha) }
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn whitespace_mask(chunk: uint8x16_t) -> uint8x16_t {
let space = unsafe { vceqq_u8(chunk, vdupq_n_u8(b' ')) };
let tab = unsafe { vceqq_u8(chunk, vdupq_n_u8(b'\t')) };
let nl = unsafe { vceqq_u8(chunk, vdupq_n_u8(b'\n')) };
let cr = unsafe { vceqq_u8(chunk, vdupq_n_u8(b'\r')) };
unsafe { vorrq_u8(vorrq_u8(space, tab), vorrq_u8(nl, cr)) }
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn alpha_mask(chunk: uint8x16_t) -> uint8x16_t {
let lower = unsafe { vorrq_u8(chunk, vdupq_n_u8(0x20)) };
range_mask(lower, b'a', b'z')
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn alphanumeric_mask(chunk: uint8x16_t) -> uint8x16_t {
let alpha = alpha_mask(chunk);
let digit = digit_mask(chunk);
unsafe { vorrq_u8(alpha, digit) }
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn ident_start_mask(chunk: uint8x16_t) -> uint8x16_t {
let alpha = alpha_mask(chunk);
let underscore = unsafe { vceqq_u8(chunk, vdupq_n_u8(b'_')) };
unsafe { vorrq_u8(alpha, underscore) }
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn ident_mask(chunk: uint8x16_t) -> uint8x16_t {
let alphanum = alphanumeric_mask(chunk);
let underscore = unsafe { vceqq_u8(chunk, vdupq_n_u8(b'_')) };
unsafe { vorrq_u8(alphanum, underscore) }
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn lower_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, b'a', b'z')
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn upper_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, b'A', b'Z')
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn ascii_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, 0x00, 0x7F)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn non_ascii_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, 0x80, 0xFF)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn ascii_graphic_mask(chunk: uint8x16_t) -> uint8x16_t {
range_mask(chunk, 0x21, 0x7E)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn ascii_control_mask(chunk: uint8x16_t) -> uint8x16_t {
let ctrl = range_mask(chunk, 0x00, 0x1F);
let del = unsafe { vceqq_u8(chunk, vdupq_n_u8(0x7F)) };
unsafe { vorrq_u8(ctrl, del) }
}
macro_rules! skip_ascii_class {
($name:ident, $prefix_len:ident, $mask:ident) => {
#[cfg_attr(not(tarpaulin), inline(always))]
#[cfg(target_feature = "neon")]
pub(super) fn $name(input: &[u8]) -> usize {
let len = input.len();
if len < NEON_CHUNK_SIZE {
return super::$prefix_len(input);
}
let ptr = input.as_ptr();
let first_chunk_len = super::$prefix_len(&input[..NEON_CHUNK_SIZE]);
if first_chunk_len != NEON_CHUNK_SIZE {
return first_chunk_len;
}
let mut cur = NEON_CHUNK_SIZE;
while cur + 2 * NEON_CHUNK_SIZE <= len {
let c0 = unsafe { vld1q_u8(ptr.add(cur)) };
let c1 = unsafe { vld1q_u8(ptr.add(cur + NEON_CHUNK_SIZE)) };
let m0 = $mask(c0);
let m1 = $mask(c1);
let miss_bits = !nibble_mask(unsafe { vandq_u8(m0, m1) });
if miss_bits != 0 {
let mb0 = !nibble_mask(m0);
if mb0 != 0 {
return cur + (mb0.trailing_zeros() / 4) as usize;
}
let mb1 = !nibble_mask(m1);
return cur + NEON_CHUNK_SIZE + (mb1.trailing_zeros() / 4) as usize;
}
cur += 2 * NEON_CHUNK_SIZE;
}
while cur + NEON_CHUNK_SIZE <= len {
let chunk = unsafe { vld1q_u8(ptr.add(cur)) };
let cmp = $mask(chunk);
let miss_bits = !nibble_mask(cmp);
if miss_bits != 0 {
return cur + (miss_bits.trailing_zeros() / 4) as usize;
}
cur += NEON_CHUNK_SIZE;
}
if cur == len {
return len;
}
let overlap_start = len - NEON_CHUNK_SIZE;
let chunk = unsafe { vld1q_u8(ptr.add(overlap_start)) };
let cmp = $mask(chunk);
let already_scanned_lanes = cur - overlap_start;
let lane_mask = (!0u64) << (already_scanned_lanes * 4);
let miss_bits = !nibble_mask(cmp) & lane_mask;
if miss_bits != 0 {
overlap_start + (miss_bits.trailing_zeros() / 4) as usize
} else {
len
}
}
};
}
skip_ascii_class!(skip_binary, prefix_len_binary, binary_mask);
skip_ascii_class!(skip_digits, prefix_len_digits, digit_mask);
skip_ascii_class!(skip_hex_digits, prefix_len_hex_digits, hex_digit_mask);
skip_ascii_class!(skip_octal_digits, prefix_len_octal_digits, octal_digit_mask);
skip_ascii_class!(skip_whitespace, prefix_len_whitespace, whitespace_mask);
skip_ascii_class!(skip_alpha, prefix_len_alpha, alpha_mask);
skip_ascii_class!(
skip_alphanumeric,
prefix_len_alphanumeric,
alphanumeric_mask
);
skip_ascii_class!(skip_ident_start, prefix_len_ident_start, ident_start_mask);
skip_ascii_class!(skip_ident, prefix_len_ident, ident_mask);
skip_ascii_class!(skip_lower, prefix_len_lower, lower_mask);
skip_ascii_class!(skip_upper, prefix_len_upper, upper_mask);
skip_ascii_class!(skip_ascii, prefix_len_ascii, ascii_mask);
skip_ascii_class!(skip_non_ascii, prefix_len_non_ascii, non_ascii_mask);
skip_ascii_class!(
skip_ascii_graphic,
prefix_len_ascii_graphic,
ascii_graphic_mask
);
skip_ascii_class!(
skip_ascii_control,
prefix_len_ascii_control,
ascii_control_mask
);
#[cfg_attr(not(tarpaulin), inline(always))]
#[cfg(target_feature = "neon")]
pub(super) fn count_matches<Nd>(input: &[u8], needles: Nd) -> usize
where
Nd: Needles,
{
let len = input.len();
if len < NEON_CHUNK_SIZE {
return input
.iter()
.filter(|&&b| needles.tail_find(core::slice::from_ref(&b)).is_some())
.count();
}
let ptr = input.as_ptr();
let mut count = 0usize;
let mut cur = 0;
while cur + 2 * NEON_CHUNK_SIZE <= len {
let c0 = unsafe { vld1q_u8(ptr.add(cur)) };
let c1 = unsafe { vld1q_u8(ptr.add(cur + NEON_CHUNK_SIZE)) };
let m0 = needles.eq_any_mask_neon(c0);
let m1 = needles.eq_any_mask_neon(c1);
count += (nibble_mask(m0).count_ones() / 4) as usize;
count += (nibble_mask(m1).count_ones() / 4) as usize;
cur += 2 * NEON_CHUNK_SIZE;
}
while cur + NEON_CHUNK_SIZE <= len {
let chunk = unsafe { vld1q_u8(ptr.add(cur)) };
let cmp = needles.eq_any_mask_neon(chunk);
count += (nibble_mask(cmp).count_ones() / 4) as usize;
cur += NEON_CHUNK_SIZE;
}
if cur < len {
let overlap_start = len - NEON_CHUNK_SIZE;
let chunk = unsafe { vld1q_u8(ptr.add(overlap_start)) };
let cmp = needles.eq_any_mask_neon(chunk);
let already = cur - overlap_start;
let lane_mask = (!0u64) << (already * 4);
count += (nibble_mask(cmp) & lane_mask).count_ones() as usize / 4;
}
count
}
#[cfg_attr(not(tarpaulin), inline(always))]
#[cfg(target_feature = "neon")]
pub(super) fn find_last<Nd>(input: &[u8], needles: Nd) -> Option<usize>
where
Nd: Needles,
{
let len = input.len();
if len < NEON_CHUNK_SIZE {
let mut last = None;
for (i, &b) in input.iter().enumerate() {
if needles.tail_find(core::slice::from_ref(&b)).is_some() {
last = Some(i);
}
}
return last;
}
let ptr = input.as_ptr();
let mut last: Option<usize> = None;
let mut cur = 0;
while cur + 2 * NEON_CHUNK_SIZE <= len {
let c0 = unsafe { vld1q_u8(ptr.add(cur)) };
let c1 = unsafe { vld1q_u8(ptr.add(cur + NEON_CHUNK_SIZE)) };
let b0 = nibble_mask(needles.eq_any_mask_neon(c0));
let b1 = nibble_mask(needles.eq_any_mask_neon(c1));
if b0 != 0 {
last = Some(cur + (15 - b0.leading_zeros() / 4) as usize);
}
if b1 != 0 {
last = Some(cur + NEON_CHUNK_SIZE + (15 - b1.leading_zeros() / 4) as usize);
}
cur += 2 * NEON_CHUNK_SIZE;
}
while cur + NEON_CHUNK_SIZE <= len {
let chunk = unsafe { vld1q_u8(ptr.add(cur)) };
let bits = nibble_mask(needles.eq_any_mask_neon(chunk));
if bits != 0 {
last = Some(cur + (15 - bits.leading_zeros() / 4) as usize);
}
cur += NEON_CHUNK_SIZE;
}
if cur < len {
let overlap_start = len - NEON_CHUNK_SIZE;
let chunk = unsafe { vld1q_u8(ptr.add(overlap_start)) };
let already = cur - overlap_start;
let lane_mask = (!0u64) << (already * 4);
let bits = nibble_mask(needles.eq_any_mask_neon(chunk)) & lane_mask;
if bits != 0 {
last = Some(overlap_start + (15 - bits.leading_zeros() / 4) as usize);
}
}
last
}
#[cfg_attr(not(tarpaulin), inline(always))]
#[cfg(target_feature = "neon")]
pub(super) fn skip_until<Nd>(input: &[u8], needles: Nd) -> Option<usize>
where
Nd: Needles,
{
let len = input.len();
if len < NEON_CHUNK_SIZE {
return needles.tail_find(input);
}
let ptr = input.as_ptr();
if let Some(hit) = needles.tail_find(&input[..NEON_CHUNK_SIZE]) {
return Some(hit);
}
let mut cur: usize = NEON_CHUNK_SIZE;
while cur + 2 * NEON_CHUNK_SIZE <= len {
let c0 = unsafe { vld1q_u8(ptr.add(cur)) };
let c1 = unsafe { vld1q_u8(ptr.add(cur + NEON_CHUNK_SIZE)) };
let m0 = needles.eq_any_mask_neon(c0);
let m1 = needles.eq_any_mask_neon(c1);
let combined = nibble_mask(unsafe { vorrq_u8(m0, m1) });
if combined != 0 {
let b0 = nibble_mask(m0);
if b0 != 0 {
return Some(cur + (b0.trailing_zeros() / 4) as usize);
}
let b1 = nibble_mask(m1);
return Some(cur + NEON_CHUNK_SIZE + (b1.trailing_zeros() / 4) as usize);
}
cur += 2 * NEON_CHUNK_SIZE;
}
while cur + NEON_CHUNK_SIZE <= len {
let chunk = unsafe { vld1q_u8(ptr.add(cur)) };
let cmp = needles.eq_any_mask_neon(chunk);
let bits = nibble_mask(cmp);
if bits != 0 {
return Some(cur + (bits.trailing_zeros() / 4) as usize);
}
cur += NEON_CHUNK_SIZE;
}
if cur == len {
return None;
}
let overlap_start = len - NEON_CHUNK_SIZE;
let chunk = unsafe { vld1q_u8(ptr.add(overlap_start)) };
let cmp = needles.eq_any_mask_neon(chunk);
let already_scanned_lanes = cur - overlap_start;
let lane_mask = (!0u64) << (already_scanned_lanes * 4);
let bits = nibble_mask(cmp) & lane_mask;
if bits != 0 {
Some(overlap_start + (bits.trailing_zeros() / 4) as usize)
} else {
None
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
#[cfg(target_feature = "neon")]
pub(super) fn skip_while<Nd>(input: &[u8], needles: Nd) -> usize
where
Nd: Needles,
{
let len = input.len();
if len < NEON_CHUNK_SIZE {
return needles.prefix_len(input);
}
let ptr = input.as_ptr();
let first_chunk_len = needles.prefix_len(&input[..NEON_CHUNK_SIZE]);
if first_chunk_len != NEON_CHUNK_SIZE {
return first_chunk_len;
}
let mut cur: usize = NEON_CHUNK_SIZE;
while cur + 2 * NEON_CHUNK_SIZE <= len {
let c0 = unsafe { vld1q_u8(ptr.add(cur)) };
let c1 = unsafe { vld1q_u8(ptr.add(cur + NEON_CHUNK_SIZE)) };
let m0 = needles.eq_any_mask_neon(c0);
let m1 = needles.eq_any_mask_neon(c1);
let miss_bits = !nibble_mask(unsafe { vandq_u8(m0, m1) });
if miss_bits != 0 {
let mb0 = !nibble_mask(m0);
if mb0 != 0 {
return cur + (mb0.trailing_zeros() / 4) as usize;
}
let mb1 = !nibble_mask(m1);
return cur + NEON_CHUNK_SIZE + (mb1.trailing_zeros() / 4) as usize;
}
cur += 2 * NEON_CHUNK_SIZE;
}
while cur + NEON_CHUNK_SIZE <= len {
let chunk = unsafe { vld1q_u8(ptr.add(cur)) };
let cmp = needles.eq_any_mask_neon(chunk);
let miss_bits = !nibble_mask(cmp);
if miss_bits != 0 {
return cur + (miss_bits.trailing_zeros() / 4) as usize;
}
cur += NEON_CHUNK_SIZE;
}
if cur == len {
return len;
}
let overlap_start = len - NEON_CHUNK_SIZE;
let chunk = unsafe { vld1q_u8(ptr.add(overlap_start)) };
let cmp = needles.eq_any_mask_neon(chunk);
let already_scanned_lanes = cur - overlap_start;
let lane_mask = (!0u64) << (already_scanned_lanes * 4);
let miss_bits = !nibble_mask(cmp) & lane_mask;
if miss_bits != 0 {
overlap_start + (miss_bits.trailing_zeros() / 4) as usize
} else {
len
}
}
#[cfg(test)]
#[cfg(target_feature = "neon")]
mod tests {
use super::*;
#[test]
fn skip_binary_short_input_defensive() {
assert_eq!(skip_binary(b""), 0);
assert_eq!(skip_binary(b"010"), 3);
assert_eq!(skip_binary(b"012"), 2);
}
#[test]
fn skip_until_short_input_defensive() {
let hit = skip_until(b"aaa", [b'a', b'b']);
assert_eq!(hit, Some(0));
let miss = skip_until(b"zzz", [b'a', b'b']);
assert_eq!(miss, None);
}
#[test]
fn skip_while_short_input_defensive() {
let r = skip_while(b"aabz", [b'a', b'b']);
assert_eq!(r, 3);
let r = skip_while(b"zzz", [b'a', b'b']);
assert_eq!(r, 0);
}
}