mod aes_cfg;
use crate::aes::aes_cfg::{INV_MIX_COLS_MATRIX, INV_S_BOX, MIX_COLS_MATRIX, RCON, S_BOX};
pub enum AesType {
AES128,
AES192,
AES256,
}
#[cfg(test)]
mod aes_tests {
use super::*;
#[test]
fn aes_ecb_test() {
let plain_text: String = String::from("SUNYSUNYSUNYSUNY");
let aes_key: Vec<u8> = String::from("abcdabcdabcdabcd").into_bytes();
let aes_type: AesType = AesType::AES128;
let ciphered_bytes: Vec<u8> =
aes_ecb_enc(&plain_text.clone().into_bytes(), &aes_key, &aes_type);
let mut ciphered_hex_str: String = String::new();
for byte in ciphered_bytes.clone() {
ciphered_hex_str.push_str(&format!("{:02x}", byte))
}
assert_eq!(
ciphered_hex_str,
String::from("30e1afaf9c36e4814f2abfd05c76cf12")
);
let deciphered_bytes: Vec<u8> = aes_ecb_dec(&ciphered_bytes, &aes_key, &aes_type);
let mut deciphered_str: String = String::new();
for byte in deciphered_bytes {
deciphered_str.push(char::from(byte));
}
assert_eq!(deciphered_str, plain_text);
}
#[test]
fn aes_cbc_test() {
let plain_text: String = String::from("SUNYSUNYSUNYSUNYSUNYSUNYSUNYSUNY");
let aes_key: Vec<u8> = String::from("abcdabcdabcdabcd").into_bytes();
let aes_type: AesType = AesType::AES128;
let iv = String::from("abcdabcdabcdabcd").into_bytes();
let ciphered_bytes: Vec<u8> =
aes_cbc_enc(&plain_text.clone().into_bytes(), &aes_key, &iv, &aes_type);
let mut ciphered_hex_str: String = String::new();
for byte in ciphered_bytes.clone() {
ciphered_hex_str.push_str(&format!("{:02x}", byte))
}
assert_eq!(
ciphered_hex_str,
String::from("8131ea8c4597cfcfdc096a878e65d35b9bf68b2edcd9d80812bb8a9253d4cb45")
);
let deciphered_bytes: Vec<u8> = aes_cbc_dec(&ciphered_bytes, &aes_key, &iv, &aes_type);
let mut deciphered_str: String = String::new();
for byte in deciphered_bytes {
deciphered_str.push(char::from(byte));
}
assert_eq!(deciphered_str, plain_text);
}
}
pub fn aes_ecb_enc(plain_text: &[u8], key: &[u8], aes_type: &AesType) -> Vec<u8> {
let mut message_blocks: Vec<Vec<u8>> = Vec::new();
let mut res: Vec<u8> = Vec::new();
for chunk in plain_text.chunks(16) {
message_blocks.push(chunk.to_vec());
}
let expanded_key: Vec<u32> = key_expansion(key);
for mut state in message_blocks {
cipher(&mut state, &expanded_key, &aes_type);
res.append(&mut state);
}
res
}
pub fn aes_ecb_dec(ciphered_bytes: &[u8], key: &[u8], aes_type: &AesType) -> Vec<u8> {
let mut blocks: Vec<Vec<u8>> = Vec::new();
let mut res: Vec<u8> = Vec::new();
for chunk in ciphered_bytes.chunks(16) {
blocks.push(chunk.to_vec());
}
let expanded_key: Vec<u32> = key_expansion(key);
for mut state in blocks {
decipher(&mut state, &expanded_key, &aes_type);
res.append(&mut state);
}
res
}
pub fn aes_cbc_enc(plain_text: &[u8], key: &[u8], iv: &[u8], aes_type: &AesType) -> Vec<u8> {
let mut message_blocks: Vec<Vec<u8>> = Vec::new();
for chunk in plain_text.chunks(16) {
message_blocks.push(chunk.to_vec());
}
let mut res: Vec<u8> = Vec::new();
let expanded_key: Vec<u32> = key_expansion(&key);
let mut pre_state: Vec<u8> = Vec::new();
for (i, mut state) in message_blocks.iter_mut().enumerate() {
if i != 0 {
for (j, byte) in pre_state.iter().enumerate() {
state[j] ^= byte;
}
} else {
for (j, byte) in iv.iter().enumerate() {
state[j] ^= byte;
}
}
cipher(&mut state, &expanded_key, &aes_type);
res.append(&mut state.clone());
pre_state = state.clone();
}
res
}
pub fn aes_cbc_dec(ciphered_bytes: &[u8], key: &[u8], iv: &[u8], aes_type: &AesType) -> Vec<u8> {
let mut message_blocks: Vec<Vec<u8>> = Vec::new();
for chunk in ciphered_bytes.chunks(16) {
message_blocks.push(chunk.to_vec());
}
let mut res: Vec<u8> = Vec::new();
let expanded_key: Vec<u32> = key_expansion(&key);
#[allow(unused_assignments)]
let mut temp_state: Vec<u8> = Vec::new();
let mut pre_state: Vec<u8> = Vec::new();
for (i, state) in message_blocks.iter_mut().enumerate() {
temp_state = state.clone();
decipher(&mut temp_state, &expanded_key, &aes_type);
if i != 0 {
for (j, byte) in pre_state.iter().enumerate() {
temp_state[j] ^= byte;
}
} else {
for (j, byte) in iv.iter().enumerate() {
temp_state[j] ^= byte;
}
}
res.append(&mut temp_state.clone());
pre_state = state.to_vec();
}
res
}
fn cipher(state: &mut [u8], expanded_key: &[u32], aes_type: &AesType) {
let loop_num: usize = match aes_type {
AesType::AES128 => 10,
AesType::AES192 => 12,
AesType::AES256 => 14,
};
add_round_key(state, &expanded_key[0..4]);
for round in 1..loop_num {
let round_key: &[u32] = &expanded_key[4 * round..(4 * round + 4)];
byte_sub(state, &S_BOX);
shift_rows(state);
mix_cols(state, &MIX_COLS_MATRIX);
add_round_key(state, round_key);
}
byte_sub(state, &S_BOX);
shift_rows(state);
add_round_key(state, &expanded_key[40..44]);
}
fn decipher(state: &mut [u8], expanded_key: &[u32], aes_type: &AesType) {
let round_num: usize = match aes_type {
AesType::AES128 => 10,
AesType::AES192 => 12,
AesType::AES256 => 14,
};
add_round_key(state, &expanded_key[40..44]);
let mut round = 0;
while round < round_num - 1 {
let round_key_group = round_num - 1 - round;
let round_key: &[u32] = &expanded_key[(4 * round_key_group)..(4 * round_key_group + 4)];
inv_shift_rows(state);
byte_sub(state, &INV_S_BOX);
add_round_key(state, round_key);
mix_cols(state, &INV_MIX_COLS_MATRIX);
round += 1;
}
inv_shift_rows(state);
byte_sub(state, &INV_S_BOX);
add_round_key(state, &expanded_key[0..4]);
}
fn byte_sub(state: &mut [u8], s_box: &[u8]) {
for byte in state {
let row = (((*byte & 0xF0) >> 4) * 16) as usize;
let col = (*byte & 0x0F) as usize;
*byte = s_box[row..row + 16][col];
}
}
fn t(word: u32, round: u8) -> u32 {
let mut bytes = word.to_be_bytes();
bytes.rotate_left(1);
byte_sub(&mut bytes, &S_BOX);
u32::from_be_bytes(bytes) ^ RCON[round as usize]
}
fn key_expansion(key_bytes: &[u8]) -> Vec<u32> {
let mut expanded_key: Vec<u32> = Vec::new();
for chunk in key_bytes.chunks(4) {
expanded_key.push(u32::from_be_bytes(
chunk.try_into().expect("Convert chunk to u32 failed!"),
));
}
let mut round: u8 = 0;
for i in 4..44 {
if i % 4 != 0 {
expanded_key.push(expanded_key[i - 4] ^ expanded_key[i - 1]);
} else {
expanded_key.push(expanded_key[i - 4] ^ t(expanded_key[i - 1], round));
round += 1;
}
}
expanded_key
}
fn add_round_key(state: &mut [u8], expanded_key: &[u32]) {
for i in 0..4 {
let key_bytes = expanded_key[i].to_be_bytes();
let row = 4 * i;
state[row..(row + 4)][0] ^= key_bytes[0];
state[row..(row + 4)][1] ^= key_bytes[1];
state[row..(row + 4)][2] ^= key_bytes[2];
state[row..(row + 4)][3] ^= key_bytes[3];
}
}
fn shift_row_left(arr: &mut [u8]) {
let first: u8 = arr[0];
let iter = vec![0, 4, 8].into_iter();
for col in iter {
arr[col] = arr[col + 4];
}
arr[12] = first;
}
fn shift_row_right(arr: &mut [u8]) {
let last: u8 = arr[12];
let iter = vec![8, 4, 0].into_iter();
for col in iter {
arr[col + 4] = arr[col];
}
arr[0] = last;
}
fn shift_rows(state: &mut [u8]) {
for row in 0..4 {
for _ in 0..row {
shift_row_left(&mut state[row..]);
}
}
}
fn inv_shift_rows(state: &mut [u8]) {
for row in 0..4 {
for _ in 0..row {
shift_row_right(&mut state[row..]);
}
}
}
fn gf_mul(n1: u8, n2: u8) -> u8 {
let gf: &dyn Fn(u8) -> u8 = &|mut byte| {
if (byte & 0x80) == 0 {
byte << 1
} else {
byte <<= 1;
byte ^ 0x1B
}
};
match n1 {
0x01 => n2,
0x02 => gf(n2),
0x03 => gf(n2) ^ n2,
0x09 => gf(gf(gf(n2))) ^ n2,
0x0b => (gf(gf(gf(n2))) ^ n2) ^ gf(n2),
0x0d => (gf(gf(gf(n2))) ^ gf(gf(n2))) ^ n2,
0x0e => (gf(gf(gf(n2))) ^ gf(gf(n2))) ^ gf(n2),
_ => 0,
}
}
fn mix_cols(state: &mut [u8], mix_cols_matrix: &[[u8; 4]; 4]) {
let state_old: Vec<u8> = state.to_vec();
for i in 0..4 {
for j in 0..4 {
let col = 4 * j;
state[col..(col + 4)][i] = gf_mul(mix_cols_matrix[i][0], state_old[col..(col + 4)][0])
^ gf_mul(mix_cols_matrix[i][1], state_old[col..(col + 4)][1])
^ gf_mul(mix_cols_matrix[i][2], state_old[col..(col + 4)][2])
^ gf_mul(mix_cols_matrix[i][3], state_old[col..(col + 4)][3]);
}
}
}