use crate::{
core::{
circuits::{
boolean::{
boolean_value::{Boolean, BooleanValue},
byte::Byte,
},
traits::boolean_circuit::BooleanCircuit,
},
expressions::expr::EvalFailure,
},
utils::{
crypto::key::{AES128Key, AES192Key, AES256Key},
matrix::Matrix,
},
};
use aes::{
cipher::{BlockEncrypt, KeyInit},
Aes128,
Aes192,
Aes256,
};
use core::panic;
const S_BOX: [u8; 256] = [
0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
];
const RC: [u8; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36];
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
pub struct AESDesc<B: Boolean> {
round_keys: Vec<Matrix<Byte<B>>>,
}
impl<B: Boolean> AESDesc<B> {
fn new(key: Matrix<Byte<B>>) -> Self {
if key.nrows != 4 {
panic!("key must have 4 rows (found {})", key.nrows);
}
let key_length = 8 * key.nrows * key.ncols;
let n_rounds = match key_length {
128 => 10,
192 => 12,
256 => 14,
_ => panic!(
"key_length must be one of {{128, 192, 256}} (found {})",
key_length
),
};
let n = key_length / 32;
let mut round_keys_col = (0..n).map(|j| key.col(j)).collect::<Vec<Matrix<Byte<B>>>>();
for i in n..4 * (n_rounds + 1) {
let mut word = round_keys_col.last().unwrap().clone();
if i % n == 0 {
let col = [
*word.get((0, 0)).unwrap(),
*word.get((1, 0)).unwrap(),
*word.get((2, 0)).unwrap(),
*word.get((3, 0)).unwrap(),
];
for r in 0..4 {
let w = word.get_mut((r, 0)).unwrap();
*w = col[(r + 1) % 4];
}
word = Self::sub_bytes(word);
let w = word.get_mut((0, 0)).unwrap();
*w ^= Byte::from(RC[i / n - 1]);
} else if key_length == 256 && i % n == 4 {
word = Self::sub_bytes(word);
}
for r in 0..4 {
let w = word.get_mut((r, 0)).unwrap();
*w ^= *round_keys_col[i - n].get((r, 0)).unwrap();
}
round_keys_col.push(word);
}
let round_keys = round_keys_col
.chunks(4)
.map(|cols| {
Matrix::new_from_column_major_iter(
(4, 4),
cols.iter()
.flat_map(|col| col.into_iter().collect::<Vec<Byte<B>>>()),
)
})
.collect();
Self { round_keys }
}
fn sub_bytes(state: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
fn sub_byte<B: Boolean>(s: Byte<B>) -> Byte<B> {
let ohe = s.one_hot_encode();
Byte::new(
S_BOX
.iter()
.zip(ohe)
.fold(vec![B::from(false); 8], |acc_bits, (s_box_byte, b)| {
(0..8)
.map(|i| {
if (*s_box_byte >> i) & 1u8 == 1u8 {
acc_bits[i] ^ b
} else {
acc_bits[i]
}
})
.collect::<Vec<B>>()
})
.try_into()
.unwrap_or_else(|v: Vec<B>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
}
let mut state = state;
state.map_mut(sub_byte);
state
}
fn shift_rows(state: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
let mut state = state;
for i in 1..4 {
let row = [
*state.get((i, 0)).unwrap(),
*state.get((i, 1)).unwrap(),
*state.get((i, 2)).unwrap(),
*state.get((i, 3)).unwrap(),
];
for j in 0..4 {
let s = state.get_mut((i, j)).unwrap();
*s = row[(i + j) % 4];
}
}
state
}
fn mix_columns(state: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
fn mix_column<B: Boolean>(col: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
let a_0 = *col.get((0, 0)).unwrap();
let a_1 = *col.get((1, 0)).unwrap();
let a_2 = *col.get((2, 0)).unwrap();
let a_3 = *col.get((3, 0)).unwrap();
fn mul_x<B: Boolean>(a: Byte<B>) -> Byte<B> {
let mut bits = a.to_vec();
bits.insert(0, B::from(false));
let c_8 = bits.pop().unwrap();
bits[4] ^= c_8;
bits[3] ^= c_8;
bits[1] ^= c_8;
bits[0] ^= c_8;
Byte::new(bits.try_into().unwrap_or_else(|v: Vec<B>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}))
}
let a_0_xor_a_1 = a_0 ^ a_1;
let a_1_xor_a_2 = a_1 ^ a_2;
let a_2_xor_a_3 = a_2 ^ a_3;
let a_0_xor_a_3 = a_0 ^ a_3;
Matrix::new_from_iter(
(4, 1),
vec![
mul_x(a_0_xor_a_1) ^ a_1_xor_a_2 ^ a_3,
mul_x(a_1_xor_a_2) ^ a_0 ^ a_2_xor_a_3,
mul_x(a_2_xor_a_3) ^ a_0_xor_a_1 ^ a_3,
mul_x(a_0_xor_a_3) ^ a_0_xor_a_1 ^ a_2,
]
.into_iter(),
)
}
let cols = (0..4).map(|j| state.col(j));
let mixed_cols = cols.flat_map(|col| mix_column(col).into_iter().collect::<Vec<Byte<B>>>());
Matrix::new_from_column_major_iter((4, 4), mixed_cols)
}
fn add_round_key(state: Matrix<Byte<B>>, key: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
let mut state = state;
for i in 0..4 {
for j in 0..4 {
let s = state.get_mut((i, j)).unwrap();
*s ^= *key.get((i, j)).unwrap();
}
}
state
}
pub fn encrypt_block(&self, block: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
if block.nrows != 4 || block.ncols != 4 {
panic!(
"block must be a 4x4 matrix (found {}x{})",
block.nrows, block.ncols
);
}
let initial_key = self.round_keys[0].clone();
let mut round_keys = self.round_keys[1..].to_vec();
let last_key = round_keys.pop().unwrap();
let mut state = round_keys.iter().fold(
Self::add_round_key(block.clone(), initial_key),
|mut state, key| {
state = Self::sub_bytes(state);
state = Self::shift_rows(state);
state = Self::mix_columns(state);
Self::add_round_key(state, key.clone())
},
);
state = Self::sub_bytes(state);
state = Self::shift_rows(state);
Self::add_round_key(state, last_key)
}
}
macro_rules! impl_aes {
($t: ident, $key: ident, $key_len: expr) => {
#[derive(Clone, Debug)]
pub struct $t<B: Boolean> {
desc: AESDesc<B>,
}
impl<B: Boolean> $t<B> {
pub fn new(key: $key<B>) -> Self {
Self {
desc: AESDesc::new(Matrix::new_from_column_major_iter(
(4, $key_len / 4),
key.inner().into_iter(),
)),
}
}
pub fn encrypt_block(&self, block: [Byte<B>; 16]) -> [Byte<B>; 16] {
let ciphertext = self.desc.encrypt_block(Matrix::new_from_column_major_iter(
(4, 4),
block.into_iter(),
));
(0..4)
.flat_map(|j| ciphertext.col(j).into_iter().collect::<Vec<Byte<B>>>())
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 16 (found {})", v.len())
})
}
}
};
}
impl_aes!(AES128, AES128Key, 16);
impl_aes!(AES192, AES192Key, 24);
impl_aes!(AES256, AES256Key, 32);
impl BooleanCircuit for AES128<BooleanValue> {
fn eval(&self, x: Vec<bool>) -> Result<Vec<bool>, EvalFailure> {
if x.len() != 256 {
panic!("AES128 expects input Vec of length 256");
}
let mut key_bool = x;
let block_bool = key_bool.split_off(128);
let key_byte: [u8; 16] = key_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let key = key_byte.into();
let block_byte: [u8; 16] = block_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let mut block = block_byte.into();
let cipher = Aes128::new(&key);
cipher.encrypt_block(&mut block);
let ciphertext = block.into_iter();
let ciphertext_bool = ciphertext
.into_iter()
.flat_map(|byte| Byte::<bool>::from(byte).to_vec())
.collect();
Ok(ciphertext_bool)
}
fn run(&self, vals: Vec<BooleanValue>) -> Vec<BooleanValue> {
if vals.len() != 256 {
panic!("AES128 expects input Vec of length 256");
}
let mut key_bool = vals;
let block_bool = key_bool.split_off(128);
let key: [Byte<BooleanValue>; 16] = key_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let block: [Byte<BooleanValue>; 16] = block_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let cipher = AES128::new(AES128Key::new_from_inner(key));
cipher
.encrypt_block(block)
.into_iter()
.flat_map(|byte| byte.get_bits())
.collect::<Vec<BooleanValue>>()
}
}
impl BooleanCircuit for AES192<BooleanValue> {
fn eval(&self, x: Vec<bool>) -> Result<Vec<bool>, EvalFailure> {
if x.len() != 320 {
panic!("AES192 expects input Vec of length 320");
}
let mut key_bool = x;
let block_bool = key_bool.split_off(192);
let key_byte: [u8; 24] = key_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let key = key_byte.into();
let block_byte: [u8; 16] = block_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let mut block = block_byte.into();
let cipher = Aes192::new(&key);
cipher.encrypt_block(&mut block);
let ciphertext = block.into_iter();
let ciphertext_bool = ciphertext
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.collect();
Ok(ciphertext_bool)
}
fn run(&self, vals: Vec<BooleanValue>) -> Vec<BooleanValue> {
if vals.len() != 320 {
panic!("AES192 expects input Vec of length 320");
}
let mut key_bool = vals;
let block_bool = key_bool.split_off(192);
let key: [Byte<BooleanValue>; 24] = key_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let block: [Byte<BooleanValue>; 16] = block_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let cipher = AES192::new(AES192Key::new_from_inner(key));
cipher
.encrypt_block(block)
.into_iter()
.flat_map(|byte| byte.get_bits())
.collect::<Vec<BooleanValue>>()
}
}
impl BooleanCircuit for AES256<BooleanValue> {
fn eval(&self, x: Vec<bool>) -> Result<Vec<bool>, EvalFailure> {
if x.len() != 384 {
panic!("AES256 expects input Vec of length 384");
}
let mut key_bool = x;
let block_bool = key_bool.split_off(256);
let key_byte: [u8; 32] = key_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let key = key_byte.into();
let block_byte: [u8; 16] = block_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let mut block = block_byte.into();
let cipher = Aes256::new(&key);
cipher.encrypt_block(&mut block);
let ciphertext = block.into_iter();
let ciphertext_bool = ciphertext
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.collect();
Ok(ciphertext_bool)
}
fn run(&self, vals: Vec<BooleanValue>) -> Vec<BooleanValue> {
if vals.len() != 384 {
panic!("AES256 expects input Vec of length 384");
}
let mut key_bool = vals;
let block_bool = key_bool.split_off(256);
let key: [Byte<BooleanValue>; 32] = key_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let block: [Byte<BooleanValue>; 16] = block_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let cipher = AES256::new(AES256Key::new_from_inner(key));
cipher
.encrypt_block(block)
.into_iter()
.flat_map(|byte| byte.get_bits())
.collect::<Vec<BooleanValue>>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::circuits::traits::boolean_circuit::tests::TestedBooleanCircuit;
use rand::Rng;
impl TestedBooleanCircuit for AES128<BooleanValue> {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self {
desc: AESDesc {
round_keys: Vec::new(),
},
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
256
}
}
impl TestedBooleanCircuit for AES192<BooleanValue> {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self {
desc: AESDesc {
round_keys: Vec::new(),
},
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
320
}
}
impl TestedBooleanCircuit for AES256<BooleanValue> {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self {
desc: AESDesc {
round_keys: Vec::new(),
},
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
384
}
}
#[test]
fn tested_aes128() {
AES128::test(1, 1)
}
#[test]
fn tested_aes192() {
AES192::test(1, 1)
}
#[test]
fn tested_aes256() {
AES256::test(1, 1)
}
}