1use crate::error::{EventError, Result};
7use aes_gcm::aead::{Aead, KeyInit, OsRng};
8use aes_gcm::{Aes256Gcm, AeadCore, Nonce};
9use base64::engine::general_purpose::STANDARD as BASE64;
10use base64::Engine;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::RwLock;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct EncryptedPayload {
19 pub key_id: String,
21
22 pub nonce: String,
24
25 pub ciphertext: String,
27
28 #[serde(default = "default_encrypted")]
30 pub encrypted: bool,
31}
32
33fn default_encrypted() -> bool {
34 true
35}
36
37impl EncryptedPayload {
38 pub fn is_encrypted(value: &serde_json::Value) -> bool {
40 value
41 .get("encrypted")
42 .and_then(|v| v.as_bool())
43 .unwrap_or(false)
44 }
45}
46
47pub trait EventEncryptor: Send + Sync {
49 fn encrypt(&self, payload: &serde_json::Value) -> Result<serde_json::Value>;
51
52 fn decrypt(&self, encrypted: &serde_json::Value) -> Result<serde_json::Value>;
54
55 fn active_key_id(&self) -> &str;
57}
58
59pub struct Aes256GcmEncryptor {
64 active_key_id: String,
66
67 keys: RwLock<HashMap<String, Aes256Gcm>>,
69}
70
71impl Aes256GcmEncryptor {
72 pub fn new(key_id: impl Into<String>, key: &[u8; 32]) -> Self {
76 let key_id = key_id.into();
77 let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key");
78 let mut keys = HashMap::new();
79 keys.insert(key_id.clone(), cipher);
80
81 Self {
82 active_key_id: key_id,
83 keys: RwLock::new(keys),
84 }
85 }
86
87 pub fn add_key(&self, key_id: impl Into<String>, key: &[u8; 32]) -> Result<()> {
91 let cipher = Aes256Gcm::new_from_slice(key).expect("32-byte key");
92 let mut keys = self.keys.write().map_err(|e| {
93 EventError::Config(format!("Failed to acquire key lock: {}", e))
94 })?;
95 keys.insert(key_id.into(), cipher);
96 Ok(())
97 }
98
99 pub fn rotate_to(&mut self, key_id: &str) -> Result<()> {
103 let keys = self.keys.read().map_err(|e| {
104 EventError::Config(format!("Failed to acquire key lock: {}", e))
105 })?;
106 if !keys.contains_key(key_id) {
107 return Err(EventError::Config(format!(
108 "Key '{}' not registered, add it first",
109 key_id
110 )));
111 }
112 self.active_key_id = key_id.to_string();
113 Ok(())
114 }
115
116 pub fn key_ids(&self) -> Vec<String> {
118 self.keys
119 .read()
120 .map(|keys| keys.keys().cloned().collect())
121 .unwrap_or_default()
122 }
123}
124
125impl EventEncryptor for Aes256GcmEncryptor {
126 fn encrypt(&self, payload: &serde_json::Value) -> Result<serde_json::Value> {
127 let plaintext = serde_json::to_vec(payload)?;
128
129 let keys = self.keys.read().map_err(|e| {
130 EventError::Config(format!("Failed to acquire key lock: {}", e))
131 })?;
132 let cipher = keys.get(&self.active_key_id).ok_or_else(|| {
133 EventError::Config(format!("Active key '{}' not found", self.active_key_id))
134 })?;
135
136 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
137 let ciphertext = cipher.encrypt(&nonce, plaintext.as_ref()).map_err(|e| {
138 EventError::Config(format!("Encryption failed: {}", e))
139 })?;
140
141 let envelope = EncryptedPayload {
142 key_id: self.active_key_id.clone(),
143 nonce: BASE64.encode(nonce),
144 ciphertext: BASE64.encode(ciphertext),
145 encrypted: true,
146 };
147
148 serde_json::to_value(envelope).map_err(Into::into)
149 }
150
151 fn decrypt(&self, encrypted: &serde_json::Value) -> Result<serde_json::Value> {
152 let envelope: EncryptedPayload = serde_json::from_value(encrypted.clone())?;
153
154 let keys = self.keys.read().map_err(|e| {
155 EventError::Config(format!("Failed to acquire key lock: {}", e))
156 })?;
157 let cipher = keys.get(&envelope.key_id).ok_or_else(|| {
158 EventError::Config(format!(
159 "Decryption key '{}' not registered",
160 envelope.key_id
161 ))
162 })?;
163
164 let nonce_bytes = BASE64.decode(&envelope.nonce).map_err(|e| {
165 EventError::Config(format!("Invalid nonce encoding: {}", e))
166 })?;
167 let nonce = Nonce::from_slice(&nonce_bytes);
168
169 let ciphertext = BASE64.decode(&envelope.ciphertext).map_err(|e| {
170 EventError::Config(format!("Invalid ciphertext encoding: {}", e))
171 })?;
172
173 let plaintext = cipher.decrypt(nonce, ciphertext.as_ref()).map_err(|e| {
174 EventError::Config(format!("Decryption failed: {}", e))
175 })?;
176
177 serde_json::from_slice(&plaintext).map_err(Into::into)
178 }
179
180 fn active_key_id(&self) -> &str {
181 &self.active_key_id
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 fn test_key() -> [u8; 32] {
190 [0x42; 32]
191 }
192
193 fn test_key_2() -> [u8; 32] {
194 [0x7A; 32]
195 }
196
197 #[test]
198 fn test_encrypt_decrypt_roundtrip() {
199 let enc = Aes256GcmEncryptor::new("key-1", &test_key());
200 let payload = serde_json::json!({"rate": 7.35, "currency": "USD/CNY"});
201
202 let encrypted = enc.encrypt(&payload).unwrap();
203 assert!(EncryptedPayload::is_encrypted(&encrypted));
204
205 let decrypted = enc.decrypt(&encrypted).unwrap();
206 assert_eq!(decrypted, payload);
207 }
208
209 #[test]
210 fn test_encrypted_payload_marker() {
211 let enc = Aes256GcmEncryptor::new("key-1", &test_key());
212 let encrypted = enc.encrypt(&serde_json::json!({"data": 1})).unwrap();
213
214 assert_eq!(encrypted["encrypted"], true);
215 assert!(encrypted["keyId"].is_string());
216 assert!(encrypted["nonce"].is_string());
217 assert!(encrypted["ciphertext"].is_string());
218 }
219
220 #[test]
221 fn test_is_encrypted_false_for_plain() {
222 let plain = serde_json::json!({"rate": 7.35});
223 assert!(!EncryptedPayload::is_encrypted(&plain));
224 }
225
226 #[test]
227 fn test_key_rotation() {
228 let mut enc = Aes256GcmEncryptor::new("key-1", &test_key());
229
230 let payload = serde_json::json!({"secret": "data"});
232 let encrypted_v1 = enc.encrypt(&payload).unwrap();
233
234 enc.add_key("key-2", &test_key_2()).unwrap();
236 enc.rotate_to("key-2").unwrap();
237 assert_eq!(enc.active_key_id(), "key-2");
238
239 let encrypted_v2 = enc.encrypt(&payload).unwrap();
241
242 assert_eq!(enc.decrypt(&encrypted_v1).unwrap(), payload);
244 assert_eq!(enc.decrypt(&encrypted_v2).unwrap(), payload);
245
246 assert_eq!(encrypted_v1["keyId"], "key-1");
248 assert_eq!(encrypted_v2["keyId"], "key-2");
249 }
250
251 #[test]
252 fn test_rotate_to_unknown_key_fails() {
253 let mut enc = Aes256GcmEncryptor::new("key-1", &test_key());
254 let result = enc.rotate_to("nonexistent");
255 assert!(result.is_err());
256 }
257
258 #[test]
259 fn test_decrypt_with_missing_key_fails() {
260 let enc1 = Aes256GcmEncryptor::new("key-1", &test_key());
261 let enc2 = Aes256GcmEncryptor::new("key-2", &test_key_2());
262
263 let encrypted = enc1.encrypt(&serde_json::json!({"data": 1})).unwrap();
264 let result = enc2.decrypt(&encrypted);
265 assert!(result.is_err());
266 }
267
268 #[test]
269 fn test_decrypt_with_wrong_key_fails() {
270 let enc1 = Aes256GcmEncryptor::new("key-1", &test_key());
271 let enc2 = Aes256GcmEncryptor::new("key-2", &test_key_2());
272 enc2.add_key("key-1", &[0xFF; 32]).unwrap();
274
275 let encrypted = enc1.encrypt(&serde_json::json!({"data": 1})).unwrap();
276 let result = enc2.decrypt(&encrypted);
277 assert!(result.is_err());
278 }
279
280 #[test]
281 fn test_key_ids() {
282 let enc = Aes256GcmEncryptor::new("key-1", &test_key());
283 enc.add_key("key-2", &test_key_2()).unwrap();
284
285 let mut ids = enc.key_ids();
286 ids.sort();
287 assert_eq!(ids, vec!["key-1", "key-2"]);
288 }
289
290 #[test]
291 fn test_encrypt_complex_payload() {
292 let enc = Aes256GcmEncryptor::new("key-1", &test_key());
293 let payload = serde_json::json!({
294 "user": "[email]",
295 "action": "login",
296 "nested": {"deep": [1, 2, 3]},
297 "tags": ["pii", "audit"]
298 });
299
300 let encrypted = enc.encrypt(&payload).unwrap();
301 let decrypted = enc.decrypt(&encrypted).unwrap();
302 assert_eq!(decrypted, payload);
303 }
304
305 #[test]
306 fn test_each_encryption_unique_nonce() {
307 let enc = Aes256GcmEncryptor::new("key-1", &test_key());
308 let payload = serde_json::json!({"data": "same"});
309
310 let e1 = enc.encrypt(&payload).unwrap();
311 let e2 = enc.encrypt(&payload).unwrap();
312
313 assert_ne!(e1["nonce"], e2["nonce"]);
315 assert_ne!(e1["ciphertext"], e2["ciphertext"]);
316 }
317}