pub fn memchr(needle: u8, haystack: &[u8]) -> Option<usize> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { memchr_avx2(needle, haystack) };
}
}
memchr_scalar(needle, haystack)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn memchr_avx2(needle: u8, haystack: &[u8]) -> Option<usize> {
use std::arch::x86_64::*;
let len = haystack.len();
if len == 0 {
return None;
}
let ptr = haystack.as_ptr();
let needle_vec = _mm256_set1_epi8(needle as i8);
let mut offset = 0;
while offset + 32 <= len {
let data = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let cmp = _mm256_cmpeq_epi8(data, needle_vec);
let mask = _mm256_movemask_epi8(cmp) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
(offset..len).find(|&i| *haystack.get_unchecked(i) == needle)
}
#[inline]
fn memchr_scalar(needle: u8, haystack: &[u8]) -> Option<usize> {
haystack.iter().position(|&b| b == needle)
}
pub fn memchr2(needle1: u8, needle2: u8, haystack: &[u8]) -> Option<usize> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { memchr2_avx2(needle1, needle2, haystack) };
}
}
haystack.iter().position(|&b| b == needle1 || b == needle2)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn memchr2_avx2(needle1: u8, needle2: u8, haystack: &[u8]) -> Option<usize> {
use std::arch::x86_64::*;
let len = haystack.len();
if len == 0 {
return None;
}
let ptr = haystack.as_ptr();
let needle1_vec = _mm256_set1_epi8(needle1 as i8);
let needle2_vec = _mm256_set1_epi8(needle2 as i8);
let mut offset = 0;
while offset + 32 <= len {
let data = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let cmp1 = _mm256_cmpeq_epi8(data, needle1_vec);
let cmp2 = _mm256_cmpeq_epi8(data, needle2_vec);
let combined = _mm256_or_si256(cmp1, cmp2);
let mask = _mm256_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
for i in offset..len {
let b = *haystack.get_unchecked(i);
if b == needle1 || b == needle2 {
return Some(i);
}
}
None
}
pub fn memchr3(needle1: u8, needle2: u8, needle3: u8, haystack: &[u8]) -> Option<usize> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { memchr3_avx2(needle1, needle2, needle3, haystack) };
}
}
haystack
.iter()
.position(|&b| b == needle1 || b == needle2 || b == needle3)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn memchr3_avx2(needle1: u8, needle2: u8, needle3: u8, haystack: &[u8]) -> Option<usize> {
use std::arch::x86_64::*;
let len = haystack.len();
if len == 0 {
return None;
}
let ptr = haystack.as_ptr();
let needle1_vec = _mm256_set1_epi8(needle1 as i8);
let needle2_vec = _mm256_set1_epi8(needle2 as i8);
let needle3_vec = _mm256_set1_epi8(needle3 as i8);
let mut offset = 0;
while offset + 32 <= len {
let data = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let cmp1 = _mm256_cmpeq_epi8(data, needle1_vec);
let cmp2 = _mm256_cmpeq_epi8(data, needle2_vec);
let cmp3 = _mm256_cmpeq_epi8(data, needle3_vec);
let combined = _mm256_or_si256(_mm256_or_si256(cmp1, cmp2), cmp3);
let mask = _mm256_movemask_epi8(combined) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
for i in offset..len {
let b = *haystack.get_unchecked(i);
if b == needle1 || b == needle2 || b == needle3 {
return Some(i);
}
}
None
}
pub fn memrchr(needle: u8, haystack: &[u8]) -> Option<usize> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { memrchr_avx2(needle, haystack) };
}
}
haystack.iter().rposition(|&b| b == needle)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn memrchr_avx2(needle: u8, haystack: &[u8]) -> Option<usize> {
use std::arch::x86_64::*;
let len = haystack.len();
if len == 0 {
return None;
}
let ptr = haystack.as_ptr();
let needle_vec = _mm256_set1_epi8(needle as i8);
let mut offset = (len / 32) * 32;
for i in (offset..len).rev() {
if *haystack.get_unchecked(i) == needle {
return Some(i);
}
}
while offset >= 32 {
offset -= 32;
let data = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let cmp = _mm256_cmpeq_epi8(data, needle_vec);
let mask = _mm256_movemask_epi8(cmp) as u32;
if mask != 0 {
return Some(offset + 31 - mask.leading_zeros() as usize);
}
}
None
}
pub fn memchr_range(lo: u8, hi: u8, haystack: &[u8]) -> Option<usize> {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { memchr_range_avx2(lo, hi, haystack) };
}
}
haystack.iter().position(|&b| b >= lo && b <= hi)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn memchr_range_avx2(lo: u8, hi: u8, haystack: &[u8]) -> Option<usize> {
use std::arch::x86_64::*;
let len = haystack.len();
if len == 0 {
return None;
}
let ptr = haystack.as_ptr();
let bias = _mm256_set1_epi8(-128i8); let lo_biased = _mm256_set1_epi8((lo as i8).wrapping_add(-128i8));
let hi_biased = _mm256_set1_epi8((hi as i8).wrapping_add(-128i8));
let mut offset = 0;
while offset + 32 <= len {
let data = _mm256_loadu_si256(ptr.add(offset) as *const __m256i);
let data_biased = _mm256_add_epi8(data, bias);
let lt_lo = _mm256_cmpgt_epi8(lo_biased, data_biased); let gt_hi = _mm256_cmpgt_epi8(data_biased, hi_biased);
let out_of_range = _mm256_or_si256(lt_lo, gt_hi);
let in_range = _mm256_andnot_si256(out_of_range, _mm256_set1_epi8(-1i8));
let mask = _mm256_movemask_epi8(in_range) as u32;
if mask != 0 {
return Some(offset + mask.trailing_zeros() as usize);
}
offset += 32;
}
for i in offset..len {
let b = *haystack.get_unchecked(i);
if b >= lo && b <= hi {
return Some(i);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memchr() {
assert_eq!(memchr(b'o', b"hello"), Some(4));
assert_eq!(memchr(b'x', b"hello"), None);
assert_eq!(memchr(b'h', b"hello"), Some(0));
assert_eq!(memchr(b'o', b"hello world"), Some(4));
}
#[test]
fn test_memchr_empty() {
assert_eq!(memchr(b'x', b""), None);
}
#[test]
fn test_memchr_large() {
let data = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaax";
assert_eq!(memchr(b'x', data), Some(64));
let data = b"xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
assert_eq!(memchr(b'x', data), Some(0));
}
#[test]
fn test_memchr2() {
assert_eq!(memchr2(b'e', b'o', b"hello"), Some(1));
assert_eq!(memchr2(b'x', b'y', b"hello"), None);
assert_eq!(memchr2(b'o', b'h', b"hello"), Some(0)); }
#[test]
fn test_memchr3() {
assert_eq!(memchr3(b'x', b'y', b'e', b"hello"), Some(1));
assert_eq!(memchr3(b'x', b'y', b'z', b"hello"), None);
}
#[test]
fn test_memrchr() {
assert_eq!(memrchr(b'l', b"hello"), Some(3));
assert_eq!(memrchr(b'x', b"hello"), None);
assert_eq!(memrchr(b'o', b"hello world"), Some(7));
}
#[test]
fn test_memrchr_large() {
let data = b"xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaax";
assert_eq!(memrchr(b'x', data), Some(64));
}
#[test]
fn test_memchr_at_chunk_boundaries() {
for pos in [0, 1, 15, 16, 30, 31, 32, 33, 63, 64, 65] {
if pos < 70 {
let mut data = vec![b'a'; 70];
data[pos] = b'x';
assert_eq!(memchr(b'x', &data), Some(pos), "Failed at position {}", pos);
}
}
}
#[test]
fn test_memchr_range_digits() {
assert_eq!(memchr_range(b'0', b'9', b"hello 123 world"), Some(6));
assert_eq!(memchr_range(b'0', b'9', b"no digits here"), None);
assert_eq!(memchr_range(b'0', b'9', b"0 at start"), Some(0));
assert_eq!(memchr_range(b'0', b'9', b"end is 9"), Some(7));
}
#[test]
fn test_memchr_range_letters() {
assert_eq!(memchr_range(b'a', b'z', b"123 abc"), Some(4));
assert_eq!(memchr_range(b'a', b'z', b"123 456"), None);
}
#[test]
fn test_memchr_range_large() {
let data = b"............................................5....................";
assert_eq!(memchr_range(b'0', b'9', data), Some(44));
}
#[test]
fn test_memchr_range_empty() {
assert_eq!(memchr_range(b'0', b'9', b""), None);
}
#[test]
fn test_memchr_range_all_match() {
assert_eq!(memchr_range(b'0', b'9', b"123456"), Some(0));
}
}