use super::sbox::{SBOX, key_expansion, gf_mul, MIX_COLS, SHIFT_ROWS};
use super::tables::{
WhiteboxTables, WhiteboxTablesLite, Bijection8, Bijection4, MixingBijection32
};
use super::{AES_BLOCK_SIZE, AES_ROUNDS};
pub struct SeededRng {
state: u64,
}
impl SeededRng {
pub fn new(seed: &[u8]) -> Self {
let mut state = 0x853c49e6748fea9bu64;
for (i, &byte) in seed.iter().enumerate() {
state ^= (byte as u64) << ((i % 8) * 8);
state = state.wrapping_mul(0x5851f42d4c957f2d);
state ^= state >> 33;
}
Self { state }
}
pub fn next_u64(&mut self) -> u64 {
self.state ^= self.state >> 12;
self.state ^= self.state << 25;
self.state ^= self.state >> 27;
self.state.wrapping_mul(0x2545f4914f6cdd1d)
}
#[allow(dead_code)] pub fn next_u8(&mut self) -> u8 {
self.next_u64() as u8
}
pub fn random_permutation(&mut self, n: usize) -> alloc::vec::Vec<u8> {
let mut perm: alloc::vec::Vec<u8> = alloc::vec::Vec::with_capacity(n);
for i in 0..n {
perm.push(i as u8);
}
for i in (1..n).rev() {
let j = (self.next_u64() as usize) % (i + 1);
perm.swap(i, j);
}
perm
}
}
extern crate alloc;
pub fn generate_tables(key: &[u8; 16], seed: &[u8]) -> WhiteboxTables {
let mut tables = WhiteboxTables::new();
let mut rng = SeededRng::new(seed);
let round_keys = key_expansion(key);
let encodings = generate_encodings(&mut rng);
let mixing_bijections = generate_mixing_bijections(&mut rng);
generate_tboxes_tyboxes(
&round_keys,
&encodings,
&mixing_bijections,
&mut tables,
);
generate_xor_tables(&encodings, &mut tables);
generate_mbl_tables(&mixing_bijections, &encodings, &mut tables);
generate_last_round_tboxes(&round_keys, &encodings, &mut tables);
tables
}
#[allow(dead_code)] #[allow(clippy::needless_range_loop)]
pub fn generate_tables_lite(key: &[u8; 16], seed: &[u8]) -> WhiteboxTablesLite {
let mut tables = WhiteboxTablesLite::new();
let mut rng = SeededRng::new(seed);
let round_keys = key_expansion(key);
let encodings = generate_encodings(&mut rng);
for round in 0..AES_ROUNDS {
for pos in 0..AES_BLOCK_SIZE {
for x in 0..256 {
let decoded = if round == 0 {
x as u8
} else {
encodings.round_output[round - 1][pos].decode(x as u8)
};
let after_key = decoded ^ round_keys[round][pos];
let after_sbox = SBOX[after_key as usize];
let encoded = encodings.round_output[round][pos].encode(after_sbox);
tables.tbox[round][pos][x] = encoded;
}
}
}
for pos in 0..AES_BLOCK_SIZE {
for x in 0..256 {
let decoded = encodings.round_output[AES_ROUNDS - 2][pos].decode(x as u8);
let after_key = decoded ^ round_keys[AES_ROUNDS - 1][pos];
tables.tbox_last[pos][x] = SBOX[after_key as usize];
}
}
tables
}
struct InternalEncodings {
round_output: [[Bijection8; AES_BLOCK_SIZE]; AES_ROUNDS],
nibble_encodings: [[[Bijection4; 2]; 96]; 9],
}
fn generate_encodings(rng: &mut SeededRng) -> InternalEncodings {
let mut encodings = InternalEncodings {
round_output: [[Bijection8::identity(); AES_BLOCK_SIZE]; AES_ROUNDS],
nibble_encodings: [[[Bijection4::identity(); 2]; 96]; 9],
};
for round in 0..AES_ROUNDS {
for pos in 0..AES_BLOCK_SIZE {
let perm = rng.random_permutation(256);
let mut bij = Bijection8::identity();
for (i, &p) in perm.iter().enumerate() {
bij.forward[i] = p;
bij.inverse[p as usize] = i as u8;
}
encodings.round_output[round][pos] = bij;
}
}
for round in 0..9 {
for table in 0..96 {
for nibble in 0..2 {
let perm = rng.random_permutation(16);
let mut bij = Bijection4::identity();
for (i, &p) in perm.iter().enumerate() {
bij.forward[i] = p;
bij.inverse[p as usize] = i as u8;
}
encodings.nibble_encodings[round][table][nibble] = bij;
}
}
}
encodings
}
fn generate_mixing_bijections(rng: &mut SeededRng) -> [MixingBijection32; 9] {
let mut mbs: [MixingBijection32; 9] = core::array::from_fn(|_| MixingBijection32::default());
for mb in &mut mbs {
let mut matrix = [[0u8; 32]; 32];
let mut inverse = [[0u8; 32]; 32];
for i in 0..32 {
matrix[i][i] = 1;
inverse[i][i] = 1;
}
for _ in 0..64 {
let i = (rng.next_u64() as usize) % 32;
let j = (rng.next_u64() as usize) % 32;
if i != j {
for k in 0..32 {
matrix[i][k] ^= matrix[j][k];
}
for inv_row in &mut inverse {
inv_row[j] ^= inv_row[i];
}
}
}
*mb = MixingBijection32 { matrix, inverse };
}
mbs
}
fn generate_tboxes_tyboxes(
round_keys: &[[u8; 16]; 11],
encodings: &InternalEncodings,
mixing_bijections: &[MixingBijection32; 9],
tables: &mut WhiteboxTables,
) {
for round in 0..9 {
for col in 0..4 {
for row in 0..4 {
let pos = col * 4 + row;
let shifted_pos = SHIFT_ROWS[pos];
for x in 0..256 {
let decoded = if round == 0 {
x as u8
} else {
encodings.round_output[round - 1][pos].decode(x as u8)
};
let after_key = decoded ^ round_keys[round][shifted_pos];
let after_sbox = SBOX[after_key as usize];
let mut mc_out = [0u8; 4];
for out_row in 0..4 {
mc_out[out_row] = gf_mul(MIX_COLS[out_row][row], after_sbox);
}
let packed = (mc_out[0] as u32)
| ((mc_out[1] as u32) << 8)
| ((mc_out[2] as u32) << 16)
| ((mc_out[3] as u32) << 24);
let mixed = mixing_bijections[round].apply(packed);
tables.tybox[round][pos][x] = mixed;
tables.tbox[round][pos][x] = after_sbox;
}
}
}
}
}
fn generate_xor_tables(encodings: &InternalEncodings, tables: &mut WhiteboxTables) {
for round in 0..9 {
for table_idx in 0..96 {
for a in 0..16u8 {
for b in 0..16u8 {
let a_decoded = encodings.nibble_encodings[round][table_idx][0].decode(a);
let b_decoded = encodings.nibble_encodings[round][table_idx][1].decode(b);
let result = a_decoded ^ b_decoded;
let encoded_result = if table_idx + 1 < 96 {
encodings.nibble_encodings[round][(table_idx + 1) % 96][0].encode(result)
} else {
result
};
tables.xor_tables[round][table_idx][a as usize][b as usize] = encoded_result;
}
}
}
}
}
#[allow(clippy::needless_range_loop)]
fn generate_mbl_tables(
mixing_bijections: &[MixingBijection32; 9],
encodings: &InternalEncodings,
tables: &mut WhiteboxTables,
) {
for round in 0..9 {
for pos in 0..AES_BLOCK_SIZE {
for x in 0..256 {
let l_encoded = (x as u32) << ((pos % 4) * 8);
let unmixed = mixing_bijections[round].apply_inverse(l_encoded);
let out_bytes = [
unmixed as u8,
(unmixed >> 8) as u8,
(unmixed >> 16) as u8,
(unmixed >> 24) as u8,
];
let encoded_bytes = [
encodings.round_output[round][pos * 4 / 16 * 4].encode(out_bytes[0]),
encodings.round_output[round][pos * 4 / 16 * 4 + 1].encode(out_bytes[1]),
encodings.round_output[round][pos * 4 / 16 * 4 + 2].encode(out_bytes[2]),
encodings.round_output[round][pos * 4 / 16 * 4 + 3].encode(out_bytes[3]),
];
tables.mbl[round][pos][x] = (encoded_bytes[0] as u32)
| ((encoded_bytes[1] as u32) << 8)
| ((encoded_bytes[2] as u32) << 16)
| ((encoded_bytes[3] as u32) << 24);
}
}
}
}
fn generate_last_round_tboxes(
round_keys: &[[u8; 16]; 11],
encodings: &InternalEncodings,
tables: &mut WhiteboxTables,
) {
let round = AES_ROUNDS - 1;
for (pos, &shifted_pos) in SHIFT_ROWS.iter().enumerate().take(AES_BLOCK_SIZE) {
for x in 0..256 {
let decoded = encodings.round_output[round - 1][pos].decode(x as u8);
let after_key = decoded ^ round_keys[round][shifted_pos];
let after_sbox = SBOX[after_key as usize];
let result = after_sbox ^ round_keys[AES_ROUNDS][shifted_pos];
tables.tbox_last[pos][x] = result;
tables.tbox[round][pos][x] = result;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_seeded_rng_deterministic() {
let seed = b"test_seed_12345";
let mut rng1 = SeededRng::new(seed);
let mut rng2 = SeededRng::new(seed);
for _ in 0..100 {
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
}
#[test]
fn test_seeded_rng_different_seeds() {
let mut rng1 = SeededRng::new(b"seed1");
let mut rng2 = SeededRng::new(b"seed2");
let mut same = true;
for _ in 0..10 {
if rng1.next_u64() != rng2.next_u64() {
same = false;
break;
}
}
assert!(!same);
}
#[test]
fn test_random_permutation() {
let mut rng = SeededRng::new(b"permutation_test");
let perm = rng.random_permutation(256);
assert_eq!(perm.len(), 256);
let mut seen = [false; 256];
for &p in &perm {
assert!(!seen[p as usize], "Duplicate value in permutation");
seen[p as usize] = true;
}
}
#[test]
fn test_generate_tables() {
let key = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c];
let seed = b"test_build_seed";
let tables = generate_tables(&key, seed);
let mut has_nonzero_tybox = false;
for round in 0..9 {
for pos in 0..16 {
for x in 0..256 {
if tables.tybox[round][pos][x] != 0 {
has_nonzero_tybox = true;
break;
}
}
}
}
assert!(has_nonzero_tybox, "Ty-boxes should have non-zero values");
assert!(tables.memory_size() > 500_000);
}
#[test]
fn test_generate_tables_lite() {
let key = [0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c];
let seed = b"test_build_seed";
let tables = generate_tables_lite(&key, seed);
let mut has_nonzero = false;
for round in 0..AES_ROUNDS {
for pos in 0..16 {
for x in 0..256 {
if tables.tbox[round][pos][x] != 0 {
has_nonzero = true;
break;
}
}
}
}
assert!(has_nonzero, "T-boxes should have non-zero values");
assert!(tables.memory_size() < 50_000);
}
#[test]
fn test_tables_deterministic() {
let key = [0x00; 16];
let seed = b"deterministic_test";
let tables1 = generate_tables(&key, seed);
let tables2 = generate_tables(&key, seed);
for round in 0..9 {
for pos in 0..16 {
for x in 0..256 {
assert_eq!(
tables1.tybox[round][pos][x],
tables2.tybox[round][pos][x],
"Tables should be deterministic"
);
}
}
}
}
}