1use crate::error::{Error, Result};
7
8use aes_gcm::aead::rand_core::RngCore;
9use aes_gcm::{
10 Aes256Gcm, Key, Nonce,
11 aead::{Aead, KeyInit, OsRng},
12};
13
14pub struct ContentEncryption {
16 key: [u8; 32],
17}
18
19impl ContentEncryption {
20 pub fn new(key: [u8; 32]) -> Self {
22 Self { key }
23 }
24
25 pub fn from_hex(hex_key: &str) -> Result<Self> {
27 let bytes =
28 hex::decode(hex_key).map_err(|e| Error::Validation(format!("invalid hex key: {e}")))?;
29 if bytes.len() != 32 {
30 return Err(Error::Validation(format!(
31 "key must be 32 bytes, got {}",
32 bytes.len()
33 )));
34 }
35 let mut key = [0u8; 32];
36 key.copy_from_slice(&bytes);
37 Ok(Self { key })
38 }
39
40 pub fn from_env() -> Result<Self> {
42 let hex_key = std::env::var("MNEMO_ENCRYPTION_KEY")
43 .map_err(|_| Error::Validation("MNEMO_ENCRYPTION_KEY not set".to_string()))?;
44 Self::from_hex(&hex_key)
45 }
46
47 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
51 let key = Key::<Aes256Gcm>::from_slice(&self.key);
52 let cipher = Aes256Gcm::new(key);
53
54 let mut nonce_bytes = [0u8; 12];
55 OsRng.fill_bytes(&mut nonce_bytes);
56 let nonce = Nonce::from_slice(&nonce_bytes);
57
58 let ciphertext = cipher
59 .encrypt(nonce, plaintext)
60 .map_err(|e| Error::Internal(format!("encryption failed: {e}")))?;
61
62 let mut output = Vec::with_capacity(12 + ciphertext.len());
63 output.extend_from_slice(&nonce_bytes);
64 output.extend_from_slice(&ciphertext);
65 Ok(output)
66 }
67
68 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
70 if data.len() < 28 {
71 return Err(Error::Validation("encrypted data too short".to_string()));
73 }
74
75 let key = Key::<Aes256Gcm>::from_slice(&self.key);
76 let cipher = Aes256Gcm::new(key);
77
78 let nonce = Nonce::from_slice(&data[..12]);
79 let ciphertext = &data[12..];
80
81 cipher
82 .decrypt(nonce, ciphertext)
83 .map_err(|_| Error::Validation("decryption tag mismatch".to_string()))
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_encryption_round_trip() {
93 let key = [0x42u8; 32];
94 let enc = ContentEncryption::new(key);
95
96 let plaintext = b"Hello, encrypted world!";
97 let encrypted = enc.encrypt(plaintext).unwrap();
98
99 assert_ne!(&encrypted[12..encrypted.len() - 16], plaintext);
100
101 let decrypted = enc.decrypt(&encrypted).unwrap();
102 assert_eq!(decrypted, plaintext);
103 }
104
105 #[test]
106 fn test_encryption_from_hex() {
107 let hex_key = "42".repeat(32);
108 let enc = ContentEncryption::from_hex(&hex_key).unwrap();
109
110 let plaintext = b"test";
111 let encrypted = enc.encrypt(plaintext).unwrap();
112 let decrypted = enc.decrypt(&encrypted).unwrap();
113 assert_eq!(decrypted, plaintext);
114 }
115
116 #[test]
117 fn test_invalid_hex_key_length() {
118 let result = ContentEncryption::from_hex("abcd");
119 assert!(result.is_err());
120 }
121
122 #[test]
123 fn test_tampered_ciphertext_fails() {
124 let key = [0x42u8; 32];
125 let enc = ContentEncryption::new(key);
126
127 let encrypted = enc.encrypt(b"secret data").unwrap();
128 let mut tampered = encrypted.clone();
129 tampered[15] ^= 0xff; let result = enc.decrypt(&tampered);
132 assert!(result.is_err());
133 }
134
135 #[test]
136 fn test_aes_gcm_round_trip() {
137 let key = [0xABu8; 32];
138 let enc = ContentEncryption::new(key);
139
140 for size in [0, 1, 16, 100, 1024, 65536] {
142 let plaintext: Vec<u8> = (0..size).map(|i| (i % 256) as u8).collect();
143 let encrypted = enc.encrypt(&plaintext).unwrap();
144 let decrypted = enc.decrypt(&encrypted).unwrap();
145 assert_eq!(decrypted, plaintext, "round-trip failed for size {size}");
146 }
147 }
148
149 #[test]
150 fn test_aes_gcm_tamper_detection() {
151 let key = [0xCDu8; 32];
152 let enc = ContentEncryption::new(key);
153 let encrypted = enc.encrypt(b"sensitive data").unwrap();
154
155 let mut tampered = encrypted.clone();
157 tampered[0] ^= 0x01;
158 assert!(enc.decrypt(&tampered).is_err());
159
160 let mut tampered = encrypted.clone();
162 tampered[14] ^= 0x01;
163 assert!(enc.decrypt(&tampered).is_err());
164
165 let mut tampered = encrypted.clone();
167 let last = tampered.len() - 1;
168 tampered[last] ^= 0x01;
169 assert!(enc.decrypt(&tampered).is_err());
170 }
171}