use crate::error::{Error, Result};
use aes_gcm::aead::rand_core::RngCore;
use aes_gcm::{
Aes256Gcm, Key, Nonce,
aead::{Aead, KeyInit, OsRng},
};
pub struct ContentEncryption {
key: [u8; 32],
}
impl ContentEncryption {
pub fn new(key: [u8; 32]) -> Self {
Self { key }
}
pub fn from_hex(hex_key: &str) -> Result<Self> {
let bytes =
hex::decode(hex_key).map_err(|e| Error::Validation(format!("invalid hex key: {e}")))?;
if bytes.len() != 32 {
return Err(Error::Validation(format!(
"key must be 32 bytes, got {}",
bytes.len()
)));
}
let mut key = [0u8; 32];
key.copy_from_slice(&bytes);
Ok(Self { key })
}
pub fn from_env() -> Result<Self> {
let hex_key = std::env::var("MNEMO_ENCRYPTION_KEY")
.map_err(|_| Error::Validation("MNEMO_ENCRYPTION_KEY not set".to_string()))?;
Self::from_hex(&hex_key)
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let key = Key::<Aes256Gcm>::from_slice(&self.key);
let cipher = Aes256Gcm::new(key);
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| Error::Internal(format!("encryption failed: {e}")))?;
let mut output = Vec::with_capacity(12 + ciphertext.len());
output.extend_from_slice(&nonce_bytes);
output.extend_from_slice(&ciphertext);
Ok(output)
}
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 28 {
return Err(Error::Validation("encrypted data too short".to_string()));
}
let key = Key::<Aes256Gcm>::from_slice(&self.key);
let cipher = Aes256Gcm::new(key);
let nonce = Nonce::from_slice(&data[..12]);
let ciphertext = &data[12..];
cipher
.decrypt(nonce, ciphertext)
.map_err(|_| Error::Validation("decryption tag mismatch".to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_round_trip() {
let key = [0x42u8; 32];
let enc = ContentEncryption::new(key);
let plaintext = b"Hello, encrypted world!";
let encrypted = enc.encrypt(plaintext).unwrap();
assert_ne!(&encrypted[12..encrypted.len() - 16], plaintext);
let decrypted = enc.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encryption_from_hex() {
let hex_key = "42".repeat(32);
let enc = ContentEncryption::from_hex(&hex_key).unwrap();
let plaintext = b"test";
let encrypted = enc.encrypt(plaintext).unwrap();
let decrypted = enc.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_invalid_hex_key_length() {
let result = ContentEncryption::from_hex("abcd");
assert!(result.is_err());
}
#[test]
fn test_tampered_ciphertext_fails() {
let key = [0x42u8; 32];
let enc = ContentEncryption::new(key);
let encrypted = enc.encrypt(b"secret data").unwrap();
let mut tampered = encrypted.clone();
tampered[15] ^= 0xff;
let result = enc.decrypt(&tampered);
assert!(result.is_err());
}
#[test]
fn test_aes_gcm_round_trip() {
let key = [0xABu8; 32];
let enc = ContentEncryption::new(key);
for size in [0, 1, 16, 100, 1024, 65536] {
let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
let encrypted = enc.encrypt(&plaintext).unwrap();
let decrypted = enc.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext, "round-trip failed for size {size}");
}
}
#[test]
fn test_aes_gcm_tamper_detection() {
let key = [0xCDu8; 32];
let enc = ContentEncryption::new(key);
let encrypted = enc.encrypt(b"sensitive data").unwrap();
let mut tampered = encrypted.clone();
tampered[0] ^= 0x01;
assert!(enc.decrypt(&tampered).is_err());
let mut tampered = encrypted.clone();
tampered[14] ^= 0x01;
assert!(enc.decrypt(&tampered).is_err());
let mut tampered = encrypted.clone();
let last = tampered.len() - 1;
tampered[last] ^= 0x01;
assert!(enc.decrypt(&tampered).is_err());
}
}