#[cfg(all(feature = "simd", target_arch = "x86_64"))]
use core::arch::x86_64::*;
pub(crate) const SIMD_THRESHOLD: usize = 16;
#[cfg(all(feature = "simd", target_arch = "x86_64", feature = "std"))]
#[inline]
fn has_sse2() -> bool {
std::arch::is_x86_feature_detected!("sse2")
}
#[cfg(all(feature = "simd", target_arch = "x86_64", not(feature = "std")))]
#[inline]
fn has_sse2() -> bool {
true
}
#[inline]
pub(crate) fn eq_bytes(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && a.len() >= SIMD_THRESHOLD {
return unsafe { eq_bytes_sse2(a, b) };
}
}
a == b
}
#[inline]
pub(crate) fn starts_with_bytes(haystack: &[u8], needle: &[u8]) -> bool {
if needle.len() > haystack.len() {
return false;
}
if needle.is_empty() {
return true;
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && needle.len() >= SIMD_THRESHOLD {
return unsafe { eq_bytes_sse2(&haystack[..needle.len()], needle) };
}
}
haystack.starts_with(needle)
}
#[inline]
pub(crate) fn ends_with_bytes(haystack: &[u8], needle: &[u8]) -> bool {
if needle.len() > haystack.len() {
return false;
}
if needle.is_empty() {
return true;
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
{
if has_sse2() && needle.len() >= SIMD_THRESHOLD {
let start = haystack.len() - needle.len();
return unsafe { eq_bytes_sse2(&haystack[start..], needle) };
}
}
haystack.ends_with(needle)
}
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
#[inline]
unsafe fn eq_bytes_sse2(a: &[u8], b: &[u8]) -> bool {
debug_assert_eq!(a.len(), b.len());
let len = a.len();
let mut offset = 0;
while offset + 16 <= len {
let a_vec = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
let b_vec = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
let cmp = _mm_cmpeq_epi8(a_vec, b_vec);
let mask = _mm_movemask_epi8(cmp);
if mask != 0xFFFF {
return false;
}
offset += 16;
}
for i in offset..len {
if a[i] != b[i] {
return false;
}
}
true
}
#[cfg(all(test, feature = "simd"))]
mod tests {
use super::*;
#[test]
fn test_eq_bytes() {
let a = b"hello world, this is a test";
let b = b"hello world, this is a test";
let c = b"hello world, this is b test";
assert!(eq_bytes(a, b));
assert!(!eq_bytes(a, c));
assert!(!eq_bytes(&a[..10], a));
}
#[test]
fn test_starts_with_bytes() {
let haystack = b"hello world, this is a test";
assert!(starts_with_bytes(haystack, b"hello"));
assert!(starts_with_bytes(haystack, b"hello world"));
assert!(!starts_with_bytes(haystack, b"world"));
assert!(starts_with_bytes(haystack, b""));
}
#[test]
fn test_ends_with_bytes() {
let haystack = b"hello world, this is a test";
assert!(ends_with_bytes(haystack, b"test"));
assert!(ends_with_bytes(haystack, b"a test"));
assert!(!ends_with_bytes(haystack, b"hello"));
assert!(ends_with_bytes(haystack, b""));
}
#[test]
fn test_simd_threshold() {
let small_a = b"hello";
let small_b = b"hello";
assert!(eq_bytes(small_a, small_b));
let large_a = b"this is a longer string that exceeds the SIMD threshold";
let large_b = b"this is a longer string that exceeds the SIMD threshold";
assert!(eq_bytes(large_a, large_b));
}
}