use crate::{
core::{
circuits::{
boolean::{
boolean_value::{Boolean, BooleanValue},
byte::Byte,
},
traits::boolean_circuit::BooleanCircuit,
},
expressions::expr::EvalFailure,
},
utils::{
crypto::key::{AES128Key, AES192Key, AES256Key},
matrix::Matrix,
},
};
use aes::{
cipher::{BlockEncrypt, KeyInit},
Aes128,
Aes192,
Aes256,
};
use core::panic;
const RC: [u8; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36];
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
pub struct AESDesc<B: Boolean> {
round_keys: Vec<Matrix<Byte<B>>>,
}
impl<B: Boolean> AESDesc<B> {
fn new(key: Matrix<Byte<B>>) -> Self {
if key.nrows != 4 {
panic!("key must have 4 rows (found {})", key.nrows);
}
let key_length = 8 * key.nrows * key.ncols;
let n_rounds = match key_length {
128 => 10,
192 => 12,
256 => 14,
_ => panic!(
"key_length must be one of {{128, 192, 256}} (found {})",
key_length
),
};
let n = key_length / 32;
let mut round_keys_col = (0..n).map(|j| key.col(j)).collect::<Vec<Matrix<Byte<B>>>>();
for i in n..4 * (n_rounds + 1) {
let mut word = round_keys_col.last().unwrap().clone();
if i % n == 0 {
let col = [
*word.get((0, 0)).unwrap(),
*word.get((1, 0)).unwrap(),
*word.get((2, 0)).unwrap(),
*word.get((3, 0)).unwrap(),
];
for r in 0..4 {
let w = word.get_mut((r, 0)).unwrap();
*w = col[(r + 1) % 4];
}
word = Self::sub_bytes(word);
let w = word.get_mut((0, 0)).unwrap();
*w ^= Byte::from(RC[i / n - 1]);
} else if key_length == 256 && i % n == 4 {
word = Self::sub_bytes(word);
}
for r in 0..4 {
let w = word.get_mut((r, 0)).unwrap();
*w ^= *round_keys_col[i - n].get((r, 0)).unwrap();
}
round_keys_col.push(word);
}
let round_keys = round_keys_col
.chunks(4)
.map(|cols| {
Matrix::new_from_column_major_iter(
(4, 4),
cols.iter()
.flat_map(|col| col.into_iter().collect::<Vec<Byte<B>>>()),
)
})
.collect();
Self { round_keys }
}
fn sub_bytes(state: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
fn sub_byte<B: Boolean>(s: Byte<B>) -> Byte<B> {
let bits = s.get_bits();
let u0 = bits[7];
let u1 = bits[6];
let u2 = bits[5];
let u3 = bits[4];
let u4 = bits[3];
let u5 = bits[2];
let u6 = bits[1];
let u7 = bits[0];
let t1 = u0 ^ u3;
let t2 = u0 ^ u5;
let t3 = u0 ^ u6;
let t4 = u3 ^ u5;
let t5 = u4 ^ u6;
let t6 = t1 ^ t5;
let t7 = u1 ^ u2;
let t8 = u7 ^ t6;
let t9 = u7 ^ t7;
let t10 = t6 ^ t7;
let t11 = u1 ^ u5;
let t12 = u2 ^ u5;
let t13 = t3 ^ t4;
let t14 = t6 ^ t11;
let t15 = t5 ^ t11;
let t16 = t5 ^ t12;
let t17 = t9 ^ t16;
let t18 = u3 ^ u7;
let t19 = t7 ^ t18;
let t20 = t1 ^ t19;
let t21 = u6 ^ u7;
let t22 = t7 ^ t21;
let t23 = t2 ^ t22;
let t24 = t2 ^ t10;
let t25 = t20 ^ t17;
let t26 = t3 ^ t16;
let t27 = t1 ^ t12;
let m1 = t13 & t6;
let m2 = t23 & t8;
let m3 = t14 ^ m1;
let m4 = t19 & u7;
let m5 = m4 ^ m1;
let m6 = t3 & t16;
let m7 = t22 & t9;
let m8 = t26 ^ m6;
let m9 = t20 & t17;
let m10 = m9 ^ m6;
let m11 = t1 & t15;
let m12 = t4 & t27;
let m13 = m12 ^ m11;
let m14 = t2 & t10;
let m15 = m14 ^ m11;
let m16 = m3 ^ m2;
let m17 = m5 ^ t24;
let m18 = m8 ^ m7;
let m19 = m10 ^ m15;
let m20 = m16 ^ m13;
let m21 = m17 ^ m15;
let m22 = m18 ^ m13;
let m23 = m19 ^ t25;
let m24 = m22 ^ m23;
let m25 = m22 & m20;
let m26 = m21 ^ m25;
let m27 = m20 ^ m21;
let m28 = m23 ^ m25;
let m29 = m28 & m27;
let m30 = m26 & m24;
let m31 = m20 & m23;
let m32 = m27 & m31;
let m33 = m27 ^ m25;
let m34 = m21 & m22;
let m35 = m24 & m34;
let m36 = m24 ^ m25;
let m37 = m21 ^ m29;
let m38 = m32 ^ m33;
let m39 = m23 ^ m30;
let m40 = m35 ^ m36;
let m41 = m38 ^ m40;
let m42 = m37 ^ m39;
let m43 = m37 ^ m38;
let m44 = m39 ^ m40;
let m45 = m42 ^ m41;
let m46 = m44 & t6;
let m47 = m40 & t8;
let m48 = m39 & u7;
let m49 = m43 & t16;
let m50 = m38 & t9;
let m51 = m37 & t17;
let m52 = m42 & t15;
let m53 = m45 & t27;
let m54 = m41 & t10;
let m55 = m44 & t13;
let m56 = m40 & t23;
let m57 = m39 & t19;
let m58 = m43 & t3;
let m59 = m38 & t22;
let m60 = m37 & t20;
let m61 = m42 & t1;
let m62 = m45 & t4;
let m63 = m41 & t2;
let l0 = m61 ^ m62;
let l1 = m50 ^ m56;
let l2 = m46 ^ m48;
let l3 = m47 ^ m55;
let l4 = m54 ^ m58;
let l5 = m49 ^ m61;
let l6 = m62 ^ l5;
let l7 = m46 ^ l3;
let l8 = m51 ^ m59;
let l9 = m52 ^ m53;
let l10 = m53 ^ l4;
let l11 = m60 ^ l2;
let l12 = m48 ^ m51;
let l13 = m50 ^ l0;
let l14 = m52 ^ m61;
let l15 = m55 ^ l1;
let l16 = m56 ^ l0;
let l17 = m57 ^ l1;
let l18 = m58 ^ l8;
let l19 = m63 ^ l4;
let l20 = l0 ^ l1;
let l21 = l1 ^ l7;
let l22 = l3 ^ l12;
let l23 = l18 ^ l2;
let l24 = l15 ^ l9;
let l25 = l6 ^ l10;
let l26 = l7 ^ l9;
let l27 = l8 ^ l10;
let l28 = l11 ^ l14;
let l29 = l11 ^ l17;
let s0 = l6 ^ l24;
let s1 = l16 ^ l26;
let s2 = l19 ^ l28;
let s3 = l6 ^ l21;
let s4 = l20 ^ l22;
let s5 = l25 ^ l29;
let s6 = l13 ^ l27;
let s7 = l6 ^ l23;
let mut res = [B::from(false); 8];
res[7] = s0;
res[6] = !s1;
res[5] = !s2;
res[4] = s3;
res[3] = s4;
res[2] = s5;
res[1] = !s6;
res[0] = !s7;
Byte::new(res)
}
let mut state = state;
state.map_mut(sub_byte);
state
}
fn shift_rows(state: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
let mut state = state;
for i in 1..4 {
let row = [
*state.get((i, 0)).unwrap(),
*state.get((i, 1)).unwrap(),
*state.get((i, 2)).unwrap(),
*state.get((i, 3)).unwrap(),
];
for j in 0..4 {
let s = state.get_mut((i, j)).unwrap();
*s = row[(i + j) % 4];
}
}
state
}
fn mix_columns(state: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
fn mix_column<B: Boolean>(col: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
let a_0 = *col.get((0, 0)).unwrap();
let a_1 = *col.get((1, 0)).unwrap();
let a_2 = *col.get((2, 0)).unwrap();
let a_3 = *col.get((3, 0)).unwrap();
fn mul_x<B: Boolean>(a: Byte<B>) -> Byte<B> {
let mut bits = a.to_vec();
bits.insert(0, B::from(false));
let c_8 = bits.pop().unwrap();
bits[4] ^= c_8;
bits[3] ^= c_8;
bits[1] ^= c_8;
bits[0] ^= c_8;
Byte::new(bits.try_into().unwrap_or_else(|v: Vec<B>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}))
}
let a_0_xor_a_1 = a_0 ^ a_1;
let a_1_xor_a_2 = a_1 ^ a_2;
let a_2_xor_a_3 = a_2 ^ a_3;
let a_0_xor_a_3 = a_0 ^ a_3;
Matrix::new_from_iter(
(4, 1),
vec![
mul_x(a_0_xor_a_1) ^ a_1_xor_a_2 ^ a_3,
mul_x(a_1_xor_a_2) ^ a_0 ^ a_2_xor_a_3,
mul_x(a_2_xor_a_3) ^ a_0_xor_a_1 ^ a_3,
mul_x(a_0_xor_a_3) ^ a_0_xor_a_1 ^ a_2,
]
.into_iter(),
)
}
let cols = (0..4).map(|j| state.col(j));
let mixed_cols = cols.flat_map(|col| mix_column(col).into_iter().collect::<Vec<Byte<B>>>());
Matrix::new_from_column_major_iter((4, 4), mixed_cols)
}
fn add_round_key(state: Matrix<Byte<B>>, key: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
let mut state = state;
for i in 0..4 {
for j in 0..4 {
let s = state.get_mut((i, j)).unwrap();
*s ^= *key.get((i, j)).unwrap();
}
}
state
}
pub fn encrypt_block(&self, block: Matrix<Byte<B>>) -> Matrix<Byte<B>> {
if block.nrows != 4 || block.ncols != 4 {
panic!(
"block must be a 4x4 matrix (found {}x{})",
block.nrows, block.ncols
);
}
let initial_key = self.round_keys[0].clone();
let mut round_keys = self.round_keys[1..].to_vec();
let last_key = round_keys.pop().unwrap();
let mut state = round_keys.iter().fold(
Self::add_round_key(block.clone(), initial_key),
|mut state, key| {
state = Self::sub_bytes(state);
state = Self::shift_rows(state);
state = Self::mix_columns(state);
Self::add_round_key(state, key.clone())
},
);
state = Self::sub_bytes(state);
state = Self::shift_rows(state);
Self::add_round_key(state, last_key)
}
}
macro_rules! impl_aes {
($t: ident, $key: ident, $key_len: expr) => {
#[derive(Clone, Debug)]
pub struct $t<B: Boolean> {
desc: AESDesc<B>,
}
impl<B: Boolean> $t<B> {
pub fn new(key: $key<B>) -> Self {
Self {
desc: AESDesc::new(Matrix::new_from_column_major_iter(
(4, $key_len / 4),
key.inner().into_iter(),
)),
}
}
pub fn encrypt_block(&self, block: [Byte<B>; 16]) -> [Byte<B>; 16] {
let ciphertext = self.desc.encrypt_block(Matrix::new_from_column_major_iter(
(4, 4),
block.into_iter(),
));
(0..4)
.flat_map(|j| ciphertext.col(j).into_iter().collect::<Vec<Byte<B>>>())
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 16 (found {})", v.len())
})
}
}
};
}
impl_aes!(AES128, AES128Key, 16);
impl_aes!(AES192, AES192Key, 24);
impl_aes!(AES256, AES256Key, 32);
impl BooleanCircuit for AES128<BooleanValue> {
fn eval(&self, x: Vec<bool>) -> Result<Vec<bool>, EvalFailure> {
if x.len() != 256 {
panic!("AES128 expects input Vec of length 256");
}
let mut key_bool = x;
let block_bool = key_bool.split_off(128);
let key_byte: [u8; 16] = key_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let key = key_byte.into();
let block_byte: [u8; 16] = block_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let mut block = block_byte.into();
let cipher = Aes128::new(&key);
cipher.encrypt_block(&mut block);
let ciphertext = block.into_iter();
let ciphertext_bool = ciphertext
.into_iter()
.flat_map(|byte| Byte::<bool>::from(byte).to_vec())
.collect();
Ok(ciphertext_bool)
}
fn run(&self, vals: Vec<BooleanValue>) -> Vec<BooleanValue> {
if vals.len() != 256 {
panic!("AES128 expects input Vec of length 256");
}
let mut key_bool = vals;
let block_bool = key_bool.split_off(128);
let key: [Byte<BooleanValue>; 16] = key_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let block: [Byte<BooleanValue>; 16] = block_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let cipher = AES128::new(AES128Key::new_from_inner(key));
cipher
.encrypt_block(block)
.into_iter()
.flat_map(|byte| byte.get_bits())
.collect::<Vec<BooleanValue>>()
}
}
impl BooleanCircuit for AES192<BooleanValue> {
fn eval(&self, x: Vec<bool>) -> Result<Vec<bool>, EvalFailure> {
if x.len() != 320 {
panic!("AES192 expects input Vec of length 320");
}
let mut key_bool = x;
let block_bool = key_bool.split_off(192);
let key_byte: [u8; 24] = key_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let key = key_byte.into();
let block_byte: [u8; 16] = block_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let mut block = block_byte.into();
let cipher = Aes192::new(&key);
cipher.encrypt_block(&mut block);
let ciphertext = block.into_iter();
let ciphertext_bool = ciphertext
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.collect();
Ok(ciphertext_bool)
}
fn run(&self, vals: Vec<BooleanValue>) -> Vec<BooleanValue> {
if vals.len() != 320 {
panic!("AES192 expects input Vec of length 320");
}
let mut key_bool = vals;
let block_bool = key_bool.split_off(192);
let key: [Byte<BooleanValue>; 24] = key_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let block: [Byte<BooleanValue>; 16] = block_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let cipher = AES192::new(AES192Key::new_from_inner(key));
cipher
.encrypt_block(block)
.into_iter()
.flat_map(|byte| byte.get_bits())
.collect::<Vec<BooleanValue>>()
}
}
impl BooleanCircuit for AES256<BooleanValue> {
fn eval(&self, x: Vec<bool>) -> Result<Vec<bool>, EvalFailure> {
if x.len() != 384 {
panic!("AES256 expects input Vec of length 384");
}
let mut key_bool = x;
let block_bool = key_bool.split_off(256);
let key_byte: [u8; 32] = key_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let key = key_byte.into();
let block_byte: [u8; 16] = block_bool
.chunks(8)
.map(|bits| {
u8::from(Byte::new(bits.to_vec().try_into().unwrap_or_else(
|v: Vec<bool>| panic!("Expected a Vec of length 8 (found {})", v.len()),
)))
})
.collect::<Vec<u8>>()
.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected a Vec of length 8 (found {})", v.len()));
let mut block = block_byte.into();
let cipher = Aes256::new(&key);
cipher.encrypt_block(&mut block);
let ciphertext = block.into_iter();
let ciphertext_bool = ciphertext
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.collect();
Ok(ciphertext_bool)
}
fn run(&self, vals: Vec<BooleanValue>) -> Vec<BooleanValue> {
if vals.len() != 384 {
panic!("AES256 expects input Vec of length 384");
}
let mut key_bool = vals;
let block_bool = key_bool.split_off(256);
let key: [Byte<BooleanValue>; 32] = key_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let block: [Byte<BooleanValue>; 16] = block_bool
.chunks(8)
.map(|bits| {
Byte::new(
bits.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
});
let cipher = AES256::new(AES256Key::new_from_inner(key));
cipher
.encrypt_block(block)
.into_iter()
.flat_map(|byte| byte.get_bits())
.collect::<Vec<BooleanValue>>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::circuits::traits::boolean_circuit::tests::TestedBooleanCircuit;
use rand::Rng;
impl TestedBooleanCircuit for AES128<BooleanValue> {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self {
desc: AESDesc {
round_keys: Vec::new(),
},
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
256
}
}
impl TestedBooleanCircuit for AES192<BooleanValue> {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self {
desc: AESDesc {
round_keys: Vec::new(),
},
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
320
}
}
impl TestedBooleanCircuit for AES256<BooleanValue> {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self {
desc: AESDesc {
round_keys: Vec::new(),
},
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
384
}
}
#[test]
fn tested_aes128() {
AES128::test(1, 1)
}
#[test]
fn tested_aes192() {
AES192::test(1, 1)
}
#[test]
fn tested_aes256() {
AES256::test(1, 1)
}
}