const POLY: usize = 0x11D;
fn build_tables() -> ([u8; 512], [u8; 256]) {
let mut exp = [0u8; 512];
let mut log = [0u8; 256];
let mut val: usize = 1;
for i in 0..255usize {
exp[i] = val as u8;
exp[i + 255] = val as u8; log[val] = i as u8;
val <<= 1;
if val & 0x100 != 0 {
val ^= POLY;
}
}
exp[255] = 1;
(exp, log)
}
use std::sync::OnceLock;
static TABLES: OnceLock<([u8; 512], [u8; 256])> = OnceLock::new();
fn tables() -> &'static ([u8; 512], [u8; 256]) {
TABLES.get_or_init(build_tables)
}
#[inline]
#[must_use]
pub fn gf_mul_lut(a: u8, b: u8) -> u8 {
if a == 0 || b == 0 {
return 0;
}
let (exp, log) = tables();
let la = log[a as usize] as usize;
let lb = log[b as usize] as usize;
exp[la + lb]
}
#[must_use]
pub fn gf_mul_naive(mut a: u8, mut b: u8) -> u8 {
let mut result = 0u8;
while b > 0 {
if b & 1 != 0 {
result ^= a;
}
let carry = a & 0x80;
a <<= 1;
if carry != 0 {
a ^= 0x1D; }
b >>= 1;
}
result
}
pub fn gf_mul_slice_simd(src: &[u8], scalar: u8, dst: &mut [u8]) {
assert!(
dst.len() >= src.len(),
"dst too short: {} < {}",
dst.len(),
src.len()
);
if scalar == 0 {
dst[..src.len()].fill(0);
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
#[allow(unsafe_code)]
return unsafe { gf_mul_slice_avx2(src, scalar, dst) };
}
if is_x86_feature_detected!("ssse3") {
#[allow(unsafe_code)]
return unsafe { gf_mul_slice_ssse3(src, scalar, dst) };
}
}
gf_mul_slice_scalar(src, scalar, dst);
}
fn gf_mul_slice_scalar(src: &[u8], scalar: u8, dst: &mut [u8]) {
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = gf_mul_lut(*s, scalar);
}
}
#[cfg(target_arch = "x86_64")]
fn build_nibble_tables(scalar: u8) -> ([u8; 16], [u8; 16]) {
let mut lo = [0u8; 16];
let mut hi = [0u8; 16];
for i in 0u8..16 {
lo[i as usize] = gf_mul_lut(i, scalar);
hi[i as usize] = gf_mul_lut(i << 4, scalar);
}
(lo, hi)
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_code)]
#[allow(clippy::cast_ptr_alignment)]
#[target_feature(enable = "avx2")]
unsafe fn gf_mul_slice_avx2(src: &[u8], scalar: u8, dst: &mut [u8]) {
use std::arch::x86_64::*;
let (lo16, hi16) = build_nibble_tables(scalar);
let lo_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(lo16.as_ptr().cast::<__m128i>()));
let hi_vec = _mm256_broadcastsi128_si256(_mm_loadu_si128(hi16.as_ptr().cast::<__m128i>()));
let mask_lo = _mm256_set1_epi8(0x0F_u8 as i8);
let chunks = src.len() / 32;
let mut offset = 0usize;
for _ in 0..chunks {
let data = _mm256_loadu_si256(src.as_ptr().add(offset).cast::<__m256i>());
let lo_idx = _mm256_and_si256(data, mask_lo);
let lo_res = _mm256_shuffle_epi8(lo_vec, lo_idx);
let hi_idx = _mm256_and_si256(_mm256_srli_epi16(data, 4), mask_lo);
let hi_res = _mm256_shuffle_epi8(hi_vec, hi_idx);
let result = _mm256_xor_si256(lo_res, hi_res);
_mm256_storeu_si256(dst.as_mut_ptr().add(offset).cast::<__m256i>(), result);
offset += 32;
}
gf_mul_slice_scalar(&src[offset..], scalar, &mut dst[offset..]);
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_code)]
#[allow(clippy::cast_ptr_alignment)]
#[target_feature(enable = "ssse3")]
unsafe fn gf_mul_slice_ssse3(src: &[u8], scalar: u8, dst: &mut [u8]) {
use std::arch::x86_64::*;
let (lo16, hi16) = build_nibble_tables(scalar);
let lo_vec = _mm_loadu_si128(lo16.as_ptr().cast::<__m128i>());
let hi_vec = _mm_loadu_si128(hi16.as_ptr().cast::<__m128i>());
let mask_lo = _mm_set1_epi8(0x0F_u8 as i8);
let chunks = src.len() / 16;
let mut offset = 0usize;
for _ in 0..chunks {
let data = _mm_loadu_si128(src.as_ptr().add(offset).cast::<__m128i>());
let lo_idx = _mm_and_si128(data, mask_lo);
let lo_res = _mm_shuffle_epi8(lo_vec, lo_idx);
let hi_idx = _mm_and_si128(_mm_srli_epi16(data, 4), mask_lo);
let hi_res = _mm_shuffle_epi8(hi_vec, hi_idx);
let result = _mm_xor_si128(lo_res, hi_res);
_mm_storeu_si128(dst.as_mut_ptr().add(offset).cast::<__m128i>(), result);
offset += 16;
}
gf_mul_slice_scalar(&src[offset..], scalar, &mut dst[offset..]);
}
pub struct SimdRsEncoder {
data_shards: usize,
parity_shards: usize,
pub gen_matrix: Vec<Vec<u8>>,
}
impl SimdRsEncoder {
#[must_use]
pub fn new(data_shards: usize, parity_shards: usize) -> Self {
assert!(data_shards >= 1, "data_shards must be >= 1");
assert!(parity_shards >= 1, "parity_shards must be >= 1");
assert!(
data_shards + parity_shards <= 255,
"total shards must be <= 255"
);
let gen_matrix = Self::build_generator(data_shards, parity_shards);
Self {
data_shards,
parity_shards,
gen_matrix,
}
}
fn build_generator(data: usize, parity: usize) -> Vec<Vec<u8>> {
let (exp, _) = tables();
(0..parity)
.map(|pi| {
let base = (data + pi) % 255;
(0..data)
.map(|di| if di == 0 { 1u8 } else { exp[(base * di) % 255] })
.collect()
})
.collect()
}
pub fn encode(&self, data: &[&[u8]]) -> Vec<Vec<u8>> {
assert_eq!(
data.len(),
self.data_shards,
"expected {} data shards, got {}",
self.data_shards,
data.len()
);
let shard_len = data[0].len();
for d in data.iter() {
assert_eq!(
d.len(),
shard_len,
"all data shards must be the same length"
);
}
let mut parity: Vec<Vec<u8>> = (0..self.parity_shards)
.map(|_| vec![0u8; shard_len])
.collect();
let mut tmp = vec![0u8; shard_len];
for (i, par) in parity.iter_mut().enumerate() {
for (j, &data_shard) in data.iter().enumerate() {
let coeff = self.gen_matrix[i][j];
gf_mul_slice_simd(data_shard, coeff, &mut tmp);
for (p, &t) in par.iter_mut().zip(tmp.iter()) {
*p ^= t;
}
}
}
parity
}
#[must_use]
pub fn data_shards(&self) -> usize {
self.data_shards
}
#[must_use]
pub fn parity_shards(&self) -> usize {
self.parity_shards
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gf_mul_lut_matches_naive() {
for a in 0u8..=255 {
for b in 0u8..=255 {
let lut = gf_mul_lut(a, b);
let naive = gf_mul_naive(a, b);
assert_eq!(lut, naive, "gf_mul_lut({a},{b}) = {lut} != naive = {naive}");
}
}
}
#[test]
fn test_gf_mul_slice_simd_matches_scalar() {
let src: Vec<u8> = (0u8..=255).collect();
for scalar in 0u8..=255 {
let mut dst_simd = vec![0u8; src.len()];
let mut dst_scalar = vec![0u8; src.len()];
gf_mul_slice_simd(&src, scalar, &mut dst_simd);
gf_mul_slice_scalar(&src, scalar, &mut dst_scalar);
assert_eq!(
dst_simd, dst_scalar,
"simd vs scalar mismatch at scalar={scalar}"
);
}
}
#[test]
fn test_rs_encode_still_correct_after_simd() {
let data_shards = 4usize;
let parity_shards = 2usize;
let shard_len = 64usize;
let raw_data: Vec<Vec<u8>> = (0..data_shards)
.map(|i| {
(0u8..shard_len as u8)
.map(|b| b.wrapping_mul(i as u8 + 1))
.collect()
})
.collect();
let enc = SimdRsEncoder::new(data_shards, parity_shards);
let data_refs: Vec<&[u8]> = raw_data.iter().map(|s| s.as_slice()).collect();
let parity = enc.encode(&data_refs);
let parity2 = enc.encode(&data_refs);
assert_eq!(
parity, parity2,
"encode is deterministic: same inputs must give same parity"
);
assert_eq!(parity.len(), parity_shards);
assert!(parity.iter().all(|p| p.len() == shard_len));
let gen_row_0 = &enc.gen_matrix[0];
let mut expected_par0 = vec![0u8; shard_len];
for (j, data_j) in raw_data.iter().enumerate() {
let coeff = gen_row_0[j];
for (e, &d) in expected_par0.iter_mut().zip(data_j.iter()) {
*e ^= gf_mul_naive(d, coeff);
}
}
assert_eq!(
parity[0], expected_par0,
"parity shard 0 from SIMD path must match naive GF computation"
);
}
#[test]
fn test_gf_mul_zero() {
for a in 0u8..=255 {
assert_eq!(gf_mul_lut(a, 0), 0);
assert_eq!(gf_mul_lut(0, a), 0);
}
}
#[test]
fn test_gf_mul_one_identity() {
for a in 0u8..=255 {
assert_eq!(gf_mul_lut(a, 1), a);
}
}
#[test]
fn test_gf_mul_commutative() {
for a in 0u8..=32 {
for b in 0u8..=32 {
assert_eq!(gf_mul_lut(a, b), gf_mul_lut(b, a));
}
}
}
#[test]
fn test_gf_mul_slice_simd_zero_scalar() {
let src: Vec<u8> = (0..64).map(|i| i as u8).collect();
let mut dst = vec![0xFFu8; 64];
gf_mul_slice_simd(&src, 0, &mut dst);
assert!(dst.iter().all(|&b| b == 0), "all zeros for scalar=0");
}
#[test]
fn test_gf_mul_slice_simd_one_scalar() {
let src: Vec<u8> = (0..64).map(|i| i as u8).collect();
let mut dst = vec![0u8; 64];
gf_mul_slice_simd(&src, 1, &mut dst);
assert_eq!(dst, src, "identity for scalar=1");
}
#[test]
fn test_simd_rs_encoder_parity_length() {
let enc = SimdRsEncoder::new(5, 3);
let data: Vec<Vec<u8>> = (0..5).map(|i| vec![i as u8; 32]).collect();
let refs: Vec<&[u8]> = data.iter().map(|s| s.as_slice()).collect();
let parity = enc.encode(&refs);
assert_eq!(parity.len(), 3);
assert!(parity.iter().all(|p| p.len() == 32));
}
}