use aes::{
cipher::{generic_array::GenericArray, BlockDecryptMut, BlockEncryptMut, IvState, KeyIvInit},
Aes256,
};
use cbc::{Decryptor, Encryptor};
use alloc::vec::Vec;
enum AesKind {
Encryptor {
aes: Encryptor<Aes256>,
},
Decryptor {
aes: Decryptor<Aes256>,
},
}
pub struct Aes {
kind: AesKind,
}
impl Aes {
pub fn new_encryptor(key: &[u8], iv: &[u8]) -> Self {
let key: [u8; 32] = key.try_into().expect("valid aes key");
let iv: [u8; 16] = iv.try_into().expect("valid aes iv");
let aes = Encryptor::<Aes256>::new(&key.into(), &iv.into());
Self {
kind: AesKind::Encryptor { aes },
}
}
pub fn new_decryptor(key: &[u8], iv: &[u8]) -> Self {
let key: [u8; 32] = key.try_into().expect("valid aes key");
let iv: [u8; 16] = iv.try_into().expect("valid aes iv");
let aes = Decryptor::<Aes256>::new(&key.into(), &iv.into());
Self {
kind: AesKind::Decryptor { aes },
}
}
pub fn encrypt<T: AsRef<[u8]>>(&mut self, plaintext: T) -> Vec<u8> {
assert!(plaintext.as_ref().len() % 16 == 0, "invalid plaintext");
let AesKind::Encryptor { aes } = &mut self.kind else {
panic!("tried to call `encrypt()` for an aes decryptor");
};
let mut blocks = plaintext
.as_ref()
.chunks(16)
.map(|chunk| {
GenericArray::from(TryInto::<[u8; 16]>::try_into(chunk).expect("to succeed"))
})
.collect::<Vec<_>>();
aes.encrypt_blocks_mut(&mut blocks);
let _iv = aes.iv_state();
blocks
.into_iter()
.flat_map(|block| block.into_iter().collect::<Vec<u8>>())
.collect()
}
pub fn decrypt<T: AsRef<[u8]>>(&mut self, ciphertext: T) -> Vec<u8> {
assert!(ciphertext.as_ref().len() % 16 == 0, "invalid ciphertext");
let AesKind::Decryptor { aes } = &mut self.kind else {
panic!("tried to call `decrypt()` for an aes encryptor");
};
let mut blocks = ciphertext
.as_ref()
.chunks(16)
.map(|chunk| {
GenericArray::from(TryInto::<[u8; 16]>::try_into(chunk).expect("to succeed"))
})
.collect::<Vec<_>>();
aes.decrypt_blocks_mut(&mut blocks);
blocks
.into_iter()
.flat_map(|block| block.into_iter().collect::<Vec<u8>>())
.collect::<Vec<u8>>()
}
pub fn iv(&self) -> [u8; 16] {
match &self.kind {
AesKind::Encryptor { aes } => aes.iv_state().into(),
AesKind::Decryptor { aes } => aes.iv_state().into(),
}
}
}