use crate::detect::has_avx2;
const SM4_GF_POLY: u8 = 0xF5;
const A_FIRST_ROW: u8 = 0xD3;
const AFFINE_B: u8 = 0xD3;
const A_ROWS: [u8; 8] = [
A_FIRST_ROW.rotate_right(0),
A_FIRST_ROW.rotate_right(1),
A_FIRST_ROW.rotate_right(2),
A_FIRST_ROW.rotate_right(3),
A_FIRST_ROW.rotate_right(4),
A_FIRST_ROW.rotate_right(5),
A_FIRST_ROW.rotate_right(6),
A_FIRST_ROW.rotate_right(7),
];
#[inline]
const fn gf_mul(mut a: u8, mut b: u8) -> u8 {
let mut r: u8 = 0;
let mut i = 0;
while i < 8 {
let mask = 0u8.wrapping_sub(b & 1);
r ^= a & mask;
let high = 0u8.wrapping_sub((a >> 7) & 1);
a = (a << 1) ^ (SM4_GF_POLY & high);
b >>= 1;
i += 1;
}
r
}
#[inline]
const fn gf_inv(x: u8) -> u8 {
let x2 = gf_mul(x, x);
let x4 = gf_mul(x2, x2);
let x8 = gf_mul(x4, x4);
let x16 = gf_mul(x8, x8);
let x32 = gf_mul(x16, x16);
let x64 = gf_mul(x32, x32);
let x128 = gf_mul(x64, x64);
let r1 = gf_mul(x128, x64);
let r2 = gf_mul(r1, x32);
let r3 = gf_mul(r2, x16);
let r4 = gf_mul(r3, x8);
let r5 = gf_mul(r4, x4);
gf_mul(r5, x2)
}
#[inline]
const fn affine_a(x: u8) -> u8 {
let mut out: u8 = 0;
let mut i = 0u32;
while i < 8 {
let row = A_ROWS[i as usize];
let prod = row & x;
let parity = (prod.count_ones() & 1) as u8;
out |= parity << (7 - i);
i += 1;
}
out
}
#[inline]
#[must_use]
const fn sbox_byte(x: u8) -> u8 {
let pre = affine_a(x) ^ AFFINE_B;
let inv = gf_inv(pre);
affine_a(inv) ^ AFFINE_B
}
#[must_use]
pub fn sbox_x8_scalar(input: &[u8; 8]) -> [u8; 8] {
let mut out = [0u8; 8];
let mut i = 0;
while i < 8 {
out[i] = sbox_byte(input[i]);
i += 1;
}
out
}
#[must_use]
#[inline]
pub fn sbox_x8(input: &[u8; 8]) -> [u8; 8] {
#[cfg(target_arch = "x86_64")]
{
if has_avx2() {
return unsafe { sbox_x8_avx2(input) };
}
}
let _ = has_avx2();
sbox_x8_scalar(input)
}
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
__m256i, _mm256_add_epi8, _mm256_and_si256, _mm256_cmpgt_epi8, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi8, _mm256_setzero_si256, _mm256_slli_epi16, _mm256_srli_epi16,
_mm256_storeu_si256, _mm256_sub_epi8, _mm256_xor_si256,
};
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn sbox_x8_avx2(input: &[u8; 8]) -> [u8; 8] {
let mut staged = [0u8; 32];
staged[..8].copy_from_slice(input);
let x = _mm256_loadu_si256(staged.as_ptr().cast::<__m256i>());
let b_const = _mm256_set1_epi8(AFFINE_B as i8);
let pre = _mm256_xor_si256(affine_a_simd(x), b_const);
let inv = gf_inv_simd(pre);
let out = _mm256_xor_si256(affine_a_simd(inv), b_const);
let mut staged_out = [0u8; 32];
_mm256_storeu_si256(staged_out.as_mut_ptr().cast::<__m256i>(), out);
let mut result = [0u8; 8];
result.copy_from_slice(&staged_out[..8]);
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn gf_mul_simd(mut a: __m256i, mut b: __m256i) -> __m256i {
let mut r = _mm256_setzero_si256();
let one = _mm256_set1_epi8(1);
let poly = _mm256_set1_epi8(SM4_GF_POLY as i8);
let mask_lo7 = _mm256_set1_epi8(0x7F);
let mut i = 0;
while i < 8 {
let bit0 = _mm256_and_si256(b, one);
let mask = _mm256_sub_epi8(_mm256_setzero_si256(), bit0);
r = _mm256_xor_si256(r, _mm256_and_si256(a, mask));
let high = _mm256_cmpgt_epi8(_mm256_setzero_si256(), a);
let a_shl1 = _mm256_add_epi8(a, a);
a = _mm256_xor_si256(a_shl1, _mm256_and_si256(poly, high));
let b_shr1 = _mm256_srli_epi16(b, 1);
b = _mm256_and_si256(b_shr1, mask_lo7);
i += 1;
}
r
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn gf_inv_simd(x: __m256i) -> __m256i {
let x2 = gf_mul_simd(x, x);
let x4 = gf_mul_simd(x2, x2);
let x8 = gf_mul_simd(x4, x4);
let x16 = gf_mul_simd(x8, x8);
let x32 = gf_mul_simd(x16, x16);
let x64 = gf_mul_simd(x32, x32);
let x128 = gf_mul_simd(x64, x64);
let r1 = gf_mul_simd(x128, x64);
let r2 = gf_mul_simd(r1, x32);
let r3 = gf_mul_simd(r2, x16);
let r4 = gf_mul_simd(r3, x8);
let r5 = gf_mul_simd(r4, x4);
gf_mul_simd(r5, x2)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn affine_a_simd(x: __m256i) -> __m256i {
let mut out = _mm256_setzero_si256();
let row0 = _mm256_set1_epi8(A_ROWS[0] as i8);
let row1 = _mm256_set1_epi8(A_ROWS[1] as i8);
let row2 = _mm256_set1_epi8(A_ROWS[2] as i8);
let row3 = _mm256_set1_epi8(A_ROWS[3] as i8);
let row4 = _mm256_set1_epi8(A_ROWS[4] as i8);
let row5 = _mm256_set1_epi8(A_ROWS[5] as i8);
let row6 = _mm256_set1_epi8(A_ROWS[6] as i8);
let row7 = _mm256_set1_epi8(A_ROWS[7] as i8);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row0, x)), 7),
);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row1, x)), 6),
);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row2, x)), 5),
);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row3, x)), 4),
);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row4, x)), 3),
);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row5, x)), 2),
);
out = _mm256_or_si256(
out,
_mm256_slli_epi16(parity_simd(_mm256_and_si256(row6, x)), 1),
);
out = _mm256_or_si256(out, parity_simd(_mm256_and_si256(row7, x)));
out
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn parity_simd(x: __m256i) -> __m256i {
let p = _mm256_xor_si256(x, _mm256_srli_epi16(x, 4));
let p = _mm256_xor_si256(p, _mm256_srli_epi16(p, 2));
let p = _mm256_xor_si256(p, _mm256_srli_epi16(p, 1));
_mm256_and_si256(p, _mm256_set1_epi8(1))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scalar_sbox_self_consistent_on_zero_one() {
assert_eq!(sbox_byte(0x00), 0xD6);
assert_eq!(sbox_byte(0x01), 0x90);
}
#[test]
fn scalar_mixed_lanes() {
let input: [u8; 8] = [0x00, 0x01, 0x55, 0xAA, 0xFF, 0x80, 0x7F, 0x42];
let out = sbox_x8_scalar(&input);
for (lane, (&inp, &got)) in input.iter().zip(out.iter()).enumerate() {
let expected = sbox_byte(inp);
assert_eq!(
got, expected,
"lane {lane} disagrees at input 0x{inp:02x}: got 0x{got:02x}, want 0x{expected:02x}",
);
}
}
}