#[cfg(target_arch = "aarch64")]
pub mod aarch64;
pub mod scalar;
#[cfg(target_arch = "x86_64")]
pub mod x86_64;
pub(crate) fn case_needle(needle: &[u8]) -> Vec<(u8, u8)> {
needle
.iter()
.map(|&c| {
(
c,
if c.is_ascii_lowercase() {
c.to_ascii_uppercase()
} else {
c.to_ascii_lowercase()
},
)
})
.collect()
}
#[derive(Debug, Clone)]
pub enum Prefilter {
#[cfg(target_arch = "x86_64")]
AVX(x86_64::PrefilterAVX),
#[cfg(target_arch = "x86_64")]
SSE(x86_64::PrefilterSSE),
#[cfg(target_arch = "aarch64")]
NEON(aarch64::PrefilterNEON),
Scalar(scalar::PrefilterScalar),
}
impl Prefilter {
pub fn new(needle: &[u8]) -> Self {
#[cfg(target_arch = "x86_64")]
if x86_64::PrefilterAVX::is_available() {
return Prefilter::AVX(unsafe { x86_64::PrefilterAVX::new(needle) });
}
#[cfg(target_arch = "x86_64")]
if x86_64::PrefilterSSE::is_available() {
return Prefilter::SSE(unsafe { x86_64::PrefilterSSE::new(needle) });
}
#[cfg(target_arch = "aarch64")]
return Prefilter::NEON(aarch64::PrefilterNEON::new(needle));
#[cfg(not(target_arch = "aarch64"))]
Prefilter::Scalar(scalar::PrefilterScalar::new(needle))
}
#[inline]
pub fn match_haystack(&self, haystack: &[u8], max_typos: u16) -> (bool, usize) {
match (self, max_typos) {
#[cfg(target_arch = "x86_64")]
(Prefilter::AVX(p), 0) => unsafe { p.match_haystack(haystack) },
#[cfg(target_arch = "x86_64")]
(Prefilter::AVX(p), _) => unsafe { p.match_haystack_typos(haystack, max_typos) },
#[cfg(target_arch = "x86_64")]
(Prefilter::SSE(p), 0) => unsafe { p.match_haystack(haystack) },
#[cfg(target_arch = "x86_64")]
(Prefilter::SSE(p), _) => unsafe { p.match_haystack_typos(haystack, max_typos) },
#[cfg(target_arch = "aarch64")]
(Prefilter::NEON(p), 0) => unsafe { p.match_haystack(haystack) },
#[cfg(target_arch = "aarch64")]
(Prefilter::NEON(p), _) => unsafe { p.match_haystack_typos(haystack, max_typos) },
(Prefilter::Scalar(p), 0) => p.match_haystack(haystack),
(Prefilter::Scalar(p), _) => p.match_haystack_typos(haystack, max_typos),
}
}
#[inline]
pub fn match_haystack_chunked(
&self,
chunk_ptrs: &[*const u8],
byte_len: u16,
max_typos: u16,
) -> (bool, usize) {
match (self, max_typos) {
#[cfg(target_arch = "x86_64")]
(Prefilter::AVX(p), 0) => unsafe { p.match_haystack_chunked(chunk_ptrs, byte_len) },
#[cfg(target_arch = "x86_64")]
(Prefilter::AVX(p), _) => unsafe {
p.match_haystack_typos_chunked(chunk_ptrs, byte_len, max_typos)
},
#[cfg(target_arch = "x86_64")]
(Prefilter::SSE(p), 0) => unsafe { p.match_haystack_chunked(chunk_ptrs, byte_len) },
#[cfg(target_arch = "x86_64")]
(Prefilter::SSE(p), _) => unsafe {
p.match_haystack_typos_chunked(chunk_ptrs, byte_len, max_typos)
},
#[cfg(target_arch = "aarch64")]
(Prefilter::NEON(p), 0) => unsafe { p.match_haystack_chunked(chunk_ptrs, byte_len) },
#[cfg(target_arch = "aarch64")]
(Prefilter::NEON(p), _) => unsafe {
p.match_haystack_typos_chunked(chunk_ptrs, byte_len, max_typos)
},
(Prefilter::Scalar(p), 0) => p.match_haystack_chunked(chunk_ptrs, byte_len),
(Prefilter::Scalar(p), _) => {
p.match_haystack_typos_chunked(chunk_ptrs, byte_len, max_typos)
}
}
}
}
#[cfg(test)]
mod tests {
fn match_haystack(needle: &str, haystack: &str) -> bool {
match_haystack_generic(needle, haystack, 0)
}
fn match_haystack_typos(needle: &str, haystack: &str, max_typos: u16) -> bool {
match_haystack_generic(needle, haystack, max_typos)
}
#[test]
fn test_exact_match() {
assert!(match_haystack("foo", "foo"));
assert!(match_haystack("a", "a"));
assert!(match_haystack("hello", "hello"));
}
#[test]
fn test_fuzzy_match_with_gaps() {
assert!(match_haystack("foo", "f_o_o"));
assert!(match_haystack("foo", "f__o__o"));
assert!(match_haystack("abc", "a_b_c"));
assert!(match_haystack("test", "t_e_s_t"));
}
#[test]
fn test_unordered_within_chunk() {
assert!(match_haystack("foo", "oof"));
assert!(match_haystack("abc", "cba"));
assert!(match_haystack("test", "tset"));
assert!(match_haystack("hello", "olleh"));
}
#[test]
fn test_case_insensitivity() {
assert!(match_haystack("foo", "FOO"));
assert!(match_haystack("Foo", "foo"));
assert!(match_haystack("ABC", "abc"));
}
#[test]
fn test_chunk_boundary() {
let haystack = "oo_______________f"; assert!(!match_haystack("foo", haystack));
let haystack = "oof_____________"; assert!(match_haystack("foo", haystack));
}
#[test]
fn test_overlapping_load() {
assert!(match_haystack("foo", "f_________________________o______"));
}
#[test]
fn test_multiple_chunks() {
assert!(match_haystack("foo", "f_______________o_______________o"));
assert!(match_haystack(
"abc",
"a_______________b_______________c_______________"
));
}
#[test]
fn test_partial_matches() {
assert!(!match_haystack("fob", "fo"));
assert!(!match_haystack("test", "tet"));
assert!(!match_haystack("abc", "a"));
}
#[test]
fn test_duplicate_characters_in_needle() {
assert!(match_haystack("foo", "foo"));
assert!(match_haystack("foo", "ofo"));
assert!(match_haystack("foo", "fo"));
assert!(match_haystack("aaa", "aaa"));
assert!(match_haystack("aaa", "aa"));
}
#[test]
fn test_haystack_with_extra_characters() {
assert!(match_haystack("foo", "foobar"));
assert!(match_haystack("foo", "prefoobar"));
assert!(match_haystack("abc", "xaxbxcx"));
}
#[test]
fn test_edge_cases_at_16_byte_boundary() {
let haystack = "123456789012345f"; assert!(match_haystack("f", haystack));
let haystack = "o_______________of"; assert!(match_haystack("foo", haystack));
}
#[test]
fn test_overlapping_chunks() {
let haystack = "_______________fo"; assert!(match_haystack("fo", haystack));
}
#[test]
fn test_single_character_needle() {
assert!(match_haystack("a", "a"));
assert!(match_haystack("a", "ba"));
assert!(match_haystack("a", "_______________a"));
assert!(!match_haystack("a", ""));
}
#[test]
fn test_repeated_character_haystack() {
assert!(match_haystack("abc", "aaabbbccc"));
assert!(match_haystack("foo", "fofofoooo"));
}
#[test]
fn test_typos_single_missing_character() {
assert!(match_haystack_typos("bar", "ba", 1));
assert!(match_haystack_typos("bar", "ar", 1));
assert!(match_haystack_typos("hello", "hllo", 1));
assert!(match_haystack_typos("test", "tst", 1));
assert!(!match_haystack_typos("bar", "ba", 0));
assert!(!match_haystack_typos("hello", "hllo", 0));
}
#[test]
fn test_typos_multiple_missing_characters() {
assert!(match_haystack_typos("hello", "hll", 2));
assert!(match_haystack_typos("testing", "tstng", 2));
assert!(match_haystack_typos("abcdef", "abdf", 2));
assert!(!match_haystack_typos("hello", "hll", 1));
assert!(!match_haystack_typos("testing", "tstng", 1));
}
#[test]
fn test_typos_with_gaps() {
assert!(match_haystack_typos("bar", "b_r", 1));
assert!(match_haystack_typos("test", "t__s_t", 1));
assert!(match_haystack_typos("helo", "h_l_", 2));
}
#[test]
fn test_typos_unordered_permutations() {
assert!(match_haystack_typos("bar", "rb", 1));
assert!(match_haystack_typos("abcdef", "fcda", 2));
}
#[test]
fn test_typos_case_insensitive() {
assert!(match_haystack_typos("BAR", "ba", 1));
assert!(match_haystack_typos("Hello", "HLL", 2));
assert!(match_haystack_typos("TeSt", "ES", 2));
assert!(!match_haystack_typos("TeSt", "ES", 1));
}
#[test]
fn test_typos_edge_cases() {
assert!(match_haystack_typos("abc", "", 3));
assert!(match_haystack_typos("foo", "fo", 5));
}
#[test]
fn test_typos_across_chunks() {
assert!(match_haystack_typos("abc", "a_______________b", 1));
assert!(match_haystack_typos(
"test",
"t_______________s_______________t",
1
));
}
#[test]
fn test_typos_single_character_needle() {
assert!(match_haystack_typos("a", "a", 0));
assert!(match_haystack_typos("a", "", 1));
assert!(!match_haystack_typos("a", "", 0));
}
fn normalize_haystack(haystack: &str) -> String {
if haystack.len() < 8 {
"_".repeat(8 - haystack.len()) + haystack
} else {
haystack.to_string()
}
}
fn match_haystack_generic(needle: &str, haystack: &str, max_typos: u16) -> bool {
use crate::prefilter::scalar::PrefilterScalar;
let haystack = normalize_haystack(haystack);
let haystack = haystack.as_bytes();
let scalar_result = {
let prefilter = PrefilterScalar::new(needle.as_bytes());
if max_typos > 0 {
prefilter.match_haystack_typos(haystack, max_typos).0
} else {
prefilter.match_haystack(haystack).0
}
};
#[cfg(target_arch = "x86_64")]
return {
use crate::prefilter::x86_64::{PrefilterAVX, PrefilterSSE};
let avx_result = unsafe {
let prefilter = PrefilterAVX::new(needle.as_bytes());
if max_typos > 0 {
prefilter.match_haystack_typos(haystack, max_typos).0
} else {
prefilter.match_haystack(haystack).0
}
};
let sse_result = unsafe {
let prefilter = PrefilterSSE::new(needle.as_bytes());
if max_typos > 0 {
prefilter.match_haystack_typos(haystack, max_typos).0
} else {
prefilter.match_haystack(haystack).0
}
};
assert_eq!(
avx_result, sse_result,
"avx and sse results should be the same"
);
assert_eq!(
avx_result, scalar_result,
"avx and scalar results should be the same"
);
avx_result
};
#[cfg(target_arch = "aarch64")]
return {
let neon_result = unsafe {
use crate::prefilter::aarch64::PrefilterNEON;
if max_typos > 0 {
PrefilterNEON::new(needle.as_bytes())
.match_haystack_typos(haystack, max_typos)
.0
} else {
PrefilterNEON::new(needle.as_bytes())
.match_haystack(haystack)
.0
}
};
assert_eq!(
neon_result, scalar_result,
"neon and scalar results should be the same"
);
neon_result
};
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
return {
use crate::prefilter::PrefilterScalar;
let prefilter = PrefilterScalar::new(needle.as_bytes());
if max_typos > 0 {
prefilter.match_haystack_typos(haystack, max_typos).0
} else {
prefilter.match_haystack(haystack).0
}
};
}
}