#[inline]
pub fn node16_find_byte(keys: &[u8; 16], count: u8, byte: u8) -> Option<u8> {
#[cfg(target_arch = "x86_64")]
unsafe {
x86::find_byte_in_16(keys.as_ptr(), count, byte)
}
#[cfg(target_arch = "aarch64")]
unsafe {
arm::find_byte_in_16(keys.as_ptr(), count, byte)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
node16_find_byte_scalar(keys, count, byte)
}
}
#[cfg(any(test, not(any(target_arch = "x86_64", target_arch = "aarch64"))))]
#[inline]
pub(crate) fn node16_find_byte_scalar(keys: &[u8; 16], count: u8, byte: u8) -> Option<u8> {
let n = (count as usize).min(16);
let mut i = 0;
while i < n {
if keys[i] == byte {
return Some(i as u8);
}
i += 1;
}
None
}
#[inline]
pub fn longest_common_prefix(a: &[u8], b: &[u8]) -> usize {
let limit = a.len().min(b.len());
let mut i = 0;
#[cfg(target_arch = "x86_64")]
while i + 16 <= limit {
let mask = unsafe { x86::cmp_16_bytes_bitmask(a[i..].as_ptr(), b[i..].as_ptr()) };
if mask != 0xFFFF {
return i + mask.trailing_ones() as usize;
}
i += 16;
}
#[cfg(target_arch = "aarch64")]
while i + 16 <= limit {
let mask = unsafe { arm::cmp_16_bytes_nibble(a[i..].as_ptr(), b[i..].as_ptr()) };
if mask != u64::MAX {
return i + (mask.trailing_ones() / 4) as usize;
}
i += 16;
}
while i < limit && a[i] == b[i] {
i += 1;
}
i
}
#[cfg(target_arch = "x86_64")]
mod x86 {
use std::arch::x86_64::{
__m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
};
#[inline]
pub(super) unsafe fn cmp_16_bytes_bitmask(a: *const u8, b: *const u8) -> u32 {
let va = unsafe { _mm_loadu_si128(a.cast::<__m128i>()) };
let vb = unsafe { _mm_loadu_si128(b.cast::<__m128i>()) };
let cmp = _mm_cmpeq_epi8(va, vb);
_mm_movemask_epi8(cmp) as u32
}
#[inline]
pub(super) unsafe fn find_byte_in_16(keys: *const u8, count: u8, byte: u8) -> Option<u8> {
let vec = unsafe { _mm_loadu_si128(keys.cast::<__m128i>()) };
let needle = _mm_set1_epi8(byte as i8);
let cmp = _mm_cmpeq_epi8(vec, needle);
let mask = _mm_movemask_epi8(cmp) as u32;
let count_mask = if count >= 16 {
0xFFFF
} else {
(1u32 << count) - 1
};
let masked = mask & count_mask;
if masked == 0 {
None
} else {
Some(masked.trailing_zeros() as u8)
}
}
}
#[cfg(target_arch = "aarch64")]
mod arm {
use std::arch::aarch64::{
uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vreinterpret_u64_u8,
vreinterpretq_u16_u8, vshrn_n_u16,
};
#[inline]
unsafe fn byte_mask_to_nibble_u64(cmp: uint8x16_t) -> u64 {
let narrow = vshrn_n_u16::<4>(vreinterpretq_u16_u8(cmp));
vget_lane_u64::<0>(vreinterpret_u64_u8(narrow))
}
#[inline]
pub(super) unsafe fn cmp_16_bytes_nibble(a: *const u8, b: *const u8) -> u64 {
let va = unsafe { vld1q_u8(a) };
let vb = unsafe { vld1q_u8(b) };
let cmp = vceqq_u8(va, vb);
unsafe { byte_mask_to_nibble_u64(cmp) }
}
#[inline]
pub(super) unsafe fn find_byte_in_16(keys: *const u8, count: u8, byte: u8) -> Option<u8> {
let vec = unsafe { vld1q_u8(keys) };
let needle = vdupq_n_u8(byte);
let cmp = vceqq_u8(vec, needle);
let mask64 = unsafe { byte_mask_to_nibble_u64(cmp) };
let count_bits = (count.min(16) as u32) * 4;
let count_mask = if count_bits == 64 {
u64::MAX
} else {
(1u64 << count_bits) - 1
};
let masked = mask64 & count_mask;
if masked == 0 {
None
} else {
Some((masked.trailing_zeros() / 4) as u8)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn find_byte_at_index_zero() {
let mut keys = [0u8; 16];
keys[0] = 0x42;
assert_eq!(node16_find_byte(&keys, 1, 0x42), Some(0));
}
#[test]
fn find_byte_at_last_valid_index() {
let mut keys = [0u8; 16];
keys[15] = 0xAB;
assert_eq!(node16_find_byte(&keys, 16, 0xAB), Some(15));
}
#[test]
fn find_byte_middle() {
let mut keys = [0u8; 16];
for (i, slot) in keys.iter_mut().enumerate().take(10) {
*slot = b'a' + i as u8;
}
assert_eq!(node16_find_byte(&keys, 10, b'f'), Some(5));
}
#[test]
fn find_byte_absent_returns_none() {
let mut keys = [0u8; 16];
for (i, slot) in keys.iter_mut().enumerate().take(8) {
*slot = b'a' + i as u8;
}
assert_eq!(node16_find_byte(&keys, 8, b'z'), None);
}
#[test]
fn find_byte_count_zero_returns_none() {
let keys = [0xAB; 16];
assert_eq!(node16_find_byte(&keys, 0, 0xAB), None);
}
#[test]
fn find_byte_ignores_unused_tail() {
let mut keys = [0u8; 16];
keys[10] = 0x77;
assert_eq!(node16_find_byte(&keys, 4, 0x77), None);
}
#[test]
fn find_byte_first_of_duplicates() {
let mut keys = [0u8; 16];
keys[3] = 0x55;
keys[7] = 0x55;
assert_eq!(node16_find_byte(&keys, 16, 0x55), Some(3));
}
#[test]
fn find_byte_matches_scalar_random() {
use std::collections::HashSet;
let mut state: u64 = 0xDEAD_BEEF_CAFE_BABE;
let next = |s: &mut u64| -> u8 {
*s = s
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(*s >> 33) as u8
};
for _ in 0..1000 {
let count = next(&mut state) % 17; let mut keys = [0u8; 16];
let mut used = HashSet::new();
for k in keys.iter_mut().take(count as usize) {
loop {
let b = next(&mut state);
if used.insert(b) {
*k = b;
break;
}
}
}
let query = next(&mut state);
let got = node16_find_byte(&keys, count, query);
let expected = node16_find_byte_scalar(&keys, count, query);
assert_eq!(
got, expected,
"mismatch on keys={keys:?} count={count} q={query}"
);
}
}
#[test]
fn lcp_empty_inputs() {
assert_eq!(longest_common_prefix(b"", b""), 0);
assert_eq!(longest_common_prefix(b"abc", b""), 0);
assert_eq!(longest_common_prefix(b"", b"abc"), 0);
}
#[test]
fn lcp_identical() {
assert_eq!(longest_common_prefix(b"hello", b"hello"), 5);
}
#[test]
fn lcp_strict_prefix() {
assert_eq!(longest_common_prefix(b"abc", b"abcdef"), 3);
assert_eq!(longest_common_prefix(b"abcdef", b"abc"), 3);
}
#[test]
fn lcp_no_common() {
assert_eq!(longest_common_prefix(b"abc", b"xyz"), 0);
}
#[test]
fn lcp_divergence_at_boundary() {
let a = b"0123456789ABCDEFhello"; let b = b"0123456789ABCDEFworld"; assert_eq!(longest_common_prefix(a, b), 16);
}
#[test]
fn lcp_long_match_then_diverge_in_chunk() {
let a = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa01"; let b = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa99"; assert_eq!(longest_common_prefix(a, b), 35);
}
#[test]
fn lcp_match_then_diverge_at_byte_15() {
let a = b"aaaaaaaaaaaaaaaXrest";
let b = b"aaaaaaaaaaaaaaaYrest";
assert_eq!(longest_common_prefix(a, b), 15);
}
#[test]
fn lcp_match_then_diverge_at_byte_16() {
let a = b"aaaaaaaaaaaaaaaaXrest";
let b = b"aaaaaaaaaaaaaaaaYrest";
assert_eq!(longest_common_prefix(a, b), 16);
}
}