use crate::error::{AllSourceError, Result};
use aes_gcm::{
Aes256Gcm, Nonce,
aead::{Aead, KeyInit, OsRng},
};
use base64::{Engine as _, engine::general_purpose};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionConfig {
pub enabled: bool,
pub key_rotation_days: u32,
pub algorithm: EncryptionAlgorithm,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum EncryptionAlgorithm {
Aes256Gcm,
ChaCha20Poly1305,
}
impl Default for EncryptionConfig {
fn default() -> Self {
Self {
enabled: true,
key_rotation_days: 90,
algorithm: EncryptionAlgorithm::Aes256Gcm,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedData {
pub ciphertext: String,
pub nonce: String,
pub key_id: String,
pub algorithm: EncryptionAlgorithm,
pub version: u32,
}
#[derive(Debug, Clone)]
struct DataEncryptionKey {
key_id: String,
key_bytes: Vec<u8>,
version: u32,
created_at: chrono::DateTime<chrono::Utc>,
active: bool,
}
pub struct FieldEncryption {
config: Arc<RwLock<EncryptionConfig>>,
deks: Arc<DashMap<String, DataEncryptionKey>>,
active_key_id: Arc<RwLock<Option<String>>>,
}
impl FieldEncryption {
pub fn new(config: EncryptionConfig) -> Result<Self> {
let manager = Self {
config: Arc::new(RwLock::new(config)),
deks: Arc::new(DashMap::new()),
active_key_id: Arc::new(RwLock::new(None)),
};
manager.rotate_keys()?;
Ok(manager)
}
pub fn encrypt_string(&self, plaintext: &str, field_name: &str) -> Result<EncryptedData> {
if !self.config.read().enabled {
return Err(AllSourceError::ValidationError(
"Encryption is disabled".to_string(),
));
}
let active_key_id = self.active_key_id.read();
let key_id = active_key_id
.as_ref()
.ok_or_else(|| AllSourceError::ValidationError("No active encryption key".to_string()))?
.clone();
let dek_ref = self.deks.get(&key_id).ok_or_else(|| {
AllSourceError::ValidationError("Encryption key not found".to_string())
})?;
let dek = dek_ref.value();
let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
let nonce_bytes = aes_gcm::aead::rand_core::RngCore::next_u64(&mut OsRng).to_le_bytes();
let mut nonce_array = [0u8; 12];
nonce_array[..8].copy_from_slice(&nonce_bytes);
let nonce = Nonce::from_slice(&nonce_array);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
Ok(EncryptedData {
ciphertext: general_purpose::STANDARD.encode(&ciphertext),
nonce: general_purpose::STANDARD.encode(nonce.as_slice()),
key_id: key_id.clone(),
algorithm: self.config.read().algorithm.clone(),
version: dek.version,
})
}
pub fn decrypt_string(&self, encrypted: &EncryptedData) -> Result<String> {
if !self.config.read().enabled {
return Err(AllSourceError::ValidationError(
"Encryption is disabled".to_string(),
));
}
let dek_ref = self.deks.get(&encrypted.key_id).ok_or_else(|| {
AllSourceError::ValidationError(format!(
"Encryption key {} not found",
encrypted.key_id
))
})?;
let dek = dek_ref.value();
let ciphertext = general_purpose::STANDARD
.decode(&encrypted.ciphertext)
.map_err(|e| {
AllSourceError::ValidationError(format!("Invalid ciphertext encoding: {e}"))
})?;
let nonce_bytes = general_purpose::STANDARD
.decode(&encrypted.nonce)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid nonce encoding: {e}")))?;
let nonce = Nonce::from_slice(&nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(&dek.key_bytes)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))?;
String::from_utf8(plaintext_bytes)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid UTF-8: {e}")))
}
pub fn rotate_keys(&self) -> Result<()> {
let mut active_key_id = self.active_key_id.write();
let key_id = uuid::Uuid::new_v4().to_string();
let mut key_bytes = vec![0u8; 32]; aes_gcm::aead::rand_core::RngCore::fill_bytes(&mut OsRng, &mut key_bytes);
let version = self.deks.len() as u32 + 1;
let new_key = DataEncryptionKey {
key_id: key_id.clone(),
key_bytes,
version,
created_at: chrono::Utc::now(),
active: true,
};
for mut entry in self.deks.iter_mut() {
entry.value_mut().active = false;
}
self.deks.insert(key_id.clone(), new_key);
*active_key_id = Some(key_id);
Ok(())
}
pub fn get_stats(&self) -> EncryptionStats {
let active_key_id = self.active_key_id.read();
EncryptionStats {
enabled: self.config.read().enabled,
total_keys: self.deks.len(),
active_key_version: active_key_id
.as_ref()
.and_then(|id| self.deks.get(id))
.map_or(0, |entry| entry.value().version),
algorithm: self.config.read().algorithm.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionStats {
pub enabled: bool,
pub total_keys: usize,
pub active_key_version: u32,
pub algorithm: EncryptionAlgorithm,
}
pub trait Encryptable {
fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData>;
fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self>
where
Self: Sized;
}
impl Encryptable for String {
fn encrypt(&self, encryption: &FieldEncryption, field_name: &str) -> Result<EncryptedData> {
encryption.encrypt_string(self, field_name)
}
fn decrypt(encrypted: &EncryptedData, encryption: &FieldEncryption) -> Result<Self> {
encryption.decrypt_string(encrypted)
}
}
pub fn encrypt_json_value(
value: &serde_json::Value,
encryption: &FieldEncryption,
field_name: &str,
) -> Result<EncryptedData> {
let json_string = serde_json::to_string(value)
.map_err(|e| AllSourceError::ValidationError(format!("JSON serialization failed: {e}")))?;
encryption.encrypt_string(&json_string, field_name)
}
pub fn decrypt_json_value(
encrypted: &EncryptedData,
encryption: &FieldEncryption,
) -> Result<serde_json::Value> {
let json_string = encryption.decrypt_string(encrypted)?;
serde_json::from_str(&json_string)
.map_err(|e| AllSourceError::ValidationError(format!("JSON deserialization failed: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_creation() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let stats = encryption.get_stats();
assert!(stats.enabled);
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.active_key_version, 1);
}
#[test]
fn test_encrypt_decrypt_string() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "sensitive data";
let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
assert_ne!(encrypted.ciphertext, plaintext);
let decrypted = encryption.decrypt_string(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_json() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let value = serde_json::json!({
"username": "john_doe",
"ssn": "123-45-6789",
"credit_card": "4111-1111-1111-1111"
});
let encrypted = encrypt_json_value(&value, &encryption, "sensitive_data").unwrap();
let decrypted = decrypt_json_value(&encrypted, &encryption).unwrap();
assert_eq!(decrypted, value);
}
#[test]
fn test_key_rotation() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "sensitive data";
let encrypted1 = encryption.encrypt_string(plaintext, "test").unwrap();
let key_id1 = encrypted1.key_id.clone();
encryption.rotate_keys().unwrap();
let encrypted2 = encryption.encrypt_string(plaintext, "test").unwrap();
let key_id2 = encrypted2.key_id.clone();
assert_ne!(key_id1, key_id2);
assert_eq!(encrypted2.version, 2);
let decrypted1 = encryption.decrypt_string(&encrypted1).unwrap();
assert_eq!(decrypted1, plaintext);
let decrypted2 = encryption.decrypt_string(&encrypted2).unwrap();
assert_eq!(decrypted2, plaintext);
}
#[test]
fn test_multiple_key_rotations() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "test data";
let mut encrypted_data = Vec::new();
for _ in 0..5 {
let encrypted = encryption.encrypt_string(plaintext, "test").unwrap();
encrypted_data.push(encrypted);
encryption.rotate_keys().unwrap();
}
for encrypted in &encrypted_data {
let decrypted = encryption.decrypt_string(encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
let stats = encryption.get_stats();
assert_eq!(stats.total_keys, 6); assert_eq!(stats.active_key_version, 6);
}
#[test]
fn test_disabled_encryption() {
let config = EncryptionConfig {
enabled: false,
..Default::default()
};
let encryption = FieldEncryption::new(config).unwrap();
let plaintext = "test";
let result = encryption.encrypt_string(plaintext, "test");
assert!(result.is_err());
}
#[test]
fn test_encryption_config_default() {
let config = EncryptionConfig::default();
assert!(config.enabled);
assert_eq!(config.key_rotation_days, 90);
assert_eq!(config.algorithm, EncryptionAlgorithm::Aes256Gcm);
}
#[test]
fn test_encryption_algorithm_equality() {
assert_eq!(
EncryptionAlgorithm::Aes256Gcm,
EncryptionAlgorithm::Aes256Gcm
);
assert_ne!(
EncryptionAlgorithm::Aes256Gcm,
EncryptionAlgorithm::ChaCha20Poly1305
);
}
#[test]
fn test_encryption_config_serde() {
let config = EncryptionConfig::default();
let json = serde_json::to_string(&config).unwrap();
let parsed: EncryptionConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.enabled, config.enabled);
assert_eq!(parsed.algorithm, config.algorithm);
}
#[test]
fn test_encrypted_data_serde() {
let encrypted = EncryptedData {
ciphertext: "encrypted_data".to_string(),
nonce: "nonce_value".to_string(),
key_id: "key-123".to_string(),
algorithm: EncryptionAlgorithm::Aes256Gcm,
version: 1,
};
let json = serde_json::to_string(&encrypted).unwrap();
let parsed: EncryptedData = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.ciphertext, encrypted.ciphertext);
assert_eq!(parsed.key_id, encrypted.key_id);
assert_eq!(parsed.version, encrypted.version);
}
#[test]
fn test_encrypt_empty_string() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "";
let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
let decrypted = encryption.decrypt_string(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_long_string() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "a".repeat(10000);
let encrypted = encryption.encrypt_string(&plaintext, "test_field").unwrap();
let decrypted = encryption.decrypt_string(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_unicode_string() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "日本語テスト 🎉 émojis";
let encrypted = encryption.encrypt_string(plaintext, "test_field").unwrap();
let decrypted = encryption.decrypt_string(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encryption_stats() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let stats = encryption.get_stats();
assert!(stats.enabled);
assert_eq!(stats.total_keys, 1);
assert_eq!(stats.active_key_version, 1);
assert_eq!(stats.algorithm, EncryptionAlgorithm::Aes256Gcm);
}
#[test]
fn test_decrypt_with_invalid_key() {
let encryption1 = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let encryption2 = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let plaintext = "test data";
let encrypted = encryption1.encrypt_string(plaintext, "test").unwrap();
let result = encryption2.decrypt_string(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_encryption_different_fields() {
let encryption = FieldEncryption::new(EncryptionConfig::default()).unwrap();
let data1 = "data for field 1";
let data2 = "data for field 2";
let encrypted1 = encryption.encrypt_string(data1, "field1").unwrap();
let encrypted2 = encryption.encrypt_string(data2, "field2").unwrap();
let encrypted1_again = encryption.encrypt_string(data1, "field1").unwrap();
assert_ne!(encrypted1.ciphertext, encrypted1_again.ciphertext);
assert_eq!(encryption.decrypt_string(&encrypted1).unwrap(), data1);
assert_eq!(encryption.decrypt_string(&encrypted2).unwrap(), data2);
}
}