1use core::hint::black_box;
22use core::ptr;
23use core::sync::atomic::{compiler_fence, Ordering};
24
25#[cfg(feature = "ct_profile")]
26use std::time::Instant;
27
28#[cfg(feature = "ct_profile")]
29mod profile {
30 use core::sync::atomic::{AtomicU64, Ordering};
31
32 static SUBSET_MASK8_CALLS: AtomicU64 = AtomicU64::new(0);
33 static PARITY128_CALLS: AtomicU64 = AtomicU64::new(0);
34 static EVAL_BYTE_SBOX_CALLS: AtomicU64 = AtomicU64::new(0);
35
36 #[inline(always)]
37 pub(super) fn bump_subset_mask8() {
38 SUBSET_MASK8_CALLS.fetch_add(1, Ordering::Relaxed);
39 }
40
41 #[inline(always)]
42 pub(super) fn bump_parity128() {
43 PARITY128_CALLS.fetch_add(1, Ordering::Relaxed);
44 }
45
46 #[inline(always)]
47 pub(super) fn bump_eval_byte_sbox() {
48 EVAL_BYTE_SBOX_CALLS.fetch_add(1, Ordering::Relaxed);
49 }
50
51 pub(super) fn reset() {
52 SUBSET_MASK8_CALLS.store(0, Ordering::Relaxed);
53 PARITY128_CALLS.store(0, Ordering::Relaxed);
54 EVAL_BYTE_SBOX_CALLS.store(0, Ordering::Relaxed);
55 }
56
57 pub(super) fn snapshot() -> super::CtAnfProfile {
58 super::CtAnfProfile {
59 subset_mask8_calls: SUBSET_MASK8_CALLS.load(Ordering::Relaxed),
60 parity128_calls: PARITY128_CALLS.load(Ordering::Relaxed),
61 eval_byte_sbox_calls: EVAL_BYTE_SBOX_CALLS.load(Ordering::Relaxed),
62 }
63 }
64}
65
66#[cfg(not(feature = "ct_profile"))]
67mod profile {
68 #[inline(always)]
69 pub(super) fn bump_subset_mask8() {}
70
71 #[inline(always)]
72 pub(super) fn bump_parity128() {}
73
74 #[inline(always)]
75 pub(super) fn bump_eval_byte_sbox() {}
76}
77
78#[cfg(feature = "ct_profile")]
79#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
80pub struct CtAnfProfile {
81 pub subset_mask8_calls: u64,
82 pub parity128_calls: u64,
83 pub eval_byte_sbox_calls: u64,
84}
85
86#[cfg(feature = "ct_profile")]
87#[derive(Clone, Copy, Debug, Default, PartialEq)]
88pub struct CtAnfHelperCostsNs {
89 pub subset_mask8_ns: f64,
90 pub parity128_ns: f64,
91 pub eval_byte_sbox_ns: f64,
92}
93
94#[cfg(feature = "ct_profile")]
95pub fn ct_profile_reset() {
96 profile::reset();
97}
98
99#[cfg(feature = "ct_profile")]
100#[must_use]
101pub fn ct_profile_snapshot() -> CtAnfProfile {
102 profile::snapshot()
103}
104
105#[cfg(feature = "ct_profile")]
106#[must_use]
107pub fn ct_profile_measure_helper_costs(iterations: u64) -> CtAnfHelperCostsNs {
108 let mut input = 0u8;
109 let mut acc = 0u64;
110
111 let t_subset = Instant::now();
112 for _ in 0..iterations {
113 let (lo, hi) = subset_mask8(input);
114 acc ^= (lo as u64) ^ ((hi >> 64) as u64);
115 input = input.wrapping_add(1);
116 }
117 let subset_ns = t_subset.elapsed().as_secs_f64() * 1e9 / iterations as f64;
118
119 let t_parity = Instant::now();
120 let mut x = 0x0123_4567_89ab_cdef_0011_2233_4455_6677u128;
121 for _ in 0..iterations {
122 acc ^= u64::from(parity128(x));
123 x = x.rotate_left(13) ^ 0x9e37_79b9_7f4a_7c15_6a09_e667_f3bc_c909u128;
124 }
125 let parity_ns = t_parity.elapsed().as_secs_f64() * 1e9 / iterations as f64;
126
127 let mut table = [0u8; 256];
128 let mut i = 0usize;
129 while i < 256 {
130 table[i] = i as u8;
131 i += 1;
132 }
133 let coeffs = build_byte_sbox_anf(&table);
134 let t_eval = Instant::now();
135 let mut y = 0u8;
136 for _ in 0..iterations {
137 y = y.wrapping_add(17);
138 acc ^= u64::from(eval_byte_sbox(&coeffs, y));
139 }
140 let eval_ns = t_eval.elapsed().as_secs_f64() * 1e9 / iterations as f64;
141
142 black_box(acc);
143 CtAnfHelperCostsNs {
144 subset_mask8_ns: subset_ns,
145 parity128_ns: parity_ns,
146 eval_byte_sbox_ns: eval_ns,
147 }
148}
149
150#[inline]
161fn eq_mask_u32(a: u8, b: u8) -> u32 {
162 let x = u16::from(a ^ b);
167 let is_zero = u32::from((x.wrapping_sub(1) >> 8) & 1);
168 0u32.wrapping_sub(is_zero)
169}
170
171#[inline]
175fn eq_mask_u8(a: u8, b: u8) -> u8 {
176 let x = u16::from(a ^ b);
178 let is_zero = ((x.wrapping_sub(1) >> 8) & 1) as u8;
179 0u8.wrapping_sub(is_zero)
180}
181
182pub fn zeroize_slice<T: Copy + Default>(slice: &mut [T]) {
197 for item in slice.iter_mut() {
198 unsafe { ptr::write_volatile(std::ptr::from_mut::<T>(item), T::default()) };
199 }
200 compiler_fence(Ordering::SeqCst);
201}
202
203pub(crate) fn ct_lookup_u32(table: &[u32; 256], idx: u8) -> u32 {
213 let mut out = 0u32;
214 let mut i = 0usize;
215 while i < 256 {
216 let table_index = i as u8;
217 out |= table[i] & eq_mask_u32(table_index, idx);
218 i += 1;
219 }
220 out
221}
222
223pub(crate) fn ct_lookup_u8_16(table: &[u8; 16], idx: u8) -> u8 {
228 let mut out = 0u8;
229 let mut i = 0usize;
230 while i < 16 {
231 let table_index = i as u8;
232 out |= table[i] & eq_mask_u8(table_index, idx);
233 i += 1;
234 }
235 out
236}
237
238#[inline]
250pub(crate) fn constant_time_eq_mask(a: &[u8], b: &[u8]) -> u8 {
251 if a.len() != b.len() {
252 return 0;
253 }
254 let mut diff = 0u8;
255 for (x, y) in a.iter().zip(b.iter()) {
256 diff |= *x ^ *y;
257 }
258 let diff = black_box(diff);
261 compiler_fence(Ordering::SeqCst);
262 eq_mask_u8(diff, 0)
264}
265
266pub(crate) const fn build_byte_sbox_anf(table: &[u8; 256]) -> [[u128; 2]; 8] {
283 let mut out = [[0u128; 2]; 8];
284 let mut bit_idx = 0usize;
285 while bit_idx < 8 {
286 let mut coeffs = [0u8; 256];
287 let mut x = 0usize;
288 while x < 256 {
289 coeffs[x] = (table[x] >> bit_idx) & 1;
290 x += 1;
291 }
292
293 let mut var = 0usize;
296 while var < 8 {
297 let stride = 1usize << var;
298 let mut mask = 0usize;
299 while mask < 256 {
300 if mask & stride != 0 {
301 coeffs[mask] ^= coeffs[mask ^ stride];
302 }
303 mask += 1;
304 }
305 var += 1;
306 }
307
308 let mut lo = 0u128;
310 let mut hi = 0u128;
311 let mut monomial = 0usize;
312 while monomial < 128 {
313 lo |= (coeffs[monomial] as u128) << monomial;
314 monomial += 1;
315 }
316 while monomial < 256 {
317 hi |= (coeffs[monomial] as u128) << (monomial - 128);
318 monomial += 1;
319 }
320
321 out[bit_idx][0] = lo;
322 out[bit_idx][1] = hi;
323 bit_idx += 1;
324 }
325 out
326}
327
328pub(crate) const fn build_nibble_sbox_anf(table: &[u8; 16]) -> [u16; 4] {
334 let mut out = [0u16; 4];
335 let mut bit_idx = 0usize;
336 while bit_idx < 4 {
337 let mut coeffs = [0u8; 16];
338 let mut x = 0usize;
339 while x < 16 {
340 coeffs[x] = (table[x] >> bit_idx) & 1;
341 x += 1;
342 }
343
344 let mut var = 0usize;
346 while var < 4 {
347 let stride = 1usize << var;
348 let mut mask = 0usize;
349 while mask < 16 {
350 if mask & stride != 0 {
351 coeffs[mask] ^= coeffs[mask ^ stride];
352 }
353 mask += 1;
354 }
355 var += 1;
356 }
357
358 let mut packed = 0u16;
359 let mut monomial = 0usize;
360 while monomial < 16 {
361 packed |= (coeffs[monomial] as u16) << monomial;
362 monomial += 1;
363 }
364 out[bit_idx] = packed;
365 bit_idx += 1;
366 }
367 out
368}
369
370#[inline]
385pub(crate) fn subset_mask8(x: u8) -> (u128, u128) {
386 profile::bump_subset_mask8();
387 let mut lo = 1u128;
396 let mut hi = 0u128;
397
398 let mask0 = 0u128.wrapping_sub(u128::from(x & 1));
400 let add_lo = lo << 1;
401 let add_hi = (hi << 1) | (lo >> 127);
402 lo |= add_lo & mask0;
403 hi |= add_hi & mask0;
404
405 let mask1 = 0u128.wrapping_sub(u128::from((x >> 1) & 1));
407 let add_lo = lo << 2;
408 let add_hi = (hi << 2) | (lo >> 126);
409 lo |= add_lo & mask1;
410 hi |= add_hi & mask1;
411
412 let mask2 = 0u128.wrapping_sub(u128::from((x >> 2) & 1));
414 let add_lo = lo << 4;
415 let add_hi = (hi << 4) | (lo >> 124);
416 lo |= add_lo & mask2;
417 hi |= add_hi & mask2;
418
419 let mask3 = 0u128.wrapping_sub(u128::from((x >> 3) & 1));
421 let add_lo = lo << 8;
422 let add_hi = (hi << 8) | (lo >> 120);
423 lo |= add_lo & mask3;
424 hi |= add_hi & mask3;
425
426 let mask4 = 0u128.wrapping_sub(u128::from((x >> 4) & 1));
428 let add_lo = lo << 16;
429 let add_hi = (hi << 16) | (lo >> 112);
430 lo |= add_lo & mask4;
431 hi |= add_hi & mask4;
432
433 let mask5 = 0u128.wrapping_sub(u128::from((x >> 5) & 1));
435 let add_lo = lo << 32;
436 let add_hi = (hi << 32) | (lo >> 96);
437 lo |= add_lo & mask5;
438 hi |= add_hi & mask5;
439
440 let mask6 = 0u128.wrapping_sub(u128::from((x >> 6) & 1));
442 let add_lo = lo << 64;
443 let add_hi = (hi << 64) | (lo >> 64);
444 lo |= add_lo & mask6;
445 hi |= add_hi & mask6;
446
447 let mask7 = 0u128.wrapping_sub(u128::from((x >> 7) & 1));
450 hi |= lo & mask7;
451
452 (lo, hi)
453}
454
455#[inline]
468pub(crate) fn parity128(x: u128) -> u8 {
469 profile::bump_parity128();
470 let lo = x as u64;
471 let hi = (x >> 64) as u64;
472 ((lo.count_ones() ^ hi.count_ones()) & 1) as u8
473}
474
475#[inline]
487pub(crate) fn eval_byte_sbox(coeffs: &[[u128; 2]; 8], input: u8) -> u8 {
488 profile::bump_eval_byte_sbox();
489 let (active_lo, active_hi) = subset_mask8(input);
490 let mut out = 0u8;
491 let mut bit_idx = 0usize;
492 while bit_idx < 8 {
493 let coeff_lo = coeffs[bit_idx][0];
494 let coeff_hi = coeffs[bit_idx][1];
495 let bit = parity128((active_lo & coeff_lo) ^ (active_hi & coeff_hi));
497 out |= bit << bit_idx;
498 bit_idx += 1;
499 }
500 out
501}
502
503#[inline]
508pub(crate) fn subset_mask4(x: u8) -> u16 {
509 let mut mask = 1u16;
510
511 let b0 = 0u16.wrapping_sub(u16::from(x & 1));
513 mask |= (mask << 1) & b0;
514
515 let b1 = 0u16.wrapping_sub(u16::from((x >> 1) & 1));
517 mask |= (mask << 2) & b1;
518
519 let b2 = 0u16.wrapping_sub(u16::from((x >> 2) & 1));
521 mask |= (mask << 4) & b2;
522
523 let b3 = 0u16.wrapping_sub(u16::from((x >> 3) & 1));
525 mask |= (mask << 8) & b3;
526
527 mask
528}
529
530#[inline]
535pub(crate) fn parity16(mut x: u16) -> u8 {
536 x ^= x >> 8;
539 x ^= x >> 4;
540 x &= 0x0f;
541 ((0x6996u16 >> x) & 1) as u8
542}
543
544#[inline]
549pub(crate) fn eval_nibble_sbox(coeffs: [u16; 4], input: u8) -> u8 {
550 let active = subset_mask4(input);
551 let mut out = 0u8;
552 let mut bit = 0usize;
553 while bit < 4 {
554 out |= parity16(active & coeffs[bit]) << bit;
555 bit += 1;
556 }
557 out
558}
559
560#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn subset_mask8_zero_and_all_ones() {
570 let (lo0, hi0) = subset_mask8(0);
572 assert_eq!(lo0, 1);
573 assert_eq!(hi0, 0);
574
575 let (lof, hif) = subset_mask8(0xff);
577 assert_eq!(lof, u128::MAX);
578 assert_eq!(hif, u128::MAX);
579 }
580
581 #[test]
582 fn subset_mask8_single_bit() {
583 let (lo, hi) = subset_mask8(0x01);
585 assert_eq!(lo, 0b11);
586 assert_eq!(hi, 0);
587
588 let (lo, hi) = subset_mask8(0x80);
590 assert_eq!(lo, 1); assert_eq!(hi, 1); }
593
594 #[test]
595 fn parity_helpers_known_values() {
596 assert_eq!(parity128(0), 0);
597 assert_eq!(parity128(1), 1);
598 assert_eq!(parity128(0b1011), 1);
599 assert_eq!(parity128(u128::MAX), 0); assert_eq!(parity16(0), 0);
602 assert_eq!(parity16(1), 1);
603 assert_eq!(parity16(0b1011), 1);
604 assert_eq!(parity16(0xffff), 0); }
606
607 #[test]
608 fn ct_lookup_u8_16_picks_exact_entry() {
609 let table = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
610 for i in 0u8..16 {
611 assert_eq!(ct_lookup_u8_16(&table, i), i);
612 }
613 }
614
615 #[test]
620 fn byte_sbox_anf_matches_direct_lookup_all_inputs() {
621 #[rustfmt::skip]
623 const AES_SBOX: [u8; 256] = [
624 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
625 0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
626 0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
627 0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
628 0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
629 0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
630 0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
631 0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
632 0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
633 0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
634 0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
635 0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
636 0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
637 0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
638 0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
639 0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16,
640 ];
641 let coeffs = build_byte_sbox_anf(&AES_SBOX);
642 for x in 0u8..=255 {
643 assert_eq!(
644 eval_byte_sbox(&coeffs, x),
645 AES_SBOX[x as usize],
646 "ANF mismatch at input {x:#04x}"
647 );
648 }
649 }
650
651 #[test]
654 fn nibble_sbox_anf_matches_direct_lookup_all_inputs() {
655 const PRESENT_SBOX: [u8; 16] =
656 [0xC, 0x5, 0x6, 0xB, 0x9, 0x0, 0xA, 0xD, 0x3, 0xE, 0xF, 0x8, 0x4, 0x7, 0x1, 0x2];
657 let coeffs = build_nibble_sbox_anf(&PRESENT_SBOX);
658 for x in 0u8..16 {
659 assert_eq!(
660 eval_nibble_sbox(coeffs, x),
661 PRESENT_SBOX[x as usize],
662 "ANF mismatch at input {x:#03x}"
663 );
664 }
665 }
666}