1use crate::error::{AllSourceError, Result};
9use aes_gcm::{
10 aead::{Aead, KeyInit, OsRng},
11 Aes256Gcm, Nonce,
12};
13use base64::{engine::general_purpose, Engine as _};
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct EncryptionConfig {
22 pub enabled: bool,
24
25 pub key_rotation_days: u32,
27
28 pub algorithm: EncryptionAlgorithm,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
33pub enum EncryptionAlgorithm {
34 Aes256Gcm,
35 ChaCha20Poly1305,
36}
37
38impl Default for EncryptionConfig {
39 fn default() -> Self {
40 Self {
41 enabled: true,
42 key_rotation_days: 90,
43 algorithm: EncryptionAlgorithm::Aes256Gcm,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct EncryptedData {
51 pub ciphertext: String,
53
54 pub nonce: String,
56
57 pub key_id: String,
59
60 pub algorithm: EncryptionAlgorithm,
62
63 pub version: u32,
65}
66
67#[derive(Debug, Clone)]
69struct DataEncryptionKey {
70 key_id: String,
71 key_bytes: Vec<u8>,
72 version: u32,
73 created_at: chrono::DateTime<chrono::Utc>,
74 active: bool,
75}
76
77pub struct FieldEncryption {
79 config: Arc<RwLock<EncryptionConfig>>,
80
81 deks: Arc<RwLock<HashMap<String, DataEncryptionKey>>>,
83
84 active_key_id: Arc<RwLock<Option<String>>>,
86}
87
88impl FieldEncryption {
89 pub fn new(config: EncryptionConfig) -> Result<Self> {
91 let manager = Self {
92 config: Arc::new(RwLock::new(config)),
93 deks: Arc::new(RwLock::new(HashMap::new())),
94 active_key_id: Arc::new(RwLock::new(None)),
95 };
96
97 manager.rotate_keys()?;
99
100 Ok(manager)
101 }
102
103 pub fn encrypt_string(&self, plaintext: &str, field_name: &str) -> Result<EncryptedData> {
105 if !self.config.read().enabled {
106 return Err(AllSourceError::ValidationError(
107 "Encryption is disabled".to_string(),
108 ));
109 }
110
111 let active_key_id = self.active_key_id.read();
112 let key_id = active_key_id
113 .as_ref()
114 .ok_or_else(|| AllSourceError::ValidationError("No active encryption key".to_string()))?
115 .clone();
116
117 let deks = self.deks.read();
118 let dek = deks.get(&key_id).ok_or_else(|| {
119 AllSourceError::ValidationError("Encryption key not found".to_string())
120 })?;
121
122 let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
124 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
125
126 let nonce_bytes = aes_gcm::aead::rand_core::RngCore::next_u64(&mut OsRng).to_le_bytes();
128 let mut nonce_array = [0u8; 12];
129 nonce_array[..8].copy_from_slice(&nonce_bytes);
130 let nonce = Nonce::from_slice(&nonce_array);
131
132 let ciphertext = cipher
134 .encrypt(nonce, plaintext.as_bytes())
135 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
136
137 Ok(EncryptedData {
138 ciphertext: general_purpose::STANDARD.encode(&ciphertext),
139 nonce: general_purpose::STANDARD.encode(nonce.as_slice()),
140 key_id: key_id.clone(),
141 algorithm: self.config.read().algorithm.clone(),
142 version: dek.version,
143 })
144 }
145
146 pub fn decrypt_string(&self, encrypted: &EncryptedData) -> Result<String> {
148 if !self.config.read().enabled {
149 return Err(AllSourceError::ValidationError(
150 "Encryption is disabled".to_string(),
151 ));
152 }
153
154 let deks = self.deks.read();
155 let dek = deks.get(&encrypted.key_id).ok_or_else(|| {
156 AllSourceError::ValidationError(format!(
157 "Encryption key {} not found",
158 encrypted.key_id
159 ))
160 })?;
161
162 let ciphertext = general_purpose::STANDARD
164 .decode(&encrypted.ciphertext)
165 .map_err(|e| {
166 AllSourceError::ValidationError(format!("Invalid ciphertext encoding: {e}"))
167 })?;
168
169 let nonce_bytes = general_purpose::STANDARD
170 .decode(&encrypted.nonce)
171 .map_err(|e| AllSourceError::ValidationError(format!("Invalid nonce encoding: {e}")))?;
172
173 let nonce = Nonce::from_slice(&nonce_bytes);
174
175 let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
177 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
178
179 let plaintext_bytes = cipher
180 .decrypt(nonce, ciphertext.as_ref())
181 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))?;
182
183 String::from_utf8(plaintext_bytes)
184 .map_err(|e| AllSourceError::ValidationError(format!("Invalid UTF-8: {e}")))
185 }
186
187 pub fn rotate_keys(&self) -> Result<()> {
189 let mut deks = self.deks.write();
190 let mut active_key_id = self.active_key_id.write();
191
192 let key_id = uuid::Uuid::new_v4().to_string();
194 let mut key_bytes = vec![0u8; 32]; aes_gcm::aead::rand_core::RngCore::fill_bytes(&mut OsRng, &mut key_bytes);
196
197 let version = deks.len() as u32 + 1;
198
199 let new_key = DataEncryptionKey {
200 key_id: key_id.clone(),
201 key_bytes,
202 version,
203 created_at: chrono::Utc::now(),
204 active: true,
205 };
206
207 for key in deks.values_mut() {
209 key.active = false;
210 }
211
212 deks.insert(key_id.clone(), new_key);
214 *active_key_id = Some(key_id);
215
216 Ok(())
217 }
218
219 pub fn get_stats(&self) -> EncryptionStats {
221 let deks = self.deks.read();
222 let active_key_id = self.active_key_id.read();
223
224 EncryptionStats {
225 enabled: self.config.read().enabled,
226 total_keys: deks.len(),
227 active_key_version: deks
228 .get(active_key_id.as_ref().unwrap_or(&String::new()))
229 .map(|k| k.version)
230 .unwrap_or(0),
231 algorithm: self.config.read().algorithm.clone(),
232 }
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct EncryptionStats {
238 pub enabled: bool,
239 pub total_keys: usize,
240 pub active_key_version: u32,
241 pub algorithm: EncryptionAlgorithm,
242}
243
244pub trait Encryptable {
246 fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData>;
247 fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self>
248 where
249 Self: Sized;
250}
251
252impl Encryptable for String {
253 fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData> {
254 encryption.encrypt_string(self, field_name)
255 }
256
257 fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self> {
258 encryption.decrypt_string(encrypted)
259 }
260}
261
262pub fn encrypt_json_value(
264 value: &serde_json::Value,
265 encryption: &FieldEncryption,
266 field_name: &str,
267) -> Result<EncryptedData> {
268 let json_string = serde_json::to_string(value)
269 .map_err(|e| AllSourceError::ValidationError(format!("JSON serialization failed: {e}")))?;
270
271 encryption.encrypt_string(&json_string, field_name)
272}
273
274pub fn decrypt_json_value(
276 encrypted: &EncryptedData,
277 encryption: &FieldEncryption,
278) -> Result<serde_json::Value> {
279 let json_string = encryption.decrypt_string(encrypted)?;
280
281 serde_json::from_str(&json_string)
282 .map_err(|e| AllSourceError::ValidationError(format!("JSON deserialization failed: {e}")))
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_encryption_creation() {
291 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
292 let stats = encryption.get_stats();
293
294 assert!(stats.enabled);
295 assert_eq!(stats.total_keys, 1);
296 assert_eq!(stats.active_key_version, 1);
297 }
298
299 #[test]
300 fn test_encrypt_decrypt_string() {
301 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
302 let plaintext = "sensitive data";
303
304 let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
305 assert_ne!(encrypted.ciphertext, plaintext);
306
307 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
308 assert_eq!(decrypted, plaintext);
309 }
310
311 #[test]
312 fn test_encrypt_decrypt_json() {
313 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
314 let value = serde_json::json!({
315 "username": "john_doe",
316 "ssn": "123-45-6789",
317 "credit_card": "4111-1111-1111-1111"
318 });
319
320 let encrypted = encrypt_json_value(&value, &encryption, "sensitive_data").unwrap();
321 let decrypted = decrypt_json_value(&encrypted, &encryption).unwrap();
322
323 assert_eq!(decrypted, value);
324 }
325
326 #[test]
327 fn test_key_rotation() {
328 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
329 let plaintext = "sensitive data";
330
331 let encrypted1 = encryption.encrypt_string(plaintext, "test").unwrap();
333 let key_id1 = encrypted1.key_id.clone();
334
335 encryption.rotate_keys().unwrap();
337
338 let encrypted2 = encryption.encrypt_string(plaintext, "test").unwrap();
340 let key_id2 = encrypted2.key_id.clone();
341
342 assert_ne!(key_id1, key_id2);
344 assert_eq!(encrypted2.version, 2);
345
346 let decrypted1 = encryption.decrypt_string(&encrypted1).unwrap();
348 assert_eq!(decrypted1, plaintext);
349
350 let decrypted2 = encryption.decrypt_string(&encrypted2).unwrap();
352 assert_eq!(decrypted2, plaintext);
353 }
354
355 #[test]
356 fn test_multiple_key_rotations() {
357 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
358 let plaintext = "test data";
359
360 let mut encrypted_data = Vec::new();
361
362 for _ in 0..5 {
364 let encrypted = encryption.encrypt_string(plaintext, "test").unwrap();
365 encrypted_data.push(encrypted);
366 encryption.rotate_keys().unwrap();
367 }
368
369 for encrypted in &encrypted_data {
371 let decrypted = encryption.decrypt_string(encrypted).unwrap();
372 assert_eq!(decrypted, plaintext);
373 }
374
375 let stats = encryption.get_stats();
376 assert_eq!(stats.total_keys, 6); assert_eq!(stats.active_key_version, 6);
378 }
379
380 #[test]
381 fn test_disabled_encryption() {
382 let config = EncryptionConfig {
383 enabled: false,
384 ..Default::default()
385 };
386
387 let encryption = FieldEncryption::new(config).unwrap();
388 let plaintext = "test";
389
390 let result = encryption.encrypt_string(plaintext, "test");
391 assert!(result.is_err());
392 }
393
394 #[test]
395 fn test_encryption_config_default() {
396 let config = EncryptionConfig::default();
397 assert!(config.enabled);
398 assert_eq!(config.key_rotation_days, 90);
399 assert_eq!(config.algorithm, EncryptionAlgorithm::Aes256Gcm);
400 }
401
402 #[test]
403 fn test_encryption_algorithm_equality() {
404 assert_eq!(
405 EncryptionAlgorithm::Aes256Gcm,
406 EncryptionAlgorithm::Aes256Gcm
407 );
408 assert_ne!(
409 EncryptionAlgorithm::Aes256Gcm,
410 EncryptionAlgorithm::ChaCha20Poly1305
411 );
412 }
413
414 #[test]
415 fn test_encryption_config_serde() {
416 let config = EncryptionConfig::default();
417 let json = serde_json::to_string(&config).unwrap();
418 let parsed: EncryptionConfig = serde_json::from_str(&json).unwrap();
419 assert_eq!(parsed.enabled, config.enabled);
420 assert_eq!(parsed.algorithm, config.algorithm);
421 }
422
423 #[test]
424 fn test_encrypted_data_serde() {
425 let encrypted = EncryptedData {
426 ciphertext: "encrypted_data".to_string(),
427 nonce: "nonce_value".to_string(),
428 key_id: "key-123".to_string(),
429 algorithm: EncryptionAlgorithm::Aes256Gcm,
430 version: 1,
431 };
432
433 let json = serde_json::to_string(&encrypted).unwrap();
434 let parsed: EncryptedData = serde_json::from_str(&json).unwrap();
435 assert_eq!(parsed.ciphertext, encrypted.ciphertext);
436 assert_eq!(parsed.key_id, encrypted.key_id);
437 assert_eq!(parsed.version, encrypted.version);
438 }
439
440 #[test]
441 fn test_encrypt_empty_string() {
442 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
443 let plaintext = "";
444
445 let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
446 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
447 assert_eq!(decrypted, plaintext);
448 }
449
450 #[test]
451 fn test_encrypt_long_string() {
452 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
453 let plaintext = "a".repeat(10000);
454
455 let encrypted = encryption.encrypt_string(&plaintext, "test_field").unwrap();
456 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
457 assert_eq!(decrypted, plaintext);
458 }
459
460 #[test]
461 fn test_encrypt_unicode_string() {
462 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
463 let plaintext = "日本語テスト 🎉 émojis";
464
465 let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
466 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
467 assert_eq!(decrypted, plaintext);
468 }
469
470 #[test]
471 fn test_encryption_stats() {
472 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
473
474 let stats = encryption.get_stats();
475 assert!(stats.enabled);
476 assert_eq!(stats.total_keys, 1);
477 assert_eq!(stats.active_key_version, 1);
478 assert_eq!(stats.algorithm, EncryptionAlgorithm::Aes256Gcm);
479 }
480
481 #[test]
482 fn test_decrypt_with_invalid_key() {
483 let encryption1 = FieldEncryption::new(EncryptionConfig::default()).unwrap();
484 let encryption2 = FieldEncryption::new(EncryptionConfig::default()).unwrap();
485
486 let plaintext = "test data";
487 let encrypted = encryption1.encrypt_string(plaintext, "test").unwrap();
488
489 let result = encryption2.decrypt_string(&encrypted);
491 assert!(result.is_err());
492 }
493
494 #[test]
495 fn test_encryption_different_fields() {
496 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
497
498 let data1 = "data for field 1";
499 let data2 = "data for field 2";
500
501 let encrypted1 = encryption.encrypt_string(data1, "field1").unwrap();
502 let encrypted2 = encryption.encrypt_string(data2, "field2").unwrap();
503
504 let encrypted1_again = encryption.encrypt_string(data1, "field1").unwrap();
506 assert_ne!(encrypted1.ciphertext, encrypted1_again.ciphertext);
507
508 assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), data1);
510 assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), data2);
511 }
512}