1use crate::error::{AllSourceError, Result};
10use aes_gcm::{
11 aead::{Aead, KeyInit, OsRng},
12 Aes256Gcm, Nonce,
13};
14use base64::{Engine as _, engine::general_purpose};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18use parking_lot::RwLock;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EncryptionConfig {
23 pub enabled: bool,
25
26 pub key_rotation_days: u32,
28
29 pub algorithm: EncryptionAlgorithm,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34pub enum EncryptionAlgorithm {
35 Aes256Gcm,
36 ChaCha20Poly1305,
37}
38
39impl Default for EncryptionConfig {
40 fn default() -> Self {
41 Self {
42 enabled: true,
43 key_rotation_days: 90,
44 algorithm: EncryptionAlgorithm::Aes256Gcm,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct EncryptedData {
52 pub ciphertext: String,
54
55 pub nonce: String,
57
58 pub key_id: String,
60
61 pub algorithm: EncryptionAlgorithm,
63
64 pub version: u32,
66}
67
68#[derive(Debug, Clone)]
70struct DataEncryptionKey {
71 key_id: String,
72 key_bytes: Vec<u8>,
73 version: u32,
74 created_at: chrono::DateTime<chrono::Utc>,
75 active: bool,
76}
77
78pub struct FieldEncryption {
80 config: Arc<RwLock<EncryptionConfig>>,
81
82 deks: Arc<RwLock<HashMap<String, DataEncryptionKey>>>,
84
85 active_key_id: Arc<RwLock<Option<String>>>,
87}
88
89impl FieldEncryption {
90 pub fn new(config: EncryptionConfig) -> Result<Self> {
92 let manager = Self {
93 config: Arc::new(RwLock::new(config)),
94 deks: Arc::new(RwLock::new(HashMap::new())),
95 active_key_id: Arc::new(RwLock::new(None)),
96 };
97
98 manager.rotate_keys()?;
100
101 Ok(manager)
102 }
103
104 pub fn encrypt_string(&self, plaintext: &str, field_name: &str) -> Result<EncryptedData> {
106 if !self.config.read().enabled {
107 return Err(AllSourceError::ValidationError(
108 "Encryption is disabled".to_string(),
109 ));
110 }
111
112 let active_key_id = self.active_key_id.read();
113 let key_id = active_key_id.as_ref()
114 .ok_or_else(|| AllSourceError::ValidationError("No active encryption key".to_string()))?
115 .clone();
116
117 let deks = self.deks.read();
118 let dek = deks.get(&key_id)
119 .ok_or_else(|| AllSourceError::ValidationError("Encryption key not found".to_string()))?;
120
121 let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
123 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
124
125 let nonce_bytes = aes_gcm::aead::rand_core::RngCore::next_u64(&mut OsRng).to_le_bytes();
127 let mut nonce_array = [0u8; 12];
128 nonce_array[..8].copy_from_slice(&nonce_bytes);
129 let nonce = Nonce::from_slice(&nonce_array);
130
131 let ciphertext = cipher.encrypt(nonce, plaintext.as_bytes())
133 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {}", e)))?;
134
135 Ok(EncryptedData {
136 ciphertext: general_purpose::STANDARD.encode(&ciphertext),
137 nonce: general_purpose::STANDARD.encode(nonce.as_slice()),
138 key_id: key_id.clone(),
139 algorithm: self.config.read().algorithm.clone(),
140 version: dek.version,
141 })
142 }
143
144 pub fn decrypt_string(&self, encrypted: &EncryptedData) -> Result<String> {
146 if !self.config.read().enabled {
147 return Err(AllSourceError::ValidationError(
148 "Encryption is disabled".to_string(),
149 ));
150 }
151
152 let deks = self.deks.read();
153 let dek = deks.get(&encrypted.key_id)
154 .ok_or_else(|| AllSourceError::ValidationError(
155 format!("Encryption key {} not found", encrypted.key_id)
156 ))?;
157
158 let ciphertext = general_purpose::STANDARD.decode(&encrypted.ciphertext)
160 .map_err(|e| AllSourceError::ValidationError(format!("Invalid ciphertext encoding: {}", e)))?;
161
162 let nonce_bytes = general_purpose::STANDARD.decode(&encrypted.nonce)
163 .map_err(|e| AllSourceError::ValidationError(format!("Invalid nonce encoding: {}", e)))?;
164
165 let nonce = Nonce::from_slice(&nonce_bytes);
166
167 let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
169 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
170
171 let plaintext_bytes = cipher.decrypt(nonce, ciphertext.as_ref())
172 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {}", e)))?;
173
174 String::from_utf8(plaintext_bytes)
175 .map_err(|e| AllSourceError::ValidationError(format!("Invalid UTF-8: {}", e)))
176 }
177
178 pub fn rotate_keys(&self) -> Result<()> {
180 let mut deks = self.deks.write();
181 let mut active_key_id = self.active_key_id.write();
182
183 let key_id = uuid::Uuid::new_v4().to_string();
185 let mut key_bytes = vec![0u8; 32]; aes_gcm::aead::rand_core::RngCore::fill_bytes(&mut OsRng, &mut key_bytes);
187
188 let version = deks.len() as u32 + 1;
189
190 let new_key = DataEncryptionKey {
191 key_id: key_id.clone(),
192 key_bytes,
193 version,
194 created_at: chrono::Utc::now(),
195 active: true,
196 };
197
198 for key in deks.values_mut() {
200 key.active = false;
201 }
202
203 deks.insert(key_id.clone(), new_key);
205 *active_key_id = Some(key_id);
206
207 Ok(())
208 }
209
210 pub fn get_stats(&self) -> EncryptionStats {
212 let deks = self.deks.read();
213 let active_key_id = self.active_key_id.read();
214
215 EncryptionStats {
216 enabled: self.config.read().enabled,
217 total_keys: deks.len(),
218 active_key_version: deks.get(active_key_id.as_ref().unwrap_or(&String::new()))
219 .map(|k| k.version)
220 .unwrap_or(0),
221 algorithm: self.config.read().algorithm.clone(),
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct EncryptionStats {
228 pub enabled: bool,
229 pub total_keys: usize,
230 pub active_key_version: u32,
231 pub algorithm: EncryptionAlgorithm,
232}
233
234pub trait Encryptable {
236 fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData>;
237 fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self>
238 where
239 Self: Sized;
240}
241
242impl Encryptable for String {
243 fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData> {
244 encryption.encrypt_string(self, field_name)
245 }
246
247 fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self> {
248 encryption.decrypt_string(encrypted)
249 }
250}
251
252pub fn encrypt_json_value(
254 value: &serde_json::Value,
255 encryption: &FieldEncryption,
256 field_name: &str,
257) -> Result<EncryptedData> {
258 let json_string = serde_json::to_string(value)
259 .map_err(|e| AllSourceError::ValidationError(format!("JSON serialization failed: {}", e)))?;
260
261 encryption.encrypt_string(&json_string, field_name)
262}
263
264pub fn decrypt_json_value(
266 encrypted: &EncryptedData,
267 encryption: &FieldEncryption,
268) -> Result<serde_json::Value> {
269 let json_string = encryption.decrypt_string(encrypted)?;
270
271 serde_json::from_str(&json_string)
272 .map_err(|e| AllSourceError::ValidationError(format!("JSON deserialization failed: {}", e)))
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_encryption_creation() {
281 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
282 let stats = encryption.get_stats();
283
284 assert!(stats.enabled);
285 assert_eq!(stats.total_keys, 1);
286 assert_eq!(stats.active_key_version, 1);
287 }
288
289 #[test]
290 fn test_encrypt_decrypt_string() {
291 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
292 let plaintext = "sensitive data";
293
294 let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
295 assert_ne!(encrypted.ciphertext, plaintext);
296
297 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
298 assert_eq!(decrypted, plaintext);
299 }
300
301 #[test]
302 fn test_encrypt_decrypt_json() {
303 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
304 let value = serde_json::json!({
305 "username": "john_doe",
306 "ssn": "123-45-6789",
307 "credit_card": "4111-1111-1111-1111"
308 });
309
310 let encrypted = encrypt_json_value(&value, &encryption, "sensitive_data").unwrap();
311 let decrypted = decrypt_json_value(&encrypted, &encryption).unwrap();
312
313 assert_eq!(decrypted, value);
314 }
315
316 #[test]
317 fn test_key_rotation() {
318 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
319 let plaintext = "sensitive data";
320
321 let encrypted1 = encryption.encrypt_string(plaintext, "test").unwrap();
323 let key_id1 = encrypted1.key_id.clone();
324
325 encryption.rotate_keys().unwrap();
327
328 let encrypted2 = encryption.encrypt_string(plaintext, "test").unwrap();
330 let key_id2 = encrypted2.key_id.clone();
331
332 assert_ne!(key_id1, key_id2);
334 assert_eq!(encrypted2.version, 2);
335
336 let decrypted1 = encryption.decrypt_string(&encrypted1).unwrap();
338 assert_eq!(decrypted1, plaintext);
339
340 let decrypted2 = encryption.decrypt_string(&encrypted2).unwrap();
342 assert_eq!(decrypted2, plaintext);
343 }
344
345 #[test]
346 fn test_multiple_key_rotations() {
347 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
348 let plaintext = "test data";
349
350 let mut encrypted_data = Vec::new();
351
352 for _ in 0..5 {
354 let encrypted = encryption.encrypt_string(plaintext, "test").unwrap();
355 encrypted_data.push(encrypted);
356 encryption.rotate_keys().unwrap();
357 }
358
359 for encrypted in &encrypted_data {
361 let decrypted = encryption.decrypt_string(encrypted).unwrap();
362 assert_eq!(decrypted, plaintext);
363 }
364
365 let stats = encryption.get_stats();
366 assert_eq!(stats.total_keys, 6); assert_eq!(stats.active_key_version, 6);
368 }
369
370 #[test]
371 fn test_disabled_encryption() {
372 let mut config = EncryptionConfig::default();
373 config.enabled = false;
374
375 let encryption = FieldEncryption::new(config).unwrap();
376 let plaintext = "test";
377
378 let result = encryption.encrypt_string(plaintext, "test");
379 assert!(result.is_err());
380 }
381}