use std::num::Wrapping;
use itertools::Itertools;
pub fn carry_less_mul<const LIMBS: usize>(
a: [u64; LIMBS],
b: [u64; LIMBS],
) -> ([u64; LIMBS], [u64; LIMBS]) {
let mut res = [[0u64; LIMBS]; 2];
let slice = res.as_flattened_mut();
(0..a.len())
.cartesian_product(0..b.len())
.for_each(|(i, j)| clfma_u64(&mut slice[i + j..i + j + 2], a[i], b[j]));
(res[0], res[1])
}
pub fn carry_less_mul_1limb(a: [u64; 1], b: [u64; 1]) -> ([u64; 1], [u64; 1]) {
#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "pclmulqdq",
target_feature = "sse2",
))]
return clmul_u64_sse2(a, b);
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
target_feature = "aes",
))]
return clmul_u64_neon(a, b);
#[allow(unreachable_code)]
carry_less_mul(a, b)
}
pub fn carry_less_mul_2limbs(a: [u64; 2], b: [u64; 2]) -> ([u64; 2], [u64; 2]) {
#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "pclmulqdq",
target_feature = "sse2",
))]
return clmul_2_u64_sse2(a, b);
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
target_feature = "aes",
))]
return clmul_2_u64_neon(a, b);
#[allow(unreachable_code)]
carry_less_mul(a, b)
}
fn clfma_u64(res: &mut [u64], a: u64, b: u64) {
let bit_k: u64 = (-Wrapping((a) & 0x01)).0;
res[0] ^= b & bit_k;
for k in 1..64 {
let bit_k: u64 = (-Wrapping((a >> k) & 0x01)).0;
let tmp = b & bit_k;
res[0] ^= tmp << k;
res[1] ^= tmp >> (64 - k);
}
}
#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "pclmulqdq",
target_feature = "sse2",
))]
fn clmul_2_u64_sse2(x: [u64; 2], y: [u64; 2]) -> ([u64; 2], [u64; 2]) {
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use std::mem::MaybeUninit;
unsafe {
let a_b = _mm_loadu_si128(x.as_ptr() as *const _);
let c_d = _mm_loadu_si128(y.as_ptr() as *const _);
let prod_a_c = _mm_clmulepi64_si128::<0>(a_b, c_d);
let prod_b_d = _mm_clmulepi64_si128::<17>(a_b, c_d);
let s_ab = _mm_xor_si128(_mm_broadcastq_epi64(a_b), a_b);
let s_cd = _mm_xor_si128(_mm_broadcastq_epi64(c_d), c_d);
let prod_ab_cd = _mm_clmulepi64_si128::<17>(s_ab, s_cd);
let prod_ab_cd = _mm_xor_si128(prod_ab_cd, _mm_xor_si128(prod_a_c, prod_b_d));
let res_l = _mm_xor_si128(prod_a_c, _mm_slli_si128::<8>(prod_ab_cd));
let res_h = _mm_xor_si128(prod_b_d, _mm_srli_si128::<8>(prod_ab_cd));
let mut low = MaybeUninit::<[u64; 2]>::uninit();
_mm_storeu_si128(low.as_mut_ptr() as *mut _, res_l);
let mut high = MaybeUninit::<[u64; 2]>::uninit();
_mm_storeu_si128(high.as_mut_ptr() as *mut _, res_h);
(low.assume_init(), high.assume_init())
}
}
#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "pclmulqdq",
target_feature = "sse2",
))]
fn clmul_u64_sse2(x: [u64; 1], y: [u64; 1]) -> ([u64; 1], [u64; 1]) {
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use std::mem::MaybeUninit;
let t = unsafe {
let x = _mm_loadu_si64(x.as_ptr() as *const _);
let y = _mm_loadu_si64(y.as_ptr() as *const _);
let prod = _mm_clmulepi64_si128::<0>(x, y);
let mut res = MaybeUninit::<[u64; 2]>::uninit();
_mm_storeu_si128(res.as_mut_ptr() as *mut _, prod);
res.assume_init()
};
([t[0]], [t[1]])
}
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
target_feature = "aes",
))]
fn clmul_2_u64_neon(x: [u64; 2], y: [u64; 2]) -> ([u64; 2], [u64; 2]) {
use core::arch::aarch64::{
vgetq_lane_p64,
vld1q_u64,
vmull_high_p64,
vmull_p64,
vreinterpretq_p64_u64,
};
unsafe {
let a = vreinterpretq_p64_u64(vld1q_u64(x.as_ptr()));
let b = vreinterpretq_p64_u64(vld1q_u64(y.as_ptr()));
let lo = vmull_p64(vgetq_lane_p64::<0>(a), vgetq_lane_p64::<0>(b));
let hi = vmull_high_p64(a, b);
let s_x = x[0] ^ vgetq_lane_p64::<1>(a);
let s_y = y[0] ^ vgetq_lane_p64::<1>(b);
let mid = vmull_p64(s_x, s_y);
let middle = mid ^ lo ^ hi;
(
[lo as u64, ((lo >> 64) ^ middle) as u64],
[(hi ^ (middle >> 64)) as u64, (hi >> 64) as u64],
)
}
}
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
target_feature = "aes",
))]
fn clmul_u64_neon(x: [u64; 1], y: [u64; 1]) -> ([u64; 1], [u64; 1]) {
use core::arch::aarch64::vmull_p64;
unsafe {
let prod = vmull_p64(x[0], y[0]);
([prod as u64], [(prod >> 64) as u64])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_carry_less_mul_2limbs() {
macro_rules! prod_test_case {
($aval:expr, $bval:expr, $prod:expr) => {{
let prod_comp = carry_less_mul($aval, $bval);
assert_eq!(prod_comp, $prod);
}};
}
prod_test_case!(
[0x9f418f3bffd84bba, 0x4a7c605645afdfb1],
[0x80b7bd91cddc5be5, 0x3a97291035e41e1f],
(
[0xfffa315f244b1f92, 0x288de208c77fb2f9],
[0x94aced70da538690, 0x0f0f341c05b65c5e,]
)
);
prod_test_case!(
[0x74ef862bc1b6d333, 0x3a88103b80d97b73],
[0x753f4846eb020b5a, 0x8f108359ea25fa8f],
(
[0x419f60d2de880d0e, 0xfd4d74204161d27d],
[0xe96db1bf781f351f, 0x1c36456adc21ac7a]
)
);
prod_test_case!(
[0x6447b3dcaed62649, 0x6e4af40b2ee1b4c1],
[0xbd7a4e12fdb29840, 0x8950f56742015f25],
(
[0x010b1b56e559ca40, 0xf1ff4cfe33d20957],
[0xed77c432e4701779, 0x342437199bebda57]
)
);
}
#[test]
fn test_dispatched_matches_portable_2limbs() {
let cases: [([u64; 2], [u64; 2]); 3] = [
(
[0x9f418f3bffd84bba, 0x4a7c605645afdfb1],
[0x80b7bd91cddc5be5, 0x3a97291035e41e1f],
),
(
[0x74ef862bc1b6d333, 0x3a88103b80d97b73],
[0x753f4846eb020b5a, 0x8f108359ea25fa8f],
),
(
[0x6447b3dcaed62649, 0x6e4af40b2ee1b4c1],
[0xbd7a4e12fdb29840, 0x8950f56742015f25],
),
];
for (a, b) in cases {
let portable = carry_less_mul(a, b);
let dispatched = carry_less_mul_2limbs(a, b);
assert_eq!(portable, dispatched);
}
}
#[test]
fn test_dispatched_matches_portable_1limb() {
let cases: [([u64; 1], [u64; 1]); 3] = [
([0x9f418f3bffd84bba], [0x80b7bd91cddc5be5]),
([0x74ef862bc1b6d333], [0x753f4846eb020b5a]),
([0xffffffffffffffff], [0xffffffffffffffff]),
];
for (a, b) in cases {
let portable = carry_less_mul(a, b);
let dispatched = carry_less_mul_1limb(a, b);
assert_eq!(portable, dispatched);
}
}
}