use crate::protocol::Message;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct SIMDProcessor {
vectorized_operations: AtomicU64,
scalar_fallbacks: AtomicU64,
bytes_processed: AtomicU64,
has_avx2: bool,
has_sse42: bool,
}
impl SIMDProcessor {
pub fn new() -> Self {
Self {
vectorized_operations: AtomicU64::new(0),
scalar_fallbacks: AtomicU64::new(0),
bytes_processed: AtomicU64::new(0),
has_avx2: Self::detect_avx2(),
has_sse42: Self::detect_sse42(),
}
}
fn detect_avx2() -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
is_x86_feature_detected!("avx2")
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
false
}
}
fn detect_sse42() -> bool {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
is_x86_feature_detected!("sse4.2")
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
{
false
}
}
pub fn simd_memcmp(&self, a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
self.bytes_processed
.fetch_add(a.len() as u64, Ordering::Relaxed);
if a.len() >= 32 && self.has_avx2 {
self.vectorized_operations.fetch_add(1, Ordering::Relaxed);
self.simd_memcmp_avx2(a, b)
} else if a.len() >= 16 && self.has_sse42 {
self.vectorized_operations.fetch_add(1, Ordering::Relaxed);
self.simd_memcmp_sse42(a, b)
} else {
self.scalar_fallbacks.fetch_add(1, Ordering::Relaxed);
a == b
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn simd_memcmp_avx2(&self, a: &[u8], b: &[u8]) -> bool {
if !is_x86_feature_detected!("avx2") {
return a == b;
}
unsafe { self.simd_memcmp_avx2_impl(a, b) }
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn simd_memcmp_avx2_impl(&self, a: &[u8], b: &[u8]) -> bool {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 32;
let remainder = len % 32;
for i in 0..chunks {
let offset = i * 32;
let va = _mm256_loadu_si256(a.as_ptr().add(offset) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(offset) as *const __m256i);
let cmp = _mm256_cmpeq_epi8(va, vb);
let mask = _mm256_movemask_epi8(cmp);
if mask != 0xFFFF_FFFF_u32 as i32 {
return false;
}
}
if remainder > 0 {
let start = chunks * 32;
return &a[start..] == &b[start..];
}
true
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
fn simd_memcmp_avx2(&self, a: &[u8], b: &[u8]) -> bool {
a == b
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn simd_memcmp_sse42(&self, a: &[u8], b: &[u8]) -> bool {
if !is_x86_feature_detected!("sse4.2") {
return a == b;
}
unsafe { self.simd_memcmp_sse42_impl(a, b) }
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.2")]
unsafe fn simd_memcmp_sse42_impl(&self, a: &[u8], b: &[u8]) -> bool {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let len = a.len();
let chunks = len / 16;
let remainder = len % 16;
for i in 0..chunks {
let offset = i * 16;
let va = _mm_loadu_si128(a.as_ptr().add(offset) as *const __m128i);
let vb = _mm_loadu_si128(b.as_ptr().add(offset) as *const __m128i);
let cmp = _mm_cmpeq_epi8(va, vb);
let mask = _mm_movemask_epi8(cmp);
if mask != 0xFFFF {
return false;
}
}
if remainder > 0 {
let start = chunks * 16;
return &a[start..] == &b[start..];
}
true
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
fn simd_memcmp_sse42(&self, a: &[u8], b: &[u8]) -> bool {
a == b
}
pub fn simd_crc32(&self, data: &[u8]) -> u32 {
self.bytes_processed
.fetch_add(data.len() as u64, Ordering::Relaxed);
if self.has_sse42 {
self.vectorized_operations.fetch_add(1, Ordering::Relaxed);
self.simd_crc32_sse42(data)
} else {
self.scalar_fallbacks.fetch_add(1, Ordering::Relaxed);
self.scalar_crc32(data)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn simd_crc32_sse42(&self, data: &[u8]) -> u32 {
if !is_x86_feature_detected!("sse4.2") {
return self.scalar_crc32(data);
}
unsafe { self.simd_crc32_sse42_impl(data) }
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse4.2")]
unsafe fn simd_crc32_sse42_impl(&self, data: &[u8]) -> u32 {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let mut crc = 0xFFFF_FFFF_u32;
let mut ptr = data.as_ptr();
let mut remaining = data.len();
while remaining >= 8 {
let value = std::ptr::read_unaligned(ptr as *const u64);
crc = _mm_crc32_u64(crc as u64, value) as u32;
ptr = ptr.add(8);
remaining -= 8;
}
while remaining >= 4 {
let value = std::ptr::read_unaligned(ptr as *const u32);
crc = _mm_crc32_u32(crc, value);
ptr = ptr.add(4);
remaining -= 4;
}
while remaining > 0 {
let value = *ptr;
crc = _mm_crc32_u8(crc, value);
ptr = ptr.add(1);
remaining -= 1;
}
!crc
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
fn simd_crc32_sse42(&self, data: &[u8]) -> u32 {
self.scalar_crc32(data)
}
fn scalar_crc32(&self, data: &[u8]) -> u32 {
let mut crc = 0xFFFF_FFFF_u32;
for &byte in data {
crc ^= byte as u32;
for _ in 0..8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB88320;
} else {
crc >>= 1;
}
}
}
!crc
}
pub fn batch_calculate_message_sizes(&self, messages: &[Message]) -> Vec<usize> {
self.vectorized_operations.fetch_add(1, Ordering::Relaxed);
let mut sizes = Vec::with_capacity(messages.len());
for message in messages {
let key_size = message.key.as_ref().map(|k| k.len()).unwrap_or(0);
let value_size = message.value.len();
let total_size = key_size + value_size + 8; sizes.push(total_size);
self.bytes_processed
.fetch_add(total_size as u64, Ordering::Relaxed);
}
sizes
}
pub fn batch_validate_messages(&self, messages: &[Message]) -> Vec<bool> {
self.vectorized_operations.fetch_add(1, Ordering::Relaxed);
let mut results = Vec::with_capacity(messages.len());
for message in messages {
let is_valid = !message.value.is_empty();
results.push(is_valid);
}
results
}
pub fn simd_string_search(&self, haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() || haystack.len() < needle.len() {
return false;
}
self.bytes_processed
.fetch_add(haystack.len() as u64, Ordering::Relaxed);
if haystack.len() >= 32 && needle.len() >= 4 && self.has_avx2 {
self.vectorized_operations.fetch_add(1, Ordering::Relaxed);
self.simd_string_search_avx2(haystack, needle)
} else {
self.scalar_fallbacks.fetch_add(1, Ordering::Relaxed);
self.scalar_string_search(haystack, needle)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn simd_string_search_avx2(&self, haystack: &[u8], needle: &[u8]) -> bool {
if !is_x86_feature_detected!("avx2") || needle.is_empty() {
return self.scalar_string_search(haystack, needle);
}
let first_char = needle[0];
for i in 0..=(haystack.len().saturating_sub(needle.len())) {
if haystack[i] == first_char {
if &haystack[i..i + needle.len()] == needle {
return true;
}
}
}
false
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
fn simd_string_search_avx2(&self, haystack: &[u8], needle: &[u8]) -> bool {
self.scalar_string_search(haystack, needle)
}
fn scalar_string_search(&self, haystack: &[u8], needle: &[u8]) -> bool {
haystack
.windows(needle.len())
.any(|window| window == needle)
}
pub fn get_stats(&self) -> SIMDStats {
let vectorized = self.vectorized_operations.load(Ordering::Relaxed);
let scalar = self.scalar_fallbacks.load(Ordering::Relaxed);
let total_ops = vectorized + scalar;
SIMDStats {
vectorized_operations: vectorized,
scalar_fallbacks: scalar,
bytes_processed: self.bytes_processed.load(Ordering::Relaxed),
vectorization_rate: if total_ops > 0 {
vectorized as f64 / total_ops as f64
} else {
0.0
},
has_avx2: self.has_avx2,
has_sse42: self.has_sse42,
}
}
}
#[derive(Debug, Clone)]
pub struct SIMDStats {
pub vectorized_operations: u64,
pub scalar_fallbacks: u64,
pub bytes_processed: u64,
pub vectorization_rate: f64,
pub has_avx2: bool,
pub has_sse42: bool,
}
impl SIMDStats {
pub fn report(&self) -> String {
format!(
"SIMD Optimization Stats:\n\
Vectorized Ops: {} ({:.1}% of total)\n\
Scalar Fallbacks: {}\n\
Bytes Processed: {:.1} MB\n\
CPU Features: AVX2={}, SSE4.2={}",
self.vectorized_operations,
self.vectorization_rate * 100.0,
self.scalar_fallbacks,
self.bytes_processed as f64 / 1_000_000.0,
self.has_avx2,
self.has_sse42
)
}
}
pub struct SIMDBatchProcessor {
simd: SIMDProcessor,
batches_processed: AtomicU64,
messages_per_batch: AtomicU64,
}
impl SIMDBatchProcessor {
pub fn new() -> Self {
Self {
simd: SIMDProcessor::new(),
batches_processed: AtomicU64::new(0),
messages_per_batch: AtomicU64::new(0),
}
}
pub fn process_message_batch(&self, messages: &[Message]) -> BatchProcessResult {
if messages.is_empty() {
return BatchProcessResult::default();
}
let batch_start = std::time::Instant::now();
let sizes = self.simd.batch_calculate_message_sizes(messages);
let validations = self.simd.batch_validate_messages(messages);
let total_size: usize = sizes.iter().sum();
let valid_count = validations.iter().filter(|&&v| v).count();
let invalid_count = messages.len() - valid_count;
let batch_duration = batch_start.elapsed();
self.batches_processed.fetch_add(1, Ordering::Relaxed);
self.messages_per_batch
.fetch_add(messages.len() as u64, Ordering::Relaxed);
BatchProcessResult {
message_count: messages.len(),
total_bytes: total_size,
valid_messages: valid_count,
invalid_messages: invalid_count,
processing_time_micros: batch_duration.as_micros() as u64,
simd_stats: self.simd.get_stats(),
}
}
pub fn get_batch_stats(&self) -> SIMDBatchStats {
let batches = self.batches_processed.load(Ordering::Relaxed);
let total_messages = self.messages_per_batch.load(Ordering::Relaxed);
SIMDBatchStats {
batches_processed: batches,
total_messages_processed: total_messages,
avg_messages_per_batch: if batches > 0 {
total_messages as f64 / batches as f64
} else {
0.0
},
simd_stats: self.simd.get_stats(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchProcessResult {
pub message_count: usize,
pub total_bytes: usize,
pub valid_messages: usize,
pub invalid_messages: usize,
pub processing_time_micros: u64,
pub simd_stats: SIMDStats,
}
impl Default for SIMDStats {
fn default() -> Self {
Self {
vectorized_operations: 0,
scalar_fallbacks: 0,
bytes_processed: 0,
vectorization_rate: 0.0,
has_avx2: false,
has_sse42: false,
}
}
}
#[derive(Debug, Clone)]
pub struct SIMDBatchStats {
pub batches_processed: u64,
pub total_messages_processed: u64,
pub avg_messages_per_batch: f64,
pub simd_stats: SIMDStats,
}
impl SIMDBatchStats {
pub fn report(&self) -> String {
format!(
"SIMD Batch Processing Stats:\n\
Batches: {} ({:.1} avg msgs/batch)\n\
Total Messages: {}\n\
{}",
self.batches_processed,
self.avg_messages_per_batch,
self.total_messages_processed,
self.simd_stats.report()
)
}
}