use std::{
arch::x86_64::{
__m256i, __m512i, _mm256_and_si256, _mm256_or_si256, _mm256_ternarylogic_epi32,
_mm256_test_epi64_mask, _mm256_testz_si256, _mm256_xor_si256, _mm512_and_si512,
_mm512_or_si512, _mm512_set1_epi64, _mm512_setzero_si512, _mm512_ternarylogic_epi32,
_mm512_test_epi64_mask, _mm512_xor_si512,
},
io::{Error, ErrorKind},
sync::OnceLock,
};
use wide::u64x8;
use crate::transpose::transpose_scalar;
pub mod benchmark;
pub mod transpose;
pub const ALL_ONES: u64x8 = u64x8::splat(0xFFFFFFFFFFFFFFFF);
pub const ZERO: u64x8 = u64x8::ZERO;
pub fn splat(n: u64) -> u64x8 {
u64x8::splat(n)
}
pub fn bitsliced_add(a: &[u64x8; 64], b: &[u64x8; 64]) -> [u64x8; 64] {
let mut carry = u64x8::ZERO;
let mut sum = [u64x8::ZERO; 64];
for i in (0..64).rev() {
let res = calc_sum_carry(a[i], b[i], carry);
sum[i] = res.0;
carry = res.1;
}
sum
}
pub fn bitsliced_add_single(a: &[u64x8; 64], b: u64) -> [u64x8; 64] {
let mut carry = u64x8::ZERO;
let mut sum = [u64x8::ZERO; 64];
for i in (0..64).rev() {
let shift_right = 63 - i;
let current_bit = (b >> shift_right) & 1;
let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
let res = calc_sum_carry(a[i], b_i, carry);
sum[i] = res.0;
carry = res.1;
}
sum
}
pub fn bitsliced_add_inline(a: &mut [u64x8; 64], b: &[u64x8; 64]) {
let mut carry = u64x8::ZERO;
for i in (0..64).rev() {
let res = calc_sum_carry(a[i], b[i], carry);
a[i] = res.0;
carry = res.1;
}
}
pub fn bitsliced_add_single_inline(a: &mut [u64x8; 64], b: u64) {
let mut carry = u64x8::ZERO;
for i in (0..64).rev() {
let shift_right = 63 - i;
let current_bit = (b >> shift_right) & 1;
let b_i = if current_bit == 1 { ALL_ONES } else { ZERO };
let res = calc_sum_carry(a[i], b_i, carry);
a[i] = res.0;
carry = res.1;
}
}
fn calc_sum_carry(a: u64x8, b: u64x8, carry: u64x8) -> (u64x8, u64x8) {
let sum = a ^ b ^ carry;
let next_carry = (a & b) | (carry & (a ^ b));
(sum, next_carry)
}
const M512_ONES: __m512i = unsafe { std::mem::transmute([!0u64; 8]) };
const M512_ZERO: __m512i = unsafe { std::mem::transmute([0u64; 8]) };
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
pub unsafe fn bitsliced_add_single_inline_avx_512(a: &mut [__m512i; 64], b: u64) {
let mut carry = M512_ZERO;
let max_bit_pos = 64 - (b.leading_zeros() as usize);
for i in (0..64).rev() {
let bit_index = 63 - i;
if bit_index >= max_bit_pos {
if _mm512_test_epi64_mask(carry, carry) == 0 {
break;
}
}
let current_bit = if ((b >> bit_index) & 1) == 1 {
M512_ONES
} else {
M512_ZERO
};
let a_orig = a[i];
a[i] = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0x96);
carry = _mm512_ternarylogic_epi32(a_orig, current_bit, carry, 0xE8);
}
}
const M2_ONES: __m256i = unsafe { std::mem::transmute([!0u64; 4]) };
const M2_ZERO: __m256i = unsafe { std::mem::transmute([0u64; 4]) };
#[target_feature(enable = "avx2")]
pub unsafe fn bitsliced_add_single_inline_avx_2(a: &mut [__m256i; 64], b: u64) {
let mut carry = M2_ZERO;
let max_bit_pos = 64 - (b.leading_zeros() as usize);
for i in (0..64).rev() {
let bit_index = 63 - i;
if bit_index >= max_bit_pos {
if _mm256_testz_si256(carry, carry) != 0 {
break;
}
}
let current_bit = if ((b >> bit_index) & 1) == 1 {
M2_ONES
} else {
M2_ZERO
};
let a_orig = a[i];
let xor_ab = _mm256_xor_si256(a_orig, current_bit);
let and_ab = _mm256_and_si256(a_orig, current_bit);
a[i] = _mm256_xor_si256(xor_ab, carry);
carry = _mm256_or_si256(and_ab, _mm256_and_si256(carry, xor_ab));
}
}
pub fn bitsliced_modulo_power_of_two(a: &[u64x8; 64], k: usize) -> Result<[u64x8; 64], Error> {
if k > 64 {
return Err(Error::new(
ErrorKind::InvalidData,
"k must be <= 64 for bitsliced modulo",
));
}
let mut out = [u64x8::splat(0); 64];
let start: usize = 64 - k;
out[start..].copy_from_slice(&a[start..]);
Ok(out)
}
pub fn bitsliced_modulo_power_of_two_inline(a: &mut [u64x8; 64], k: usize) -> Result<(), Error> {
if k > 64 {
return Err(Error::new(
ErrorKind::InvalidData,
"k must be <= 64 for bitsliced modulo",
));
}
let end: usize = 64 - k;
for i in 0..end {
a[i] = u64x8::splat(0);
}
Ok(())
}
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
pub fn bitsliced_modulo_power_of_two_inline_avx_512(
a: &mut [__m512i; 64],
k: usize,
) -> Result<(), Error> {
if k > 64 {
return Err(Error::new(
ErrorKind::InvalidData,
"k must be <= 64 for bitsliced modulo",
));
}
let end: usize = 64 - k;
for i in 0..end {
a[i] = M512_ZERO
}
Ok(())
}
#[target_feature(enable = "avx2")]
pub fn bitsliced_modulo_power_of_two_inline_avx_2(
a: &mut [__m256i; 64],
k: usize,
) -> Result<(), Error> {
if k > 64 {
return Err(Error::new(
ErrorKind::InvalidData,
"k must be <= 64 for bitsliced modulo",
));
}
let end: usize = 64 - k;
for i in 0..end {
a[i] = M2_ZERO
}
Ok(())
}
pub fn des_reduction(h: &[u64x8; 64], i: u64) -> [u64x8; 64] {
let mut sum = bitsliced_add_single(h, i);
bitsliced_modulo_power_of_two_inline(&mut sum, 56).unwrap();
sum
}
pub fn des_reduction_inline(h: &mut [u64x8; 64], i: u64) {
bitsliced_add_single_inline(h, i);
bitsliced_modulo_power_of_two_inline(h, 56).unwrap();
}
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
pub unsafe fn des_reduction_inline_avx_512(h: &mut [__m512i; 64], i: u64) {
unsafe { bitsliced_add_single_inline_avx_512(h, i) };
bitsliced_modulo_power_of_two_inline_avx_512(h, 56).unwrap();
}
#[target_feature(enable = "avx2")]
pub unsafe fn des_reduction_inline_avx_2(h: &mut [__m256i; 64], i: u64) {
unsafe { bitsliced_add_single_inline_avx_2(h, i) };
bitsliced_modulo_power_of_two_inline_avx_2(h, 56).unwrap();
}
static USE_GFNI: OnceLock<bool> = OnceLock::new();
pub fn transpose_64x64(input: &[u64; 64]) -> [u64; 64] {
if *USE_GFNI.get_or_init(|| {
#[cfg(target_arch = "x86_64")]
{
std::is_x86_feature_detected!("gfni")
&& std::is_x86_feature_detected!("avx512f")
&& std::is_x86_feature_detected!("avx512bw")
&& std::is_x86_feature_detected!("avx512vbmi")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}) {
unsafe { crate::transpose::transpose_gfni(input) }
} else {
transpose_scalar(input)
}
}
#[cfg(test)]
mod tests {
use std::arch::x86_64::{_mm256_setzero_si256, _mm256_storeu_si256, _mm512_storeu_si512};
use super::*;
#[test]
fn test_add_works() {
let mut a = [ZERO; 64];
a[63] = ALL_ONES;
let mut b = [ZERO; 64];
b[63] = ALL_ONES;
let sum = bitsliced_add(&a, &b);
assert_eq!(sum[63], ZERO);
assert_eq!(sum[62], ALL_ONES);
for i in 0..62 {
assert_eq!(sum[i], ZERO);
}
}
#[test]
fn test_add_single_works() {
let mut a = [ZERO; 64];
a[63] = ALL_ONES;
let sum = bitsliced_add_single(&a, 1);
assert_eq!(sum[63], ZERO);
assert_eq!(sum[62], ALL_ONES);
for i in 0..62 {
assert_eq!(sum[i], ZERO);
}
}
#[test]
fn test_add_inline_works() {
let mut a = [ZERO; 64];
a[63] = ALL_ONES;
let mut b = [ZERO; 64];
b[63] = ALL_ONES;
bitsliced_add_inline(&mut a, &b);
assert_eq!(a[63], ZERO);
assert_eq!(a[62], ALL_ONES);
for i in 0..62 {
assert_eq!(a[i], ZERO);
}
}
#[test]
fn test_add_single_inline_works() {
let mut a = [ZERO; 64];
a[63] = ALL_ONES;
bitsliced_add_single_inline(&mut a, 1);
assert_eq!(a[63], ZERO);
assert_eq!(a[62], ALL_ONES);
for i in 0..62 {
assert_eq!(a[i], ZERO);
}
}
#[test]
fn test_add_single_inline_avx_512_works() {
let mut a = [unsafe { _mm512_setzero_si512() }; 64];
a[63] = unsafe { std::mem::transmute([!0u64; 8]) };
unsafe { bitsliced_add_single_inline_avx_512(&mut a, 1) };
let mut arr = [0u64; 8];
unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[63]) };
assert_eq!(arr[0], 0);
unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[62]) };
assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
for i in 0..62 {
unsafe { _mm512_storeu_si512(arr.as_mut_ptr() as *mut _, a[i]) };
assert_eq!(arr[0], 0);
}
}
#[test]
fn test_add_single_inline_avx_2_works() {
let mut a = [unsafe { _mm256_setzero_si256() }; 64];
a[63] = unsafe { std::mem::transmute([!0u64; 4]) };
unsafe { bitsliced_add_single_inline_avx_2(&mut a, 1) };
let mut arr = [0u64; 8];
unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[63]) };
assert_eq!(arr[0], 0);
unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[62]) };
assert_eq!(arr[0], 0xFFFFFFFFFFFFFFFF);
for i in 0..62 {
unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut _, a[i]) };
assert_eq!(arr[0], 0);
}
}
#[test]
fn test_modulo_works() {
let a = [ALL_ONES; 64];
let res = bitsliced_modulo_power_of_two(&a, 56).unwrap();
for i in 0..8 {
assert_eq!(res[i], ZERO);
}
for i in 8..64 {
assert_eq!(res[i], ALL_ONES);
}
}
#[test]
fn test_modulo_inline_works() {
let mut a = [ALL_ONES; 64];
let _ = bitsliced_modulo_power_of_two_inline(&mut a, 56).unwrap();
for i in 0..8 {
assert_eq!(a[i], ZERO);
}
for i in 8..64 {
assert_eq!(a[i], ALL_ONES);
}
}
#[test]
fn test_modulo_inline_avx_2_works() {
unsafe {
let mut a = [M2_ONES; 64];
let _ = bitsliced_modulo_power_of_two_inline_avx_2(&mut a, 56).unwrap();
let zero_raw: [u8; 32] = std::mem::transmute(M2_ZERO);
let ones_raw: [u8; 32] = std::mem::transmute(M2_ONES);
for i in 0..8 {
let actual: [u8; 32] = std::mem::transmute(a[i]);
assert_eq!(actual, zero_raw, "Index {} should be ZERO", i);
}
for i in 8..64 {
let actual: [u8; 32] = std::mem::transmute(a[i]);
assert_eq!(actual, ones_raw, "Index {} should be ONES", i);
}
}
}
}