use core::hint::black_box;
use core::ptr;
use core::sync::atomic::{compiler_fence, Ordering};
#[cfg(feature = "ct_profile")]
use std::time::Instant;
#[cfg(feature = "ct_profile")]
mod profile {
use core::sync::atomic::{AtomicU64, Ordering};
static SUBSET_MASK8_CALLS: AtomicU64 = AtomicU64::new(0);
static PARITY128_CALLS: AtomicU64 = AtomicU64::new(0);
static EVAL_BYTE_SBOX_CALLS: AtomicU64 = AtomicU64::new(0);
#[inline(always)]
pub(super) fn bump_subset_mask8() {
SUBSET_MASK8_CALLS.fetch_add(1, Ordering::Relaxed);
}
#[inline(always)]
pub(super) fn bump_parity128() {
PARITY128_CALLS.fetch_add(1, Ordering::Relaxed);
}
#[inline(always)]
pub(super) fn bump_eval_byte_sbox() {
EVAL_BYTE_SBOX_CALLS.fetch_add(1, Ordering::Relaxed);
}
pub(super) fn reset() {
SUBSET_MASK8_CALLS.store(0, Ordering::Relaxed);
PARITY128_CALLS.store(0, Ordering::Relaxed);
EVAL_BYTE_SBOX_CALLS.store(0, Ordering::Relaxed);
}
pub(super) fn snapshot() -> super::CtAnfProfile {
super::CtAnfProfile {
subset_mask8_calls: SUBSET_MASK8_CALLS.load(Ordering::Relaxed),
parity128_calls: PARITY128_CALLS.load(Ordering::Relaxed),
eval_byte_sbox_calls: EVAL_BYTE_SBOX_CALLS.load(Ordering::Relaxed),
}
}
}
#[cfg(not(feature = "ct_profile"))]
mod profile {
#[inline(always)]
pub(super) fn bump_subset_mask8() {}
#[inline(always)]
pub(super) fn bump_parity128() {}
#[inline(always)]
pub(super) fn bump_eval_byte_sbox() {}
}
#[cfg(feature = "ct_profile")]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct CtAnfProfile {
pub subset_mask8_calls: u64,
pub parity128_calls: u64,
pub eval_byte_sbox_calls: u64,
}
#[cfg(feature = "ct_profile")]
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub struct CtAnfHelperCostsNs {
pub subset_mask8_ns: f64,
pub parity128_ns: f64,
pub eval_byte_sbox_ns: f64,
}
#[cfg(feature = "ct_profile")]
pub fn ct_profile_reset() {
profile::reset();
}
#[cfg(feature = "ct_profile")]
#[must_use]
pub fn ct_profile_snapshot() -> CtAnfProfile {
profile::snapshot()
}
#[cfg(feature = "ct_profile")]
#[must_use]
pub fn ct_profile_measure_helper_costs(iterations: u64) -> CtAnfHelperCostsNs {
let mut input = 0u8;
let mut acc = 0u64;
let t_subset = Instant::now();
for _ in 0..iterations {
let (lo, hi) = subset_mask8(input);
acc ^= (lo as u64) ^ ((hi >> 64) as u64);
input = input.wrapping_add(1);
}
let subset_ns = t_subset.elapsed().as_secs_f64() * 1e9 / iterations as f64;
let t_parity = Instant::now();
let mut x = 0x0123_4567_89ab_cdef_0011_2233_4455_6677u128;
for _ in 0..iterations {
acc ^= u64::from(parity128(x));
x = x.rotate_left(13) ^ 0x9e37_79b9_7f4a_7c15_6a09_e667_f3bc_c909u128;
}
let parity_ns = t_parity.elapsed().as_secs_f64() * 1e9 / iterations as f64;
let mut table = [0u8; 256];
let mut i = 0usize;
while i < 256 {
table[i] = i as u8;
i += 1;
}
let coeffs = build_byte_sbox_anf(&table);
let t_eval = Instant::now();
let mut y = 0u8;
for _ in 0..iterations {
y = y.wrapping_add(17);
acc ^= u64::from(eval_byte_sbox(&coeffs, y));
}
let eval_ns = t_eval.elapsed().as_secs_f64() * 1e9 / iterations as f64;
black_box(acc);
CtAnfHelperCostsNs {
subset_mask8_ns: subset_ns,
parity128_ns: parity_ns,
eval_byte_sbox_ns: eval_ns,
}
}
#[inline]
fn eq_mask_u32(a: u8, b: u8) -> u32 {
let x = u16::from(a ^ b);
let is_zero = u32::from((x.wrapping_sub(1) >> 8) & 1);
0u32.wrapping_sub(is_zero)
}
#[inline]
fn eq_mask_u8(a: u8, b: u8) -> u8 {
let x = u16::from(a ^ b);
let is_zero = ((x.wrapping_sub(1) >> 8) & 1) as u8;
0u8.wrapping_sub(is_zero)
}
pub fn zeroize_slice<T: Copy + Default>(slice: &mut [T]) {
for item in slice.iter_mut() {
unsafe { ptr::write_volatile(std::ptr::from_mut::<T>(item), T::default()) };
}
compiler_fence(Ordering::SeqCst);
}
pub(crate) fn ct_lookup_u32(table: &[u32; 256], idx: u8) -> u32 {
let mut out = 0u32;
let mut i = 0usize;
while i < 256 {
let table_index = i as u8;
out |= table[i] & eq_mask_u32(table_index, idx);
i += 1;
}
out
}
pub(crate) fn ct_lookup_u8_16(table: &[u8; 16], idx: u8) -> u8 {
let mut out = 0u8;
let mut i = 0usize;
while i < 16 {
let table_index = i as u8;
out |= table[i] & eq_mask_u8(table_index, idx);
i += 1;
}
out
}
#[inline]
pub(crate) fn constant_time_eq_mask(a: &[u8], b: &[u8]) -> u8 {
if a.len() != b.len() {
return 0;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= *x ^ *y;
}
let diff = black_box(diff);
compiler_fence(Ordering::SeqCst);
eq_mask_u8(diff, 0)
}
pub(crate) const fn build_byte_sbox_anf(table: &[u8; 256]) -> [[u128; 2]; 8] {
let mut out = [[0u128; 2]; 8];
let mut bit_idx = 0usize;
while bit_idx < 8 {
let mut coeffs = [0u8; 256];
let mut x = 0usize;
while x < 256 {
coeffs[x] = (table[x] >> bit_idx) & 1;
x += 1;
}
let mut var = 0usize;
while var < 8 {
let stride = 1usize << var;
let mut mask = 0usize;
while mask < 256 {
if mask & stride != 0 {
coeffs[mask] ^= coeffs[mask ^ stride];
}
mask += 1;
}
var += 1;
}
let mut lo = 0u128;
let mut hi = 0u128;
let mut monomial = 0usize;
while monomial < 128 {
lo |= (coeffs[monomial] as u128) << monomial;
monomial += 1;
}
while monomial < 256 {
hi |= (coeffs[monomial] as u128) << (monomial - 128);
monomial += 1;
}
out[bit_idx][0] = lo;
out[bit_idx][1] = hi;
bit_idx += 1;
}
out
}
pub(crate) const fn build_nibble_sbox_anf(table: &[u8; 16]) -> [u16; 4] {
let mut out = [0u16; 4];
let mut bit_idx = 0usize;
while bit_idx < 4 {
let mut coeffs = [0u8; 16];
let mut x = 0usize;
while x < 16 {
coeffs[x] = (table[x] >> bit_idx) & 1;
x += 1;
}
let mut var = 0usize;
while var < 4 {
let stride = 1usize << var;
let mut mask = 0usize;
while mask < 16 {
if mask & stride != 0 {
coeffs[mask] ^= coeffs[mask ^ stride];
}
mask += 1;
}
var += 1;
}
let mut packed = 0u16;
let mut monomial = 0usize;
while monomial < 16 {
packed |= (coeffs[monomial] as u16) << monomial;
monomial += 1;
}
out[bit_idx] = packed;
bit_idx += 1;
}
out
}
#[inline]
pub(crate) fn subset_mask8(x: u8) -> (u128, u128) {
profile::bump_subset_mask8();
let mut lo = 1u128;
let mut hi = 0u128;
let mask0 = 0u128.wrapping_sub(u128::from(x & 1));
let add_lo = lo << 1;
let add_hi = (hi << 1) | (lo >> 127);
lo |= add_lo & mask0;
hi |= add_hi & mask0;
let mask1 = 0u128.wrapping_sub(u128::from((x >> 1) & 1));
let add_lo = lo << 2;
let add_hi = (hi << 2) | (lo >> 126);
lo |= add_lo & mask1;
hi |= add_hi & mask1;
let mask2 = 0u128.wrapping_sub(u128::from((x >> 2) & 1));
let add_lo = lo << 4;
let add_hi = (hi << 4) | (lo >> 124);
lo |= add_lo & mask2;
hi |= add_hi & mask2;
let mask3 = 0u128.wrapping_sub(u128::from((x >> 3) & 1));
let add_lo = lo << 8;
let add_hi = (hi << 8) | (lo >> 120);
lo |= add_lo & mask3;
hi |= add_hi & mask3;
let mask4 = 0u128.wrapping_sub(u128::from((x >> 4) & 1));
let add_lo = lo << 16;
let add_hi = (hi << 16) | (lo >> 112);
lo |= add_lo & mask4;
hi |= add_hi & mask4;
let mask5 = 0u128.wrapping_sub(u128::from((x >> 5) & 1));
let add_lo = lo << 32;
let add_hi = (hi << 32) | (lo >> 96);
lo |= add_lo & mask5;
hi |= add_hi & mask5;
let mask6 = 0u128.wrapping_sub(u128::from((x >> 6) & 1));
let add_lo = lo << 64;
let add_hi = (hi << 64) | (lo >> 64);
lo |= add_lo & mask6;
hi |= add_hi & mask6;
let mask7 = 0u128.wrapping_sub(u128::from((x >> 7) & 1));
hi |= lo & mask7;
(lo, hi)
}
#[inline]
pub(crate) fn parity128(x: u128) -> u8 {
profile::bump_parity128();
let lo = x as u64;
let hi = (x >> 64) as u64;
((lo.count_ones() ^ hi.count_ones()) & 1) as u8
}
#[inline]
pub(crate) fn eval_byte_sbox(coeffs: &[[u128; 2]; 8], input: u8) -> u8 {
profile::bump_eval_byte_sbox();
let (active_lo, active_hi) = subset_mask8(input);
let mut out = 0u8;
let mut bit_idx = 0usize;
while bit_idx < 8 {
let coeff_lo = coeffs[bit_idx][0];
let coeff_hi = coeffs[bit_idx][1];
let bit = parity128((active_lo & coeff_lo) ^ (active_hi & coeff_hi));
out |= bit << bit_idx;
bit_idx += 1;
}
out
}
#[inline]
pub(crate) fn subset_mask4(x: u8) -> u16 {
let mut mask = 1u16;
let b0 = 0u16.wrapping_sub(u16::from(x & 1));
mask |= (mask << 1) & b0;
let b1 = 0u16.wrapping_sub(u16::from((x >> 1) & 1));
mask |= (mask << 2) & b1;
let b2 = 0u16.wrapping_sub(u16::from((x >> 2) & 1));
mask |= (mask << 4) & b2;
let b3 = 0u16.wrapping_sub(u16::from((x >> 3) & 1));
mask |= (mask << 8) & b3;
mask
}
#[inline]
pub(crate) fn parity16(mut x: u16) -> u8 {
x ^= x >> 8;
x ^= x >> 4;
x &= 0x0f;
((0x6996u16 >> x) & 1) as u8
}
#[inline]
pub(crate) fn eval_nibble_sbox(coeffs: [u16; 4], input: u8) -> u8 {
let active = subset_mask4(input);
let mut out = 0u8;
let mut bit = 0usize;
while bit < 4 {
out |= parity16(active & coeffs[bit]) << bit;
bit += 1;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn subset_mask8_zero_and_all_ones() {
let (lo0, hi0) = subset_mask8(0);
assert_eq!(lo0, 1);
assert_eq!(hi0, 0);
let (lof, hif) = subset_mask8(0xff);
assert_eq!(lof, u128::MAX);
assert_eq!(hif, u128::MAX);
}
#[test]
fn subset_mask8_single_bit() {
let (lo, hi) = subset_mask8(0x01);
assert_eq!(lo, 0b11);
assert_eq!(hi, 0);
let (lo, hi) = subset_mask8(0x80);
assert_eq!(lo, 1); assert_eq!(hi, 1); }
#[test]
fn parity_helpers_known_values() {
assert_eq!(parity128(0), 0);
assert_eq!(parity128(1), 1);
assert_eq!(parity128(0b1011), 1);
assert_eq!(parity128(u128::MAX), 0);
assert_eq!(parity16(0), 0);
assert_eq!(parity16(1), 1);
assert_eq!(parity16(0b1011), 1);
assert_eq!(parity16(0xffff), 0); }
#[test]
fn ct_lookup_u8_16_picks_exact_entry() {
let table = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
for i in 0u8..16 {
assert_eq!(ct_lookup_u8_16(&table, i), i);
}
}
#[test]
fn byte_sbox_anf_matches_direct_lookup_all_inputs() {
#[rustfmt::skip]
const AES_SBOX: [u8; 256] = [
0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16,
];
let coeffs = build_byte_sbox_anf(&AES_SBOX);
for x in 0u8..=255 {
assert_eq!(
eval_byte_sbox(&coeffs, x),
AES_SBOX[x as usize],
"ANF mismatch at input {x:#04x}"
);
}
}
#[test]
fn nibble_sbox_anf_matches_direct_lookup_all_inputs() {
const PRESENT_SBOX: [u8; 16] =
[0xC, 0x5, 0x6, 0xB, 0x9, 0x0, 0xA, 0xD, 0x3, 0xE, 0xF, 0x8, 0x4, 0x7, 0x1, 0x2];
let coeffs = build_nibble_sbox_anf(&PRESENT_SBOX);
for x in 0u8..16 {
assert_eq!(
eval_nibble_sbox(coeffs, x),
PRESENT_SBOX[x as usize],
"ANF mismatch at input {x:#03x}"
);
}
}
}