use super::tables::{WhiteboxTables, WhiteboxTablesLite};
use super::sbox::SHIFT_ROWS;
use super::AES_BLOCK_SIZE;
#[allow(clippy::needless_range_loop)]
pub fn whitebox_encrypt(block: &mut [u8; AES_BLOCK_SIZE], tables: &WhiteboxTables) {
let mut state = *block;
for round in 0..9 {
let shifted = apply_shift_rows(&state);
let mut tybox_outputs = [[0u32; 4]; 4];
for col in 0..4 {
for row in 0..4 {
let pos = col * 4 + row;
let input = shifted[pos];
tybox_outputs[col][row] = tables.tybox[round][pos][input as usize];
}
}
for col in 0..4 {
let combined = xor_combine_column(
tybox_outputs[col][0],
tybox_outputs[col][1],
tybox_outputs[col][2],
tybox_outputs[col][3],
round,
col,
tables,
);
state[col * 4] = combined as u8;
state[col * 4 + 1] = (combined >> 8) as u8;
state[col * 4 + 2] = (combined >> 16) as u8;
state[col * 4 + 3] = (combined >> 24) as u8;
}
}
let shifted = apply_shift_rows(&state);
for i in 0..AES_BLOCK_SIZE {
state[i] = tables.tbox_last[i][shifted[i] as usize];
}
*block = state;
}
#[allow(dead_code)] pub fn whitebox_encrypt_lite(block: &mut [u8; AES_BLOCK_SIZE], tables: &WhiteboxTablesLite) {
let mut state = *block;
for round in 0..9 {
let shifted = apply_shift_rows(&state);
let mut after_tbox = [0u8; 16];
for i in 0..16 {
after_tbox[i] = tables.tbox[round][i][shifted[i] as usize];
}
state = mix_columns(&after_tbox);
}
let shifted = apply_shift_rows(&state);
for i in 0..16 {
state[i] = tables.tbox_last[i][shifted[i] as usize];
}
*block = state;
}
#[allow(dead_code)] pub fn whitebox_decrypt(_block: &mut [u8; AES_BLOCK_SIZE], _tables: &WhiteboxTables) {
unimplemented!("Whitebox decryption requires inverse tables - not yet implemented")
}
#[inline]
fn apply_shift_rows(state: &[u8; 16]) -> [u8; 16] {
let mut result = [0u8; 16];
for i in 0..16 {
result[i] = state[SHIFT_ROWS[i]];
}
result
}
fn xor_combine_column(
a: u32,
b: u32,
c: u32,
d: u32,
round: usize,
col: usize,
tables: &WhiteboxTables,
) -> u32 {
let mut result = 0u32;
for byte_idx in 0..4 {
let a_byte = ((a >> (byte_idx * 8)) & 0xff) as u8;
let b_byte = ((b >> (byte_idx * 8)) & 0xff) as u8;
let c_byte = ((c >> (byte_idx * 8)) & 0xff) as u8;
let d_byte = ((d >> (byte_idx * 8)) & 0xff) as u8;
let ab = xor_byte_via_tables(a_byte, b_byte, round, col * 4 + byte_idx, tables);
let cd = xor_byte_via_tables(c_byte, d_byte, round, col * 4 + byte_idx + 16, tables);
let abcd = xor_byte_via_tables(ab, cd, round, col * 4 + byte_idx + 32, tables);
result |= (abcd as u32) << (byte_idx * 8);
}
result
}
fn xor_byte_via_tables(a: u8, b: u8, round: usize, table_base: usize, tables: &WhiteboxTables) -> u8 {
let table_idx = table_base % 96;
let a_lo = a & 0x0f;
let a_hi = (a >> 4) & 0x0f;
let b_lo = b & 0x0f;
let b_hi = (b >> 4) & 0x0f;
let lo = tables.xor_tables[round][table_idx][a_lo as usize][b_lo as usize];
let hi = tables.xor_tables[round][(table_idx + 1) % 96][a_hi as usize][b_hi as usize];
(hi << 4) | (lo & 0x0f)
}
#[allow(dead_code)] fn mix_columns(state: &[u8; 16]) -> [u8; 16] {
let mut result = [0u8; 16];
for col in 0..4 {
let s0 = state[col * 4];
let s1 = state[col * 4 + 1];
let s2 = state[col * 4 + 2];
let s3 = state[col * 4 + 3];
result[col * 4] = gf_mul_2(s0) ^ gf_mul_3(s1) ^ s2 ^ s3;
result[col * 4 + 1] = s0 ^ gf_mul_2(s1) ^ gf_mul_3(s2) ^ s3;
result[col * 4 + 2] = s0 ^ s1 ^ gf_mul_2(s2) ^ gf_mul_3(s3);
result[col * 4 + 3] = gf_mul_3(s0) ^ s1 ^ s2 ^ gf_mul_2(s3);
}
result
}
#[allow(dead_code)] #[inline]
fn gf_mul_2(a: u8) -> u8 {
let mut result = a << 1;
if a & 0x80 != 0 {
result ^= 0x1b;
}
result
}
#[allow(dead_code)] #[inline]
fn gf_mul_3(a: u8) -> u8 {
gf_mul_2(a) ^ a
}
#[cfg(test)]
mod tests {
use super::*;
use crate::whitebox::generator::{generate_tables, generate_tables_lite};
const TEST_KEY: [u8; 16] = [
0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c
];
const TEST_PLAINTEXT: [u8; 16] = [
0x32, 0x43, 0xf6, 0xa8, 0x88, 0x5a, 0x30, 0x8d,
0x31, 0x31, 0x98, 0xa2, 0xe0, 0x37, 0x07, 0x34
];
const TEST_CIPHERTEXT: [u8; 16] = [
0x39, 0x25, 0x84, 0x1d, 0x02, 0xdc, 0x09, 0xfb,
0xdc, 0x11, 0x85, 0x97, 0x19, 0x6a, 0x0b, 0x32
];
#[test]
fn test_shift_rows() {
let input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let output = apply_shift_rows(&input);
assert_eq!(output[0], 0); assert_eq!(output[1], 5); assert_eq!(output[2], 10); assert_eq!(output[3], 15); }
#[test]
fn test_mix_columns() {
let input = [
0xdb, 0x13, 0x53, 0x45,
0xf2, 0x0a, 0x22, 0x5c,
0x01, 0x01, 0x01, 0x01,
0xc6, 0xc6, 0xc6, 0xc6,
];
let output = mix_columns(&input);
assert_eq!(output[0], 0x8e);
assert_eq!(output[1], 0x4d);
assert_eq!(output[2], 0xa1);
assert_eq!(output[3], 0xbc);
}
#[test]
fn test_whitebox_encrypt_lite_roundtrip() {
let tables = generate_tables_lite(&TEST_KEY, b"test_seed");
let mut block = TEST_PLAINTEXT;
whitebox_encrypt_lite(&mut block, &tables);
assert_ne!(block, TEST_PLAINTEXT, "Block should be different after encryption");
}
#[test]
fn test_whitebox_deterministic() {
let tables = generate_tables(&TEST_KEY, b"deterministic_test");
let mut block1 = TEST_PLAINTEXT;
let mut block2 = TEST_PLAINTEXT;
whitebox_encrypt(&mut block1, &tables);
whitebox_encrypt(&mut block2, &tables);
assert_eq!(block1, block2, "Same plaintext should produce same ciphertext");
}
#[test]
fn test_whitebox_different_plaintexts() {
let tables = generate_tables(&TEST_KEY, b"test_seed");
let mut block1 = [0u8; 16];
let mut block2 = [1u8; 16];
whitebox_encrypt(&mut block1, &tables);
whitebox_encrypt(&mut block2, &tables);
assert_ne!(block1, block2, "Different plaintexts should produce different ciphertexts");
}
#[test]
fn test_whitebox_avalanche() {
let tables = generate_tables(&TEST_KEY, b"avalanche_test");
let mut block1 = [0u8; 16];
let mut block2 = [0u8; 16];
block2[0] = 1;
whitebox_encrypt(&mut block1, &tables);
whitebox_encrypt(&mut block2, &tables);
let mut diff_bits = 0;
for i in 0..16 {
diff_bits += (block1[i] ^ block2[i]).count_ones();
}
assert!(
diff_bits >= 30 && diff_bits <= 100,
"Avalanche effect: {} bits differ (expected ~64)",
diff_bits
);
}
}