#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub const MAX_PATTERNS: usize = 8;
pub const MAX_PATTERN_LEN: usize = 8;
pub struct Teddy {
patterns: Vec<Vec<u8>>,
#[cfg(target_arch = "x86_64")]
lo_nibble_table: [u8; 16],
#[cfg(target_arch = "x86_64")]
hi_nibble_table: [u8; 16],
#[cfg(target_arch = "x86_64")]
lo_simd_table: std::sync::OnceLock<[u8; 32]>,
#[cfg(target_arch = "x86_64")]
hi_simd_table: std::sync::OnceLock<[u8; 32]>,
#[cfg(target_arch = "x86_64")]
use_avx2: bool,
}
impl Teddy {
pub fn new(patterns: Vec<Vec<u8>>) -> Option<Self> {
if patterns.is_empty() || patterns.len() > MAX_PATTERNS {
return None;
}
if patterns
.iter()
.any(|p| p.is_empty() || p.len() > MAX_PATTERN_LEN)
{
return None;
}
#[cfg(target_arch = "x86_64")]
{
let mut lo_nibble_table = [0u8; 16];
let mut hi_nibble_table = [0u8; 16];
for (i, pattern) in patterns.iter().enumerate() {
let first_byte = pattern[0];
let lo_nibble = (first_byte & 0x0F) as usize;
let hi_nibble = (first_byte >> 4) as usize;
lo_nibble_table[lo_nibble] |= 1 << i;
hi_nibble_table[hi_nibble] |= 1 << i;
}
Some(Self {
patterns,
lo_nibble_table,
hi_nibble_table,
lo_simd_table: std::sync::OnceLock::new(),
hi_simd_table: std::sync::OnceLock::new(),
use_avx2: is_x86_feature_detected!("avx2"),
})
}
#[cfg(not(target_arch = "x86_64"))]
Some(Self { patterns })
}
pub fn patterns(&self) -> &[Vec<u8>] {
&self.patterns
}
pub fn pattern_count(&self) -> usize {
self.patterns.len()
}
pub fn find(&self, haystack: &[u8]) -> Option<(usize, usize)> {
self.find_from(haystack, 0)
}
#[inline]
pub fn find_from(&self, haystack: &[u8], pos: usize) -> Option<(usize, usize)> {
if pos >= haystack.len() {
return None;
}
#[cfg(target_arch = "x86_64")]
{
if self.use_avx2 {
return unsafe { self.find_avx2_from(haystack, pos) };
}
}
self.find_scalar_from(haystack, pos)
}
#[inline]
pub fn find_iter<'a, 'h>(&'a self, haystack: &'h [u8]) -> TeddyIter<'a, 'h> {
TeddyIter {
teddy: self,
haystack,
pos: 0,
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(dead_code)]
unsafe fn find_avx2(&self, haystack: &[u8]) -> Option<(usize, usize)> {
self.find_avx2_from(haystack, 0)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn find_avx2_from(&self, haystack: &[u8], start_pos: usize) -> Option<(usize, usize)> {
let len = haystack.len();
if start_pos >= len {
return None;
}
let lo_bytes = self
.lo_simd_table
.get_or_init(|| self.build_simd_table_cached(&self.lo_nibble_table));
let hi_bytes = self
.hi_simd_table
.get_or_init(|| self.build_simd_table_cached(&self.hi_nibble_table));
let lo_table = _mm256_loadu_si256(lo_bytes.as_ptr() as *const __m256i);
let hi_table = _mm256_loadu_si256(hi_bytes.as_ptr() as *const __m256i);
let lo_mask = _mm256_set1_epi8(0x0F);
let ptr = haystack.as_ptr();
let mut offset = start_pos;
while offset + 32 <= len {
let data = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let lo_nibbles = _mm256_and_si256(data, lo_mask);
let hi_nibbles = _mm256_and_si256(_mm256_srli_epi16(data, 4), lo_mask);
let lo_matches = _mm256_shuffle_epi8(lo_table, lo_nibbles);
let hi_matches = _mm256_shuffle_epi8(hi_table, hi_nibbles);
let candidates = _mm256_and_si256(lo_matches, hi_matches);
let mask =
_mm256_movemask_epi8(_mm256_cmpeq_epi8(candidates, _mm256_setzero_si256())) as u32;
let candidate_mask = !mask;
if candidate_mask != 0 {
let mut remaining = candidate_mask;
while remaining != 0 {
let bit_pos = remaining.trailing_zeros() as usize;
remaining &= remaining - 1;
let pos = offset + bit_pos;
let pattern_bits = *haystack.get_unchecked(pos);
let pattern_mask = self.lo_nibble_table[(pattern_bits & 0x0F) as usize]
& self.hi_nibble_table[(pattern_bits >> 4) as usize];
for (pat_idx, pattern) in self.patterns.iter().enumerate() {
if (pattern_mask & (1 << pat_idx)) != 0
&& pos + pattern.len() <= len
&& haystack[pos..pos + pattern.len()] == *pattern
{
return Some((pat_idx, pos));
}
}
}
}
offset += 32;
}
self.find_scalar_from(&haystack[offset..], offset)
}
#[cfg(target_arch = "x86_64")]
fn build_simd_table_cached(&self, table: &[u8; 16]) -> [u8; 32] {
let mut result = [0u8; 32];
result[0..16].copy_from_slice(table);
result[16..32].copy_from_slice(table);
result
}
#[allow(dead_code)]
fn find_scalar(&self, haystack: &[u8]) -> Option<(usize, usize)> {
self.find_scalar_from(haystack, 0)
}
fn find_scalar_from(&self, haystack: &[u8], base_offset: usize) -> Option<(usize, usize)> {
for (i, window) in haystack.windows(1).enumerate() {
let pos = base_offset + i;
#[cfg(target_arch = "x86_64")]
let pattern_mask = {
let first_byte = window[0];
self.lo_nibble_table[(first_byte & 0x0F) as usize]
& self.hi_nibble_table[(first_byte >> 4) as usize]
};
#[cfg(not(target_arch = "x86_64"))]
let pattern_mask = 0xFFu8;
if pattern_mask != 0 {
for (pat_idx, pattern) in self.patterns.iter().enumerate() {
#[cfg(target_arch = "x86_64")]
if (pattern_mask & (1 << pat_idx)) == 0 {
continue;
}
if i + pattern.len() <= haystack.len()
&& &haystack[i..i + pattern.len()] == pattern.as_slice()
{
return Some((pat_idx, pos));
}
}
}
}
None
}
}
pub struct TeddyIter<'a, 'h> {
teddy: &'a Teddy,
haystack: &'h [u8],
pos: usize,
}
impl<'a, 'h> Iterator for TeddyIter<'a, 'h> {
type Item = (usize, usize);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.pos >= self.haystack.len() {
return None;
}
let result = self.teddy.find_from(self.haystack, self.pos);
if let Some((pat_idx, abs_pos)) = result {
self.pos = abs_pos + 1;
Some((pat_idx, abs_pos))
} else {
self.pos = self.haystack.len();
None
}
}
}
#[cfg(all(test, target_arch = "x86_64"))]
mod tests {
use super::*;
#[test]
fn test_teddy_single() {
let teddy = Teddy::new(vec![b"hello".to_vec()]).unwrap();
assert_eq!(teddy.find(b"say hello world"), Some((0, 4)));
}
#[test]
fn test_teddy_multiple() {
let teddy = Teddy::new(vec![b"cat".to_vec(), b"dog".to_vec()]).unwrap();
assert_eq!(teddy.find(b"I have a dog"), Some((1, 9)));
}
#[test]
fn test_teddy_no_match() {
let teddy = Teddy::new(vec![b"xyz".to_vec()]).unwrap();
assert_eq!(teddy.find(b"hello world"), None);
}
#[test]
fn test_teddy_first_pattern_wins() {
let teddy = Teddy::new(vec![b"abc".to_vec(), b"abc".to_vec()]).unwrap();
let result = teddy.find(b"xxxabcxxx");
assert_eq!(result, Some((0, 3))); }
#[test]
fn test_teddy_overlapping() {
let teddy = Teddy::new(vec![b"aa".to_vec(), b"aaa".to_vec()]).unwrap();
let result = teddy.find(b"xaaaax");
assert_eq!(result, Some((0, 1)));
}
#[test]
fn test_teddy_at_start() {
let teddy = Teddy::new(vec![b"hello".to_vec()]).unwrap();
assert_eq!(teddy.find(b"hello world"), Some((0, 0)));
}
#[test]
fn test_teddy_at_end() {
let teddy = Teddy::new(vec![b"world".to_vec()]).unwrap();
assert_eq!(teddy.find(b"hello world"), Some((0, 6)));
}
#[test]
fn test_teddy_iter() {
let teddy = Teddy::new(vec![b"a".to_vec()]).unwrap();
let matches: Vec<_> = teddy.find_iter(b"abacada").collect();
assert_eq!(matches, vec![(0, 0), (0, 2), (0, 4), (0, 6)]);
}
#[test]
fn test_teddy_iter_multiple_patterns() {
let teddy = Teddy::new(vec![b"a".to_vec(), b"b".to_vec()]).unwrap();
let matches: Vec<_> = teddy.find_iter(b"abba").collect();
assert_eq!(matches, vec![(0, 0), (1, 1), (1, 2), (0, 3)]);
}
#[test]
fn test_teddy_empty_haystack() {
let teddy = Teddy::new(vec![b"hello".to_vec()]).unwrap();
assert_eq!(teddy.find(b""), None);
}
#[test]
fn test_teddy_large_input() {
let teddy = Teddy::new(vec![b"needle".to_vec()]).unwrap();
let mut haystack = vec![b'x'; 100];
haystack[50..56].copy_from_slice(b"needle");
assert_eq!(teddy.find(&haystack), Some((0, 50)));
}
#[test]
fn test_teddy_too_many_patterns() {
let patterns: Vec<Vec<u8>> = (0..10).map(|i| vec![b'a' + i]).collect();
assert!(Teddy::new(patterns).is_none());
}
#[test]
fn test_teddy_empty_pattern() {
assert!(Teddy::new(vec![vec![]]).is_none());
}
}