#[cfg(all(feature = "simd", target_arch = "x86_64"))]
use core::arch::x86_64::*;
pub(crate) const SIMD_THRESHOLD: usize = 16;
#[cfg(all(feature = "simd", target_arch = "x86_64", feature = "std"))]
#[inline]
fn has_sse2() -> bool {
std::arch::is_x86_feature_detected!("sse2")
}
#[cfg(all(feature = "simd", target_arch = "x86_64", not(feature = "std")))]
#[inline]
fn has_sse2() -> bool {
true
}
#[inline]
pub(crate) fn eq_bytes(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && a.len() >= SIMD_THRESHOLD {
return unsafe { eq_bytes_sse2(a, b) };
}
}
a == b
}
#[inline]
pub(crate) fn starts_with_bytes(haystack: &[u8], needle: &[u8]) -> bool {
if needle.len() > haystack.len() {
return false;
}
if needle.is_empty() {
return true;
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && needle.len() >= SIMD_THRESHOLD {
return unsafe { eq_bytes_sse2(&haystack[..needle.len()], needle) };
}
}
haystack.starts_with(needle)
}
#[inline]
pub(crate) fn ends_with_bytes(haystack: &[u8], needle: &[u8]) -> bool {
if needle.len() > haystack.len() {
return false;
}
if needle.is_empty() {
return true;
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && needle.len() >= SIMD_THRESHOLD {
let start = haystack.len() - needle.len();
return unsafe { eq_bytes_sse2(&haystack[start..], needle) };
}
}
haystack.ends_with(needle)
}
#[allow(dead_code)]
#[inline]
pub(crate) fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() {
return Some(0);
}
let needle_len = needle.len();
if needle_len > haystack.len() {
return None;
}
if needle_len == 1 {
return find_first_byte(haystack, needle[0]);
}
if needle_len < SIMD_THRESHOLD {
return find_short_bytes(haystack, needle);
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && needle_len >= SIMD_THRESHOLD && haystack.len() >= SIMD_THRESHOLD {
return unsafe { find_bytes_sse2(haystack, needle) };
}
}
haystack
.windows(needle_len)
.position(|window| window == needle)
}
#[allow(dead_code)]
#[inline]
fn find_first_byte(haystack: &[u8], needle: u8) -> Option<usize> {
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && haystack.len() >= SIMD_THRESHOLD {
return unsafe { find_byte_sse2(haystack, needle) };
}
}
haystack.iter().position(|&b| b == needle)
}
#[allow(dead_code)]
#[inline]
fn find_short_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
debug_assert!(needle.len() > 1 && needle.len() < SIMD_THRESHOLD);
let needle_len = needle.len();
let last_start = haystack.len() - needle_len;
let mut pos = 0;
while pos <= last_start {
let offset = find_first_byte(&haystack[pos..last_start + 1], needle[0])?;
let candidate = pos + offset;
if haystack[candidate + 1..candidate + needle_len] == needle[1..] {
return Some(candidate);
}
pos = candidate + 1;
}
None
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn eq_bytes_sse2(a: &[u8], b: &[u8]) -> bool {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let mut offset = 0;
while offset + 16 <= len {
let a_vec = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
let b_vec = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
let cmp = _mm_cmpeq_epi8(a_vec, b_vec);
let mask = _mm_movemask_epi8(cmp);
if mask != 0xFFFF {
return false;
}
offset += 16;
}
for i in offset..len {
if a[i] != b[i] {
return false;
}
}
true
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[allow(dead_code)]
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn find_bytes_sse2(haystack: &[u8], needle: &[u8]) -> Option<usize> {
let haystack_len = haystack.len();
let needle_len = needle.len();
if needle_len > haystack_len {
return None;
}
if needle_len == 1 {
return find_byte_sse2(haystack, needle[0]);
}
let first_byte = needle[0];
let mut pos = 0;
while pos + needle_len <= haystack_len {
let offset = find_byte_sse2(&haystack[pos..], first_byte)?;
let candidate_pos = pos + offset;
if candidate_pos + needle_len <= haystack_len {
if eq_bytes_sse2(&haystack[candidate_pos..candidate_pos + needle_len], needle) {
return Some(candidate_pos);
}
pos = candidate_pos + 1;
} else {
return None;
}
}
None
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[allow(dead_code)]
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn find_byte_sse2(haystack: &[u8], needle: u8) -> Option<usize> {
let len = haystack.len();
let mut offset = 0;
let needle_vec = _mm_set1_epi8(needle as i8);
while offset + 16 <= len {
let haystack_vec = _mm_loadu_si128(haystack.as_ptr().add(offset) as *const __m128i);
let cmp = _mm_cmpeq_epi8(haystack_vec, needle_vec);
let mask = _mm_movemask_epi8(cmp);
if mask != 0 {
let bit_pos = mask.trailing_zeros() as usize;
return Some(offset + bit_pos);
}
offset += 16;
}
haystack[offset..len]
.iter()
.position(|&b| b == needle)
.map(|pos| offset + pos)
}
#[cfg(all(test, feature = "simd"))]
mod tests {
use super::*;
#[test]
fn test_eq_bytes() {
let a = b"hello world, this is a test";
let b = b"hello world, this is a test";
let c = b"hello world, this is b test";
assert!(eq_bytes(a, b));
assert!(!eq_bytes(a, c));
assert!(!eq_bytes(&a[..10], a));
}
#[test]
fn test_starts_with_bytes() {
let haystack = b"hello world, this is a test";
assert!(starts_with_bytes(haystack, b"hello"));
assert!(starts_with_bytes(haystack, b"hello world"));
assert!(!starts_with_bytes(haystack, b"world"));
assert!(starts_with_bytes(haystack, b""));
}
#[test]
fn test_ends_with_bytes() {
let haystack = b"hello world, this is a test";
assert!(ends_with_bytes(haystack, b"test"));
assert!(ends_with_bytes(haystack, b"a test"));
assert!(!ends_with_bytes(haystack, b"hello"));
assert!(ends_with_bytes(haystack, b""));
}
#[test]
fn test_find_bytes() {
let haystack = b"hello world, this is a test";
assert_eq!(find_bytes(haystack, b"world"), Some(6));
assert_eq!(find_bytes(haystack, b"test"), Some(23));
assert_eq!(find_bytes(haystack, b"xyz"), None);
assert_eq!(find_bytes(haystack, b""), Some(0));
assert_eq!(find_bytes(b"aaabaaaab", b"aab"), Some(1));
assert_eq!(find_bytes(b"abcabcabcd", b"abcd"), Some(6));
assert_eq!(find_bytes(b"aaaaaa", b"aab"), None);
}
#[test]
fn test_find_byte() {
let haystack = b"hello world";
unsafe {
assert_eq!(find_byte_sse2(haystack, b'w'), Some(6));
assert_eq!(find_byte_sse2(haystack, b'h'), Some(0));
assert_eq!(find_byte_sse2(haystack, b'd'), Some(10));
assert_eq!(find_byte_sse2(haystack, b'x'), None);
}
}
#[test]
fn test_simd_threshold() {
let small_a = b"hello";
let small_b = b"hello";
assert!(eq_bytes(small_a, small_b));
let large_a = b"this is a longer string that exceeds the SIMD threshold";
let large_b = b"this is a longer string that exceeds the SIMD threshold";
assert!(eq_bytes(large_a, large_b));
}
}