#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
pub struct SimdPolynomialOps;
impl SimdPolynomialOps {
#[cfg(target_arch = "x86_64")]
pub fn poly_add_simd(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
if is_x86_feature_detected!("avx2") {
unsafe {
Self::poly_add_avx2(a, b, result);
}
} else {
Self::poly_add_scalar(a, b, result);
}
}
#[cfg(target_arch = "x86_64")]
pub fn poly_sub_simd(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
if is_x86_feature_detected!("avx2") {
unsafe {
Self::poly_sub_avx2(a, b, result);
}
} else {
Self::poly_sub_scalar(a, b, result);
}
}
#[cfg(target_arch = "x86_64")]
pub fn poly_scalar_mul_simd(a: &[i32; 256], scalar: i32, result: &mut [i32; 256]) {
if is_x86_feature_detected!("avx2") {
unsafe {
Self::poly_scalar_mul_avx2(a, scalar, result);
}
} else {
Self::poly_scalar_mul_scalar(a, scalar, result);
}
}
#[cfg(target_arch = "x86_64")]
pub fn poly_reduce_simd(a: &mut [i32; 256], q: i32) {
if is_x86_feature_detected!("avx2") {
unsafe {
Self::poly_reduce_avx2(a, q);
}
} else {
Self::poly_reduce_scalar(a, q);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn poly_add_avx2(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
for i in (0..256).step_by(8) {
let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
let vr = _mm256_add_epi32(va, vb);
_mm256_storeu_si256(result.as_mut_ptr().add(i) as *mut __m256i, vr);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn poly_sub_avx2(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
for i in (0..256).step_by(8) {
let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(i) as *const __m256i);
let vr = _mm256_sub_epi32(va, vb);
_mm256_storeu_si256(result.as_mut_ptr().add(i) as *mut __m256i, vr);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn poly_scalar_mul_avx2(a: &[i32; 256], scalar: i32, result: &mut [i32; 256]) {
let vscalar = _mm256_set1_epi32(scalar);
for i in (0..256).step_by(8) {
let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let vr = _mm256_mullo_epi32(va, vscalar);
_mm256_storeu_si256(result.as_mut_ptr().add(i) as *mut __m256i, vr);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn poly_reduce_avx2(a: &mut [i32; 256], q: i32) {
let vq = _mm256_set1_epi32(q);
let vq_half = _mm256_set1_epi32(q / 2);
for i in (0..256).step_by(8) {
let va = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
let vr = va;
let mask_neg = _mm256_cmpgt_epi32(vr, vq_half);
let adjustment = _mm256_and_si256(mask_neg, vq);
let final_result = _mm256_sub_epi32(vr, adjustment);
_mm256_storeu_si256(a.as_mut_ptr().add(i) as *mut __m256i, final_result);
}
}
fn poly_add_scalar(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
for i in 0..256 {
result[i] = a[i].wrapping_add(b[i]);
}
}
fn poly_sub_scalar(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
for i in 0..256 {
result[i] = a[i].wrapping_sub(b[i]);
}
}
fn poly_scalar_mul_scalar(a: &[i32; 256], scalar: i32, result: &mut [i32; 256]) {
for i in 0..256 {
result[i] = a[i].wrapping_mul(scalar);
}
}
fn poly_reduce_scalar(a: &mut [i32; 256], q: i32) {
for i in 0..256 {
a[i] = a[i] % q;
if a[i] > q / 2 {
a[i] -= q;
}
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn poly_add_simd(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
Self::poly_add_scalar(a, b, result);
}
#[cfg(not(target_arch = "x86_64"))]
pub fn poly_sub_simd(a: &[i32; 256], b: &[i32; 256], result: &mut [i32; 256]) {
Self::poly_sub_scalar(a, b, result);
}
#[cfg(not(target_arch = "x86_64"))]
pub fn poly_scalar_mul_simd(a: &[i32; 256], scalar: i32, result: &mut [i32; 256]) {
Self::poly_scalar_mul_scalar(a, scalar, result);
}
#[cfg(not(target_arch = "x86_64"))]
pub fn poly_reduce_simd(a: &mut [i32; 256], q: i32) {
Self::poly_reduce_scalar(a, q);
}
}
pub struct SimdHashOps;
impl SimdHashOps {
#[cfg(target_arch = "x86_64")]
pub fn parallel_hash_4way(
inputs: [&[u8]; 4],
outputs: &mut [[u8; 32]; 4]
) {
if is_x86_feature_detected!("sha") {
unsafe {
Self::parallel_sha256_4way(inputs, outputs);
}
} else {
Self::parallel_hash_fallback(inputs, outputs);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sha")]
unsafe fn parallel_sha256_4way(
inputs: [&[u8]; 4],
outputs: &mut [[u8; 32]; 4]
) {
for i in 0..4 {
let hash = blake3::hash(inputs[i]);
outputs[i].copy_from_slice(hash.as_bytes());
}
}
fn parallel_hash_fallback(
inputs: [&[u8]; 4],
outputs: &mut [[u8; 32]; 4]
) {
for i in 0..4 {
let hash = blake3::hash(inputs[i]);
outputs[i].copy_from_slice(hash.as_bytes());
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn parallel_hash_4way(
inputs: [&[u8]; 4],
outputs: &mut [[u8; 32]; 4]
) {
Self::parallel_hash_fallback(inputs, outputs);
}
}
pub struct PrefetchUtils;
impl PrefetchUtils {
#[cfg(target_arch = "x86_64")]
pub fn prefetch_read(ptr: *const u8) {
unsafe {
_mm_prefetch(ptr as *const i8, _MM_HINT_T0);
}
}
#[cfg(target_arch = "x86_64")]
pub fn prefetch_write(ptr: *const u8) {
unsafe {
_mm_prefetch(ptr as *const i8, _MM_HINT_T1);
}
}
#[cfg(target_arch = "x86_64")]
pub fn prefetch_range(start: *const u8, len: usize) {
const CACHE_LINE_SIZE: usize = 64;
let end = unsafe { start.add(len) };
let mut current = start;
while current < end {
unsafe {
_mm_prefetch(current as *const i8, _MM_HINT_T0);
current = current.add(CACHE_LINE_SIZE);
}
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn prefetch_read(_ptr: *const u8) {}
#[cfg(not(target_arch = "x86_64"))]
pub fn prefetch_write(_ptr: *const u8) {}
#[cfg(not(target_arch = "x86_64"))]
pub fn prefetch_range(_start: *const u8, _len: usize) {}
}
pub struct CacheOptimizedOps;
impl CacheOptimizedOps {
pub fn aligned_copy(src: &[u8], dst: &mut [u8]) {
assert_eq!(src.len(), dst.len());
#[cfg(target_arch = "x86_64")]
{
if src.len() >= 32 && is_x86_feature_detected!("avx2") {
unsafe {
Self::avx2_copy(src, dst);
}
return;
}
}
dst.copy_from_slice(src);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn avx2_copy(src: &[u8], dst: &mut [u8]) {
let len = src.len();
let mut i = 0;
while i + 32 <= len {
let chunk = _mm256_loadu_si256(src.as_ptr().add(i) as *const __m256i);
_mm256_storeu_si256(dst.as_mut_ptr().add(i) as *mut __m256i, chunk);
i += 32;
}
while i < len {
dst[i] = src[i];
i += 1;
}
}
pub fn secure_zero_simd(data: &mut [u8]) {
#[cfg(target_arch = "x86_64")]
{
if data.len() >= 32 && is_x86_feature_detected!("avx2") {
unsafe {
Self::avx2_zero(data);
}
return;
}
}
data.fill(0);
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn avx2_zero(data: &mut [u8]) {
let len = data.len();
let mut i = 0;
let zero = _mm256_setzero_si256();
while i + 32 <= len {
_mm256_storeu_si256(data.as_mut_ptr().add(i) as *mut __m256i, zero);
i += 32;
}
while i < len {
data[i] = 0;
i += 1;
}
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_poly_add_simd() {
let a = [1i32; 256];
let b = [2i32; 256];
let mut result = [0i32; 256];
SimdPolynomialOps::poly_add_simd(&a, &b, &mut result);
for i in 0..256 {
assert_eq!(result[i], 3);
}
}
#[test]
fn test_poly_scalar_mul_simd() {
let a = [2i32; 256];
let scalar = 3;
let mut result = [0i32; 256];
SimdPolynomialOps::poly_scalar_mul_simd(&a, scalar, &mut result);
for i in 0..256 {
assert_eq!(result[i], 6);
}
}
#[test]
fn test_parallel_hash() {
let inputs = [
b"test1".as_slice(),
b"test2".as_slice(),
b"test3".as_slice(),
b"test4".as_slice(),
];
let mut outputs = [[0u8; 32]; 4];
SimdHashOps::parallel_hash_4way(inputs, &mut outputs);
assert_ne!(outputs[0], outputs[1]);
assert_ne!(outputs[1], outputs[2]);
assert_ne!(outputs[2], outputs[3]);
}
#[test]
fn test_aligned_copy() {
let src = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut dst = vec![0u8; 10];
CacheOptimizedOps::aligned_copy(&src, &mut dst);
assert_eq!(src, dst);
}
#[test]
fn test_secure_zero() {
let mut data = vec![1u8, 2, 3, 4, 5];
CacheOptimizedOps::secure_zero_simd(&mut data);
assert_eq!(data, vec![0u8; 5]);
}
#[test]
fn test_prefetch_operations() {
let data = vec![1u8; 1024];
PrefetchUtils::prefetch_read(data.as_ptr());
PrefetchUtils::prefetch_write(data.as_ptr());
PrefetchUtils::prefetch_range(data.as_ptr(), data.len());
}
}