mod gf;
use super::BlockCipher;
use gf::{gf_mul, inv_sub_byte, sub_byte};
#[inline]
fn add_round_key(state: &mut [u8; 16], rk: &[u8]) {
for (s, k) in state.iter_mut().zip(rk.iter()) {
*s ^= *k;
}
}
#[inline]
fn sub_bytes(state: &mut [u8; 16]) {
for b in state.iter_mut() {
*b = sub_byte(*b);
}
}
#[inline]
fn inv_sub_bytes(state: &mut [u8; 16]) {
for b in state.iter_mut() {
*b = inv_sub_byte(*b);
}
}
#[inline]
fn shift_rows(s: &mut [u8; 16]) {
let t = *s;
s[1] = t[5];
s[5] = t[9];
s[9] = t[13];
s[13] = t[1];
s[2] = t[10];
s[6] = t[14];
s[10] = t[2];
s[14] = t[6];
s[3] = t[15];
s[7] = t[3];
s[11] = t[7];
s[15] = t[11];
}
#[inline]
fn inv_shift_rows(s: &mut [u8; 16]) {
let t = *s;
s[1] = t[13];
s[5] = t[1];
s[9] = t[5];
s[13] = t[9];
s[2] = t[10];
s[6] = t[14];
s[10] = t[2];
s[14] = t[6];
s[3] = t[7];
s[7] = t[11];
s[11] = t[15];
s[15] = t[3];
}
#[inline]
fn mix_columns(s: &mut [u8; 16]) {
for c in 0..4 {
let i = 4 * c;
let (a0, a1, a2, a3) = (s[i], s[i + 1], s[i + 2], s[i + 3]);
s[i] = gf_mul(a0, 2) ^ gf_mul(a1, 3) ^ a2 ^ a3;
s[i + 1] = a0 ^ gf_mul(a1, 2) ^ gf_mul(a2, 3) ^ a3;
s[i + 2] = a0 ^ a1 ^ gf_mul(a2, 2) ^ gf_mul(a3, 3);
s[i + 3] = gf_mul(a0, 3) ^ a1 ^ a2 ^ gf_mul(a3, 2);
}
}
#[inline]
fn inv_mix_columns(s: &mut [u8; 16]) {
for c in 0..4 {
let i = 4 * c;
let (a0, a1, a2, a3) = (s[i], s[i + 1], s[i + 2], s[i + 3]);
s[i] = gf_mul(a0, 0x0e) ^ gf_mul(a1, 0x0b) ^ gf_mul(a2, 0x0d) ^ gf_mul(a3, 0x09);
s[i + 1] = gf_mul(a0, 0x09) ^ gf_mul(a1, 0x0e) ^ gf_mul(a2, 0x0b) ^ gf_mul(a3, 0x0d);
s[i + 2] = gf_mul(a0, 0x0d) ^ gf_mul(a1, 0x09) ^ gf_mul(a2, 0x0e) ^ gf_mul(a3, 0x0b);
s[i + 3] = gf_mul(a0, 0x0b) ^ gf_mul(a1, 0x0d) ^ gf_mul(a2, 0x09) ^ gf_mul(a3, 0x0e);
}
}
fn key_expansion(key: &[u8], nk: usize, nr: usize, out: &mut [u8]) {
let total_words = 4 * (nr + 1);
out[..key.len()].copy_from_slice(key);
let mut rcon = 1u8;
for i in nk..total_words {
let prev = i - 1;
let mut t = [
out[prev * 4],
out[prev * 4 + 1],
out[prev * 4 + 2],
out[prev * 4 + 3],
];
if i % nk == 0 {
t = [t[1], t[2], t[3], t[0]];
for b in t.iter_mut() {
*b = sub_byte(*b);
}
t[0] ^= rcon;
rcon = gf_mul(rcon, 2);
} else if nk > 6 && i % nk == 4 {
for b in t.iter_mut() {
*b = sub_byte(*b);
}
}
let base = i * 4;
let src = (i - nk) * 4;
for j in 0..4 {
out[base + j] = out[src + j] ^ t[j];
}
}
}
fn encrypt(rk: &[u8], nr: usize, block: &mut [u8; 16]) {
add_round_key(block, &rk[0..16]);
for round in 1..nr {
sub_bytes(block);
shift_rows(block);
mix_columns(block);
add_round_key(block, &rk[round * 16..round * 16 + 16]);
}
sub_bytes(block);
shift_rows(block);
add_round_key(block, &rk[nr * 16..nr * 16 + 16]);
}
fn decrypt(rk: &[u8], nr: usize, block: &mut [u8; 16]) {
add_round_key(block, &rk[nr * 16..nr * 16 + 16]);
for round in (1..nr).rev() {
inv_shift_rows(block);
inv_sub_bytes(block);
add_round_key(block, &rk[round * 16..round * 16 + 16]);
inv_mix_columns(block);
}
inv_shift_rows(block);
inv_sub_bytes(block);
add_round_key(block, &rk[0..16]);
}
macro_rules! aes_variant {
($(#[$meta:meta])* $name:ident, $key_bytes:literal, $nk:literal, $nr:literal, $rk_len:literal) => {
$(#[$meta])*
#[derive(Clone)]
pub struct $name {
rk: [u8; $rk_len],
}
impl $name {
pub fn new(key: &[u8; $key_bytes]) -> Self {
let mut rk = [0u8; $rk_len];
key_expansion(key, $nk, $nr, &mut rk);
$name { rk }
}
}
impl BlockCipher for $name {
const BLOCK_SIZE: usize = 16;
const KEY_SIZE: usize = $key_bytes;
#[inline]
fn encrypt_block(&self, block: &mut [u8; 16]) {
encrypt(&self.rk, $nr, block);
}
#[inline]
fn decrypt_block(&self, block: &mut [u8; 16]) {
decrypt(&self.rk, $nr, block);
}
}
impl Drop for $name {
fn drop(&mut self) {
for b in self.rk.iter_mut() {
*b = 0;
}
core::hint::black_box(&self.rk);
}
}
};
}
aes_variant!(
Aes128, 16, 4, 10, 176
);
aes_variant!(
Aes192, 24, 6, 12, 208
);
aes_variant!(
Aes256, 32, 8, 14, 240
);
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::from_hex;
#[test]
fn fips197_aes128() {
let key = from_hex::<16>("000102030405060708090a0b0c0d0e0f");
let cipher = Aes128::new(&key);
let mut block = from_hex::<16>("00112233445566778899aabbccddeeff");
cipher.encrypt_block(&mut block);
assert_eq!(block, from_hex::<16>("69c4e0d86a7b0430d8cdb78070b4c55a"));
cipher.decrypt_block(&mut block);
assert_eq!(block, from_hex::<16>("00112233445566778899aabbccddeeff"));
}
#[test]
fn fips197_aes192() {
let key = from_hex::<24>("000102030405060708090a0b0c0d0e0f1011121314151617");
let cipher = Aes192::new(&key);
let mut block = from_hex::<16>("00112233445566778899aabbccddeeff");
cipher.encrypt_block(&mut block);
assert_eq!(block, from_hex::<16>("dda97ca4864cdfe06eaf70a0ec0d7191"));
cipher.decrypt_block(&mut block);
assert_eq!(block, from_hex::<16>("00112233445566778899aabbccddeeff"));
}
#[test]
fn fips197_aes256() {
let key =
from_hex::<32>("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
let cipher = Aes256::new(&key);
let mut block = from_hex::<16>("00112233445566778899aabbccddeeff");
cipher.encrypt_block(&mut block);
assert_eq!(block, from_hex::<16>("8ea2b7ca516745bfeafc49904b496089"));
cipher.decrypt_block(&mut block);
assert_eq!(block, from_hex::<16>("00112233445566778899aabbccddeeff"));
}
#[test]
fn roundtrip_all_byte_values() {
let key = from_hex::<16>("2b7e151628aed2a6abf7158809cf4f3c");
let cipher = Aes128::new(&key);
for v in 0u16..=255 {
let original = [v as u8; 16];
let mut block = original;
cipher.encrypt_block(&mut block);
assert_ne!(block, original, "ciphertext should differ from plaintext");
cipher.decrypt_block(&mut block);
assert_eq!(block, original);
}
}
}