#![cfg(all(target_arch = "aarch64", target_endian = "little"))]
#![allow(dead_code)]
use core::arch::aarch64::{
__crc32d, uint8x16_t, vceqq_u8, vgetq_lane_u64, vld1q_u8, vreinterpretq_u64_u8,
};
use super::scalar;
pub(crate) const KERNEL_TAG: &str = "neon";
#[target_feature(enable = "crc")]
#[inline]
pub(crate) unsafe fn hash_mix_u64(value: u64) -> u64 {
let crc = __crc32d(0, value) as u64;
((crc << 32) ^ value.rotate_left(13)).wrapping_mul(scalar::HASH_MIX_PRIME)
}
#[target_feature(enable = "neon")]
#[inline]
pub(crate) unsafe fn prefix_len_simd(lhs: *const u8, rhs: *const u8, max: usize) -> usize {
let mut off = 0usize;
while off + 16 <= max {
let a: uint8x16_t = unsafe { vld1q_u8(lhs.add(off)) };
let b: uint8x16_t = unsafe { vld1q_u8(rhs.add(off)) };
let eq = vceqq_u8(a, b);
let lanes = vreinterpretq_u64_u8(eq);
let low = vgetq_lane_u64(lanes, 0);
if low != u64::MAX {
let diff = low ^ u64::MAX;
return off + scalar::mismatch_byte_index(diff as usize);
}
let high = vgetq_lane_u64(lanes, 1);
if high != u64::MAX {
let diff = high ^ u64::MAX;
return off + 8 + scalar::mismatch_byte_index(diff as usize);
}
off += 16;
}
off
}
#[target_feature(enable = "neon")]
#[inline]
pub(crate) unsafe fn common_prefix_len_ptr(lhs: *const u8, rhs: *const u8, max: usize) -> usize {
let off = unsafe { prefix_len_simd(lhs, rhs, max) };
unsafe { scalar::common_prefix_len_scalar_ptr(lhs, rhs, off, max) }
}
#[target_feature(enable = "neon")]
#[inline]
pub(crate) unsafe fn count_match_from_indices(
concat: &[u8],
current_idx: usize,
candidate_idx: usize,
tail_limit: usize,
seed_len: usize,
) -> usize {
let seed = seed_len.min(tail_limit);
if seed == tail_limit {
return seed;
}
let remaining = tail_limit - seed;
let base = concat.as_ptr();
let lhs = unsafe { base.add(candidate_idx + seed) };
let rhs = unsafe { base.add(current_idx + seed) };
let extra = unsafe { common_prefix_len_ptr(lhs, rhs, remaining) };
seed + extra
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;
#[test]
fn neon_prefix_len_matches_scalar_on_long_run() {
let a = b"abcdefghijklmnopqrstuvwxyz0123456789-+=*";
let mut b: Vec<u8> = a.to_vec();
b[25] = b'!';
let max = a.len();
let neon = unsafe { common_prefix_len_ptr(a.as_ptr(), b.as_ptr(), max) };
let scl = unsafe { scalar::common_prefix_len_ptr(a.as_ptr(), b.as_ptr(), max) };
assert_eq!(neon, scl);
assert_eq!(neon, 25);
}
#[test]
fn neon_handles_short_input() {
let a = b"abc";
let b = b"abc";
let max = a.len();
assert_eq!(
unsafe { common_prefix_len_ptr(a.as_ptr(), b.as_ptr(), max) },
3
);
}
}