const BYTE_FREQUENCIES: [u8; 256] = [
55, 52, 51, 50, 49, 48, 47, 46, 45, 103, 242, 66, 67, 229, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 56, 32, 31, 30, 29, 28, 255, 148, 164, 149, 136, 160, 155, 173, 221, 222, 134, 122, 232, 202, 215, 224, 208, 220, 204, 187, 183, 179, 177, 168, 178, 200, 226, 195, 154, 184, 174, 126, 120, 191, 157, 194, 170, 189, 162, 161, 150, 193, 142, 137, 171, 176, 185,
167, 186, 112, 175, 192, 188, 156, 140, 143, 123, 133, 128, 147, 138, 146, 114,
223, 151, 249, 216, 238, 236, 253, 227, 218, 230, 247, 135, 180, 241, 233, 246,
244, 231, 139, 245, 243, 251, 235, 201, 196, 240, 214, 152, 182, 205, 181, 127,
27, 212, 211, 210, 213, 228, 197, 169, 159, 131, 172, 105, 80, 98, 96, 97, 81, 207, 145, 116, 115, 144, 130, 153, 121, 107, 132, 109, 110, 124, 111, 82, 108, 118, 141, 113, 129, 119, 125, 165, 117, 92, 106, 83, 72, 99, 93, 65, 79, 166, 237, 163, 199, 190, 225, 209, 203, 198, 217, 219, 206, 234, 248, 158, 239, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ];
#[inline]
fn ascii_fold_byte(b: u8) -> u8 {
if b.is_ascii_uppercase() { b | 0x20 } else { b }
}
#[inline]
fn ascii_swap_case(b: u8) -> u8 {
b ^ 0x20
}
#[inline]
fn case_insensitive_rank(lower: u8) -> u8 {
if lower.is_ascii_lowercase() {
let upper = ascii_swap_case(lower);
BYTE_FREQUENCIES[lower as usize].max(BYTE_FREQUENCIES[upper as usize])
} else {
BYTE_FREQUENCIES[lower as usize]
}
}
fn select_rare_pair(needle_lower: &[u8]) -> (usize, usize) {
debug_assert!(needle_lower.len() >= 2);
let mut best1 = (u8::MAX, 0usize); let mut best2 = (u8::MAX, 1usize);
for (i, &b) in needle_lower.iter().enumerate() {
let r = case_insensitive_rank(b);
if r < best1.0 {
best2 = best1;
best1 = (r, i);
} else if r < best2.0 && i != best1.1 {
best2 = (r, i);
}
}
let i1 = best1.1.min(best2.1);
let i2 = best1.1.max(best2.1);
(i1, i2)
}
#[inline]
fn verify_scalar(h: *const u8, needle_lower: &[u8]) -> bool {
for (i, _) in needle_lower.iter().enumerate() {
if ascii_fold_byte(unsafe { *h.add(i) }) != needle_lower[i] {
return false;
}
}
true
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn verify_avx2(h: *const u8, needle_lower: &[u8]) -> bool {
use core::arch::x86_64::*;
let len = needle_lower.len();
let mut i = 0usize;
let flip = _mm256_set1_epi8(0x80u8 as i8);
let a_minus_1 = _mm256_set1_epi8((b'A' - 1) as i8 ^ 0x80u8 as i8);
let z_plus_1 = _mm256_set1_epi8((b'Z' + 1) as i8 ^ 0x80u8 as i8);
let bit20 = _mm256_set1_epi8(0x20u8 as i8);
while i + 32 <= len {
let hv = unsafe { _mm256_loadu_si256(h.add(i) as *const __m256i) };
let nv = unsafe { _mm256_loadu_si256(needle_lower.as_ptr().add(i) as *const __m256i) };
let x = _mm256_xor_si256(hv, flip);
let ge_a = _mm256_cmpgt_epi8(x, a_minus_1);
let le_z = _mm256_cmpgt_epi8(z_plus_1, x);
let upper = _mm256_and_si256(ge_a, le_z);
let folded = _mm256_or_si256(hv, _mm256_and_si256(upper, bit20));
let eq = _mm256_cmpeq_epi8(folded, nv);
if _mm256_movemask_epi8(eq) != -1i32 {
return false;
}
i += 32;
}
while i < len {
if ascii_fold_byte(unsafe { *h.add(i) }) != needle_lower[i] {
return false;
}
i += 1;
}
true
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn neon_movemask(v: core::arch::aarch64::uint8x16_t) -> u16 {
use core::arch::aarch64::*;
static BITS: [u8; 16] = [1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128];
let bit_mask = unsafe { vld1q_u8(BITS.as_ptr()) };
let masked = vandq_u8(v, bit_mask);
let lo = vaddv_u8(vget_low_u8(masked));
let hi = vaddv_u8(vget_high_u8(masked));
(lo as u16) | ((hi as u16) << 8)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon,dotprod")]
unsafe fn verify_neon_dotprod(h: *const u8, needle_lower: &[u8]) -> bool {
use core::arch::aarch64::*;
let len = needle_lower.len();
let mut i = 0usize;
let a_val = vdupq_n_u8(b'A');
let z_val = vdupq_n_u8(b'Z');
let bit20 = vdupq_n_u8(0x20);
while i + 16 <= len {
let hv = unsafe { vld1q_u8(h.add(i)) };
let nv = unsafe { vld1q_u8(needle_lower.as_ptr().add(i)) };
let upper = vandq_u8(vcgeq_u8(hv, a_val), vcleq_u8(hv, z_val));
let folded = vorrq_u8(hv, vandq_u8(upper, bit20));
let xored = veorq_u8(folded, nv);
let dots: uint32x4_t;
let zero = vdupq_n_u32(0);
unsafe {
core::arch::asm!(
"udot {d:v}.4s, {a:v}.16b, {b:v}.16b",
d = inlateout(vreg) zero => dots,
a = in(vreg) xored,
b = in(vreg) xored,
);
}
if vmaxvq_u32(dots) != 0 {
return false;
}
i += 16;
}
while i < len {
if ascii_fold_byte(unsafe { *h.add(i) }) != needle_lower[i] {
return false;
}
i += 1;
}
true
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn search_packed_pair_neon(
haystack: &[u8],
needle_lower: &[u8],
i1: usize,
i2: usize,
) -> bool {
use core::arch::aarch64::*;
let n = needle_lower.len();
let hlen = haystack.len();
let ptr = haystack.as_ptr();
let last_start = hlen - n;
let b1 = needle_lower[i1];
let b1_alt = if b1.is_ascii_lowercase() {
ascii_swap_case(b1)
} else {
b1
};
let b2 = needle_lower[i2];
let b2_alt = if b2.is_ascii_lowercase() {
ascii_swap_case(b2)
} else {
b2
};
let v1_lo = vdupq_n_u8(b1);
let v1_hi = vdupq_n_u8(b1_alt);
let v2_lo = vdupq_n_u8(b2);
let v2_hi = vdupq_n_u8(b2_alt);
let max_idx = i1.max(i2);
let max_offset = hlen.saturating_sub(max_idx + 16);
let mut offset = 0usize;
while offset <= max_offset {
let chunk1 = unsafe { vld1q_u8(ptr.add(offset + i1)) };
let chunk2 = unsafe { vld1q_u8(ptr.add(offset + i2)) };
let eq1 = vorrq_u8(vceqq_u8(chunk1, v1_lo), vceqq_u8(chunk1, v1_hi));
let eq2 = vorrq_u8(vceqq_u8(chunk2, v2_lo), vceqq_u8(chunk2, v2_hi));
let mut mask = unsafe { neon_movemask(vandq_u8(eq1, eq2)) };
while mask != 0 {
let bit = mask.trailing_zeros() as usize;
let candidate = offset + bit;
if candidate > last_start {
return false;
}
if unsafe { verify_dispatch(ptr.add(candidate), needle_lower) } {
return true;
}
mask &= mask - 1;
}
offset += 16;
}
if offset <= last_start {
let rare_pos =
if case_insensitive_rank(needle_lower[i1]) <= case_insensitive_rank(needle_lower[i2]) {
i1
} else {
i2
};
let rare_byte = needle_lower[rare_pos];
let tail_start = offset + rare_pos;
let tail_end = last_start + rare_pos + 1;
if tail_start < tail_end {
let tail_space = &haystack[tail_start..tail_end];
if rare_byte.is_ascii_lowercase() {
for pos in memchr::memchr2_iter(rare_byte, ascii_swap_case(rare_byte), tail_space) {
let candidate = offset + pos;
if unsafe { verify_dispatch(ptr.add(candidate), needle_lower) } {
return true;
}
}
} else {
for pos in memchr::memchr_iter(rare_byte, tail_space) {
let candidate = offset + pos;
if unsafe { verify_dispatch(ptr.add(candidate), needle_lower) } {
return true;
}
}
}
}
}
false
}
#[inline]
unsafe fn verify_dispatch(h: *const u8, needle_lower: &[u8]) -> bool {
#[cfg(target_arch = "x86_64")]
{
if needle_lower.len() >= 32 && std::is_x86_feature_detected!("avx2") {
return unsafe { verify_avx2(h, needle_lower) };
}
}
#[cfg(target_arch = "aarch64")]
{
if needle_lower.len() >= 16 && std::arch::is_aarch64_feature_detected!("dotprod") {
return unsafe { verify_neon_dotprod(h, needle_lower) };
}
}
verify_scalar(h, needle_lower)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn search_packed_pair_avx2(
haystack: &[u8],
needle_lower: &[u8],
i1: usize,
i2: usize,
) -> bool {
use core::arch::x86_64::*;
let n = needle_lower.len();
let hlen = haystack.len();
let ptr = haystack.as_ptr();
let last_start = hlen - n;
let b1 = needle_lower[i1];
let b1_alt = if b1.is_ascii_lowercase() {
ascii_swap_case(b1)
} else {
b1
};
let b2 = needle_lower[i2];
let b2_alt = if b2.is_ascii_lowercase() {
ascii_swap_case(b2)
} else {
b2
};
let v1_lo = _mm256_set1_epi8(b1 as i8);
let v1_hi = _mm256_set1_epi8(b1_alt as i8);
let v2_lo = _mm256_set1_epi8(b2 as i8);
let v2_hi = _mm256_set1_epi8(b2_alt as i8);
let max_idx = i1.max(i2);
let max_offset = hlen.saturating_sub(max_idx + 32);
let mut offset = 0usize;
while offset <= max_offset {
let chunk1 = unsafe { _mm256_loadu_si256(ptr.add(offset + i1) as *const __m256i) };
let chunk2 = unsafe { _mm256_loadu_si256(ptr.add(offset + i2) as *const __m256i) };
let eq1 = _mm256_or_si256(
_mm256_cmpeq_epi8(chunk1, v1_lo),
_mm256_cmpeq_epi8(chunk1, v1_hi),
);
let eq2 = _mm256_or_si256(
_mm256_cmpeq_epi8(chunk2, v2_lo),
_mm256_cmpeq_epi8(chunk2, v2_hi),
);
let mut mask = _mm256_movemask_epi8(_mm256_and_si256(eq1, eq2)) as u32;
while mask != 0 {
let bit = mask.trailing_zeros() as usize;
let candidate = offset + bit;
if candidate > last_start {
return false;
}
if unsafe { verify_dispatch(ptr.add(candidate), needle_lower) } {
return true;
}
mask &= mask - 1;
}
offset += 32;
}
if offset <= last_start {
let rare_pos =
if case_insensitive_rank(needle_lower[i1]) <= case_insensitive_rank(needle_lower[i2]) {
i1
} else {
i2
};
let rare_byte = needle_lower[rare_pos];
let tail_start = offset + rare_pos;
let tail_end = last_start + rare_pos + 1;
if tail_start < tail_end {
let tail_space = &haystack[tail_start..tail_end];
if rare_byte.is_ascii_lowercase() {
for pos in memchr::memchr2_iter(rare_byte, ascii_swap_case(rare_byte), tail_space) {
let candidate = offset + pos;
if unsafe { verify_dispatch(ptr.add(candidate), needle_lower) } {
return true;
}
}
} else {
for pos in memchr::memchr_iter(rare_byte, tail_space) {
let candidate = offset + pos;
if unsafe { verify_dispatch(ptr.add(candidate), needle_lower) } {
return true;
}
}
}
}
}
false
}
pub fn search_packed_pair(haystack: &[u8], needle_lower: &[u8]) -> bool {
let n = needle_lower.len();
if n == 0 {
return true;
}
if n < 2 {
return search(haystack, needle_lower);
}
if n > haystack.len() {
return false;
}
let (i1, i2) = select_rare_pair(needle_lower);
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx2") {
let max_idx = i1.max(i2);
if haystack.len() >= max_idx + 32 {
return unsafe { search_packed_pair_avx2(haystack, needle_lower, i1, i2) };
}
}
}
#[cfg(target_arch = "aarch64")]
{
let first_byte_rank = case_insensitive_rank(needle_lower[0]);
let max_idx = i1.max(i2);
if first_byte_rank >= 200 && haystack.len() >= max_idx + 16 {
return unsafe { search_packed_pair_neon(haystack, needle_lower, i1, i2) };
}
}
search(haystack, needle_lower)
}
pub fn search(haystack: &[u8], needle_lower: &[u8]) -> bool {
let n = needle_lower.len();
if n == 0 {
return true;
}
if n > haystack.len() {
return false;
}
let search_space = &haystack[..=haystack.len() - n];
let first = needle_lower[0];
if first.is_ascii_lowercase() {
let alt = ascii_swap_case(first);
for pos in memchr::memchr2_iter(first, alt, search_space) {
if unsafe { verify_dispatch(haystack.as_ptr().add(pos), needle_lower) } {
return true;
}
}
} else {
for pos in memchr::memchr_iter(first, search_space) {
if unsafe { verify_dispatch(haystack.as_ptr().add(pos), needle_lower) } {
return true;
}
}
}
false
}
pub fn search_scalar(haystack: &[u8], needle_lower: &[u8]) -> bool {
let n = needle_lower.len();
if n == 0 {
return true;
}
if n > haystack.len() {
return false;
}
let search_space = &haystack[..=haystack.len() - n];
let first = needle_lower[0];
if first.is_ascii_lowercase() {
let alt = ascii_swap_case(first);
for pos in memchr::memchr2_iter(first, alt, search_space) {
if unsafe { verify_scalar(haystack.as_ptr().add(pos), needle_lower) } {
return true;
}
}
} else {
for pos in memchr::memchr_iter(first, search_space) {
if unsafe { verify_scalar(haystack.as_ptr().add(pos), needle_lower) } {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_case_insensitive() {
assert!(search_packed_pair(b"Hello World", b"hello"));
assert!(search_packed_pair(b"Hello World", b"world"));
assert!(search_packed_pair(b"NOMORE bugs", b"nomore"));
assert!(!search_packed_pair(b"Hello World", b"xyz"));
}
#[test]
fn edge_cases() {
assert!(search_packed_pair(b"ab", b"ab"));
assert!(search_packed_pair(b"AB", b"ab"));
assert!(!search_packed_pair(b"a", b"ab"));
assert!(search_packed_pair(b"anything", b""));
assert!(!search_packed_pair(b"", b"x"));
}
#[test]
fn packed_pair_matches_search() {
let haystacks: &[&[u8]] = &[
b"The quick brown fox jumps over the lazy dog",
b"int mutex_lock(struct mutex *lock) { return 0; }",
b"#define NOMORE_RETRIES 5\nif (nomore) return;",
b"abcdefghijklmnopqrstuvwxyz",
b"short",
];
let needles: &[&[u8]] = &[b"fox", b"mutex", b"nomore", b"xyz", b"the", b"short", b"qr"];
for h in haystacks {
for n in needles {
let lower: Vec<u8> = n.iter().map(|b| b.to_ascii_lowercase()).collect();
assert_eq!(
search_packed_pair(h, &lower),
search(h, &lower),
"mismatch for haystack={:?} needle={:?}",
std::str::from_utf8(h),
std::str::from_utf8(n),
);
}
}
}
#[test]
fn long_haystack_neon_path() {
let haystack =
b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaTHIS_IS_A_LONG_NEEDLE_TESTbbbbbbbbbbbbbbbbbb";
assert!(search_packed_pair(haystack, b"this_is_a_long_needle_test"));
assert!(!search_packed_pair(
haystack,
b"this_is_a_long_needle_testz"
));
let long_needle = b"struct mutex *lock";
let haystack2 = b"int STRUCT MUTEX *LOCK(struct mutex *lock) { return 0; }";
assert!(search_packed_pair(haystack2, long_needle));
let upper_hay = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
assert!(search_packed_pair(upper_hay, b"qrstuvwxyz0123456789a"));
assert!(!search_packed_pair(upper_hay, b"qrstuvwxyz01234567899"));
let end_hay = b"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxfind_me";
assert!(search_packed_pair(end_hay, b"find_me"));
assert!(search_packed_pair(end_hay, b"xx"));
let mut big = vec![b'z'; 1024];
big[1000..1010].copy_from_slice(b"hElLo_WoRl");
assert!(search_packed_pair(&big, b"hello_wo"));
assert!(!search_packed_pair(&big, b"hello_world"));
}
#[test]
fn rare_pair_selection() {
let (i1, i2) = select_rare_pair(b"nomore");
let ranks: Vec<u8> = b"nomore"
.iter()
.map(|&b| case_insensitive_rank(b))
.collect();
let r1 = ranks[i1];
let r2 = ranks[i2];
for (i, &r) in ranks.iter().enumerate() {
if i != i1 && i != i2 {
assert!(r1 <= r || r2 <= r, "pair ({i1},{i2}) not optimal");
}
}
}
}