#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn broadcast_byte(byte: u8) -> __m256i {
_mm256_set1_epi8(byte as i8)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn load_unaligned(ptr: *const u8) -> __m256i {
_mm256_loadu_si256(ptr as *const __m256i)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn compare_eq(a: __m256i, b: __m256i) -> __m256i {
_mm256_cmpeq_epi8(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn movemask(v: __m256i) -> u32 {
_mm256_movemask_epi8(v) as u32
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn find_byte_in_chunk(chunk: &[u8; 32], needle: u8) -> Option<usize> {
let needle_vec = broadcast_byte(needle);
let data = load_unaligned(chunk.as_ptr());
let cmp = compare_eq(data, needle_vec);
let mask = movemask(cmp);
if mask != 0 {
Some(mask.trailing_zeros() as usize)
} else {
None
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn find_byte2_in_chunk(chunk: &[u8; 32], needle1: u8, needle2: u8) -> Option<usize> {
let needle1_vec = broadcast_byte(needle1);
let needle2_vec = broadcast_byte(needle2);
let data = load_unaligned(chunk.as_ptr());
let cmp1 = compare_eq(data, needle1_vec);
let cmp2 = compare_eq(data, needle2_vec);
let combined = _mm256_or_si256(cmp1, cmp2);
let mask = movemask(combined);
if mask != 0 {
Some(mask.trailing_zeros() as usize)
} else {
None
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn find_byte3_in_chunk(
chunk: &[u8; 32],
needle1: u8,
needle2: u8,
needle3: u8,
) -> Option<usize> {
let needle1_vec = broadcast_byte(needle1);
let needle2_vec = broadcast_byte(needle2);
let needle3_vec = broadcast_byte(needle3);
let data = load_unaligned(chunk.as_ptr());
let cmp1 = compare_eq(data, needle1_vec);
let cmp2 = compare_eq(data, needle2_vec);
let cmp3 = compare_eq(data, needle3_vec);
let combined = _mm256_or_si256(_mm256_or_si256(cmp1, cmp2), cmp3);
let mask = movemask(combined);
if mask != 0 {
Some(mask.trailing_zeros() as usize)
} else {
None
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn shuffle_nibble(table: __m256i, indices: __m256i) -> __m256i {
let low_nibble = _mm256_and_si256(indices, _mm256_set1_epi8(0x0F));
_mm256_shuffle_epi8(table, low_nibble)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn and(a: __m256i, b: __m256i) -> __m256i {
_mm256_and_si256(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn or(a: __m256i, b: __m256i) -> __m256i {
_mm256_or_si256(a, b)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn zero() -> __m256i {
_mm256_setzero_si256()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
#[allow(dead_code)]
pub unsafe fn all_ones() -> __m256i {
_mm256_set1_epi8(-1)
}
#[cfg(all(test, target_arch = "x86_64"))]
mod tests {
use super::*;
#[test]
fn test_find_byte_in_chunk() {
if !is_x86_feature_detected!("avx2") {
return;
}
let mut chunk = [0u8; 32];
chunk[15] = b'x';
unsafe {
assert_eq!(find_byte_in_chunk(&chunk, b'x'), Some(15));
assert_eq!(find_byte_in_chunk(&chunk, b'y'), None);
}
}
#[test]
fn test_find_byte2_in_chunk() {
if !is_x86_feature_detected!("avx2") {
return;
}
let mut chunk = [0u8; 32];
chunk[10] = b'a';
chunk[20] = b'b';
unsafe {
assert_eq!(find_byte2_in_chunk(&chunk, b'a', b'b'), Some(10));
assert_eq!(find_byte2_in_chunk(&chunk, b'x', b'b'), Some(20));
assert_eq!(find_byte2_in_chunk(&chunk, b'x', b'y'), None);
}
}
#[test]
fn test_find_byte3_in_chunk() {
if !is_x86_feature_detected!("avx2") {
return;
}
let mut chunk = [0u8; 32];
chunk[5] = b'c';
chunk[15] = b'a';
chunk[25] = b'b';
unsafe {
assert_eq!(find_byte3_in_chunk(&chunk, b'a', b'b', b'c'), Some(5));
}
}
}