#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use crate::primitives::aead::{
AES_GCM_128_KEY_LEN, AES_GCM_256_KEY_LEN, AeadCipher, AeadError, Nonce, TAG_LEN, Tag,
};
use aws_lc_rs::aead::{AES_128_GCM, AES_256_GCM, Aad, LessSafeKey, Nonce as AwsNonce, UnboundKey};
use rand::RngCore;
use rand::rngs::OsRng;
use tracing::instrument;
use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
use crate::primitives::resource_limits::{validate_decryption_size, validate_encryption_size};
macro_rules! impl_aes_gcm {
(
$(#[$meta:meta])*
$name:ident, $key_len:expr, $algorithm:expr, $label:literal
) => {
$(#[$meta])*
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct $name {
key_bytes: [u8; $key_len],
}
impl AeadCipher for $name {
const KEY_LEN: usize = $key_len;
#[instrument(level = "debug", skip(key), fields(key_len = key.len()))]
fn new(key: &[u8]) -> Result<Self, AeadError> {
if key.len() != Self::KEY_LEN {
return Err(AeadError::InvalidKeyLength);
}
crate::primitives::aead::warn_if_all_zero_key(key, $label);
let mut key_bytes = [0u8; $key_len];
key_bytes.copy_from_slice(key);
Ok($name { key_bytes })
}
fn generate_nonce() -> Nonce {
let mut nonce = [0u8; 12];
OsRng.fill_bytes(&mut nonce);
nonce
}
#[instrument(level = "debug", skip(self, nonce, plaintext, aad), fields(algorithm = $label, plaintext_len = plaintext.len(), has_aad = aad.is_some()))]
fn encrypt(
&self,
nonce: &Nonce,
plaintext: &[u8],
aad: Option<&[u8]>,
) -> Result<(Vec<u8>, Tag), AeadError> {
validate_encryption_size(plaintext.len()).map_err(
|_e: crate::primitives::resource_limits::ResourceError| {
AeadError::EncryptionFailed(
"plaintext exceeds resource limits".to_string(),
)
},
)?;
let unbound_key = UnboundKey::new($algorithm, &self.key_bytes)
.map_err(|_e| AeadError::InvalidKeyLength)?;
let key = LessSafeKey::new(unbound_key);
let aws_nonce = AwsNonce::try_assume_unique_for_key(nonce)
.map_err(|_e| AeadError::InvalidNonceLength)?;
let aad = Aad::from(aad.unwrap_or(&[]));
let mut in_out = Vec::with_capacity(plaintext.len().saturating_add(TAG_LEN));
in_out.extend_from_slice(plaintext);
key.seal_in_place_append_tag(aws_nonce, aad, &mut in_out)
.map_err(|_e| AeadError::EncryptionFailed("AEAD seal failed".to_string()))?;
if in_out.len() < TAG_LEN {
return Err(AeadError::EncryptionFailed("ciphertext too short".to_string()));
}
let ct_len = in_out.len().saturating_sub(TAG_LEN);
let ciphertext = in_out
.get(..ct_len)
.ok_or_else(|| AeadError::EncryptionFailed("invalid ciphertext length".to_string()))?
.to_vec();
let mut tag = [0u8; TAG_LEN];
let tag_slice = in_out
.get(ct_len..)
.ok_or_else(|| AeadError::EncryptionFailed("invalid tag offset".to_string()))?;
tag.copy_from_slice(tag_slice);
Ok((ciphertext, tag))
}
#[instrument(level = "debug", skip(self, nonce, ciphertext, tag, aad), fields(algorithm = $label, ciphertext_len = ciphertext.len(), has_aad = aad.is_some()))]
fn decrypt(
&self,
nonce: &Nonce,
ciphertext: &[u8],
tag: &Tag,
aad: Option<&[u8]>,
) -> Result<Zeroizing<Vec<u8>>, AeadError> {
validate_decryption_size(ciphertext.len()).map_err(
|_e: crate::primitives::resource_limits::ResourceError| {
AeadError::DecryptionFailed(
"ciphertext exceeds resource limits".to_string(),
)
},
)?;
let unbound_key = UnboundKey::new($algorithm, &self.key_bytes)
.map_err(|_e| AeadError::InvalidKeyLength)?;
let key = LessSafeKey::new(unbound_key);
let aws_nonce = AwsNonce::try_assume_unique_for_key(nonce)
.map_err(|_e| AeadError::InvalidNonceLength)?;
let aad = Aad::from(aad.unwrap_or(&[]));
let mut in_out: Zeroizing<Vec<u8>> =
Zeroizing::new(Vec::with_capacity(ciphertext.len().saturating_add(TAG_LEN)));
in_out.extend_from_slice(ciphertext);
in_out.extend_from_slice(tag);
let plaintext_len = key
.open_in_place(aws_nonce, aad, in_out.as_mut_slice())
.map_err(|_e| {
AeadError::DecryptionFailed(
"AEAD authentication failed".to_string(),
)
})?
.len();
let plaintext_slice = in_out.get(..plaintext_len).ok_or_else(|| {
AeadError::DecryptionFailed("plaintext length exceeds buffer".to_string())
})?;
Ok(Zeroizing::new(plaintext_slice.to_vec()))
}
}
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!($name))
.field("key_bytes", &"[REDACTED]")
.finish()
}
}
impl subtle::ConstantTimeEq for $name {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.key_bytes.ct_eq(&other.key_bytes)
}
}
impl $name {
#[must_use]
pub fn generate_key() -> zeroize::Zeroizing<[u8; $key_len]> {
let mut key = zeroize::Zeroizing::new([0u8; $key_len]);
OsRng.fill_bytes(&mut *key);
key
}
}
};
}
impl_aes_gcm!(
AesGcm128, AES_GCM_128_KEY_LEN, &AES_128_GCM, "AES-GCM-128"
);
impl_aes_gcm!(
AesGcm256, AES_GCM_256_KEY_LEN, &AES_256_GCM, "AES-GCM-256"
);
#[cfg(test)]
#[allow(clippy::unwrap_used)] #[allow(clippy::panic)] mod tests {
use super::*;
#[test]
fn test_aes_gcm_128_key_generation_succeeds() {
let key1 = AesGcm128::generate_key();
let key2 = AesGcm128::generate_key();
assert_eq!(key1.len(), AES_GCM_128_KEY_LEN);
assert_eq!(key2.len(), AES_GCM_128_KEY_LEN);
assert_ne!(key1, key2);
}
#[test]
fn test_aes_gcm_128_invalid_key_length_fails() {
let key = [0u8; 8]; let result = AesGcm128::new(&key);
assert!(result.is_err());
if let Err(AeadError::InvalidKeyLength) = result {
} else {
panic!("Expected InvalidKeyLength error");
}
}
#[test]
fn test_aes_gcm_256_key_generation_succeeds() {
let key1 = AesGcm256::generate_key();
let key2 = AesGcm256::generate_key();
assert_eq!(key1.len(), AES_GCM_256_KEY_LEN);
assert_eq!(key2.len(), AES_GCM_256_KEY_LEN);
assert_ne!(key1, key2);
}
#[test]
fn test_aes_gcm_256_invalid_key_length_fails() {
let key = [0u8; 16]; let result = AesGcm256::new(&key);
assert!(result.is_err());
if let Err(AeadError::InvalidKeyLength) = result {
} else {
panic!("Expected InvalidKeyLength error");
}
}
#[test]
fn test_aes_gcm_128_encrypt_decrypt_roundtrip() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = b"Hello, World!";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_256_encrypt_decrypt_roundtrip() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
let nonce = AesGcm256::generate_nonce();
let plaintext = b"Hello, World!";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_256_seal_generates_fresh_nonce_per_call() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
let plaintext = b"seal must generate a fresh nonce per call";
let (nonce1, ct1, tag1) = cipher.seal(plaintext, None).unwrap();
let (nonce2, ct2, tag2) = cipher.seal(plaintext, None).unwrap();
assert_ne!(nonce1, nonce2, "seal() must generate a fresh nonce per call");
assert_ne!(ct1, ct2, "fresh nonces must produce distinct ciphertexts");
assert_ne!(tag1, tag2, "fresh nonces must produce distinct tags");
let pt1 = cipher.decrypt(&nonce1, &ct1, &tag1, None).unwrap();
let pt2 = cipher.decrypt(&nonce2, &ct2, &tag2, None).unwrap();
assert_eq!(plaintext, pt1.as_slice());
assert_eq!(plaintext, pt2.as_slice());
}
#[test]
fn test_aes_gcm_128_seal_roundtrip_with_aad() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let plaintext = b"seal with aad";
let aad = b"context";
let (nonce, ct, tag) = cipher.seal(plaintext, Some(aad)).unwrap();
let decrypted = cipher.decrypt(&nonce, &ct, &tag, Some(aad)).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_128_with_aad_roundtrip() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = b"Secret data";
let aad = b"Additional authenticated data";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, Some(aad)).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, Some(aad)).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_128_with_aad_verification_failure_fails() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = b"Secret data";
let aad = b"Correct AAD";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, Some(aad)).unwrap();
let wrong_aad = b"Wrong AAD";
let result = cipher.decrypt(&nonce, &ciphertext, &tag, Some(wrong_aad));
assert!(result.is_err());
if let Err(AeadError::DecryptionFailed(_)) = result {
} else {
panic!("Expected DecryptionFailed error");
}
}
#[test]
fn test_aes_gcm_128_invalid_tag_fails() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = b"Secret data";
let (ciphertext, mut tag) = cipher.encrypt(&nonce, plaintext, None).unwrap();
tag[0] ^= 0xFF;
let result = cipher.decrypt(&nonce, &ciphertext, &tag, None);
assert!(result.is_err());
if let Err(AeadError::DecryptionFailed(_)) = result {
} else {
panic!("Expected DecryptionFailed error");
}
}
#[test]
fn test_aes_gcm_constant_time_tag_verification_is_correct() {
let tag1 = [1u8; 16];
let tag2 = [1u8; 16];
let tag3 = [2u8; 16];
assert!(super::super::verify_tag_constant_time(&tag1, &tag2));
assert!(!super::super::verify_tag_constant_time(&tag1, &tag3));
}
#[test]
fn test_aes_gcm_zeroize_data_clears_bytes_succeeds() {
let mut data = vec![0xFF; 100];
super::super::zeroize_data(&mut data);
assert_eq!(data, vec![0u8; 100]);
}
#[test]
fn test_aes_gcm_128_empty_plaintext_roundtrip() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = b"";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
assert_eq!(ciphertext.len(), 0);
assert_eq!(tag.len(), TAG_LEN);
}
#[test]
fn test_aes_gcm_256_empty_plaintext_roundtrip() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
let nonce = AesGcm256::generate_nonce();
let plaintext = b"";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
assert_eq!(ciphertext.len(), 0);
assert_eq!(tag.len(), TAG_LEN);
}
#[test]
fn test_aes_gcm_128_large_plaintext_roundtrip() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = vec![0xAB; 1024 * 1024];
let (ciphertext, tag) = cipher.encrypt(&nonce, &plaintext, None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
assert_eq!(ciphertext.len(), 1024 * 1024);
assert_eq!(tag.len(), TAG_LEN);
}
#[test]
fn test_aes_gcm_256_large_plaintext_roundtrip() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
let nonce = AesGcm256::generate_nonce();
let plaintext = vec![0xAB; 1024 * 1024];
let (ciphertext, tag) = cipher.encrypt(&nonce, &plaintext, None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
assert_eq!(ciphertext.len(), 1024 * 1024);
assert_eq!(tag.len(), TAG_LEN);
}
#[test]
fn test_aes_gcm_128_corrupted_ciphertext_fails() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = b"Secret data";
let (mut ciphertext, tag) = cipher.encrypt(&nonce, plaintext, None).unwrap();
if let Some(last) = ciphertext.last_mut() {
*last ^= 0xFF;
}
let result = cipher.decrypt(&nonce, &ciphertext, &tag, None);
assert!(result.is_err());
if let Err(AeadError::DecryptionFailed(_)) = result {
} else {
panic!("Expected DecryptionFailed error");
}
}
#[test]
fn test_aes_gcm_256_with_aad_roundtrip() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
let nonce = AesGcm256::generate_nonce();
let plaintext = b"Secret data";
let aad = b"Additional authenticated data";
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, Some(aad)).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, Some(aad)).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_128_multiple_encryptions_all_roundtrip() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
for i in 0..100 {
let nonce = AesGcm128::generate_nonce();
let plaintext = format!("Message {}", i);
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext.as_bytes(), None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext.as_bytes(), decrypted.as_slice());
}
}
#[test]
fn test_aes_gcm_256_multiple_encryptions_all_roundtrip() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
for i in 0..100 {
let nonce = AesGcm256::generate_nonce();
let plaintext = format!("Message {}", i);
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext.as_bytes(), None).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, None).unwrap();
assert_eq!(plaintext.as_bytes(), decrypted.as_slice());
}
}
#[test]
fn test_aes_gcm_128_roundtrip_consistency_roundtrip() {
let key: [u8; 16] = [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
0x0e, 0x0f,
];
let nonce: [u8; 12] =
[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let plaintext: &[u8] = b"Test message for AES-128-GCM";
let aad: &[u8] = b"Additional data";
let cipher = AesGcm128::new(&key).unwrap();
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, Some(aad)).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, Some(aad)).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_256_roundtrip_consistency_roundtrip() {
let key: [u8; 32] = [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
0x1c, 0x1d, 0x1e, 0x1f,
];
let nonce: [u8; 12] =
[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let plaintext: &[u8] = b"Test message for AES-256-GCM";
let aad: &[u8] = b"Additional data";
let cipher = AesGcm256::new(&key).unwrap();
let (ciphertext, tag) = cipher.encrypt(&nonce, plaintext, Some(aad)).unwrap();
let decrypted = cipher.decrypt(&nonce, &ciphertext, &tag, Some(aad)).unwrap();
assert_eq!(plaintext, decrypted.as_slice());
}
#[test]
fn test_aes_gcm_128_encryption_size_limit_has_correct_size() {
let key = AesGcm128::generate_key();
let cipher = AesGcm128::new(&*key).unwrap();
let nonce = AesGcm128::generate_nonce();
let plaintext = vec![0xAB; 101 * 1024 * 1024];
let result = cipher.encrypt(&nonce, &plaintext, None);
assert!(result.is_err(), "Should fail with resource limit exceeded");
if let Err(AeadError::EncryptionFailed(msg)) = result {
assert!(msg.contains("resource limit"), "Error should mention resource limit: {}", msg);
} else {
panic!("Expected EncryptionFailed error");
}
}
#[test]
fn test_aes_gcm_256_decryption_size_limit_has_correct_size() {
let key = AesGcm256::generate_key();
let cipher = AesGcm256::new(&*key).unwrap();
let nonce = AesGcm256::generate_nonce();
let ciphertext = vec![0xCD; 101 * 1024 * 1024];
let tag = [0u8; TAG_LEN];
let result = cipher.decrypt(&nonce, &ciphertext, &tag, None);
assert!(result.is_err(), "Should fail with resource limit exceeded");
if let Err(AeadError::DecryptionFailed(msg)) = result {
assert!(
msg.contains("resource limits") || msg.contains("AEAD"),
"Error should be a DecryptionFailed variant: {}",
msg
);
} else {
panic!("Expected DecryptionFailed error");
}
}
}