use crate::gf16::{mul_f, mul_table};
const MASK_LSB: u64 = 0x1111111111111111;
const MASK_MSB: u64 = 0x8888888888888888;
#[inline]
pub(crate) fn m_vec_add(src: &[u64], acc: &mut [u64], m_vec_limbs: usize) {
for i in 0..m_vec_limbs {
acc[i] ^= src[i];
}
}
#[inline(always)]
fn m_vec_mul_add_scalar(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
let tab = mul_table(a);
let t0 = u64::from(tab & 0xff);
let t1 = u64::from((tab >> 8) & 0xf);
let t2 = u64::from((tab >> 16) & 0xf);
let t3 = u64::from((tab >> 24) & 0xf);
for i in 0..legs {
acc[i] ^= (src[i] & MASK_LSB).wrapping_mul(t0)
^ ((src[i] >> 1) & MASK_LSB).wrapping_mul(t1)
^ ((src[i] >> 2) & MASK_LSB).wrapping_mul(t2)
^ ((src[i] >> 3) & MASK_LSB).wrapping_mul(t3);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "ssse3")]
unsafe fn m_vec_mul_add_ssse3(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let mut lo_bytes = [0u8; 16];
let mut hi_bytes = [0u8; 16];
for i in 0..16u8 {
let m = mul_f(a, i);
lo_bytes[i as usize] = m;
hi_bytes[i as usize] = m << 4;
}
unsafe {
let lo_tbl = _mm_loadu_si128(lo_bytes.as_ptr().cast());
let hi_tbl = _mm_loadu_si128(hi_bytes.as_ptr().cast());
let lo_mask = _mm_set1_epi8(0x0F);
let src_ptr = src.as_ptr().cast::<u8>();
let acc_ptr = acc.as_mut_ptr().cast::<u8>();
let total_bytes = legs * 8;
let mut i = 0usize;
while i + 16 <= total_bytes {
let data = _mm_loadu_si128(src_ptr.add(i).cast());
let acc_v = _mm_loadu_si128(acc_ptr.add(i).cast());
let lo_idx = _mm_and_si128(data, lo_mask);
let hi_idx = _mm_and_si128(_mm_srli_epi16(data, 4), lo_mask);
let product = _mm_xor_si128(
_mm_shuffle_epi8(lo_tbl, lo_idx),
_mm_shuffle_epi8(hi_tbl, hi_idx),
);
_mm_storeu_si128(acc_ptr.add(i).cast(), _mm_xor_si128(acc_v, product));
i += 16;
}
let j = i / 8;
if j < legs {
m_vec_mul_add_scalar(&src[j..], a, &mut acc[j..], legs - j);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn m_vec_mul_add_avx2(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
let mut lo_bytes = [0u8; 16];
let mut hi_bytes = [0u8; 16];
for i in 0..16u8 {
let m = mul_f(a, i);
lo_bytes[i as usize] = m;
hi_bytes[i as usize] = m << 4;
}
unsafe {
let lo_tbl = _mm256_broadcastsi128_si256(_mm_loadu_si128(lo_bytes.as_ptr().cast()));
let hi_tbl = _mm256_broadcastsi128_si256(_mm_loadu_si128(hi_bytes.as_ptr().cast()));
let lo_mask = _mm256_set1_epi8(0x0F);
let src_ptr = src.as_ptr().cast::<u8>();
let acc_ptr = acc.as_mut_ptr().cast::<u8>();
let total_bytes = legs * 8;
let mut i = 0usize;
while i + 32 <= total_bytes {
let data = _mm256_loadu_si256(src_ptr.add(i).cast());
let acc_v = _mm256_loadu_si256(acc_ptr.add(i).cast());
let lo_idx = _mm256_and_si256(data, lo_mask);
let hi_idx = _mm256_and_si256(_mm256_srli_epi16(data, 4), lo_mask);
let product = _mm256_xor_si256(
_mm256_shuffle_epi8(lo_tbl, lo_idx),
_mm256_shuffle_epi8(hi_tbl, hi_idx),
);
_mm256_storeu_si256(acc_ptr.add(i).cast(), _mm256_xor_si256(acc_v, product));
i += 32;
}
if i + 16 <= total_bytes {
let lo_tbl128 = _mm256_castsi256_si128(lo_tbl);
let hi_tbl128 = _mm256_castsi256_si128(hi_tbl);
let lo_mask128 = _mm_set1_epi8(0x0F);
let data = _mm_loadu_si128(src_ptr.add(i).cast());
let acc_v = _mm_loadu_si128(acc_ptr.add(i).cast());
let lo_idx = _mm_and_si128(data, lo_mask128);
let hi_idx = _mm_and_si128(_mm_srli_epi16(data, 4), lo_mask128);
let product = _mm_xor_si128(
_mm_shuffle_epi8(lo_tbl128, lo_idx),
_mm_shuffle_epi8(hi_tbl128, hi_idx),
);
_mm_storeu_si128(acc_ptr.add(i).cast(), _mm_xor_si128(acc_v, product));
i += 16;
}
let j = i / 8;
if j < legs {
m_vec_mul_add_scalar(&src[j..], a, &mut acc[j..], legs - j);
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn m_vec_mul_add_neon(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
use std::arch::aarch64::*;
let mut lo_bytes = [0u8; 16];
let mut hi_bytes = [0u8; 16];
for i in 0..16u8 {
let m = mul_f(a, i);
lo_bytes[i as usize] = m;
hi_bytes[i as usize] = m << 4;
}
unsafe {
let lo_tbl = vld1q_u8(lo_bytes.as_ptr());
let hi_tbl = vld1q_u8(hi_bytes.as_ptr());
let lo_mask = vdupq_n_u8(0x0F);
let src_ptr = src.as_ptr().cast::<u8>();
let acc_ptr = acc.as_mut_ptr().cast::<u8>();
let total_bytes = legs * 8;
let mut i = 0usize;
while i + 16 <= total_bytes {
let data = vld1q_u8(src_ptr.add(i));
let acc_v = vld1q_u8(acc_ptr.add(i));
let lo_idx = vandq_u8(data, lo_mask);
let hi_idx = vshrq_n_u8::<4>(data);
let product = veorq_u8(vqtbl1q_u8(lo_tbl, lo_idx), vqtbl1q_u8(hi_tbl, hi_idx));
vst1q_u8(acc_ptr.add(i), veorq_u8(acc_v, product));
i += 16;
}
let j = i / 8;
if j < legs {
m_vec_mul_add_scalar(&src[j..], a, &mut acc[j..], legs - j);
}
}
}
const SIMD_MIN_LIMBS: usize = 256;
#[cfg(target_arch = "aarch64")]
#[inline]
fn dispatch_mul_add(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
if legs >= SIMD_MIN_LIMBS {
unsafe { m_vec_mul_add_neon(src, a, acc, legs) }
} else {
m_vec_mul_add_scalar(src, a, acc, legs)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline]
fn dispatch_mul_add(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
if legs >= SIMD_MIN_LIMBS {
if is_x86_feature_detected!("avx2") {
unsafe { m_vec_mul_add_avx2(src, a, acc, legs) }
} else if is_x86_feature_detected!("ssse3") {
unsafe { m_vec_mul_add_ssse3(src, a, acc, legs) }
} else {
m_vec_mul_add_scalar(src, a, acc, legs)
}
} else {
m_vec_mul_add_scalar(src, a, acc, legs)
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
#[inline]
fn dispatch_mul_add(src: &[u64], a: u8, acc: &mut [u64], legs: usize) {
m_vec_mul_add_scalar(src, a, acc, legs)
}
#[inline]
pub(crate) fn m_vec_mul_add(src: &[u64], a: u8, acc: &mut [u64], m_vec_limbs: usize) {
let src = &src[..m_vec_limbs];
let acc = &mut acc[..m_vec_limbs];
dispatch_mul_add(src, a, acc, m_vec_limbs);
}
#[inline]
pub(crate) fn vec_mul_add_u64(legs: usize, src: &[u64], a: u8, acc: &mut [u64]) {
let src = &src[..legs];
let acc = &mut acc[..legs];
dispatch_mul_add(src, a, acc, legs);
}
#[inline]
fn bins_mul_add_x_inv(bins: &mut [u64], src: usize, dst: usize, n: usize) {
for i in 0..n {
let t = bins[src + i] & MASK_LSB;
bins[dst + i] ^= ((bins[src + i] ^ t) >> 1) ^ (t.wrapping_mul(9));
}
}
#[inline]
fn bins_mul_add_x(bins: &mut [u64], src: usize, dst: usize, n: usize) {
for i in 0..n {
let t = bins[src + i] & MASK_MSB;
bins[dst + i] ^= ((bins[src + i] ^ t) << 1) ^ ((t >> 3).wrapping_mul(3));
}
}
pub(crate) fn m_vec_multiply_bins(bins: &mut [u64], out: &mut [u64], m_vec_limbs: usize) {
let mvl = m_vec_limbs;
bins_mul_add_x_inv(bins, 5 * mvl, 10 * mvl, mvl);
bins_mul_add_x(bins, 11 * mvl, 12 * mvl, mvl);
bins_mul_add_x_inv(bins, 10 * mvl, 7 * mvl, mvl);
bins_mul_add_x(bins, 12 * mvl, 6 * mvl, mvl);
bins_mul_add_x_inv(bins, 7 * mvl, 14 * mvl, mvl);
bins_mul_add_x(bins, 6 * mvl, 3 * mvl, mvl);
bins_mul_add_x_inv(bins, 14 * mvl, 15 * mvl, mvl);
bins_mul_add_x(bins, 3 * mvl, 8 * mvl, mvl);
bins_mul_add_x_inv(bins, 15 * mvl, 13 * mvl, mvl);
bins_mul_add_x(bins, 8 * mvl, 4 * mvl, mvl);
bins_mul_add_x_inv(bins, 13 * mvl, 9 * mvl, mvl);
bins_mul_add_x(bins, 4 * mvl, 2 * mvl, mvl);
bins_mul_add_x_inv(bins, 9 * mvl, mvl, mvl);
bins_mul_add_x(bins, 2 * mvl, mvl, mvl);
out[..mvl].copy_from_slice(&bins[mvl..2 * mvl]);
}
#[cfg(test)]
mod tests {
use super::*;
fn run<F: FnMut(&[u64], u8, &mut [u64], usize)>(
mut f: F,
seed: u64,
a: u8,
legs: usize,
) -> Vec<u64> {
let mut s = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
let mut next = || {
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
s
};
let mut src = vec![0u64; legs];
let mut acc = vec![0u64; legs];
for k in 0..legs {
src[k] = next();
acc[k] = next();
}
f(&src, a, &mut acc, legs);
acc
}
#[test]
fn simd_paths_match_scalar() {
let sizes = (1..=33usize).chain([255usize, 256, 257, 260, 300, 384, 512]);
for legs in sizes {
for a in 0u8..16 {
let seed = 0x5EED ^ (legs as u64) ^ ((a as u64) << 32);
let expected = run(m_vec_mul_add_scalar, seed, a, legs);
let got = run(dispatch_mul_add, seed, a, legs);
assert_eq!(expected, got, "dispatch != scalar (legs={legs}, a={a})");
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("ssse3") {
let g = run(
|s, a, acc, l| unsafe { m_vec_mul_add_ssse3(s, a, acc, l) },
seed,
a,
legs,
);
assert_eq!(expected, g, "ssse3 != scalar (legs={legs}, a={a})");
}
if is_x86_feature_detected!("avx2") {
let g = run(
|s, a, acc, l| unsafe { m_vec_mul_add_avx2(s, a, acc, l) },
seed,
a,
legs,
);
assert_eq!(expected, g, "avx2 != scalar (legs={legs}, a={a})");
}
}
}
}
}
#[test]
#[ignore = "timing benchmark; run with --release --ignored --nocapture"]
#[allow(clippy::cast_precision_loss, clippy::type_complexity)]
fn timing_mul_add() {
use std::hint::black_box;
use std::time::Instant;
let iters: u32 = 4_000_000;
let a: u8 = 0xB;
for &legs in &[4usize, 5, 7, 9, 16, 32, 64, 128, 256, 512, 1024] {
let mut src = vec![0u64; legs];
for (k, v) in src.iter_mut().enumerate() {
*v = 0x0123_4567_89AB_CDEFu64.wrapping_mul(k as u64 + 1) ^ 0xDEAD_BEEF_CAFE_BABE;
}
let time = |label: &str, mut f: Box<dyn FnMut(&mut [u64])>| -> f64 {
let mut acc = vec![0u64; legs];
let start = Instant::now();
for _ in 0..iters {
f(&mut acc);
black_box(&acc);
}
let ns = start.elapsed().as_nanos() as f64 / iters as f64;
println!(" {label:<8} {ns:7.3} ns/op");
ns
};
println!("\nm_vec_mul_add timing (legs={legs}, {iters} iters):");
let src_s = src.clone();
let scalar = time(
"scalar",
Box::new(move |acc| {
m_vec_mul_add_scalar(black_box(&src_s), black_box(a), acc, legs)
}),
);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("ssse3") {
let src_x = src.clone();
let ssse3 = time(
"ssse3",
Box::new(move |acc| unsafe {
m_vec_mul_add_ssse3(black_box(&src_x), black_box(a), acc, legs)
}),
);
println!(" ssse3 vs scalar: {:.2}x", scalar / ssse3);
if is_x86_feature_detected!("avx2") {
let src_y = src.clone();
let avx2 = time(
"avx2",
Box::new(move |acc| unsafe {
m_vec_mul_add_avx2(black_box(&src_y), black_box(a), acc, legs)
}),
);
println!(" avx2 vs scalar: {:.2}x", scalar / avx2);
println!(" avx2 vs ssse3: {:.2}x", ssse3 / avx2);
}
}
}
}
}
}