use aes::Aes128;
use cbc::cipher::block_padding::Pkcs7;
use cbc::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit};
use crate::crypto::CryptoError;
use crate::io::mbuf::{Mbuf, MbufPool, MbufQueue};
type Aes128CbcEnc = cbc::Encryptor<Aes128>;
type Aes128CbcDec = cbc::Decryptor<Aes128>;
pub const AES_KEYLEN: usize = 32;
pub const AES_BLOCK_SIZE: usize = 16;
pub const AES_128_KEY_LEN: usize = 16;
fn key_iv(aes_key: &[u8; AES_KEYLEN]) -> &[u8; AES_128_KEY_LEN] {
aes_key
.first_chunk::<AES_128_KEY_LEN>()
.expect("AES_KEYLEN >= AES_128_KEY_LEN by construction")
}
pub fn encrypt_to_vec(msg: &[u8], aes_key: &[u8; AES_KEYLEN]) -> Result<Vec<u8>, CryptoError> {
let kiv = key_iv(aes_key);
let cipher = Aes128CbcEnc::new(kiv.into(), kiv.into());
Ok(cipher.encrypt_padded_vec_mut::<Pkcs7>(msg))
}
pub fn decrypt_to_vec(enc: &[u8], aes_key: &[u8; AES_KEYLEN]) -> Result<Vec<u8>, CryptoError> {
if enc.is_empty() || !enc.len().is_multiple_of(AES_BLOCK_SIZE) {
return Err(CryptoError::DecryptionFailed);
}
let kiv = key_iv(aes_key);
let cipher = Aes128CbcDec::new(kiv.into(), kiv.into());
cipher
.decrypt_padded_vec_mut::<Pkcs7>(enc)
.map_err(|_| CryptoError::BadPadding)
}
pub fn encrypt_to_chain(
msg: &[u8],
aes_key: &[u8; AES_KEYLEN],
pool: &MbufPool,
) -> Result<MbufQueue, CryptoError> {
let cipher_bytes = encrypt_to_vec(msg, aes_key)?;
let mut queue = MbufQueue::new();
let mut remaining = cipher_bytes.as_slice();
while !remaining.is_empty() {
let mut buf: Mbuf = pool.get();
let n = buf.recv(remaining);
debug_assert!(n > 0, "fresh mbuf cannot accept any bytes");
if n == 0 {
return Err(CryptoError::EncryptionFailed);
}
remaining = &remaining[n..];
queue.push_back(buf);
}
Ok(queue)
}
pub fn decrypt_chain_to_chain(
enc: &mut MbufQueue,
aes_key: &[u8; AES_KEYLEN],
pool: &MbufPool,
) -> Result<MbufQueue, CryptoError> {
let mut bytes = Vec::with_capacity(enc.total_len());
while let Some(buf) = enc.pop_front() {
bytes.extend_from_slice(buf.readable());
}
let plain = decrypt_to_vec(&bytes, aes_key)?;
let mut queue = MbufQueue::new();
let mut remaining = plain.as_slice();
while !remaining.is_empty() {
let mut buf = pool.get();
let n = buf.recv(remaining);
debug_assert!(n > 0, "fresh mbuf cannot accept any bytes");
if n == 0 {
return Err(CryptoError::DecryptionFailed);
}
remaining = &remaining[n..];
queue.push_back(buf);
}
Ok(queue)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::Crypto;
#[test]
fn empty_plaintext_round_trips() {
let key = Crypto::generate_aes_key().unwrap();
let cipher = encrypt_to_vec(b"", &key).unwrap();
assert_eq!(cipher.len(), AES_BLOCK_SIZE);
let plain = decrypt_to_vec(&cipher, &key).unwrap();
assert!(plain.is_empty());
}
#[test]
fn block_aligned_plaintext_pads_full_block() {
let key = Crypto::generate_aes_key().unwrap();
let msg = vec![0xab; AES_BLOCK_SIZE];
let cipher = encrypt_to_vec(&msg, &key).unwrap();
assert_eq!(cipher.len(), 2 * AES_BLOCK_SIZE);
let plain = decrypt_to_vec(&cipher, &key).unwrap();
assert_eq!(plain, msg);
}
#[test]
fn encryption_is_deterministic() {
let key = Crypto::generate_aes_key().unwrap();
let a = encrypt_to_vec(b"same", &key).unwrap();
let b = encrypt_to_vec(b"same", &key).unwrap();
assert_eq!(a, b, "key-as-IV makes the cipher deterministic");
}
#[test]
fn known_vector_pin() {
let key: [u8; AES_KEYLEN] = [
0x10, 0x32, 0x54, 0x76, 0x98, 0xba, 0xdc, 0xfe, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab,
0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, 0xef, 0xcd, 0xab, 0x89,
0x67, 0x45, 0x23, 0x01,
];
let plaintext = b"";
let cipher = encrypt_to_vec(plaintext, &key).unwrap();
assert_eq!(cipher.len(), AES_BLOCK_SIZE);
let expected: [u8; AES_BLOCK_SIZE] = [
0x98, 0xe1, 0x44, 0x32, 0xf6, 0x65, 0x78, 0xb9, 0x45, 0xd6, 0x4f, 0xc4, 0x60, 0x27,
0x1b, 0xab,
];
assert_eq!(cipher, expected);
let round = decrypt_to_vec(&cipher, &key).unwrap();
assert_eq!(round.as_slice(), plaintext);
}
#[test]
fn truncated_ciphertext_is_rejected() {
let key = Crypto::generate_aes_key().unwrap();
let cipher = encrypt_to_vec(b"abc", &key).unwrap();
let truncated = &cipher[..cipher.len() - 1];
assert!(decrypt_to_vec(truncated, &key).is_err());
}
#[test]
fn empty_ciphertext_is_rejected() {
let key = Crypto::generate_aes_key().unwrap();
assert!(decrypt_to_vec(&[], &key).is_err());
}
#[test]
fn wrong_key_fails_padding_check() {
const TRIALS: usize = 32;
let mut observed_rejection = false;
for _ in 0..TRIALS {
let key_a = Crypto::generate_aes_key().unwrap();
let key_b = Crypto::generate_aes_key().unwrap();
let cipher = encrypt_to_vec(b"abc", &key_a).unwrap();
if decrypt_to_vec(&cipher, &key_b).is_err() {
observed_rejection = true;
break;
}
}
assert!(
observed_rejection,
"expected at least one wrong-key decryption in {TRIALS} trials to fail PKCS#7 padding"
);
}
#[test]
fn chain_round_trip() {
let pool = MbufPool::default();
let key = Crypto::generate_aes_key().unwrap();
let mut chain = encrypt_to_chain(b"hello world", &key, &pool).unwrap();
let mut plain = decrypt_chain_to_chain(&mut chain, &key, &pool).unwrap();
assert!(chain.is_empty());
let bytes: Vec<u8> = plain.iter().flat_map(|m| m.readable().to_vec()).collect();
assert_eq!(bytes, b"hello world");
plain.recycle(&pool);
}
}