#![cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
__m128i, _mm_clmulepi64_si128, _mm_loadu_si128, _mm_setzero_si128, _mm_storeu_si128,
_mm_xor_si128,
};
#[inline]
const fn reverse_byte(b: u8) -> u8 {
let b = ((b & 0xF0) >> 4) | ((b & 0x0F) << 4);
let b = ((b & 0xCC) >> 2) | ((b & 0x33) << 2);
((b & 0xAA) >> 1) | ((b & 0x55) << 1)
}
#[inline]
const fn natural_bytes(b: &[u8; 16]) -> [u8; 16] {
let mut buf = [0u8; 16];
let mut i = 0;
while i < 16 {
buf[i] = reverse_byte(b[i]);
i += 1;
}
buf
}
#[inline]
const fn from_natural_bytes(b: &[u8; 16]) -> [u8; 16] {
natural_bytes(b)
}
#[target_feature(enable = "pclmulqdq,sse2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn ghash_mul_clmul(h: &[u8; 16], x: &[u8; 16]) -> [u8; 16] {
let h_n = natural_bytes(h);
let x_n = natural_bytes(x);
let h_vec = _mm_loadu_si128(h_n.as_ptr().cast::<__m128i>());
let x_vec = _mm_loadu_si128(x_n.as_ptr().cast::<__m128i>());
let t00 = _mm_clmulepi64_si128(h_vec, x_vec, 0x00);
let t01 = _mm_clmulepi64_si128(h_vec, x_vec, 0x01);
let t10 = _mm_clmulepi64_si128(h_vec, x_vec, 0x10);
let t11 = _mm_clmulepi64_si128(h_vec, x_vec, 0x11);
let p_low = _mm_xor_si128(_mm_setzero_si128(), t00);
let p_high = _mm_xor_si128(_mm_setzero_si128(), t11);
let middle = _mm_xor_si128(t01, t10);
let mut scratch = [0u8; 32];
let mut p_low_bytes = [0u8; 16];
let mut p_high_bytes = [0u8; 16];
let mut middle_bytes = [0u8; 16];
_mm_storeu_si128(p_low_bytes.as_mut_ptr().cast::<__m128i>(), p_low);
_mm_storeu_si128(p_high_bytes.as_mut_ptr().cast::<__m128i>(), p_high);
_mm_storeu_si128(middle_bytes.as_mut_ptr().cast::<__m128i>(), middle);
scratch[..16].copy_from_slice(&p_low_bytes);
for i in 0..16 {
scratch[8 + i] ^= middle_bytes[i];
}
for i in 0..16 {
scratch[16 + i] ^= p_high_bytes[i];
}
let mut low = u128::from_le_bytes(scratch[..16].try_into().unwrap_or([0; 16]));
let mut high = u128::from_le_bytes(scratch[16..].try_into().unwrap_or([0; 16]));
let mut idx: i32 = 127;
while idx >= 0 {
#[allow(clippy::cast_sign_loss)]
let i = idx as u32;
let bit = (high >> i) & 1;
let mask = 0u128.wrapping_sub(bit);
let positions = [i, i + 1, i + 2, i + 7];
for &p in &positions {
if p < 128 {
low ^= (1u128 << p) & mask;
} else {
high ^= (1u128 << (p - 128)) & mask;
}
}
high &= !(1u128 << i);
idx -= 1;
}
let _ = high;
let out_natural = low.to_le_bytes();
from_natural_bytes(&out_natural)
}