1use crate::error::{AllSourceError, Result};
9use aes_gcm::{
10 aead::{Aead, KeyInit, OsRng},
11 Aes256Gcm, Nonce,
12};
13use base64::{engine::general_purpose, Engine as _};
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct EncryptionConfig {
22 pub enabled: bool,
24
25 pub key_rotation_days: u32,
27
28 pub algorithm: EncryptionAlgorithm,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
33pub enum EncryptionAlgorithm {
34 Aes256Gcm,
35 ChaCha20Poly1305,
36}
37
38impl Default for EncryptionConfig {
39 fn default() -> Self {
40 Self {
41 enabled: true,
42 key_rotation_days: 90,
43 algorithm: EncryptionAlgorithm::Aes256Gcm,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct EncryptedData {
51 pub ciphertext: String,
53
54 pub nonce: String,
56
57 pub key_id: String,
59
60 pub algorithm: EncryptionAlgorithm,
62
63 pub version: u32,
65}
66
67#[derive(Debug, Clone)]
69struct DataEncryptionKey {
70 key_id: String,
71 key_bytes: Vec<u8>,
72 version: u32,
73 created_at: chrono::DateTime<chrono::Utc>,
74 active: bool,
75}
76
77pub struct FieldEncryption {
79 config: Arc<RwLock<EncryptionConfig>>,
80
81 deks: Arc<RwLock<HashMap<String, DataEncryptionKey>>>,
83
84 active_key_id: Arc<RwLock<Option<String>>>,
86}
87
88impl FieldEncryption {
89 pub fn new(config: EncryptionConfig) -> Result<Self> {
91 let manager = Self {
92 config: Arc::new(RwLock::new(config)),
93 deks: Arc::new(RwLock::new(HashMap::new())),
94 active_key_id: Arc::new(RwLock::new(None)),
95 };
96
97 manager.rotate_keys()?;
99
100 Ok(manager)
101 }
102
103 pub fn encrypt_string(&self, plaintext: &str, field_name: &str) -> Result<EncryptedData> {
105 if !self.config.read().enabled {
106 return Err(AllSourceError::ValidationError(
107 "Encryption is disabled".to_string(),
108 ));
109 }
110
111 let active_key_id = self.active_key_id.read();
112 let key_id = active_key_id
113 .as_ref()
114 .ok_or_else(|| AllSourceError::ValidationError("No active encryption key".to_string()))?
115 .clone();
116
117 let deks = self.deks.read();
118 let dek = deks.get(&key_id).ok_or_else(|| {
119 AllSourceError::ValidationError("Encryption key not found".to_string())
120 })?;
121
122 let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
124 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
125
126 let nonce_bytes = aes_gcm::aead::rand_core::RngCore::next_u64(&mut OsRng).to_le_bytes();
128 let mut nonce_array = [0u8; 12];
129 nonce_array[..8].copy_from_slice(&nonce_bytes);
130 let nonce = Nonce::from_slice(&nonce_array);
131
132 let ciphertext = cipher
134 .encrypt(nonce, plaintext.as_bytes())
135 .map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {}", e)))?;
136
137 Ok(EncryptedData {
138 ciphertext: general_purpose::STANDARD.encode(&ciphertext),
139 nonce: general_purpose::STANDARD.encode(nonce.as_slice()),
140 key_id: key_id.clone(),
141 algorithm: self.config.read().algorithm.clone(),
142 version: dek.version,
143 })
144 }
145
146 pub fn decrypt_string(&self, encrypted: &EncryptedData) -> Result<String> {
148 if !self.config.read().enabled {
149 return Err(AllSourceError::ValidationError(
150 "Encryption is disabled".to_string(),
151 ));
152 }
153
154 let deks = self.deks.read();
155 let dek = deks.get(&encrypted.key_id).ok_or_else(|| {
156 AllSourceError::ValidationError(format!(
157 "Encryption key {} not found",
158 encrypted.key_id
159 ))
160 })?;
161
162 let ciphertext = general_purpose::STANDARD
164 .decode(&encrypted.ciphertext)
165 .map_err(|e| {
166 AllSourceError::ValidationError(format!("Invalid ciphertext encoding: {}", e))
167 })?;
168
169 let nonce_bytes = general_purpose::STANDARD
170 .decode(&encrypted.nonce)
171 .map_err(|e| {
172 AllSourceError::ValidationError(format!("Invalid nonce encoding: {}", e))
173 })?;
174
175 let nonce = Nonce::from_slice(&nonce_bytes);
176
177 let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
179 .map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {}", e)))?;
180
181 let plaintext_bytes = cipher
182 .decrypt(nonce, ciphertext.as_ref())
183 .map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {}", e)))?;
184
185 String::from_utf8(plaintext_bytes)
186 .map_err(|e| AllSourceError::ValidationError(format!("Invalid UTF-8: {}", e)))
187 }
188
189 pub fn rotate_keys(&self) -> Result<()> {
191 let mut deks = self.deks.write();
192 let mut active_key_id = self.active_key_id.write();
193
194 let key_id = uuid::Uuid::new_v4().to_string();
196 let mut key_bytes = vec![0u8; 32]; aes_gcm::aead::rand_core::RngCore::fill_bytes(&mut OsRng, &mut key_bytes);
198
199 let version = deks.len() as u32 + 1;
200
201 let new_key = DataEncryptionKey {
202 key_id: key_id.clone(),
203 key_bytes,
204 version,
205 created_at: chrono::Utc::now(),
206 active: true,
207 };
208
209 for key in deks.values_mut() {
211 key.active = false;
212 }
213
214 deks.insert(key_id.clone(), new_key);
216 *active_key_id = Some(key_id);
217
218 Ok(())
219 }
220
221 pub fn get_stats(&self) -> EncryptionStats {
223 let deks = self.deks.read();
224 let active_key_id = self.active_key_id.read();
225
226 EncryptionStats {
227 enabled: self.config.read().enabled,
228 total_keys: deks.len(),
229 active_key_version: deks
230 .get(active_key_id.as_ref().unwrap_or(&String::new()))
231 .map(|k| k.version)
232 .unwrap_or(0),
233 algorithm: self.config.read().algorithm.clone(),
234 }
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct EncryptionStats {
240 pub enabled: bool,
241 pub total_keys: usize,
242 pub active_key_version: u32,
243 pub algorithm: EncryptionAlgorithm,
244}
245
246pub trait Encryptable {
248 fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData>;
249 fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self>
250 where
251 Self: Sized;
252}
253
254impl Encryptable for String {
255 fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData> {
256 encryption.encrypt_string(self, field_name)
257 }
258
259 fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self> {
260 encryption.decrypt_string(encrypted)
261 }
262}
263
264pub fn encrypt_json_value(
266 value: &serde_json::Value,
267 encryption: &FieldEncryption,
268 field_name: &str,
269) -> Result<EncryptedData> {
270 let json_string = serde_json::to_string(value).map_err(|e| {
271 AllSourceError::ValidationError(format!("JSON serialization failed: {}", e))
272 })?;
273
274 encryption.encrypt_string(&json_string, field_name)
275}
276
277pub fn decrypt_json_value(
279 encrypted: &EncryptedData,
280 encryption: &FieldEncryption,
281) -> Result<serde_json::Value> {
282 let json_string = encryption.decrypt_string(encrypted)?;
283
284 serde_json::from_str(&json_string)
285 .map_err(|e| AllSourceError::ValidationError(format!("JSON deserialization failed: {}", e)))
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_encryption_creation() {
294 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
295 let stats = encryption.get_stats();
296
297 assert!(stats.enabled);
298 assert_eq!(stats.total_keys, 1);
299 assert_eq!(stats.active_key_version, 1);
300 }
301
302 #[test]
303 fn test_encrypt_decrypt_string() {
304 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
305 let plaintext = "sensitive data";
306
307 let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
308 assert_ne!(encrypted.ciphertext, plaintext);
309
310 let decrypted = encryption.decrypt_string(&encrypted).unwrap();
311 assert_eq!(decrypted, plaintext);
312 }
313
314 #[test]
315 fn test_encrypt_decrypt_json() {
316 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
317 let value = serde_json::json!({
318 "username": "john_doe",
319 "ssn": "123-45-6789",
320 "credit_card": "4111-1111-1111-1111"
321 });
322
323 let encrypted = encrypt_json_value(&value, &encryption, "sensitive_data").unwrap();
324 let decrypted = decrypt_json_value(&encrypted, &encryption).unwrap();
325
326 assert_eq!(decrypted, value);
327 }
328
329 #[test]
330 fn test_key_rotation() {
331 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
332 let plaintext = "sensitive data";
333
334 let encrypted1 = encryption.encrypt_string(plaintext, "test").unwrap();
336 let key_id1 = encrypted1.key_id.clone();
337
338 encryption.rotate_keys().unwrap();
340
341 let encrypted2 = encryption.encrypt_string(plaintext, "test").unwrap();
343 let key_id2 = encrypted2.key_id.clone();
344
345 assert_ne!(key_id1, key_id2);
347 assert_eq!(encrypted2.version, 2);
348
349 let decrypted1 = encryption.decrypt_string(&encrypted1).unwrap();
351 assert_eq!(decrypted1, plaintext);
352
353 let decrypted2 = encryption.decrypt_string(&encrypted2).unwrap();
355 assert_eq!(decrypted2, plaintext);
356 }
357
358 #[test]
359 fn test_multiple_key_rotations() {
360 let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
361 let plaintext = "test data";
362
363 let mut encrypted_data = Vec::new();
364
365 for _ in 0..5 {
367 let encrypted = encryption.encrypt_string(plaintext, "test").unwrap();
368 encrypted_data.push(encrypted);
369 encryption.rotate_keys().unwrap();
370 }
371
372 for encrypted in &encrypted_data {
374 let decrypted = encryption.decrypt_string(encrypted).unwrap();
375 assert_eq!(decrypted, plaintext);
376 }
377
378 let stats = encryption.get_stats();
379 assert_eq!(stats.total_keys, 6); assert_eq!(stats.active_key_version, 6);
381 }
382
383 #[test]
384 fn test_disabled_encryption() {
385 let mut config = EncryptionConfig::default();
386 config.enabled = false;
387
388 let encryption = FieldEncryption::new(config).unwrap();
389 let plaintext = "test";
390
391 let result = encryption.encrypt_string(plaintext, "test");
392 assert!(result.is_err());
393 }
394}