use alloc::vec::Vec;
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use super::{AES_GCM_NONCE_LEN, AES_GCM_TAG_LEN, KEY_LEN};
use crate::error::{Error, Result};
pub(super) fn encrypt(key: &[u8], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
check_key_len(key)?;
let mut nonce_bytes = [0u8; AES_GCM_NONCE_LEN];
mod_rand::tier3::fill_bytes(&mut nonce_bytes)
.map_err(|_| Error::RandomFailure("mod_rand::tier3::fill_bytes"))?;
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(key));
let nonce = Nonce::from_slice(&nonce_bytes);
let ct_and_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| Error::AuthenticationFailed)?;
let mut out = Vec::with_capacity(AES_GCM_NONCE_LEN + ct_and_tag.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ct_and_tag);
Ok(out)
}
pub(super) fn decrypt(key: &[u8], wire: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
check_key_len(key)?;
if wire.len() < AES_GCM_NONCE_LEN + AES_GCM_TAG_LEN {
return Err(Error::InvalidCiphertext(alloc::format!(
"buffer too short ({} bytes, need at least {})",
wire.len(),
AES_GCM_NONCE_LEN + AES_GCM_TAG_LEN
)));
}
let (nonce_bytes, ct_and_tag) = wire.split_at(AES_GCM_NONCE_LEN);
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(key));
let nonce = Nonce::from_slice(nonce_bytes);
cipher
.decrypt(
nonce,
Payload {
msg: ct_and_tag,
aad,
},
)
.map_err(|_| Error::AuthenticationFailed)
}
#[inline]
fn check_key_len(key: &[u8]) -> Result<()> {
if key.len() == KEY_LEN {
Ok(())
} else {
Err(Error::InvalidKey {
expected: KEY_LEN,
actual: key.len(),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn nist_test_case_14_known_answer() {
let key = [0u8; 32];
let nonce = [0u8; 12];
let plaintext = [0u8; 16];
let expected_ct = hex_to_bytes("cea7403d4d606b6e074ec5d3baf39d18");
let expected_tag = hex_to_bytes("d0d1c8a799996bf0265b98b5d48ab919");
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key));
let n = Nonce::from_slice(&nonce);
let got = cipher
.encrypt(
n,
Payload {
msg: &plaintext,
aad: &[],
},
)
.unwrap();
let (got_ct, got_tag) = got.split_at(expected_ct.len());
assert_eq!(got_ct, &expected_ct[..]);
assert_eq!(got_tag, &expected_tag[..]);
let recovered = cipher
.decrypt(
n,
Payload {
msg: &got,
aad: &[],
},
)
.unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn nist_test_case_15_known_answer() {
let key = hex_to_bytes("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308");
let nonce = hex_to_bytes("cafebabefacedbaddecaf888");
let plaintext = hex_to_bytes(
"d9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72\
1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255",
);
let expected_ct = hex_to_bytes(
"522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa\
8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662898015ad",
);
let expected_tag = hex_to_bytes("b094dac5d93471bdec1a502270e3cc6c");
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key));
let n = Nonce::from_slice(&nonce);
let got = cipher
.encrypt(
n,
Payload {
msg: &plaintext,
aad: &[],
},
)
.unwrap();
let (got_ct, got_tag) = got.split_at(expected_ct.len());
assert_eq!(got_ct, &expected_ct[..]);
assert_eq!(got_tag, &expected_tag[..]);
let recovered = cipher
.decrypt(
n,
Payload {
msg: &got,
aad: &[],
},
)
.unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn round_trip_via_module_wrapper() {
let key = [0xb2u8; 32];
let pt = b"the wrapper layers nonce-prepend on top of the upstream primitive";
let wire = encrypt(&key, pt, &[]).unwrap();
assert_eq!(wire.len(), AES_GCM_NONCE_LEN + pt.len() + AES_GCM_TAG_LEN);
let recovered = decrypt(&key, &wire, &[]).unwrap();
assert_eq!(recovered, pt);
}
#[test]
fn check_key_len_accepts_exactly_32() {
assert!(check_key_len(&[0u8; 32]).is_ok());
}
#[test]
fn check_key_len_rejects_off_by_one() {
assert!(check_key_len(&[0u8; 31]).is_err());
assert!(check_key_len(&[0u8; 33]).is_err());
}
fn hex_to_bytes(s: &str) -> alloc::vec::Vec<u8> {
hex::decode(s).expect("valid hex")
}
}