use aes_gcm::{
aead::{Aead, KeyInit, Payload},
Aes256Gcm, Nonce,
};
use rand::RngCore;
use zeroize::Zeroizing;
use crate::{
key::{MasterKey, KEY_SIZE},
Result, VaultError,
};
pub const NONCE_SIZE: usize = 12;
pub struct Cipher {
encryption_key: Zeroizing<[u8; KEY_SIZE]>,
}
impl Cipher {
pub fn new(master_key: &MasterKey) -> Self {
Self {
encryption_key: Zeroizing::new(master_key.encryption_key()),
}
}
pub fn from_raw_key(key: [u8; KEY_SIZE]) -> Self {
Self {
encryption_key: Zeroizing::new(key),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<(Vec<u8>, [u8; NONCE_SIZE])> {
let cipher = Aes256Gcm::new_from_slice(&*self.encryption_key)
.map_err(|e| VaultError::CryptoError(format!("Invalid key: {e}")))?;
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| VaultError::CryptoError(format!("Encryption failed: {e}")))?;
Ok((ciphertext, nonce_bytes))
}
pub fn decrypt(&self, ciphertext: &[u8], nonce_bytes: &[u8]) -> Result<Vec<u8>> {
if nonce_bytes.len() != NONCE_SIZE {
return Err(VaultError::CryptoError(format!(
"Invalid nonce size: expected {NONCE_SIZE}, got {}",
nonce_bytes.len()
)));
}
let cipher = Aes256Gcm::new_from_slice(&*self.encryption_key)
.map_err(|e| VaultError::CryptoError(format!("Invalid key: {e}")))?;
let nonce = Nonce::from_slice(nonce_bytes);
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| VaultError::CryptoError(format!("Decryption failed: {e}")))
}
pub fn encrypt_with_aad(
&self,
plaintext: &[u8],
aad: &[u8],
) -> Result<(Vec<u8>, [u8; NONCE_SIZE])> {
let cipher = Aes256Gcm::new_from_slice(&*self.encryption_key)
.map_err(|e| VaultError::CryptoError(format!("Invalid key: {e}")))?;
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let payload = Payload {
msg: plaintext,
aad,
};
let raw_ct = cipher
.encrypt(nonce, payload)
.map_err(|e| VaultError::CryptoError(format!("Encryption failed: {e}")))?;
let mut tagged = Vec::with_capacity(1 + raw_ct.len());
tagged.push(0x02);
tagged.extend_from_slice(&raw_ct);
Ok((tagged, nonce_bytes))
}
pub fn decrypt_with_aad(
&self,
ciphertext: &[u8],
nonce_bytes: &[u8],
aad: &[u8],
) -> Result<Vec<u8>> {
if nonce_bytes.len() != NONCE_SIZE {
return Err(VaultError::CryptoError(format!(
"Invalid nonce size: expected {NONCE_SIZE}, got {}",
nonce_bytes.len()
)));
}
if ciphertext.first() == Some(&0x02) && ciphertext.len() > 1 {
let cipher = Aes256Gcm::new_from_slice(&*self.encryption_key)
.map_err(|e| VaultError::CryptoError(format!("Invalid key: {e}")))?;
let nonce = Nonce::from_slice(nonce_bytes);
let payload = Payload {
msg: &ciphertext[1..],
aad,
};
match cipher.decrypt(nonce, payload) {
Ok(plaintext) => Ok(plaintext),
Err(_) => self.decrypt(ciphertext, nonce_bytes),
}
} else {
self.decrypt(ciphertext, nonce_bytes)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::key::KEY_SIZE;
fn test_key() -> MasterKey {
MasterKey::from_bytes([0u8; KEY_SIZE])
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let cipher = Cipher::new(&test_key());
let plaintext = b"hello, world!";
let (ciphertext, nonce) = cipher.encrypt(plaintext).unwrap();
assert_ne!(&ciphertext[..], plaintext);
let decrypted = cipher.decrypt(&ciphertext, &nonce).unwrap();
assert_eq!(&decrypted[..], plaintext);
}
#[test]
fn test_decrypt_wrong_nonce_fails() {
let cipher = Cipher::new(&test_key());
let plaintext = b"secret data";
let (ciphertext, _nonce) = cipher.encrypt(plaintext).unwrap();
let wrong_nonce = [0u8; NONCE_SIZE];
let result = cipher.decrypt(&ciphertext, &wrong_nonce);
assert!(result.is_err());
}
#[test]
fn test_each_encryption_unique_nonce() {
let cipher = Cipher::new(&test_key());
let plaintext = b"same text";
let (_, nonce1) = cipher.encrypt(plaintext).unwrap();
let (_, nonce2) = cipher.encrypt(plaintext).unwrap();
assert_ne!(nonce1, nonce2);
}
#[test]
fn test_empty_plaintext() {
let cipher = Cipher::new(&test_key());
let plaintext = b"";
let (ciphertext, nonce) = cipher.encrypt(plaintext).unwrap();
let decrypted = cipher.decrypt(&ciphertext, &nonce).unwrap();
assert_eq!(&decrypted[..], plaintext);
}
#[test]
fn test_large_plaintext() {
let cipher = Cipher::new(&test_key());
let plaintext = vec![0xabu8; 1024 * 1024]; let (ciphertext, nonce) = cipher.encrypt(&plaintext).unwrap();
let decrypted = cipher.decrypt(&ciphertext, &nonce).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_invalid_nonce_size() {
let cipher = Cipher::new(&test_key());
let (ciphertext, _) = cipher.encrypt(b"data").unwrap();
let short_nonce = [0u8; 8];
let result = cipher.decrypt(&ciphertext, &short_nonce);
assert!(matches!(result, Err(VaultError::CryptoError(_))));
}
#[test]
fn test_tampered_ciphertext_fails() {
let cipher = Cipher::new(&test_key());
let (mut ciphertext, nonce) = cipher.encrypt(b"secret").unwrap();
if let Some(byte) = ciphertext.first_mut() {
*byte ^= 0xff;
}
let result = cipher.decrypt(&ciphertext, &nonce);
assert!(result.is_err());
}
#[test]
fn test_encrypt_decrypt_with_aad_roundtrip() {
let cipher = Cipher::new(&test_key());
let plaintext = b"hello with aad";
let aad = b"context-binding-data";
let (ciphertext, nonce) = cipher.encrypt_with_aad(plaintext, aad).unwrap();
let decrypted = cipher.decrypt_with_aad(&ciphertext, &nonce, aad).unwrap();
assert_eq!(&decrypted[..], plaintext);
}
#[test]
fn test_wrong_aad_fails() {
let cipher = Cipher::new(&test_key());
let (ciphertext, nonce) = cipher.encrypt_with_aad(b"secret", b"correct-aad").unwrap();
let result = cipher.decrypt_with_aad(&ciphertext, &nonce, b"wrong-aad");
assert!(result.is_err());
}
#[test]
fn test_empty_aad_works() {
let cipher = Cipher::new(&test_key());
let (ciphertext, nonce) = cipher.encrypt_with_aad(b"data", b"").unwrap();
let decrypted = cipher.decrypt_with_aad(&ciphertext, &nonce, b"").unwrap();
assert_eq!(&decrypted[..], b"data");
}
#[test]
fn test_aad_legacy_fallback() {
let cipher = Cipher::new(&test_key());
let (ciphertext, nonce) = cipher.encrypt(b"old data").unwrap();
let decrypted = cipher
.decrypt_with_aad(&ciphertext, &nonce, b"any-aad")
.unwrap();
assert_eq!(&decrypted[..], b"old data");
}
#[test]
fn test_aad_legacy_fallback_when_first_byte_is_0x02() {
let cipher = Cipher::new(&test_key());
loop {
let (ciphertext, nonce) = cipher.encrypt(b"edge case").unwrap();
if ciphertext[0] == 0x02 {
let decrypted = cipher
.decrypt_with_aad(&ciphertext, &nonce, b"some-aad")
.unwrap();
assert_eq!(&decrypted[..], b"edge case");
break;
}
}
}
#[test]
fn test_aad_version_tag_present() {
let cipher = Cipher::new(&test_key());
let (ciphertext, _) = cipher.encrypt_with_aad(b"tagged", b"aad").unwrap();
assert_eq!(ciphertext[0], 0x02, "AAD ciphertext must start with 0x02");
}
#[test]
fn test_aad_tampered_ciphertext_fails() {
let cipher = Cipher::new(&test_key());
let (mut ciphertext, nonce) = cipher.encrypt_with_aad(b"secret", b"aad").unwrap();
if ciphertext.len() > 2 {
ciphertext[2] ^= 0xff;
}
let result = cipher.decrypt_with_aad(&ciphertext, &nonce, b"aad");
assert!(result.is_err());
}
#[test]
fn test_uses_hkdf_derived_key_not_raw() {
let master = test_key();
assert_ne!(
master.encryption_key(),
*master.as_bytes(),
"Cipher should use HKDF-derived key, not raw master key"
);
}
}