use crate::system::cpu_features::{CpuFeatures, get_cpu_features};
use crate::error::{Result, ZiporaError};
use crate::memory::cache_layout::{PrefetchHint, CacheLayoutConfig, align_to_cache_line};
use std::ptr;
const SMALL_COPY_THRESHOLD: usize = 64;
const MEDIUM_COPY_THRESHOLD: usize = 4096;
const CACHE_LINE_SIZE: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdTier {
Avx512,
Avx2,
Sse2,
Scalar,
}
#[derive(Debug, Clone)]
pub struct SimdMemOps {
tier: SimdTier,
cpu_features: &'static CpuFeatures,
cache_config: CacheLayoutConfig,
}
impl SimdMemOps {
pub fn new() -> Self {
let cpu_features = get_cpu_features();
let tier = Self::select_optimal_tier(cpu_features);
Self {
tier,
cpu_features,
cache_config: CacheLayoutConfig::new(),
}
}
pub fn with_cache_config(cache_config: CacheLayoutConfig) -> Self {
let cpu_features = get_cpu_features();
let tier = Self::select_optimal_tier(cpu_features);
Self {
tier,
cpu_features,
cache_config,
}
}
fn select_optimal_tier(features: &CpuFeatures) -> SimdTier {
if features.has_avx512f && features.has_avx512vl && features.has_avx512bw {
SimdTier::Avx512
} else if features.has_avx2 {
SimdTier::Avx2
} else if features.has_sse41 && features.has_sse42 {
SimdTier::Sse2
} else {
SimdTier::Scalar
}
}
pub fn tier(&self) -> SimdTier {
self.tier
}
pub fn cpu_features(&self) -> &CpuFeatures {
self.cpu_features
}
pub fn cache_config(&self) -> &CacheLayoutConfig {
&self.cache_config
}
}
impl SimdMemOps {
pub fn copy_nonoverlapping(&self, src: &[u8], dst: &mut [u8]) -> Result<()> {
if src.len() != dst.len() {
return Err(ZiporaError::invalid_data(
format!("Source and destination lengths don't match: {} vs {}", src.len(), dst.len())
));
}
if src.is_empty() {
return Ok(());
}
let src_start = src.as_ptr() as usize;
let src_end = src_start + src.len();
let dst_start = dst.as_mut_ptr() as usize;
let dst_end = dst_start + dst.len();
if (src_start < dst_end && dst_start < src_end) {
return Err(ZiporaError::invalid_data(
"Source and destination slices must not overlap".to_string()
));
}
unsafe {
self.simd_memcpy_unaligned(dst.as_mut_ptr(), src.as_ptr(), src.len());
}
Ok(())
}
pub fn copy_aligned(&self, src: &[u8], dst: &mut [u8]) -> Result<()> {
if src.len() != dst.len() {
return Err(ZiporaError::invalid_data(
format!("Source and destination lengths don't match: {} vs {}", src.len(), dst.len())
));
}
let src_aligned = (src.as_ptr() as usize) % CACHE_LINE_SIZE == 0;
let dst_aligned = (dst.as_mut_ptr() as usize) % CACHE_LINE_SIZE == 0;
if !src_aligned || !dst_aligned {
return Err(ZiporaError::invalid_data(
"Source and destination must be 64-byte aligned for aligned copy".to_string()
));
}
if src.is_empty() {
return Ok(());
}
unsafe {
self.simd_memcpy_aligned(dst.as_mut_ptr(), src.as_ptr(), src.len());
}
Ok(())
}
pub fn compare(&self, a: &[u8], b: &[u8]) -> i32 {
use std::cmp::Ordering;
match a.len().cmp(&b.len()) {
Ordering::Less => {
let result = unsafe { self.simd_memcmp(a.as_ptr(), b.as_ptr(), a.len()) };
if result == 0 { -1 } else { result }
}
Ordering::Greater => {
let result = unsafe { self.simd_memcmp(a.as_ptr(), b.as_ptr(), b.len()) };
if result == 0 { 1 } else { result }
}
Ordering::Equal => {
if a.is_empty() { 0 } else {
unsafe { self.simd_memcmp(a.as_ptr(), b.as_ptr(), a.len()) }
}
}
}
}
pub fn find_byte(&self, haystack: &[u8], needle: u8) -> Option<usize> {
if haystack.is_empty() {
return None;
}
unsafe { self.simd_memchr(haystack.as_ptr(), needle, haystack.len()) }
}
pub fn fill(&self, slice: &mut [u8], value: u8) {
if slice.is_empty() {
return;
}
unsafe {
self.simd_memset(slice.as_mut_ptr(), value, slice.len());
}
}
pub fn prefetch(&self, addr: *const u8, hint: PrefetchHint) {
if !self.cache_config.enable_prefetch {
return;
}
#[cfg(target_arch = "x86_64")]
unsafe {
match hint {
PrefetchHint::T0 => std::arch::x86_64::_mm_prefetch(addr as *const i8, std::arch::x86_64::_MM_HINT_T0),
PrefetchHint::T1 => std::arch::x86_64::_mm_prefetch(addr as *const i8, std::arch::x86_64::_MM_HINT_T1),
PrefetchHint::T2 => std::arch::x86_64::_mm_prefetch(addr as *const i8, std::arch::x86_64::_MM_HINT_T2),
PrefetchHint::NTA => std::arch::x86_64::_mm_prefetch(addr as *const i8, std::arch::x86_64::_MM_HINT_NTA),
}
}
#[cfg(target_arch = "aarch64")]
unsafe {
match hint {
PrefetchHint::T0 | PrefetchHint::T1 => {
std::arch::asm!("prfm pldl1keep, [{}]", in(reg) addr);
}
PrefetchHint::T2 => {
std::arch::asm!("prfm pldl2keep, [{}]", in(reg) addr);
}
PrefetchHint::NTA => {
std::arch::asm!("prfm pldl1strm, [{}]", in(reg) addr);
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = (addr, hint);
}
}
pub fn prefetch_range(&self, data: &[u8]) {
if !self.cache_config.enable_prefetch || data.is_empty() {
return;
}
let distance = self.cache_config.prefetch_distance;
let cache_line_size = self.cache_config.cache_line_size;
let step_size = cache_line_size.min(distance);
for chunk in data.chunks(step_size) {
self.prefetch(chunk.as_ptr(), PrefetchHint::T0);
}
}
pub fn copy_cache_optimized(&self, src: &[u8], dst: &mut [u8]) -> Result<()> {
if src.len() != dst.len() {
return Err(ZiporaError::invalid_data(
format!("Source and destination lengths don't match: {} vs {}", src.len(), dst.len())
));
}
if src.is_empty() {
return Ok(());
}
if src.len() >= self.cache_config.prefetch_distance && self.cache_config.enable_prefetch {
self.prefetch_range(src);
if src.len() >= MEDIUM_COPY_THRESHOLD {
std::hint::spin_loop();
}
}
let src_aligned = (src.as_ptr() as usize) % self.cache_config.cache_line_size == 0;
let dst_aligned = (dst.as_mut_ptr() as usize) % self.cache_config.cache_line_size == 0;
if src_aligned && dst_aligned && src.len() >= self.cache_config.cache_line_size {
self.copy_aligned(src, dst)
} else {
self.copy_nonoverlapping(src, dst)
}
}
pub fn compare_cache_optimized(&self, a: &[u8], b: &[u8]) -> i32 {
let min_len = a.len().min(b.len());
if min_len >= self.cache_config.prefetch_distance && self.cache_config.enable_prefetch {
self.prefetch_range(a);
self.prefetch_range(b);
}
self.compare(a, b)
}
}
impl SimdMemOps {
#[inline]
unsafe fn simd_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, len: usize) {
match (self.tier, len) {
(SimdTier::Avx512, len) if len >= 64 => {
unsafe { self.avx512_memcpy_unaligned(dst, src, len); }
}
(SimdTier::Avx2, len) if len >= 32 => {
unsafe { self.avx2_memcpy_unaligned(dst, src, len); }
}
(SimdTier::Sse2, len) if len >= 16 => {
unsafe { self.sse2_memcpy_unaligned(dst, src, len); }
}
_ => {
unsafe { self.scalar_memcpy(dst, src, len); }
}
}
}
#[inline]
unsafe fn simd_memcpy_aligned(&self, dst: *mut u8, src: *const u8, len: usize) {
match (self.tier, len) {
(SimdTier::Avx512, len) if len >= 64 => {
unsafe { self.avx512_memcpy_aligned(dst, src, len); }
}
(SimdTier::Avx2, len) if len >= 32 => {
unsafe { self.avx2_memcpy_aligned(dst, src, len); }
}
(SimdTier::Sse2, len) if len >= 16 => {
unsafe { self.sse2_memcpy_aligned(dst, src, len); }
}
_ => {
unsafe { self.scalar_memcpy(dst, src, len); }
}
}
}
#[inline]
unsafe fn simd_memcmp(&self, a: *const u8, b: *const u8, len: usize) -> i32 {
match (self.tier, len) {
(SimdTier::Avx512, len) if len >= 64 => {
unsafe { self.avx512_memcmp(a, b, len) }
}
(SimdTier::Avx2, len) if len >= 32 => {
unsafe { self.avx2_memcmp(a, b, len) }
}
(SimdTier::Sse2, len) if len >= 16 => {
unsafe { self.sse2_memcmp(a, b, len) }
}
_ => {
unsafe { self.scalar_memcmp(a, b, len) }
}
}
}
#[inline]
unsafe fn simd_memchr(&self, haystack: *const u8, needle: u8, len: usize) -> Option<usize> {
match (self.tier, len) {
(SimdTier::Avx512, len) if len >= 64 => {
unsafe { self.avx512_memchr(haystack, needle, len) }
}
(SimdTier::Avx2, len) if len >= 32 => {
unsafe { self.avx2_memchr(haystack, needle, len) }
}
(SimdTier::Sse2, len) if len >= 16 => {
unsafe { self.sse2_memchr(haystack, needle, len) }
}
_ => {
unsafe { self.scalar_memchr(haystack, needle, len) }
}
}
}
#[inline]
unsafe fn simd_memset(&self, ptr: *mut u8, value: u8, len: usize) {
match (self.tier, len) {
(SimdTier::Avx512, len) if len >= 64 => {
unsafe { self.avx512_memset(ptr, value, len); }
}
(SimdTier::Avx2, len) if len >= 32 => {
unsafe { self.avx2_memset(ptr, value, len); }
}
(SimdTier::Sse2, len) if len >= 16 => {
unsafe { self.sse2_memset(ptr, value, len); }
}
_ => {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
}
}
#[cfg(target_arch = "x86_64")]
impl SimdMemOps {
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn avx512_memcpy_aligned(&self, dst: *mut u8, src: *const u8, mut len: usize) {
use std::arch::x86_64::*;
let mut dst_ptr = dst;
let mut src_ptr = src;
while len >= 64 {
unsafe {
let data = _mm512_load_si512(src_ptr as *const __m512i);
_mm512_store_si512(dst_ptr as *mut __m512i, data);
src_ptr = src_ptr.add(64);
dst_ptr = dst_ptr.add(64);
}
len -= 64;
}
if len > 0 {
unsafe { self.scalar_memcpy(dst_ptr, src_ptr, len); }
}
}
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn avx512_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, mut len: usize) {
use std::arch::x86_64::*;
let mut dst_ptr = dst;
let mut src_ptr = src;
while len >= 64 {
unsafe {
let data = _mm512_loadu_si512(src_ptr as *const __m512i);
_mm512_storeu_si512(dst_ptr as *mut __m512i, data);
src_ptr = src_ptr.add(64);
dst_ptr = dst_ptr.add(64);
}
len -= 64;
}
if len > 0 {
unsafe { self.scalar_memcpy(dst_ptr, src_ptr, len); }
}
}
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn avx512_memcmp(&self, mut a: *const u8, mut b: *const u8, mut len: usize) -> i32 {
use std::arch::x86_64::*;
while len >= 64 {
unsafe {
let va = _mm512_loadu_si512(a as *const __m512i);
let vb = _mm512_loadu_si512(b as *const __m512i);
let mask = _mm512_cmpneq_epu8_mask(va, vb);
if mask != 0 {
let byte_offset = mask.trailing_zeros() as usize;
let byte_a = *a.add(byte_offset);
let byte_b = *b.add(byte_offset);
return (byte_a as i32) - (byte_b as i32);
}
a = a.add(64);
b = b.add(64);
}
len -= 64;
}
if len > 0 {
unsafe { self.scalar_memcmp(a, b, len) }
} else {
0
}
}
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn avx512_memchr(&self, mut haystack: *const u8, needle: u8, mut len: usize) -> Option<usize> {
use std::arch::x86_64::*;
let needle_vec = unsafe { _mm512_set1_epi8(needle as i8) };
let mut offset = 0;
while len >= 64 {
unsafe {
let data = _mm512_loadu_si512(haystack as *const __m512i);
let mask = _mm512_cmpeq_epu8_mask(data, needle_vec);
if mask != 0 {
let byte_offset = mask.trailing_zeros() as usize;
return Some(offset + byte_offset);
}
haystack = haystack.add(64);
}
offset += 64;
len -= 64;
}
if len > 0 {
if let Some(remaining_offset) = unsafe { self.scalar_memchr(haystack, needle, len) } {
Some(offset + remaining_offset)
} else {
None
}
} else {
None
}
}
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn avx512_memset(&self, mut ptr: *mut u8, value: u8, mut len: usize) {
use std::arch::x86_64::*;
let value_vec = unsafe { _mm512_set1_epi8(value as i8) };
while len >= 64 {
unsafe {
_mm512_storeu_si512(ptr as *mut __m512i, value_vec);
ptr = ptr.add(64);
}
len -= 64;
}
if len > 0 {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
}
#[cfg(not(target_arch = "x86_64"))]
impl SimdMemOps {
#[inline]
unsafe fn avx512_memcpy_aligned(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { self.scalar_memcpy(dst, src, len); }
}
#[inline]
unsafe fn avx512_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { self.scalar_memcpy(dst, src, len); }
}
#[inline]
unsafe fn avx512_memcmp(&self, a: *const u8, b: *const u8, len: usize) -> i32 {
unsafe { self.scalar_memcmp(a, b, len) }
}
#[inline]
unsafe fn avx512_memchr(&self, haystack: *const u8, needle: u8, len: usize) -> Option<usize> {
unsafe { self.scalar_memchr(haystack, needle, len) }
}
#[inline]
unsafe fn avx512_memset(&self, ptr: *mut u8, value: u8, len: usize) {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
#[cfg(target_arch = "x86_64")]
impl SimdMemOps {
#[target_feature(enable = "avx2")]
unsafe fn avx2_memcpy_aligned(&self, dst: *mut u8, src: *const u8, mut len: usize) {
use std::arch::x86_64::*;
let mut dst_ptr = dst;
let mut src_ptr = src;
while len >= 32 {
unsafe {
let data = _mm256_load_si256(src_ptr as *const __m256i);
_mm256_store_si256(dst_ptr as *mut __m256i, data);
src_ptr = src_ptr.add(32);
dst_ptr = dst_ptr.add(32);
}
len -= 32;
}
if len > 0 {
unsafe { self.scalar_memcpy(dst_ptr, src_ptr, len); }
}
}
#[target_feature(enable = "avx2")]
unsafe fn avx2_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, mut len: usize) {
use std::arch::x86_64::*;
let mut dst_ptr = dst;
let mut src_ptr = src;
if len >= MEDIUM_COPY_THRESHOLD {
let prefetch_distance = 256;
if len > prefetch_distance {
unsafe {
_mm_prefetch(src_ptr.add(prefetch_distance) as *const i8, _MM_HINT_T0);
}
}
}
while len >= 32 {
unsafe {
let data = _mm256_loadu_si256(src_ptr as *const __m256i);
_mm256_storeu_si256(dst_ptr as *mut __m256i, data);
src_ptr = src_ptr.add(32);
dst_ptr = dst_ptr.add(32);
}
len -= 32;
}
if len > 0 {
unsafe { self.scalar_memcpy(dst_ptr, src_ptr, len); }
}
}
#[target_feature(enable = "avx2")]
unsafe fn avx2_memcmp(&self, mut a: *const u8, mut b: *const u8, mut len: usize) -> i32 {
use std::arch::x86_64::*;
while len >= 32 {
unsafe {
let va = _mm256_loadu_si256(a as *const __m256i);
let vb = _mm256_loadu_si256(b as *const __m256i);
let cmp = _mm256_cmpeq_epi8(va, vb);
let mask = _mm256_movemask_epi8(cmp) as u32;
if mask != 0xFFFFFFFF {
let diff_mask = !mask;
let byte_offset = diff_mask.trailing_zeros() as usize;
let byte_a = *a.add(byte_offset);
let byte_b = *b.add(byte_offset);
return (byte_a as i32) - (byte_b as i32);
}
a = a.add(32);
b = b.add(32);
}
len -= 32;
}
if len > 0 {
unsafe { self.scalar_memcmp(a, b, len) }
} else {
0
}
}
#[target_feature(enable = "avx2")]
unsafe fn avx2_memchr(&self, mut haystack: *const u8, needle: u8, mut len: usize) -> Option<usize> {
use std::arch::x86_64::*;
let needle_vec = unsafe { _mm256_set1_epi8(needle as i8) };
let mut offset = 0;
while len >= 32 {
unsafe {
let data = _mm256_loadu_si256(haystack as *const __m256i);
let cmp = _mm256_cmpeq_epi8(data, needle_vec);
let mask = _mm256_movemask_epi8(cmp) as u32;
if mask != 0 {
let byte_offset = mask.trailing_zeros() as usize;
return Some(offset + byte_offset);
}
haystack = haystack.add(32);
}
offset += 32;
len -= 32;
}
if len > 0 {
if let Some(remaining_offset) = unsafe { self.scalar_memchr(haystack, needle, len) } {
Some(offset + remaining_offset)
} else {
None
}
} else {
None
}
}
#[target_feature(enable = "avx2")]
unsafe fn avx2_memset(&self, mut ptr: *mut u8, value: u8, mut len: usize) {
use std::arch::x86_64::*;
let value_vec = unsafe { _mm256_set1_epi8(value as i8) };
while len >= 32 {
unsafe {
_mm256_storeu_si256(ptr as *mut __m256i, value_vec);
ptr = ptr.add(32);
}
len -= 32;
}
if len > 0 {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
}
#[cfg(not(target_arch = "x86_64"))]
impl SimdMemOps {
#[inline]
unsafe fn avx2_memcpy_aligned(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { self.scalar_memcpy(dst, src, len); }
}
#[inline]
unsafe fn avx2_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { self.scalar_memcpy(dst, src, len); }
}
#[inline]
unsafe fn avx2_memcmp(&self, a: *const u8, b: *const u8, len: usize) -> i32 {
unsafe { self.scalar_memcmp(a, b, len) }
}
#[inline]
unsafe fn avx2_memchr(&self, haystack: *const u8, needle: u8, len: usize) -> Option<usize> {
unsafe { self.scalar_memchr(haystack, needle, len) }
}
#[inline]
unsafe fn avx2_memset(&self, ptr: *mut u8, value: u8, len: usize) {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
#[cfg(target_arch = "x86_64")]
impl SimdMemOps {
#[target_feature(enable = "sse2")]
unsafe fn sse2_memcpy_aligned(&self, dst: *mut u8, src: *const u8, mut len: usize) {
use std::arch::x86_64::*;
let mut dst_ptr = dst;
let mut src_ptr = src;
while len >= 16 {
unsafe {
let data = _mm_load_si128(src_ptr as *const __m128i);
_mm_store_si128(dst_ptr as *mut __m128i, data);
src_ptr = src_ptr.add(16);
dst_ptr = dst_ptr.add(16);
}
len -= 16;
}
if len > 0 {
unsafe { self.scalar_memcpy(dst_ptr, src_ptr, len); }
}
}
#[target_feature(enable = "sse2")]
unsafe fn sse2_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, mut len: usize) {
use std::arch::x86_64::*;
let mut dst_ptr = dst;
let mut src_ptr = src;
while len >= 16 {
unsafe {
let data = _mm_loadu_si128(src_ptr as *const __m128i);
_mm_storeu_si128(dst_ptr as *mut __m128i, data);
src_ptr = src_ptr.add(16);
dst_ptr = dst_ptr.add(16);
}
len -= 16;
}
if len > 0 {
unsafe { self.scalar_memcpy(dst_ptr, src_ptr, len); }
}
}
#[target_feature(enable = "sse2")]
unsafe fn sse2_memcmp(&self, mut a: *const u8, mut b: *const u8, mut len: usize) -> i32 {
use std::arch::x86_64::*;
while len >= 16 {
unsafe {
let va = _mm_loadu_si128(a as *const __m128i);
let vb = _mm_loadu_si128(b as *const __m128i);
let cmp = _mm_cmpeq_epi8(va, vb);
let mask = _mm_movemask_epi8(cmp) as u16;
if mask != 0xFFFF {
let diff_mask = !mask;
let byte_offset = diff_mask.trailing_zeros() as usize;
let byte_a = *a.add(byte_offset);
let byte_b = *b.add(byte_offset);
return (byte_a as i32) - (byte_b as i32);
}
a = a.add(16);
b = b.add(16);
}
len -= 16;
}
if len > 0 {
unsafe { self.scalar_memcmp(a, b, len) }
} else {
0
}
}
#[target_feature(enable = "sse2")]
unsafe fn sse2_memchr(&self, mut haystack: *const u8, needle: u8, mut len: usize) -> Option<usize> {
use std::arch::x86_64::*;
let needle_vec = unsafe { _mm_set1_epi8(needle as i8) };
let mut offset = 0;
while len >= 16 {
unsafe {
let data = _mm_loadu_si128(haystack as *const __m128i);
let cmp = _mm_cmpeq_epi8(data, needle_vec);
let mask = _mm_movemask_epi8(cmp) as u16;
if mask != 0 {
let byte_offset = mask.trailing_zeros() as usize;
return Some(offset + byte_offset);
}
haystack = haystack.add(16);
}
offset += 16;
len -= 16;
}
if len > 0 {
if let Some(remaining_offset) = unsafe { self.scalar_memchr(haystack, needle, len) } {
Some(offset + remaining_offset)
} else {
None
}
} else {
None
}
}
#[target_feature(enable = "sse2")]
unsafe fn sse2_memset(&self, mut ptr: *mut u8, value: u8, mut len: usize) {
use std::arch::x86_64::*;
let value_vec = unsafe { _mm_set1_epi8(value as i8) };
while len >= 16 {
unsafe {
_mm_storeu_si128(ptr as *mut __m128i, value_vec);
ptr = ptr.add(16);
}
len -= 16;
}
if len > 0 {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
}
#[cfg(not(target_arch = "x86_64"))]
impl SimdMemOps {
#[inline]
unsafe fn sse2_memcpy_aligned(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { self.scalar_memcpy(dst, src, len); }
}
#[inline]
unsafe fn sse2_memcpy_unaligned(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { self.scalar_memcpy(dst, src, len); }
}
#[inline]
unsafe fn sse2_memcmp(&self, a: *const u8, b: *const u8, len: usize) -> i32 {
unsafe { self.scalar_memcmp(a, b, len) }
}
#[inline]
unsafe fn sse2_memchr(&self, haystack: *const u8, needle: u8, len: usize) -> Option<usize> {
unsafe { self.scalar_memchr(haystack, needle, len) }
}
#[inline]
unsafe fn sse2_memset(&self, ptr: *mut u8, value: u8, len: usize) {
unsafe { self.scalar_memset(ptr, value, len); }
}
}
impl SimdMemOps {
#[inline]
unsafe fn scalar_memcpy(&self, dst: *mut u8, src: *const u8, len: usize) {
unsafe { ptr::copy_nonoverlapping(src, dst, len); }
}
#[inline]
unsafe fn scalar_memcmp(&self, a: *const u8, b: *const u8, len: usize) -> i32 {
for i in 0..len {
unsafe {
let byte_a = *a.add(i);
let byte_b = *b.add(i);
if byte_a != byte_b {
return (byte_a as i32) - (byte_b as i32);
}
}
}
0
}
#[inline]
unsafe fn scalar_memchr(&self, haystack: *const u8, needle: u8, len: usize) -> Option<usize> {
for i in 0..len {
unsafe {
if *haystack.add(i) == needle {
return Some(i);
}
}
}
None
}
#[inline]
unsafe fn scalar_memset(&self, ptr: *mut u8, value: u8, len: usize) {
unsafe { ptr::write_bytes(ptr, value, len); }
}
}
impl Default for SimdMemOps {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_SIMD_OPS: std::sync::OnceLock<SimdMemOps> = std::sync::OnceLock::new();
pub fn get_global_simd_ops() -> &'static SimdMemOps {
GLOBAL_SIMD_OPS.get_or_init(|| SimdMemOps::new())
}
pub fn fast_copy(src: &[u8], dst: &mut [u8]) -> Result<()> {
get_global_simd_ops().copy_nonoverlapping(src, dst)
}
pub fn fast_compare(a: &[u8], b: &[u8]) -> i32 {
get_global_simd_ops().compare(a, b)
}
pub fn fast_find_byte(haystack: &[u8], needle: u8) -> Option<usize> {
get_global_simd_ops().find_byte(haystack, needle)
}
pub fn fast_fill(slice: &mut [u8], value: u8) {
get_global_simd_ops().fill(slice, value)
}
pub fn fast_copy_cache_optimized(src: &[u8], dst: &mut [u8]) -> Result<()> {
get_global_simd_ops().copy_cache_optimized(src, dst)
}
pub fn fast_compare_cache_optimized(a: &[u8], b: &[u8]) -> i32 {
get_global_simd_ops().compare_cache_optimized(a, b)
}
pub fn fast_prefetch<T: ?Sized>(data: &T, hint: PrefetchHint) {
let addr = data as *const T as *const u8;
get_global_simd_ops().prefetch(addr, hint)
}
pub fn fast_prefetch_range(data: &[u8]) {
if data.is_empty() {
return;
}
get_global_simd_ops().prefetch_range(data)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_simd_ops_creation() {
let ops = SimdMemOps::new();
println!("Selected SIMD tier: {:?}", ops.tier());
assert!(matches!(ops.tier(), SimdTier::Avx512 | SimdTier::Avx2 | SimdTier::Sse2 | SimdTier::Scalar));
}
#[test]
fn test_global_simd_ops() {
let ops1 = get_global_simd_ops();
let ops2 = get_global_simd_ops();
assert_eq!(ops1.tier(), ops2.tier());
}
#[test]
fn test_memory_copy_basic() {
let src = b"Hello, SIMD World!";
let mut dst = vec![0u8; src.len()];
fast_copy(src, &mut dst).unwrap();
assert_eq!(src, &dst[..]);
}
#[test]
fn test_memory_copy_large() {
let size = 8192;
let src: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let mut dst = vec![0u8; size];
fast_copy(&src, &mut dst).unwrap();
assert_eq!(src, dst);
}
#[test]
fn test_memory_copy_empty() {
let src: &[u8] = &[];
let mut dst: Vec<u8> = vec![];
let result = fast_copy(src, &mut dst);
assert!(result.is_ok());
}
#[test]
fn test_memory_copy_size_mismatch() {
let src = b"Hello";
let mut dst = vec![0u8; 10];
let result = fast_copy(src, &mut dst);
assert!(result.is_err());
}
#[test]
fn test_memory_compare_equal() {
let a = b"Hello, World!";
let b = b"Hello, World!";
assert_eq!(fast_compare(a, b), 0);
}
#[test]
fn test_memory_compare_different() {
let a = b"Hello, World!";
let b = b"Hello, SIMD!";
let result = fast_compare(a, b);
assert_ne!(result, 0);
}
#[test]
fn test_memory_compare_different_lengths() {
let a = b"Hello";
let b = b"Hello, World!";
let result = fast_compare(a, b);
assert!(result < 0);
let result2 = fast_compare(b, a);
assert!(result2 > 0); }
#[test]
fn test_byte_search_found() {
let haystack = b"Hello, SIMD World!";
let needle = b'S';
let result = fast_find_byte(haystack, needle);
assert_eq!(result, Some(7));
}
#[test]
fn test_byte_search_not_found() {
let haystack = b"Hello, World!";
let needle = b'X';
let result = fast_find_byte(haystack, needle);
assert_eq!(result, None);
}
#[test]
fn test_byte_search_empty() {
let haystack: &[u8] = &[];
let needle = b'A';
let result = fast_find_byte(haystack, needle);
assert_eq!(result, None);
}
#[test]
fn test_memory_fill() {
let mut buffer = vec![0u8; 100];
fast_fill(&mut buffer, 0xFF);
assert!(buffer.iter().all(|&b| b == 0xFF));
}
#[test]
fn test_memory_fill_empty() {
let mut buffer: Vec<u8> = vec![];
fast_fill(&mut buffer, 0xFF);
assert!(buffer.is_empty());
}
#[test]
fn test_aligned_copy() {
let ops = SimdMemOps::new();
let layout_src = std::alloc::Layout::from_size_align(128, 64).unwrap();
let layout_dst = std::alloc::Layout::from_size_align(128, 64).unwrap();
unsafe {
let src_ptr = std::alloc::alloc(layout_src);
let dst_ptr = std::alloc::alloc(layout_dst);
if !src_ptr.is_null() && !dst_ptr.is_null() {
for i in 0..128 {
*src_ptr.add(i) = (i % 256) as u8;
}
let src_slice = std::slice::from_raw_parts(src_ptr, 128);
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, 128);
let result = ops.copy_aligned(src_slice, dst_slice);
assert!(result.is_ok());
for i in 0..128 {
assert_eq!(*src_ptr.add(i), *dst_ptr.add(i));
}
std::alloc::dealloc(src_ptr, layout_src);
std::alloc::dealloc(dst_ptr, layout_dst);
}
}
}
#[test]
fn test_size_categories() {
let ops = SimdMemOps::new();
let sizes = vec![1, 8, 16, 32, 64, 128, 1024, 4096, 8192];
for size in sizes {
let src: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let mut dst = vec![0u8; size];
let result = ops.copy_nonoverlapping(&src, &mut dst);
assert!(result.is_ok(), "Failed for size {}", size);
assert_eq!(src, dst, "Mismatch for size {}", size);
}
}
#[test]
fn test_pattern_search() {
let haystack = b"AAAABBBBCCCCDDDD";
assert_eq!(fast_find_byte(haystack, b'A'), Some(0));
assert_eq!(fast_find_byte(haystack, b'B'), Some(4));
assert_eq!(fast_find_byte(haystack, b'C'), Some(8));
assert_eq!(fast_find_byte(haystack, b'D'), Some(12));
assert_eq!(fast_find_byte(haystack, b'E'), None);
}
#[test]
fn test_performance_comparison() {
let size = 1024;
let src: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let mut dst_simd = vec![0u8; size];
let mut dst_std = vec![0u8; size];
fast_copy(&src, &mut dst_simd).unwrap();
dst_std.copy_from_slice(&src);
assert_eq!(dst_simd, dst_std);
assert_eq!(fast_compare(&dst_simd, &dst_std), 0);
}
#[test]
fn test_cross_tier_consistency() {
let test_data: Vec<u8> = (0u8..=255u8).collect();
let needle = 128u8;
let ops = SimdMemOps::new();
let result = ops.find_byte(&test_data, needle);
assert_eq!(result, Some(128));
let other_data: Vec<u8> = (0u8..=255u8).map(|i| if i == 128 { 129 } else { i }).collect();
let cmp = ops.compare(&test_data, &other_data);
assert!(cmp < 0); }
#[test]
fn test_cache_optimized_operations() {
let size = 4096;
let src: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let mut dst = vec![0u8; size];
let result = fast_copy_cache_optimized(&src, &mut dst);
assert!(result.is_ok());
assert_eq!(src, dst);
let cmp = fast_compare_cache_optimized(&src, &dst);
assert_eq!(cmp, 0);
dst[100] = 255;
let cmp2 = fast_compare_cache_optimized(&src, &dst);
assert_ne!(cmp2, 0);
}
#[test]
fn test_prefetch_operations() {
let data = vec![1u8; 1024];
fast_prefetch(&data[0], PrefetchHint::T0);
fast_prefetch(&data[0], PrefetchHint::T1);
fast_prefetch(&data[0], PrefetchHint::T2);
fast_prefetch(&data[0], PrefetchHint::NTA);
let value: u64 = 42;
fast_prefetch(&value, PrefetchHint::T0);
fast_prefetch_range(&data);
fast_prefetch_range(&data[100..200]);
fast_prefetch_range(&[]);
}
#[test]
fn test_simd_ops_with_cache_config() {
use crate::memory::cache_layout::{CacheLayoutConfig, AccessPattern};
let config = CacheLayoutConfig::sequential();
let ops = SimdMemOps::with_cache_config(config);
assert_eq!(ops.cache_config().access_pattern, AccessPattern::Sequential);
assert!(ops.cache_config().enable_prefetch);
let src = b"Hello, Cache-Optimized SIMD!";
let mut dst = vec![0u8; src.len()];
let result = ops.copy_cache_optimized(src, &mut dst);
assert!(result.is_ok());
assert_eq!(src, &dst[..]);
}
#[test]
fn test_cache_config_access() {
let ops = SimdMemOps::new();
let config = ops.cache_config();
assert!(config.cache_line_size > 0);
assert!(config.hierarchy.l1_size > 0);
assert!(config.prefetch_distance > 0);
}
#[test]
fn test_prefetch_with_different_sizes() {
let ops = SimdMemOps::new();
let small_data = vec![1u8; 32];
ops.prefetch_range(&small_data);
let large_data = vec![1u8; 8192];
ops.prefetch_range(&large_data);
let empty: &[u8] = &[];
ops.prefetch_range(empty);
}
#[test]
fn test_cache_optimized_copy_edge_cases() {
let ops = SimdMemOps::new();
let empty_src: &[u8] = &[];
let mut empty_dst: Vec<u8> = vec![];
let result = ops.copy_cache_optimized(empty_src, &mut empty_dst);
assert!(result.is_ok());
let src = b"hello";
let mut dst = vec![0u8; 10];
let result = ops.copy_cache_optimized(src, &mut dst);
assert!(result.is_err());
let layout = std::alloc::Layout::from_size_align(4096, 64).unwrap();
unsafe {
let src_ptr = std::alloc::alloc(layout);
let dst_ptr = std::alloc::alloc(layout);
if !src_ptr.is_null() && !dst_ptr.is_null() {
for i in 0..4096 {
*src_ptr.add(i) = (i % 256) as u8;
}
let src_slice = std::slice::from_raw_parts(src_ptr, 4096);
let dst_slice = std::slice::from_raw_parts_mut(dst_ptr, 4096);
let result = ops.copy_cache_optimized(src_slice, dst_slice);
assert!(result.is_ok());
for i in 0..4096 {
assert_eq!(*src_ptr.add(i), *dst_ptr.add(i));
}
std::alloc::dealloc(src_ptr, layout);
std::alloc::dealloc(dst_ptr, layout);
}
}
}
}