use std::cmp::Ordering;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use crate::system::cpu_features::get_cpu_features;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchTier {
Scalar,
SSE42,
AVX2,
AVX512,
NEON,
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub enable_sse42: bool,
pub enable_avx2: bool,
pub enable_avx512: bool,
pub enable_neon: bool,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
enable_sse42: true,
enable_avx2: true,
enable_avx512: true,
enable_neon: true,
}
}
}
pub struct SimdStringSearch {
config: SearchConfig,
tier: SearchTier,
}
impl SimdStringSearch {
pub fn new() -> Self {
Self::with_config(SearchConfig::default())
}
pub fn with_config(config: SearchConfig) -> Self {
let tier = Self::detect_optimal_tier(&config);
Self { config, tier }
}
fn detect_optimal_tier(config: &SearchConfig) -> SearchTier {
let features = get_cpu_features();
#[cfg(target_arch = "x86_64")]
{
if config.enable_avx512 && features.has_avx512f {
return SearchTier::AVX512;
}
if config.enable_avx2 && features.has_avx2 {
return SearchTier::AVX2;
}
if config.enable_sse42 && features.has_sse42 {
return SearchTier::SSE42;
}
}
#[cfg(target_arch = "aarch64")]
{
if config.enable_neon && features.has_neon {
return SearchTier::NEON;
}
}
SearchTier::Scalar
}
pub fn tier(&self) -> SearchTier {
self.tier
}
pub fn find_char(&self, haystack: &[u8], ch: u8) -> Option<usize> {
match self.tier {
#[cfg(target_arch = "x86_64")]
SearchTier::SSE42 => sse42_strchr(haystack, ch),
#[cfg(target_arch = "x86_64")]
SearchTier::AVX2 | SearchTier::AVX512 => {
if is_x86_feature_detected!("sse4.2") {
sse42_strchr(haystack, ch)
} else {
scalar_strchr(haystack, ch)
}
}
#[cfg(target_arch = "aarch64")]
SearchTier::NEON => neon_strchr(haystack, ch),
_ => scalar_strchr(haystack, ch),
}
}
pub fn find_pattern(&self, haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
return Some(0);
}
if haystack.len() < needle.len() {
return None;
}
match self.tier {
#[cfg(target_arch = "x86_64")]
SearchTier::SSE42 => sse42_strstr(haystack, needle),
#[cfg(target_arch = "x86_64")]
SearchTier::AVX2 | SearchTier::AVX512 => {
if needle.len() <= 16 && is_x86_feature_detected!("sse4.2") {
sse42_strstr(haystack, needle)
} else {
scalar_strstr(haystack, needle)
}
}
#[cfg(target_arch = "aarch64")]
SearchTier::NEON => neon_strstr(haystack, needle),
_ => scalar_strstr(haystack, needle),
}
}
pub fn find_any_of(&self, haystack: &[u8], chars: &[u8]) -> Option<usize> {
if chars.is_empty() {
return None;
}
match self.tier {
#[cfg(target_arch = "x86_64")]
SearchTier::SSE42 => sse42_multi_search(haystack, chars),
#[cfg(target_arch = "x86_64")]
SearchTier::AVX2 | SearchTier::AVX512 => {
if is_x86_feature_detected!("sse4.2") {
sse42_multi_search(haystack, chars)
} else {
scalar_multi_search(haystack, chars)
}
}
#[cfg(target_arch = "aarch64")]
SearchTier::NEON => neon_multi_search(haystack, chars),
_ => scalar_multi_search(haystack, chars),
}
}
pub fn compare_strings(&self, s1: &[u8], s2: &[u8]) -> Ordering {
match self.tier {
#[cfg(target_arch = "x86_64")]
SearchTier::SSE42 => sse42_strcmp(s1, s2),
#[cfg(target_arch = "x86_64")]
SearchTier::AVX2 | SearchTier::AVX512 => {
if is_x86_feature_detected!("sse4.2") {
sse42_strcmp(s1, s2)
} else {
scalar_strcmp(s1, s2)
}
}
#[cfg(target_arch = "aarch64")]
SearchTier::NEON => neon_strcmp(s1, s2),
_ => scalar_strcmp(s1, s2),
}
}
}
impl Default for SimdStringSearch {
fn default() -> Self {
Self::new()
}
}
#[cfg(target_arch = "x86_64")]
pub fn sse42_strchr(haystack: &[u8], ch: u8) -> Option<usize> {
if !is_x86_feature_detected!("sse4.2") {
return scalar_strchr(haystack, ch);
}
unsafe { sse42_strchr_impl(haystack, ch) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_strchr_impl(haystack: &[u8], ch: u8) -> Option<usize> {
const CHUNK_SIZE: usize = 16;
let mut needle_bytes = [0u8; 16];
needle_bytes[0] = ch;
let needle = unsafe { _mm_loadu_si128(needle_bytes.as_ptr() as *const __m128i) };
let mut offset = 0;
let len = haystack.len();
while offset + CHUNK_SIZE <= len {
let chunk = unsafe { _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i) };
let idx = unsafe {
_mm_cmpestri(
needle,
1, chunk,
CHUNK_SIZE as i32,
_SIDD_CMP_EQUAL_ANY | _SIDD_UBYTE_OPS,
)
};
if idx < CHUNK_SIZE as i32 {
return Some(offset + idx as usize);
}
offset += CHUNK_SIZE;
}
haystack[offset..].iter().position(|&b| b == ch).map(|i| offset + i)
}
#[cfg(target_arch = "x86_64")]
pub fn sse42_strstr(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
return Some(0);
}
if haystack.len() < needle.len() {
return None;
}
if !is_x86_feature_detected!("sse4.2") {
return scalar_strstr(haystack, needle);
}
unsafe { sse42_strstr_impl(haystack, needle) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_strstr_impl(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let needle_len = needle.len();
if needle_len <= 16 {
unsafe { sse42_strstr_short(haystack, needle) }
} else if needle_len <= 32 {
unsafe { sse42_strstr_medium(haystack, needle) }
} else {
unsafe { sse42_strstr_long(haystack, needle) }
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_strstr_short(haystack: &[u8], needle: &[u8]) -> Option<usize> {
const CHUNK_SIZE: usize = 16;
let needle_len = needle.len();
let mut needle_bytes = [0u8; 16];
needle_bytes[..needle_len].copy_from_slice(needle);
let needle_vec = unsafe { _mm_loadu_si128(needle_bytes.as_ptr() as *const __m128i) };
let mut offset = 0;
let search_limit = haystack.len().saturating_sub(needle_len);
while offset <= search_limit {
let remaining = haystack.len() - offset;
let chunk_len = remaining.min(CHUNK_SIZE);
let chunk = unsafe { _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i) };
let idx = unsafe {
_mm_cmpestri(
needle_vec,
needle_len as i32,
chunk,
chunk_len as i32,
_SIDD_CMP_EQUAL_ORDERED | _SIDD_UBYTE_OPS,
)
};
if idx < chunk_len as i32 {
let match_pos = offset + idx as usize;
if match_pos + needle_len <= haystack.len()
&& &haystack[match_pos..match_pos + needle_len] == needle
{
return Some(match_pos);
}
}
offset += 1; }
None
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_strstr_medium(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let needle_len = needle.len();
let first_16 = &needle[..16];
let mut offset = 0;
let search_limit = haystack.len().saturating_sub(needle_len);
while offset <= search_limit {
if let Some(match_pos) = unsafe { sse42_strstr_short(&haystack[offset..], first_16) } {
let absolute_pos = offset + match_pos;
if absolute_pos + needle_len <= haystack.len()
&& &haystack[absolute_pos..absolute_pos + needle_len] == needle
{
return Some(absolute_pos);
}
offset = absolute_pos + 1;
} else {
break;
}
}
None
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_strstr_long(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let first_byte = needle[0];
let mut offset = 0;
while let Some(match_pos) = unsafe { sse42_strchr_impl(&haystack[offset..], first_byte) } {
let absolute_pos = offset + match_pos;
if absolute_pos + needle.len() <= haystack.len()
&& &haystack[absolute_pos..absolute_pos + needle.len()] == needle
{
return Some(absolute_pos);
}
offset = absolute_pos + 1;
}
None
}
#[cfg(target_arch = "x86_64")]
pub fn sse42_multi_search(haystack: &[u8], chars: &[u8]) -> Option<usize> {
if chars.is_empty() {
return None;
}
if !is_x86_feature_detected!("sse4.2") {
return scalar_multi_search(haystack, chars);
}
unsafe { sse42_multi_search_impl(haystack, chars) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_multi_search_impl(haystack: &[u8], chars: &[u8]) -> Option<usize> {
const CHUNK_SIZE: usize = 16;
let chars_len = chars.len().min(16);
let mut chars_bytes = [0u8; 16];
chars_bytes[..chars_len].copy_from_slice(&chars[..chars_len]);
let chars_vec = unsafe { _mm_loadu_si128(chars_bytes.as_ptr() as *const __m128i) };
let mut offset = 0;
let len = haystack.len();
while offset + CHUNK_SIZE <= len {
let chunk = unsafe { _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i) };
let idx = unsafe {
_mm_cmpestri(
chars_vec,
chars_len as i32,
chunk,
CHUNK_SIZE as i32,
_SIDD_CMP_EQUAL_ANY | _SIDD_UBYTE_OPS,
)
};
if idx < CHUNK_SIZE as i32 {
return Some(offset + idx as usize);
}
offset += CHUNK_SIZE;
}
haystack[offset..]
.iter()
.position(|&b| chars.contains(&b))
.map(|i| offset + i)
}
#[cfg(target_arch = "x86_64")]
pub fn sse42_strcmp(s1: &[u8], s2: &[u8]) -> Ordering {
if !is_x86_feature_detected!("sse4.2") {
return scalar_strcmp(s1, s2);
}
unsafe { sse42_strcmp_impl(s1, s2) }
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse42_strcmp_impl(s1: &[u8], s2: &[u8]) -> Ordering {
const CHUNK_SIZE: usize = 16;
let min_len = s1.len().min(s2.len());
let mut offset = 0;
while offset + CHUNK_SIZE <= min_len {
let chunk1 = unsafe { _mm_loadu_si128(s1.as_ptr().add(offset) as *const __m128i) };
let chunk2 = unsafe { _mm_loadu_si128(s2.as_ptr().add(offset) as *const __m128i) };
let cmp = unsafe { _mm_cmpeq_epi8(chunk1, chunk2) };
let mask = unsafe { _mm_movemask_epi8(cmp) };
if mask != 0xFFFF {
for i in 0..CHUNK_SIZE {
let b1 = s1[offset + i];
let b2 = s2[offset + i];
if b1 != b2 {
return b1.cmp(&b2);
}
}
}
offset += CHUNK_SIZE;
}
for i in offset..min_len {
match s1[i].cmp(&s2[i]) {
Ordering::Equal => continue,
other => return other,
}
}
s1.len().cmp(&s2.len())
}
#[cfg(target_arch = "aarch64")]
pub fn neon_strchr(haystack: &[u8], ch: u8) -> Option<usize> {
scalar_strchr(haystack, ch)
}
#[cfg(target_arch = "aarch64")]
pub fn neon_strstr(haystack: &[u8], needle: &[u8]) -> Option<usize> {
scalar_strstr(haystack, needle)
}
#[cfg(target_arch = "aarch64")]
pub fn neon_multi_search(haystack: &[u8], chars: &[u8]) -> Option<usize> {
scalar_multi_search(haystack, chars)
}
#[cfg(target_arch = "aarch64")]
pub fn neon_strcmp(s1: &[u8], s2: &[u8]) -> Ordering {
scalar_strcmp(s1, s2)
}
pub fn scalar_strchr(haystack: &[u8], ch: u8) -> Option<usize> {
haystack.iter().position(|&b| b == ch)
}
pub fn scalar_strstr(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
return Some(0);
}
if haystack.len() < needle.len() {
return None;
}
haystack
.windows(needle.len())
.position(|window| window == needle)
}
pub fn scalar_multi_search(haystack: &[u8], chars: &[u8]) -> Option<usize> {
haystack.iter().position(|&b| chars.contains(&b))
}
pub fn scalar_strcmp(s1: &[u8], s2: &[u8]) -> Ordering {
s1.cmp(s2)
}
pub fn find_char(haystack: &[u8], ch: u8) -> Option<usize> {
SimdStringSearch::new().find_char(haystack, ch)
}
pub fn find_pattern(haystack: &[u8], needle: &[u8]) -> Option<usize> {
SimdStringSearch::new().find_pattern(haystack, needle)
}
pub fn find_any_of(haystack: &[u8], chars: &[u8]) -> Option<usize> {
SimdStringSearch::new().find_any_of(haystack, chars)
}
pub fn compare_strings(s1: &[u8], s2: &[u8]) -> Ordering {
SimdStringSearch::new().compare_strings(s1, s2)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_tier_detection() {
let searcher = SimdStringSearch::new();
let tier = searcher.tier();
#[cfg(target_arch = "x86_64")]
{
assert!(matches!(
tier,
SearchTier::SSE42 | SearchTier::AVX2 | SearchTier::AVX512 | SearchTier::Scalar
));
}
#[cfg(target_arch = "aarch64")]
{
assert!(matches!(tier, SearchTier::NEON | SearchTier::Scalar));
}
}
#[test]
fn test_find_char_basic() {
let searcher = SimdStringSearch::new();
assert_eq!(searcher.find_char(b"hello", b'h'), Some(0));
assert_eq!(searcher.find_char(b"hello", b'e'), Some(1));
assert_eq!(searcher.find_char(b"hello", b'o'), Some(4));
assert_eq!(searcher.find_char(b"hello", b'x'), None);
assert_eq!(searcher.find_char(b"", b'a'), None);
}
#[test]
fn test_find_char_long_string() {
let searcher = SimdStringSearch::new();
let haystack = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaax";
assert_eq!(searcher.find_char(haystack, b'x'), Some(31));
let haystack = b"aaaaaxaaaaaxaaaaax";
assert_eq!(searcher.find_char(haystack, b'x'), Some(5));
}
#[test]
fn test_find_pattern_basic() {
let searcher = SimdStringSearch::new();
assert_eq!(searcher.find_pattern(b"hello world", b"world"), Some(6));
assert_eq!(searcher.find_pattern(b"hello world", b"hello"), Some(0));
assert_eq!(searcher.find_pattern(b"hello world", b"lo wo"), Some(3));
assert_eq!(searcher.find_pattern(b"hello world", b"xyz"), None);
assert_eq!(searcher.find_pattern(b"hello", b""), Some(0));
assert_eq!(searcher.find_pattern(b"hi", b"hello"), None);
}
#[test]
fn test_find_pattern_short() {
let searcher = SimdStringSearch::new();
assert_eq!(searcher.find_pattern(b"abcdefgh", b"cde"), Some(2));
assert_eq!(
searcher.find_pattern(b"0123456789abcdef", b"789a"),
Some(7)
);
let haystack = b"prefix_0123456789abcdef_suffix";
let needle = b"0123456789abcdef";
assert_eq!(searcher.find_pattern(haystack, needle), Some(7));
}
#[test]
fn test_find_pattern_medium() {
let searcher = SimdStringSearch::new();
let needle = b"0123456789abcdefghij"; let haystack = b"prefix_0123456789abcdefghij_suffix";
assert_eq!(searcher.find_pattern(haystack, needle), Some(7));
}
#[test]
fn test_find_pattern_long() {
let searcher = SimdStringSearch::new();
let needle = b"0123456789abcdefghijklmnopqrstuvwxyz"; let haystack = b"prefix_0123456789abcdefghijklmnopqrstuvwxyz_suffix";
assert_eq!(searcher.find_pattern(haystack, needle), Some(7));
let mut long_haystack = vec![b'a'; 1000];
long_haystack.extend_from_slice(needle);
long_haystack.extend_from_slice(b"_suffix");
assert_eq!(searcher.find_pattern(&long_haystack, needle), Some(1000));
}
#[test]
fn test_find_pattern_overlapping() {
let searcher = SimdStringSearch::new();
assert_eq!(searcher.find_pattern(b"aaaa", b"aa"), Some(0));
assert_eq!(searcher.find_pattern(b"ababab", b"abab"), Some(0));
}
#[test]
fn test_find_any_of_basic() {
let searcher = SimdStringSearch::new();
assert_eq!(searcher.find_any_of(b"hello", b"aeiou"), Some(1)); assert_eq!(searcher.find_any_of(b"world", b"aeiou"), Some(1));
assert_eq!(searcher.find_any_of(b"xyz", b"aeiou"), None);
assert_eq!(searcher.find_any_of(b"hello", b""), None);
}
#[test]
fn test_find_any_of_multiple() {
let searcher = SimdStringSearch::new();
let haystack = b"abcdefghijklmnop";
let chars = b"xyz";
assert_eq!(searcher.find_any_of(haystack, chars), None);
let chars = b"efg";
assert_eq!(searcher.find_any_of(haystack, chars), Some(4)); }
#[test]
fn test_compare_strings_basic() {
let searcher = SimdStringSearch::new();
assert_eq!(
searcher.compare_strings(b"hello", b"hello"),
Ordering::Equal
);
assert_eq!(searcher.compare_strings(b"abc", b"abd"), Ordering::Less);
assert_eq!(
searcher.compare_strings(b"xyz", b"abc"),
Ordering::Greater
);
assert_eq!(searcher.compare_strings(b"abc", b"abcd"), Ordering::Less);
assert_eq!(
searcher.compare_strings(b"abcd", b"abc"),
Ordering::Greater
);
}
#[test]
fn test_compare_strings_long() {
let searcher = SimdStringSearch::new();
let s1 = b"0123456789abcdefghijklmnopqrstuvwxyz";
let s2 = b"0123456789abcdefghijklmnopqrstuvwxyz";
assert_eq!(searcher.compare_strings(s1, s2), Ordering::Equal);
let s1 = b"0123456789abcdefghijklmnopqrstuvwxyz";
let s2 = b"0123456789abcdefghijklmnopqrstuvwxyZ"; assert_eq!(searcher.compare_strings(s1, s2), Ordering::Greater); }
#[test]
fn test_convenience_functions() {
assert_eq!(find_char(b"hello", b'e'), Some(1));
assert_eq!(find_pattern(b"hello world", b"world"), Some(6));
assert_eq!(find_any_of(b"hello", b"aeiou"), Some(1));
assert_eq!(compare_strings(b"abc", b"abc"), Ordering::Equal);
}
#[test]
fn test_scalar_fallback() {
assert_eq!(scalar_strchr(b"hello", b'e'), Some(1));
assert_eq!(scalar_strstr(b"hello world", b"world"), Some(6));
assert_eq!(scalar_multi_search(b"hello", b"aeiou"), Some(1));
assert_eq!(scalar_strcmp(b"abc", b"abc"), Ordering::Equal);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sse42_implementation() {
if !is_x86_feature_detected!("sse4.2") {
println!("SSE4.2 not available, skipping test");
return;
}
assert_eq!(sse42_strchr(b"hello world", b'w'), Some(6));
assert_eq!(sse42_strstr(b"hello world", b"world"), Some(6));
assert_eq!(sse42_multi_search(b"hello", b"aeiou"), Some(1));
assert_eq!(sse42_strcmp(b"abc", b"abc"), Ordering::Equal);
}
#[test]
fn test_edge_cases() {
let searcher = SimdStringSearch::new();
assert_eq!(searcher.find_char(b"a", b'a'), Some(0));
assert_eq!(searcher.find_char(b"a", b'b'), None);
assert_eq!(searcher.find_pattern(b"hello", b"h"), Some(0));
assert_eq!(searcher.find_pattern(b"hello", b"lo"), Some(3));
assert_eq!(searcher.find_char(b"aaaa", b'a'), Some(0));
assert_eq!(searcher.find_pattern(b"aaaa", b"aa"), Some(0));
}
#[test]
fn test_custom_config() {
let config = SearchConfig {
enable_sse42: false,
enable_avx2: false,
enable_avx512: false,
enable_neon: false,
};
let searcher = SimdStringSearch::with_config(config);
assert_eq!(searcher.tier(), SearchTier::Scalar);
assert_eq!(searcher.find_char(b"hello", b'e'), Some(1));
assert_eq!(searcher.find_pattern(b"hello world", b"world"), Some(6));
}
}