#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
mod avx512;
mod build;
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
mod neon;
#[cfg(test)]
mod tests;
const ASCII_BYTES: usize = 128;
const N_BUCKETS: usize = 8;
const MAX_SCAN_LEN: usize = 8;
const MASK_ROWS: usize = 64;
pub const HARRY_MIN_PATTERN_COUNT: usize = 64;
#[derive(Clone)]
struct BucketLiteral {
bytes: Box<[u8]>,
value: u32,
}
#[derive(Clone, Default)]
struct PrefixGroup {
exact_values: Vec<u32>,
long_literals: Vec<BucketLiteral>,
}
#[derive(Clone, Default)]
struct PrefixMap {
keys: Box<[u64]>,
values: Box<[PrefixGroup]>,
}
impl BucketLiteral {
fn heap_bytes(&self) -> usize {
self.bytes.len()
}
}
impl PrefixGroup {
fn heap_bytes(&self) -> usize {
self.exact_values.capacity() * size_of::<u32>()
+ self.long_literals.capacity() * size_of::<BucketLiteral>()
+ self
.long_literals
.iter()
.map(|l| l.heap_bytes())
.sum::<usize>()
}
}
impl PrefixMap {
fn from_unsorted(iter: impl Iterator<Item = (u64, PrefixGroup)>) -> Self {
let mut pairs: Vec<(u64, PrefixGroup)> = iter.collect();
if pairs.is_empty() {
return Self::default();
}
pairs.sort_unstable_by_key(|(k, _)| *k);
let (keys, values): (Vec<_>, Vec<_>) = pairs.into_iter().unzip();
Self {
keys: keys.into_boxed_slice(),
values: values.into_boxed_slice(),
}
}
fn heap_bytes(&self) -> usize {
self.keys.len() * size_of::<u64>()
+ self.values.len() * size_of::<PrefixGroup>()
+ self.values.iter().map(|g| g.heap_bytes()).sum::<usize>()
}
#[inline(always)]
fn get(&self, key: u64) -> Option<&PrefixGroup> {
self.keys
.binary_search(&key)
.ok()
.map(|idx| &self.values[idx])
}
}
#[derive(Clone, Default)]
struct BucketVerify {
length_mask: u8,
groups: [PrefixMap; MAX_SCAN_LEN - 1],
}
impl BucketVerify {
fn heap_bytes(&self) -> usize {
self.groups.iter().map(|m| m.heap_bytes()).sum()
}
}
#[derive(Clone)]
pub struct HarryMatcher {
single_byte_values: Box<[Vec<u32>; ASCII_BYTES]>,
single_byte_keys: Box<[u8]>,
single_byte_match_mask: [u64; 2],
has_single_byte: bool,
low_mask: Box<[[u8; MASK_ROWS]; MAX_SCAN_LEN]>,
high_mask: Box<[[u8; MASK_ROWS]; MAX_SCAN_LEN]>,
bucket_verify: [BucketVerify; N_BUCKETS],
all_patterns_ascii: bool,
max_prefix_len: usize,
}
impl HarryMatcher {
pub fn heap_bytes(&self) -> usize {
let sbv = ASCII_BYTES * size_of::<Vec<u32>>()
+ self
.single_byte_values
.iter()
.map(|v| v.capacity() * size_of::<u32>())
.sum::<usize>();
let sbk = self.single_byte_keys.len();
let masks = 2 * MAX_SCAN_LEN * MASK_ROWS; let buckets: usize = self.bucket_verify.iter().map(|b| b.heap_bytes()).sum();
sbv + sbk + masks + buckets
}
#[inline(always)]
pub fn is_match(&self, text: &str) -> bool {
let haystack = text.as_bytes();
if haystack.is_empty() {
return false;
}
self.is_match_bytes(haystack)
}
#[inline(always)]
pub fn for_each_match_value(&self, text: &str, mut on_value: impl FnMut(u32) -> bool) -> bool {
let haystack = text.as_bytes();
if haystack.is_empty() {
return false;
}
if self.has_single_byte && self.scan_single_byte_literals(haystack, &mut on_value) {
return true;
}
self.scan_multi_dispatch(haystack, &mut on_value)
}
#[inline(always)]
fn is_match_bytes(&self, haystack: &[u8]) -> bool {
#[cfg(any(
all(feature = "simd_runtime_dispatch", target_arch = "aarch64"),
all(feature = "simd_runtime_dispatch", target_arch = "x86_64")
))]
{
if self.has_single_byte && haystack.len() == 1 {
return self.single_byte_contains(haystack[0]);
}
if self.all_patterns_ascii {
if haystack[0] < 0x80 {
if self.has_single_byte && self.scan_single_byte_any_ascii_haystack(haystack) {
return true;
}
return self.scan_multi_dispatch_any_ascii_lead(haystack);
}
return self.scan_multi_dispatch_any(haystack);
}
if haystack[0] >= 0x80 && self.has_single_byte {
return self.scan_multi_dispatch_any(haystack);
}
}
self.for_each_match_value(unsafe { std::str::from_utf8_unchecked(haystack) }, |_| true)
}
#[inline(always)]
fn scan_single_byte_literals(
&self,
haystack: &[u8],
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
if self.all_patterns_ascii && haystack[0] >= 0x80 {
return self.scan_single_byte_literals_ascii(haystack, on_value);
}
for &byte in haystack {
if byte < 128 {
for &value in &self.single_byte_values[byte as usize] {
if on_value(value) {
return true;
}
}
}
}
false
}
#[inline(always)]
fn single_byte_contains(&self, byte: u8) -> bool {
if byte >= ASCII_BYTES as u8 {
return false;
}
let word = (byte >> 6) as usize;
let bit = byte & 0x3F;
(self.single_byte_match_mask[word] >> bit) & 1 != 0
}
#[inline(always)]
fn scan_single_byte_any_ascii_haystack(&self, haystack: &[u8]) -> bool {
debug_assert!(self.all_patterns_ascii);
if self.single_byte_keys.is_empty() {
return false;
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
if self.single_byte_keys.len() <= 4 && haystack.len() >= 16 {
return unsafe { self.scan_single_byte_any_ascii_haystack_neon(haystack) };
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
if self.single_byte_keys.len() <= 4
&& haystack.len() >= 64
&& is_x86_feature_detected!("avx512vbmi")
{
return unsafe { self.scan_single_byte_any_ascii_haystack_avx512(haystack) };
}
haystack
.iter()
.copied()
.any(|byte| self.single_byte_contains(byte))
}
#[inline(always)]
fn scan_single_byte_literals_ascii(
&self,
haystack: &[u8],
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
{
unsafe { self.scan_single_byte_literals_ascii_neon(haystack, on_value) }
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512vbmi") {
return unsafe { self.scan_single_byte_literals_ascii_avx512(haystack, on_value) };
}
self.scan_single_byte_literals_ascii_scalar(haystack, on_value)
}
#[cfg(not(any(
all(feature = "simd_runtime_dispatch", target_arch = "aarch64"),
all(feature = "simd_runtime_dispatch", target_arch = "x86_64")
)))]
self.scan_single_byte_literals_ascii_scalar(haystack, on_value)
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
#[inline(always)]
fn scan_single_byte_literals_ascii_scalar(
&self,
haystack: &[u8],
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
for &byte in haystack {
if byte < 128 {
for &value in &self.single_byte_values[byte as usize] {
if on_value(value) {
return true;
}
}
}
}
false
}
#[inline(always)]
fn scan_multi_dispatch(&self, haystack: &[u8], on_value: &mut impl FnMut(u32) -> bool) -> bool {
if haystack.len() < 2 {
return false;
}
if self.all_patterns_ascii && haystack[0] >= 0x80 {
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return unsafe { self.scan_neon_ascii(haystack, on_value) };
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx512vbmi") {
return unsafe { self.scan_avx512vbmi_ascii(haystack, on_value) };
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
return self.scan_scalar_range_ascii(haystack, 0, haystack.len() - 1, on_value);
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return unsafe { self.scan_neon(haystack, on_value) };
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx512vbmi") {
return unsafe { self.scan_avx512vbmi(haystack, on_value) };
}
#[cfg(not(all(feature = "simd_runtime_dispatch", target_arch = "aarch64")))]
return self.scan_scalar_range(haystack, 0, haystack.len() - 1, on_value);
}
#[inline(always)]
fn scan_multi_dispatch_any(&self, haystack: &[u8]) -> bool {
if self.has_single_byte && haystack.len() == 1 {
return self.single_byte_contains(haystack[0]);
}
if haystack.len() < 2 {
return false;
}
if self.all_patterns_ascii && haystack[0] >= 0x80 {
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return unsafe { self.scan_neon_ascii_any(haystack) };
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx512vbmi") {
return unsafe { self.scan_avx512vbmi_ascii_any(haystack) };
}
#[cfg(not(any(
all(feature = "simd_runtime_dispatch", target_arch = "aarch64"),
all(feature = "simd_runtime_dispatch", target_arch = "x86_64")
)))]
return self.scan_scalar_range_any_ascii(haystack, 0, haystack.len() - 1);
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return unsafe { self.scan_neon_any(haystack) };
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx512vbmi") {
return unsafe { self.scan_avx512vbmi_any(haystack) };
}
#[cfg(not(any(
all(feature = "simd_runtime_dispatch", target_arch = "aarch64"),
all(feature = "simd_runtime_dispatch", target_arch = "x86_64")
)))]
return self.scan_scalar_range_any(haystack, 0, haystack.len() - 1);
#[allow(unreachable_code)]
self.scan_scalar_range_any(haystack, 0, haystack.len() - 1)
}
#[inline(always)]
fn scan_multi_dispatch_any_ascii_lead(&self, haystack: &[u8]) -> bool {
debug_assert!(self.all_patterns_ascii);
debug_assert!(!haystack.is_empty() && haystack[0] < 0x80);
if haystack.len() < 2 {
return false;
}
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "aarch64"))]
return unsafe { self.scan_neon_ascii_lead_any(haystack) };
#[cfg(all(feature = "simd_runtime_dispatch", target_arch = "x86_64"))]
if is_x86_feature_detected!("avx512vbmi") {
return unsafe { self.scan_avx512vbmi_ascii_lead_any(haystack) };
}
#[cfg(not(any(
all(feature = "simd_runtime_dispatch", target_arch = "aarch64"),
all(feature = "simd_runtime_dispatch", target_arch = "x86_64")
)))]
return self.scan_scalar_range_any_no_single_byte(haystack, 0, haystack.len() - 1);
#[allow(unreachable_code)]
self.scan_scalar_range_any_no_single_byte(haystack, 0, haystack.len() - 1)
}
#[inline(always)]
fn scan_scalar_range(
&self,
haystack: &[u8],
start: usize,
end: usize,
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
for pos in start..=end {
if !self.all_patterns_ascii && (haystack[pos] & 0xC0) == 0x80 {
continue;
}
let hit_mask = self.match_mask_at(haystack, pos);
if hit_mask != 0 && self.verify_hits(haystack, pos, hit_mask, on_value) {
return true;
}
}
false
}
#[inline(always)]
fn scan_scalar_range_any(&self, haystack: &[u8], start: usize, end: usize) -> bool {
for pos in start..=end {
let byte = haystack[pos];
if self.single_byte_contains(byte) {
return true;
}
if !self.all_patterns_ascii && (byte & 0xC0) == 0x80 {
continue;
}
let hit_mask = self.match_mask_at(haystack, pos);
if hit_mask != 0 && self.verify_hits_any(haystack, pos, hit_mask) {
return true;
}
}
false
}
#[inline(always)]
fn scan_scalar_range_any_no_single_byte(
&self,
haystack: &[u8],
start: usize,
end: usize,
) -> bool {
for pos in start..=end {
let byte = haystack[pos];
if !self.all_patterns_ascii && (byte & 0xC0) == 0x80 {
continue;
}
let hit_mask = self.match_mask_at(haystack, pos);
if hit_mask != 0 && self.verify_hits_any(haystack, pos, hit_mask) {
return true;
}
}
false
}
#[inline(always)]
fn scan_scalar_range_ascii(
&self,
haystack: &[u8],
start: usize,
end: usize,
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
for pos in start..=end {
if haystack[pos] >= 0x80 {
continue;
}
let hit_mask = self.match_mask_at(haystack, pos);
if hit_mask != 0 && self.verify_hits(haystack, pos, hit_mask, on_value) {
return true;
}
}
false
}
#[inline(always)]
fn scan_scalar_range_any_ascii(&self, haystack: &[u8], start: usize, end: usize) -> bool {
for pos in start..=end {
let byte = haystack[pos];
if byte >= 0x80 {
continue;
}
if self.single_byte_contains(byte) {
return true;
}
let hit_mask = self.match_mask_at(haystack, pos);
if hit_mask != 0 && self.verify_hits_any(haystack, pos, hit_mask) {
return true;
}
}
false
}
#[inline(always)]
fn match_mask_at(&self, haystack: &[u8], start: usize) -> u8 {
let available = (haystack.len() - start).min(self.max_prefix_len);
let mut state = 0u8;
for column in 0..available {
let byte = haystack[start + column];
state |= self.low_mask[column][(byte & 0x3F) as usize]
| self.high_mask[column][((byte >> 1) & 0x3F) as usize];
}
!state
}
#[inline(always)]
fn verify_hits(
&self,
haystack: &[u8],
start: usize,
mut hit_mask: u8,
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
while hit_mask != 0 {
let bucket = hit_mask.trailing_zeros() as usize;
hit_mask &= hit_mask - 1;
if self.verify_bucket(haystack, start, bucket, on_value) {
return true;
}
}
false
}
#[inline(always)]
fn verify_hits_any(&self, haystack: &[u8], start: usize, mut hit_mask: u8) -> bool {
while hit_mask != 0 {
let bucket = hit_mask.trailing_zeros() as usize;
hit_mask &= hit_mask - 1;
if self.verify_bucket_any(haystack, start, bucket) {
return true;
}
}
false
}
#[inline(always)]
fn verify_bucket(
&self,
haystack: &[u8],
start: usize,
bucket: usize,
on_value: &mut impl FnMut(u32) -> bool,
) -> bool {
let bv = &self.bucket_verify[bucket];
let mut lengths = bv.length_mask;
while lengths != 0 {
let len_idx = lengths.trailing_zeros() as usize;
lengths &= lengths - 1;
let prefix_len = len_idx + 2;
if start + prefix_len > haystack.len() {
continue;
}
let key = prefix_key(&haystack[start..start + prefix_len]);
let Some(group) = bv.groups[len_idx].get(key) else {
continue;
};
for &value in &group.exact_values {
if on_value(value) {
return true;
}
}
for literal in &group.long_literals {
let len = literal.bytes.len();
if start + len > haystack.len() {
continue;
}
if haystack[start + prefix_len..start + len] == literal.bytes[prefix_len..]
&& on_value(literal.value)
{
return true;
}
}
}
false
}
#[inline(always)]
fn verify_bucket_any(&self, haystack: &[u8], start: usize, bucket: usize) -> bool {
let bv = &self.bucket_verify[bucket];
let mut lengths = bv.length_mask;
while lengths != 0 {
let len_idx = lengths.trailing_zeros() as usize;
lengths &= lengths - 1;
let prefix_len = len_idx + 2;
if start + prefix_len > haystack.len() {
continue;
}
let key = prefix_key(&haystack[start..start + prefix_len]);
let Some(group) = bv.groups[len_idx].get(key) else {
continue;
};
if !group.exact_values.is_empty() {
return true;
}
for literal in &group.long_literals {
let len = literal.bytes.len();
if start + len > haystack.len() {
continue;
}
if haystack[start + prefix_len..start + len] == literal.bytes[prefix_len..] {
return true;
}
}
}
false
}
}
#[inline(always)]
fn prefix_key(bytes: &[u8]) -> u64 {
const { assert!(cfg!(target_endian = "little")) };
debug_assert!(!bytes.is_empty() && bytes.len() <= 8);
let mut key = 0u64;
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), (&raw mut key).cast::<u8>(), bytes.len());
}
key
}