#![allow(unsafe_code)]
#[allow(unused_macros)]
#[cfg(all(test, feature = "simd-avx2"))]
macro_rules! debug_log {
($($arg:tt)*) => {
};
}
#[allow(unused_macros)]
#[cfg(not(all(test, feature = "simd-avx2")))]
macro_rules! debug_log {
($($arg:tt)*) => {};
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
use core::arch::x86_64::{
__m256i,
_mm256_loadu_si256,
_mm256_storeu_si256,
_mm256_xor_si256,
};
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
pub fn sparse_dense_mul_avx2(
output: &mut [u8],
sparse: &[u8],
dense: &[u8],
weight: u32,
n_bits: usize,
) {
super::super::portable::sparse_dense_mul_portable(output, sparse, dense, weight, n_bits);
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
pub fn shift_xor_avx2_bytes(dest: &mut [u8], source: &[u8], distance: usize) {
let byte_shift = distance / 8;
let bit_shift = distance % 8;
debug_log!(
"shift_xor_avx2_bytes: byte_shift={}, bit_shift={}, dest.len={}, source.len={}",
byte_shift,
bit_shift,
dest.len(),
source.len()
);
if bit_shift == 0 {
if byte_shift >= dest.len() {
return;
}
let chunks = (dest.len() - byte_shift) / 32;
debug_log!(" byte-aligned: chunks={}", chunks);
for i in 0..chunks {
let offset = i * 32;
if offset + byte_shift + 32 <= dest.len() {
let src_chunk = if offset + 32 <= source.len() {
unsafe { _mm256_loadu_si256(source.as_ptr().add(offset) as *const __m256i) }
} else {
let mut src_bytes = [0u8; 32];
let copy_len = source.len().saturating_sub(offset);
if copy_len > 0 {
src_bytes[..copy_len].copy_from_slice(&source[offset..offset + copy_len]);
}
unsafe { core::ptr::read(src_bytes.as_ptr() as *const __m256i) }
};
let dest_chunk = unsafe {
_mm256_loadu_si256(dest.as_ptr().add(byte_shift + offset) as *const __m256i)
};
let result = unsafe { _mm256_xor_si256(dest_chunk, src_chunk) };
unsafe {
_mm256_storeu_si256(
dest.as_mut_ptr().add(byte_shift + offset) as *mut __m256i,
result,
);
}
}
}
for i in (chunks * 32)..(dest.len() - byte_shift) {
dest[i + byte_shift] ^= if i < source.len() { source[i] } else { 0 };
}
} else {
let inv_shift = 8 - bit_shift;
debug_log!(" bit-level: inv_shift={}", inv_shift);
for i in 0..(dest.len() - byte_shift - 1) {
let shifted = (if i < source.len() { source[i] } else { 0 } << bit_shift) |
(if i + 1 < source.len() {
source[i + 1] >> inv_shift
} else {
0
});
dest[i + byte_shift] ^= shifted;
}
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
pub fn shift_xor_avx2(dest: &mut [u64], source: &[u64], distance: usize) {
let word_shift = distance / 64;
let bit_shift = distance % 64;
debug_log!(
"shift_xor_avx2: word_shift={}, bit_shift={}, dest.len={}, source.len={}",
word_shift,
bit_shift,
dest.len(),
source.len()
);
if bit_shift == 0 {
if word_shift >= dest.len() {
return;
}
let chunks = (dest.len() - word_shift) / 4;
debug_log!(" word-aligned: chunks={}", chunks);
for i in 0..chunks {
let offset = i * 4;
if offset + word_shift + 4 <= dest.len() {
if offset + 4 <= source.len() {
let src_chunk = unsafe {
_mm256_loadu_si256(source.as_ptr().add(offset) as *const __m256i)
};
let dest_chunk = unsafe {
_mm256_loadu_si256(dest.as_ptr().add(word_shift + offset) as *const __m256i)
};
let result = unsafe { _mm256_xor_si256(dest_chunk, src_chunk) };
unsafe {
_mm256_storeu_si256(
dest.as_mut_ptr().add(word_shift + offset) as *mut __m256i,
result,
);
}
} else {
for j in 0..4 {
let idx = offset + j;
if idx < source.len() && idx + word_shift < dest.len() {
dest[idx + word_shift] ^= source[idx];
}
}
}
}
}
for i in (chunks * 4)..source.len() {
if i + word_shift < dest.len() {
dest[i + word_shift] ^= source[i];
}
}
} else {
let inv_shift = 64 - bit_shift;
debug_log!(" bit-level: inv_shift={}", inv_shift);
for (i, &src_val) in source.iter().enumerate() {
if i + word_shift < dest.len() {
let shifted = src_val >> bit_shift;
dest[i + word_shift] ^= shifted;
if i + word_shift + 1 < dest.len() && i + 1 < source.len() {
let carry = source[i + 1] << inv_shift;
dest[i + word_shift + 1] ^= carry;
}
}
}
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
pub fn vect_add_avx2(output: &mut [u8], a: &[u8], b: &[u8]) {
super::super::portable::vect_add_portable(output, a, b);
}
#[cfg(not(all(target_arch = "x86_64", feature = "simd-avx2")))]
pub fn sparse_dense_mul_avx2(
output: &mut [u8],
sparse: &[u8],
dense: &[u8],
weight: u32,
n_bits: usize,
) {
super::super::portable::sparse_dense_mul_portable(output, sparse, dense, weight, n_bits);
}
#[cfg(not(all(target_arch = "x86_64", feature = "simd-avx2")))]
pub fn shift_xor_avx2(dest: &mut [u64], source: &[u64], distance: usize) {
super::super::portable::shift_xor_portable(dest, source, distance);
}
#[cfg(not(all(target_arch = "x86_64", feature = "simd-avx2")))]
pub fn vect_add_avx2(output: &mut [u8], a: &[u8], b: &[u8]) {
super::super::portable::vect_add_portable(output, a, b);
}