use super::{Algorithm, EncryptionResult, Key};
use crate::{Error, Result};
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use alloc::vec::Vec;
use rand::RngCore;
pub const NONCE_SIZE: usize = 12;
pub const TAG_SIZE: usize = 16;
pub fn encrypt_aes_gcm(plaintext: &[u8], key: &Key, aad: &[u8]) -> Result<EncryptionResult> {
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
.map_err(|e| Error::EncryptionFailed(e.to_string()))?;
let ciphertext_with_tag = if aad.is_empty() {
cipher.encrypt(nonce, plaintext)
} else {
use aes_gcm::aead::Payload;
cipher.encrypt(nonce, Payload { msg: plaintext, aad })
}
.map_err(|e| Error::EncryptionFailed(e.to_string()))?;
let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - TAG_SIZE);
Ok(EncryptionResult {
ciphertext: ciphertext.to_vec(),
algorithm: Algorithm::Aes256Gcm,
nonce: nonce_bytes.to_vec(),
tag: tag.to_vec(),
})
}
pub fn decrypt_aes_gcm(encrypted: &EncryptionResult, key: &Key, aad: &[u8]) -> Result<Vec<u8>> {
if encrypted.algorithm != Algorithm::Aes256Gcm {
return Err(Error::DecryptionFailed(format!(
"Algorithm mismatch: expected {:?}, got {:?}",
Algorithm::Aes256Gcm,
encrypted.algorithm
)));
}
if encrypted.nonce.len() != NONCE_SIZE {
return Err(Error::InvalidNonceLength {
expected: NONCE_SIZE,
actual: encrypted.nonce.len(),
});
}
if encrypted.tag.len() != TAG_SIZE {
return Err(Error::DecryptionFailed(format!(
"Invalid tag length: expected {}, got {}",
TAG_SIZE,
encrypted.tag.len()
)));
}
let nonce = Nonce::from_slice(&encrypted.nonce);
let cipher = Aes256Gcm::new_from_slice(key.as_bytes())
.map_err(|e| Error::DecryptionFailed(e.to_string()))?;
let mut ciphertext_with_tag = encrypted.ciphertext.clone();
ciphertext_with_tag.extend_from_slice(&encrypted.tag);
let plaintext = if aad.is_empty() {
cipher.decrypt(nonce, ciphertext_with_tag.as_ref())
} else {
use aes_gcm::aead::Payload;
cipher.decrypt(
nonce,
Payload {
msg: ciphertext_with_tag.as_ref(),
aad,
},
)
}
.map_err(|_| Error::AuthenticationFailed)?;
Ok(plaintext)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encryption::generate_key;
#[test]
fn test_aes_gcm_roundtrip() {
let key = generate_key();
let plaintext = b"Hello, AES-GCM!";
let encrypted = encrypt_aes_gcm(plaintext, &key, &[]).unwrap();
assert_eq!(encrypted.algorithm, Algorithm::Aes256Gcm);
assert_eq!(encrypted.nonce.len(), NONCE_SIZE);
assert_eq!(encrypted.tag.len(), TAG_SIZE);
let decrypted = decrypt_aes_gcm(&encrypted, &key, &[]).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_aes_gcm_with_aad() {
let key = generate_key();
let plaintext = b"Secret message";
let aad = b"additional authenticated data";
let encrypted = encrypt_aes_gcm(plaintext, &key, aad).unwrap();
let decrypted = decrypt_aes_gcm(&encrypted, &key, aad).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_aes_gcm_wrong_key() {
let key = generate_key();
let wrong_key = generate_key();
let plaintext = b"Secret message";
let encrypted = encrypt_aes_gcm(plaintext, &key, &[]).unwrap();
let result = decrypt_aes_gcm(&encrypted, &wrong_key, &[]);
assert!(matches!(result, Err(Error::AuthenticationFailed)));
}
#[test]
fn test_aes_gcm_tampered_ciphertext() {
let key = generate_key();
let plaintext = b"Secret message";
let mut encrypted = encrypt_aes_gcm(plaintext, &key, &[]).unwrap();
if !encrypted.ciphertext.is_empty() {
encrypted.ciphertext[0] ^= 0xFF;
}
let result = decrypt_aes_gcm(&encrypted, &key, &[]);
assert!(matches!(result, Err(Error::AuthenticationFailed)));
}
#[test]
fn test_aes_gcm_tampered_tag() {
let key = generate_key();
let plaintext = b"Secret message";
let mut encrypted = encrypt_aes_gcm(plaintext, &key, &[]).unwrap();
encrypted.tag[0] ^= 0xFF;
let result = decrypt_aes_gcm(&encrypted, &key, &[]);
assert!(matches!(result, Err(Error::AuthenticationFailed)));
}
}