use crate::sieve::dispatch::HardwareTier;
pub(crate) mod collector;
pub(crate) mod compiler;
pub(crate) mod dispatch;
pub struct SimdSieve<'a> {
pub(crate) haystack: &'a [u8],
pub(crate) offset: usize,
pub(crate) verification_patterns: [&'a [u8]; 16],
pub(crate) pattern_count: usize,
pub(crate) max_len: usize,
pub(crate) tier: HardwareTier,
pub(crate) current_mask: u64,
pub(crate) next_mask_cache: u64,
pub(crate) mask_base_offset: usize,
pub(crate) verifier: fn(&[u8], &[u8]) -> bool,
}
impl<'a> SimdSieve<'a> {
#[must_use]
pub fn estimate_match_count(
haystack: &'a [u8],
patterns: &[&'a [u8]],
case_insensitive: bool,
) -> u64 {
let haystack = &haystack[..haystack.len().min(4096)];
let sieve_result = if case_insensitive {
Self::new_case_insensitive(haystack, patterns)
} else {
Self::new(haystack, patterns)
};
let Ok(mut sieve) = sieve_result else {
return 0;
};
let mut global_popcnt: u64 = 0;
while sieve.fetch_next_chunk() {
if sieve.current_mask != 0 {
global_popcnt += u64::from(sieve.current_mask.count_ones());
sieve.current_mask = 0;
}
if sieve.next_mask_cache != 0 {
global_popcnt += u64::from(sieve.next_mask_cache.count_ones());
sieve.next_mask_cache = 0;
}
}
while sieve.offset + sieve.max_len <= haystack.len() {
let current_idx = sieve.offset;
sieve.offset += 1;
for p_idx in 0..sieve.pattern_count {
let vp = sieve.verification_patterns[p_idx];
let prefix_len = vp.len().min(4);
if (sieve.verifier)(&haystack[current_idx..current_idx + prefix_len], &vp[..prefix_len]) {
global_popcnt += 1;
break;
}
}
}
while sieve.offset <= haystack.len() {
let current_idx = sieve.offset;
sieve.offset += 1;
for p_idx in 0..sieve.pattern_count {
let vp = sieve.verification_patterns[p_idx];
let prefix_len = vp.len().min(4);
if current_idx + prefix_len <= haystack.len() && (sieve.verifier)(&haystack[current_idx..current_idx + prefix_len], &vp[..prefix_len]) {
global_popcnt += 1;
break;
}
}
}
global_popcnt
}
}
#[cfg(test)]
mod tests {
use super::SimdSieve;
use super::dispatch::HardwareTier;
#[test]
fn case_insensitive_iterator_keeps_boundary_matches() {
let mut haystack = vec![0u8; 1024];
haystack[63] = b'Z';
haystack[150] = b'z';
haystack[1023] = b'Z';
let sieve = SimdSieve::new_case_insensitive(&haystack, &[b"Z"]).unwrap();
let tier = match &sieve.tier {
#[cfg(target_arch = "x86_64")]
HardwareTier::Avx512(_) => "avx512",
#[cfg(target_arch = "x86_64")]
HardwareTier::Avx2(_) => "avx2",
#[cfg(target_arch = "aarch64")]
HardwareTier::Neon(_) => "neon",
HardwareTier::Scalar(_) => "scalar",
};
let results: Vec<usize> = sieve.collect();
eprintln!("tier={tier} results={results:?}");
assert_eq!(results, vec![63, 150, 1023]);
}
}
#[cfg(test)]
mod sieve_unit_tests {
use super::*;
#[test]
fn empty_haystack_no_matches() {
let patterns: &[&[u8]] = &[b"test"];
let matches: Vec<_> = SimdSieve::new(b"", patterns).unwrap().collect();
assert!(matches.is_empty());
}
#[test]
fn single_pattern_found() {
let patterns: &[&[u8]] = &[b"GET"];
let matches: Vec<_> = SimdSieve::new(b"GET /index.html", patterns)
.unwrap()
.collect();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0], 0);
}
#[test]
fn multiple_patterns_found() {
let patterns: &[&[u8]] = &[b"GET", b"POST"];
let haystack = b"GET /a HTTP/1.1\r\nPOST /b HTTP/1.1\r\n";
let matches: Vec<_> = SimdSieve::new(haystack, patterns).unwrap().collect();
assert!(matches.len() >= 2, "both patterns should match");
}
#[test]
fn pattern_at_end_of_haystack() {
let patterns: &[&[u8]] = &[b"end"];
let matches: Vec<_> = SimdSieve::new(b"the end", patterns).unwrap().collect();
assert_eq!(matches.len(), 1);
}
#[test]
fn no_match_returns_empty() {
let patterns: &[&[u8]] = &[b"xyz"];
let matches: Vec<_> = SimdSieve::new(b"abcdefghijklmnop", patterns)
.unwrap()
.collect();
assert!(matches.is_empty());
}
#[test]
fn estimate_match_count_basic() {
let patterns: &[&[u8]] = &[b"abc"];
let haystack = vec![b'x'; 128];
let count = SimdSieve::estimate_match_count(&haystack, patterns, false);
assert!(count <= haystack.len() as u64, "estimate should not exceed haystack length");
}
#[test]
fn binary_data_no_crash() {
let patterns: &[&[u8]] = &[b"\x00\x01\x02"];
let data = vec![0u8; 1024];
let _matches: Vec<_> = SimdSieve::new(&data, patterns).unwrap().collect();
}
#[test]
fn pattern_longer_than_haystack() {
let patterns: &[&[u8]] = &[b"toolongpattern"];
let matches: Vec<_> = SimdSieve::new(b"short", patterns).unwrap().collect();
assert!(matches.is_empty());
}
}