use super::IndCpaCipher;
use aes::{
cipher::{consts::U16, generic_array::GenericArray, FromBlockCipher, StreamCipher},
NewBlockCipher,
};
use tink_core::{utils::wrap_err, TinkError};
pub const AES_CTR_MIN_IV_SIZE: usize = 12;
pub const AES_BLOCK_SIZE_IN_BYTES: usize = 16;
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
enum AesCtrVariant {
Aes128(aes::Aes128),
Aes256(aes::Aes256),
}
#[derive(Clone)]
pub struct AesCtr {
key: AesCtrVariant,
pub iv_size: usize,
}
impl AesCtr {
pub fn new(key: &[u8], iv_size: usize) -> Result<AesCtr, TinkError> {
let key_size = key.len();
super::validate_aes_key_size(key_size).map_err(|e| wrap_err("AesCtr", e))?;
if !(AES_CTR_MIN_IV_SIZE..=AES_BLOCK_SIZE_IN_BYTES).contains(&iv_size) {
return Err(format!("AesCtr: invalid IV size: {}", iv_size).into());
}
let key = match key.len() {
16 => {
AesCtrVariant::Aes128(
aes::Aes128::new_from_slice(key).unwrap(),
)
}
32 => {
AesCtrVariant::Aes256(
aes::Aes256::new_from_slice(key).unwrap(),
)
}
l => return Err(format!("AesCtr: invalid AES key size {} (want 16, 32)", l).into()),
};
Ok(AesCtr { key, iv_size })
}
pub fn key_len(&self) -> usize {
match &self.key {
AesCtrVariant::Aes128(_) => 16,
AesCtrVariant::Aes256(_) => 32,
}
}
fn new_iv(&self) -> GenericArray<u8, U16> {
let mut padded_iv = [0; AES_BLOCK_SIZE_IN_BYTES];
let iv = tink_core::subtle::random::get_random_bytes(self.iv_size);
padded_iv[..iv.len()].copy_from_slice(&iv);
padded_iv.into()
}
}
impl IndCpaCipher for AesCtr {
fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, TinkError> {
if plaintext.len() > ((isize::MAX as usize) - self.iv_size) {
return Err("AesCtr: plaintext too long".into());
}
let iv = self.new_iv();
let mut ciphertext = Vec::with_capacity(self.iv_size + plaintext.len());
ciphertext.extend_from_slice(&iv[..self.iv_size]);
ciphertext.extend_from_slice(plaintext);
match &self.key {
AesCtrVariant::Aes128(key) => {
let mut stream = aes::Aes128Ctr::from_block_cipher(key.clone(), &iv);
stream.apply_keystream(&mut ciphertext[self.iv_size..]);
}
AesCtrVariant::Aes256(key) => {
let mut stream = aes::Aes256Ctr::from_block_cipher(key.clone(), &iv);
stream.apply_keystream(&mut ciphertext[self.iv_size..]);
}
}
Ok(ciphertext)
}
fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, TinkError> {
if ciphertext.len() < self.iv_size {
return Err("AesCtr: ciphertext too short".into());
}
let mut padded_iv = [0; AES_BLOCK_SIZE_IN_BYTES];
padded_iv[..self.iv_size].copy_from_slice(&ciphertext[..self.iv_size]);
let mut plaintext = Vec::with_capacity(ciphertext.len() - self.iv_size);
plaintext.extend_from_slice(&ciphertext[self.iv_size..]);
match &self.key {
AesCtrVariant::Aes128(key) => {
let mut stream = aes::Aes128Ctr::from_block_cipher(key.clone(), &padded_iv.into());
stream.apply_keystream(&mut plaintext);
}
AesCtrVariant::Aes256(key) => {
let mut stream = aes::Aes256Ctr::from_block_cipher(key.clone(), &padded_iv.into());
stream.apply_keystream(&mut plaintext);
}
}
Ok(plaintext)
}
}