use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use argon2::{password_hash::rand_core::RngCore, Algorithm, Argon2, Params, Version};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use chacha20poly1305::{ChaCha20Poly1305, Key as ChaChaKey, Nonce as ChaChaNonce};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use zeroize::Zeroizing;
#[derive(Debug, Error)]
pub enum EncryptionError {
#[error("Encryption failed: {0}")]
EncryptionFailed(String),
#[error("Decryption failed: {0}")]
DecryptionFailed(String),
#[error("Key derivation failed: {0}")]
KeyDerivationFailed(String),
#[error("Invalid key length: expected {expected}, got {actual}")]
InvalidKeyLength { expected: usize, actual: usize },
#[error("Invalid nonce length: expected {expected}, got {actual}")]
InvalidNonceLength { expected: usize, actual: usize },
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("Invalid ciphertext format")]
InvalidCiphertextFormat,
#[error("Key rotation failed: {0}")]
KeyRotationFailed(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EncryptionAlgorithm {
Aes256Gcm,
ChaCha20Poly1305,
}
impl Default for EncryptionAlgorithm {
fn default() -> Self {
Self::Aes256Gcm
}
}
impl std::fmt::Display for EncryptionAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Aes256Gcm => write!(f, "AES-256-GCM"),
Self::ChaCha20Poly1305 => write!(f, "ChaCha20-Poly1305"),
}
}
}
#[derive(Clone)]
pub struct EncryptionKey {
key_data: Arc<Zeroizing<Vec<u8>>>,
algorithm: EncryptionAlgorithm,
key_id: String,
created_at: chrono::DateTime<chrono::Utc>,
}
impl EncryptionKey {
pub fn from_bytes(
key_bytes: &[u8],
algorithm: EncryptionAlgorithm,
) -> Result<Self, EncryptionError> {
let expected_len = Self::key_length(algorithm);
if key_bytes.len() != expected_len {
return Err(EncryptionError::InvalidKeyLength {
expected: expected_len,
actual: key_bytes.len(),
});
}
Ok(Self {
key_data: Arc::new(Zeroizing::new(key_bytes.to_vec())),
algorithm,
key_id: generate_key_id(),
created_at: chrono::Utc::now(),
})
}
pub fn generate(algorithm: EncryptionAlgorithm) -> Result<Self, EncryptionError> {
let key_len = Self::key_length(algorithm);
let mut key_bytes = vec![0u8; key_len];
OsRng.fill_bytes(&mut key_bytes);
Self::from_bytes(&key_bytes, algorithm)
}
pub fn derive_from_password(
password: &str,
salt: &[u8],
algorithm: EncryptionAlgorithm,
) -> Result<Self, EncryptionError> {
let key_len = Self::key_length(algorithm);
let params = Params::new(
Params::DEFAULT_M_COST,
Params::DEFAULT_T_COST,
Params::DEFAULT_P_COST,
Some(key_len),
)
.map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
let mut key_bytes = vec![0u8; key_len];
argon2
.hash_password_into(password.as_bytes(), salt, &mut key_bytes)
.map_err(|e| EncryptionError::KeyDerivationFailed(e.to_string()))?;
Self::from_bytes(&key_bytes, algorithm)
}
pub fn key_length(algorithm: EncryptionAlgorithm) -> usize {
match algorithm {
EncryptionAlgorithm::Aes256Gcm => 32,
EncryptionAlgorithm::ChaCha20Poly1305 => 32,
}
}
pub fn nonce_length(algorithm: EncryptionAlgorithm) -> usize {
match algorithm {
EncryptionAlgorithm::Aes256Gcm => 12,
EncryptionAlgorithm::ChaCha20Poly1305 => 12,
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.key_data
}
pub fn key_id(&self) -> &str {
&self.key_id
}
pub fn algorithm(&self) -> EncryptionAlgorithm {
self.algorithm
}
pub fn created_at(&self) -> chrono::DateTime<chrono::Utc> {
self.created_at
}
}
impl std::fmt::Debug for EncryptionKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionKey")
.field("key_id", &self.key_id)
.field("algorithm", &self.algorithm)
.field("created_at", &self.created_at)
.field("key_data", &"<redacted>")
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedData {
pub algorithm: EncryptionAlgorithm,
pub nonce: String,
pub ciphertext: String,
pub key_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub associated_data: Option<String>,
}
impl EncryptedData {
pub fn to_string(&self) -> Result<String, EncryptionError> {
serde_json::to_string(self).map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))
}
pub fn from_string(s: &str) -> Result<Self, EncryptionError> {
serde_json::from_str(s).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
}
}
#[derive(Debug, Clone)]
pub struct EncryptionConfig {
pub default_algorithm: EncryptionAlgorithm,
pub auto_key_rotation: bool,
pub key_validity_secs: u64,
pub max_keys: usize,
}
impl Default for EncryptionConfig {
fn default() -> Self {
Self {
default_algorithm: EncryptionAlgorithm::Aes256Gcm,
auto_key_rotation: false,
key_validity_secs: 86400 * 90, max_keys: 100,
}
}
}
pub struct EncryptionEngine {
keys: RwLock<HashMap<String, EncryptionKey>>,
default_key_id: RwLock<Option<String>>,
config: EncryptionConfig,
}
impl EncryptionEngine {
pub fn new() -> Self {
Self::with_config(EncryptionConfig::default())
}
pub fn with_config(config: EncryptionConfig) -> Self {
Self {
keys: RwLock::new(HashMap::new()),
default_key_id: RwLock::new(None),
config,
}
}
pub fn generate_key(&self) -> Result<String, EncryptionError> {
self.generate_key_with_algorithm(self.config.default_algorithm)
}
pub fn generate_key_with_algorithm(
&self,
algorithm: EncryptionAlgorithm,
) -> Result<String, EncryptionError> {
let key = EncryptionKey::generate(algorithm)?;
let key_id = key.key_id().to_string();
let mut keys = self.keys.write();
if keys.len() >= self.config.max_keys {
if let Some((oldest_id, _)) = keys
.iter()
.min_by_key(|(_, k)| k.created_at())
.map(|(id, k)| (id.clone(), k.created_at()))
{
keys.remove(&oldest_id);
}
}
keys.insert(key_id.clone(), key);
*self.default_key_id.write() = Some(key_id.clone());
tracing::info!("Generated new encryption key: {}", key_id);
Ok(key_id)
}
pub fn derive_key_from_password(
&self,
password: &str,
salt: &[u8],
) -> Result<String, EncryptionError> {
let key =
EncryptionKey::derive_from_password(password, salt, self.config.default_algorithm)?;
let key_id = key.key_id().to_string();
self.keys.write().insert(key_id.clone(), key);
*self.default_key_id.write() = Some(key_id.clone());
tracing::info!("Derived encryption key from password: {}", key_id);
Ok(key_id)
}
pub fn add_key(&self, key: EncryptionKey) -> Result<String, EncryptionError> {
let key_id = key.key_id().to_string();
self.keys.write().insert(key_id.clone(), key);
*self.default_key_id.write() = Some(key_id.clone());
Ok(key_id)
}
pub fn get_key(&self, key_id: &str) -> Result<EncryptionKey, EncryptionError> {
self.keys
.read()
.get(key_id)
.cloned()
.ok_or_else(|| EncryptionError::KeyNotFound(key_id.to_string()))
}
pub fn remove_key(&self, key_id: &str) -> Result<bool, EncryptionError> {
let removed = self.keys.write().remove(key_id).is_some();
if removed {
let mut default_id = self.default_key_id.write();
if default_id.as_deref() == Some(key_id) {
*default_id = None;
}
}
Ok(removed)
}
pub fn list_keys(&self) -> Vec<String> {
self.keys.read().keys().cloned().collect()
}
pub fn get_default_key(&self) -> Result<EncryptionKey, EncryptionError> {
let default_id = self
.default_key_id
.read()
.clone()
.ok_or_else(|| EncryptionError::KeyNotFound("default".to_string()))?;
self.get_key(&default_id)
}
pub fn set_default_key(&self, key_id: &str) -> Result<(), EncryptionError> {
if !self.keys.read().contains_key(key_id) {
return Err(EncryptionError::KeyNotFound(key_id.to_string()));
}
*self.default_key_id.write() = Some(key_id.to_string());
Ok(())
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, EncryptionError> {
let key = self.get_default_key()?;
self.encrypt_with_key(&key, plaintext)
}
pub fn encrypt_with_key(
&self,
key: &EncryptionKey,
plaintext: &[u8],
) -> Result<EncryptedData, EncryptionError> {
let nonce_len = EncryptionKey::nonce_length(key.algorithm);
let mut nonce_bytes = vec![0u8; nonce_len];
OsRng.fill_bytes(&mut nonce_bytes);
let ciphertext = match key.algorithm {
EncryptionAlgorithm::Aes256Gcm => {
let cipher_key = Key::<Aes256Gcm>::from_slice(key.as_bytes());
let cipher = Aes256Gcm::new(cipher_key);
let nonce = Nonce::from_slice(&nonce_bytes);
cipher
.encrypt(nonce, plaintext)
.map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
let cipher_key = ChaChaKey::from_slice(key.as_bytes());
let cipher = ChaCha20Poly1305::new(cipher_key);
let nonce = ChaChaNonce::from_slice(&nonce_bytes);
cipher
.encrypt(nonce, plaintext)
.map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?
}
};
Ok(EncryptedData {
algorithm: key.algorithm,
nonce: BASE64.encode(&nonce_bytes),
ciphertext: BASE64.encode(&ciphertext),
key_id: key.key_id().to_string(),
associated_data: None,
})
}
pub fn encrypt_string(&self, plaintext: &str) -> Result<EncryptedData, EncryptionError> {
self.encrypt(plaintext.as_bytes())
}
pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, EncryptionError> {
let key = self.get_key(&encrypted.key_id)?;
self.decrypt_with_key(&key, encrypted)
}
pub fn decrypt_with_key(
&self,
key: &EncryptionKey,
encrypted: &EncryptedData,
) -> Result<Vec<u8>, EncryptionError> {
let nonce_bytes = BASE64
.decode(&encrypted.nonce)
.map_err(|_| EncryptionError::InvalidCiphertextFormat)?;
let ciphertext = BASE64
.decode(&encrypted.ciphertext)
.map_err(|_| EncryptionError::InvalidCiphertextFormat)?;
let expected_nonce_len = EncryptionKey::nonce_length(key.algorithm);
if nonce_bytes.len() != expected_nonce_len {
return Err(EncryptionError::InvalidNonceLength {
expected: expected_nonce_len,
actual: nonce_bytes.len(),
});
}
let plaintext = match key.algorithm {
EncryptionAlgorithm::Aes256Gcm => {
let cipher_key = Key::<Aes256Gcm>::from_slice(key.as_bytes());
let cipher = Aes256Gcm::new(cipher_key);
let nonce = Nonce::from_slice(&nonce_bytes);
cipher
.decrypt(nonce, ciphertext.as_slice())
.map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?
}
EncryptionAlgorithm::ChaCha20Poly1305 => {
let cipher_key = ChaChaKey::from_slice(key.as_bytes());
let cipher = ChaCha20Poly1305::new(cipher_key);
let nonce = ChaChaNonce::from_slice(&nonce_bytes);
cipher
.decrypt(nonce, ciphertext.as_slice())
.map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))?
}
};
Ok(plaintext)
}
pub fn decrypt_to_string(&self, encrypted: &EncryptedData) -> Result<String, EncryptionError> {
let plaintext = self.decrypt(encrypted)?;
String::from_utf8(plaintext).map_err(|e| EncryptionError::DecryptionFailed(e.to_string()))
}
pub fn rotate_key(&self, old_key_id: &str) -> Result<String, EncryptionError> {
if !self.keys.read().contains_key(old_key_id) {
return Err(EncryptionError::KeyNotFound(old_key_id.to_string()));
}
let new_key_id = self.generate_key()?;
tracing::info!("Key rotated: {} -> {}", old_key_id, new_key_id);
Ok(new_key_id)
}
pub fn get_keys_requiring_rotation(&self) -> Vec<String> {
let now = chrono::Utc::now();
let validity = chrono::Duration::seconds(self.config.key_validity_secs as i64);
self.keys
.read()
.iter()
.filter(|(_, key)| {
let age = now.signed_duration_since(key.created_at());
age > validity
})
.map(|(id, _)| id.clone())
.collect()
}
pub fn export_key(&self, key_id: &str) -> Result<String, EncryptionError> {
let key = self.get_key(key_id)?;
Ok(BASE64.encode(key.as_bytes()))
}
pub fn import_key(
&self,
key_b64: &str,
algorithm: EncryptionAlgorithm,
) -> Result<String, EncryptionError> {
let key_bytes = BASE64
.decode(key_b64)
.map_err(|_| EncryptionError::InvalidCiphertextFormat)?;
let key = EncryptionKey::from_bytes(&key_bytes, algorithm)?;
self.add_key(key)
}
pub fn key_count(&self) -> usize {
self.keys.read().len()
}
pub fn has_default_key(&self) -> bool {
self.default_key_id.read().is_some()
}
}
impl Default for EncryptionEngine {
fn default() -> Self {
Self::new()
}
}
fn generate_key_id() -> String {
use uuid::Uuid;
format!("key_{}", Uuid::new_v4())
}
pub fn generate_salt() -> Vec<u8> {
let mut salt = vec![0u8; 16];
OsRng.fill_bytes(&mut salt);
salt
}
pub fn derive_key_from_password(
password: &str,
salt: &[u8],
algorithm: EncryptionAlgorithm,
) -> Result<EncryptionKey, EncryptionError> {
EncryptionKey::derive_from_password(password, salt, algorithm)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_key_generation() {
let key = EncryptionKey::generate(EncryptionAlgorithm::Aes256Gcm).unwrap();
assert_eq!(key.as_bytes().len(), 32);
assert!(!key.key_id().is_empty());
}
#[test]
fn test_encryption_key_from_bytes() {
let key_bytes = vec![0u8; 32];
let key = EncryptionKey::from_bytes(&key_bytes, EncryptionAlgorithm::Aes256Gcm).unwrap();
assert_eq!(key.as_bytes().len(), 32);
}
#[test]
fn test_encryption_key_invalid_length() {
let key_bytes = vec![0u8; 16];
let result = EncryptionKey::from_bytes(&key_bytes, EncryptionAlgorithm::Aes256Gcm);
assert!(result.is_err());
}
#[test]
fn test_key_derivation() {
let salt = generate_salt();
let key =
derive_key_from_password("password123", &salt, EncryptionAlgorithm::Aes256Gcm).unwrap();
assert_eq!(key.as_bytes().len(), 32);
}
#[test]
fn test_aes_gcm_encryption_decryption() {
let engine = EncryptionEngine::new();
engine.generate_key().unwrap();
let plaintext = b"Hello, World!";
let encrypted = engine.encrypt(plaintext).unwrap();
let decrypted = engine.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.to_vec(), decrypted);
}
#[test]
fn test_chacha_encryption_decryption() {
let mut config = EncryptionConfig::default();
config.default_algorithm = EncryptionAlgorithm::ChaCha20Poly1305;
let engine = EncryptionEngine::with_config(config);
engine.generate_key().unwrap();
let plaintext = b"Hello, ChaCha20!";
let encrypted = engine.encrypt(plaintext).unwrap();
let decrypted = engine.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.to_vec(), decrypted);
}
#[test]
fn test_string_encryption() {
let engine = EncryptionEngine::new();
engine.generate_key().unwrap();
let plaintext = "Secret message";
let encrypted = engine.encrypt_string(plaintext).unwrap();
let decrypted = engine.decrypt_to_string(&encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_multiple_keys() {
let engine = EncryptionEngine::new();
let key_id1 = engine.generate_key().unwrap();
let key_id2 = engine
.generate_key_with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305)
.unwrap();
assert_eq!(engine.key_count(), 2);
assert!(engine.list_keys().contains(&key_id1));
assert!(engine.list_keys().contains(&key_id2));
}
#[test]
fn test_key_removal() {
let engine = EncryptionEngine::new();
let key_id = engine.generate_key().unwrap();
assert!(engine.remove_key(&key_id).unwrap());
assert!(!engine.list_keys().contains(&key_id));
}
#[test]
fn test_key_rotation() {
let engine = EncryptionEngine::new();
let old_key_id = engine.generate_key().unwrap();
let new_key_id = engine.rotate_key(&old_key_id).unwrap();
assert_ne!(old_key_id, new_key_id);
assert!(engine.list_keys().contains(&new_key_id));
}
#[test]
fn test_encrypted_data_serialization() {
let engine = EncryptionEngine::new();
engine.generate_key().unwrap();
let encrypted = engine.encrypt_string("test").unwrap();
let serialized = encrypted.to_string().unwrap();
let deserialized = EncryptedData::from_string(&serialized).unwrap();
assert_eq!(encrypted.key_id, deserialized.key_id);
assert_eq!(encrypted.nonce, deserialized.nonce);
assert_eq!(encrypted.ciphertext, deserialized.ciphertext);
}
#[test]
fn test_key_export_import() {
let engine = EncryptionEngine::new();
let key_id = engine.generate_key().unwrap();
let exported = engine.export_key(&key_id).unwrap();
engine.remove_key(&key_id).unwrap();
let imported_id = engine
.import_key(&exported, EncryptionAlgorithm::Aes256Gcm)
.unwrap();
assert!(engine.list_keys().contains(&imported_id));
}
#[test]
fn test_different_plaintexts_different_ciphertexts() {
let engine = EncryptionEngine::new();
engine.generate_key().unwrap();
let encrypted1 = engine.encrypt(b"test").unwrap();
let encrypted2 = engine.encrypt(b"test").unwrap();
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
assert_ne!(encrypted1.nonce, encrypted2.nonce);
}
#[test]
fn test_large_data_encryption() {
let engine = EncryptionEngine::new();
engine.generate_key().unwrap();
let large_data = vec![0u8; 1_000_000]; let encrypted = engine.encrypt(&large_data).unwrap();
let decrypted = engine.decrypt(&encrypted).unwrap();
assert_eq!(large_data, decrypted);
}
#[test]
fn test_empty_data_encryption() {
let engine = EncryptionEngine::new();
engine.generate_key().unwrap();
let encrypted = engine.encrypt(b"").unwrap();
let decrypted = engine.decrypt(&encrypted).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn test_algorithm_display() {
assert_eq!(format!("{}", EncryptionAlgorithm::Aes256Gcm), "AES-256-GCM");
assert_eq!(
format!("{}", EncryptionAlgorithm::ChaCha20Poly1305),
"ChaCha20-Poly1305"
);
}
#[test]
fn test_encryption_key_debug_redaction() {
let key = EncryptionKey::generate(EncryptionAlgorithm::Aes256Gcm).unwrap();
let debug_output = format!("{:?}", key);
assert!(debug_output.contains("<redacted>"));
assert!(!debug_output.contains(&BASE64.encode(key.as_bytes())));
}
#[test]
fn test_default_key() {
let engine = EncryptionEngine::new();
assert!(!engine.has_default_key());
engine.generate_key().unwrap();
assert!(engine.has_default_key());
}
#[test]
fn test_set_default_key() {
let engine = EncryptionEngine::new();
let key_id1 = engine.generate_key().unwrap();
let key_id2 = engine.generate_key().unwrap();
let default_key = engine.get_default_key().unwrap();
assert_eq!(default_key.key_id(), key_id2);
engine.set_default_key(&key_id1).unwrap();
let default_key = engine.get_default_key().unwrap();
assert_eq!(default_key.key_id(), key_id1);
}
#[test]
fn test_wrong_algorithm_decryption() {
let engine = EncryptionEngine::new();
let aes_key = engine
.generate_key_with_algorithm(EncryptionAlgorithm::Aes256Gcm)
.unwrap();
let chacha_key = engine
.generate_key_with_algorithm(EncryptionAlgorithm::ChaCha20Poly1305)
.unwrap();
let aes_encrypted = engine
.encrypt_with_key(&engine.get_key(&aes_key).unwrap(), b"test")
.unwrap();
let result = engine.decrypt_with_key(&engine.get_key(&chacha_key).unwrap(), &aes_encrypted);
assert!(result.is_err());
}
#[test]
fn test_concurrent_encryption() {
use std::sync::Arc;
use std::thread;
let engine = Arc::new(EncryptionEngine::new());
engine.generate_key().unwrap();
let mut handles = vec![];
for i in 0..10 {
let e = Arc::clone(&engine);
handles.push(thread::spawn(move || {
let data = format!("message_{}", i);
let encrypted = e.encrypt_string(&data).unwrap();
let decrypted = e.decrypt_to_string(&encrypted).unwrap();
assert_eq!(data, decrypted);
}));
}
for h in handles {
h.join().unwrap();
}
}
}