use crate::error::{AllSourceError, Result};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum KmsProvider {
Local,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KmsConfig {
pub provider: KmsProvider,
pub config: HashMap<String, String>,
pub auto_rotate: bool,
pub rotation_period_days: u32,
}
impl Default for KmsConfig {
fn default() -> Self {
Self {
provider: KmsProvider::Local,
config: HashMap::new(),
auto_rotate: true,
rotation_period_days: 90,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyMetadata {
pub key_id: String,
pub alias: String,
pub purpose: KeyPurpose,
pub algorithm: KeyAlgorithm,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_rotated: Option<chrono::DateTime<chrono::Utc>>,
pub status: KeyStatus,
pub version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum KeyPurpose {
DataEncryption,
JwtSigning,
ApiKeySigning,
DatabaseEncryption,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum KeyAlgorithm {
Aes256Gcm,
Aes128Gcm,
ChaCha20Poly1305,
RsaOaep,
EcdsaP256,
Ed25519,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum KeyStatus {
Active,
Rotating,
Deprecated,
Destroyed,
}
#[async_trait::async_trait]
pub trait KmsClient: Send + Sync {
async fn create_key(
&self,
alias: String,
purpose: KeyPurpose,
algorithm: KeyAlgorithm,
) -> Result<KeyMetadata>;
async fn get_key(&self, key_id: &str) -> Result<KeyMetadata>;
async fn list_keys(&self) -> Result<Vec<KeyMetadata>>;
async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>>;
async fn decrypt(&self, key_id: &str, ciphertext: &[u8]) -> Result<Vec<u8>>;
async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata>;
async fn disable_key(&self, key_id: &str) -> Result<()>;
async fn enable_key(&self, key_id: &str) -> Result<()>;
async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)>;
}
pub struct LocalKms {
keys: Arc<DashMap<String, StoredKey>>,
config: KmsConfig,
}
struct StoredKey {
metadata: KeyMetadata,
key_material: Vec<u8>,
}
impl LocalKms {
pub fn new(config: KmsConfig) -> Self {
Self {
keys: Arc::new(DashMap::new()),
config,
}
}
}
#[async_trait::async_trait]
impl KmsClient for LocalKms {
async fn create_key(
&self,
alias: String,
purpose: KeyPurpose,
algorithm: KeyAlgorithm,
) -> Result<KeyMetadata> {
let key_id = uuid::Uuid::new_v4().to_string();
let key_material = match algorithm {
KeyAlgorithm::Aes256Gcm => {
let mut key = vec![0u8; 32];
use aes_gcm::aead::{OsRng, rand_core::RngCore};
RngCore::fill_bytes(&mut OsRng, &mut key);
key
}
KeyAlgorithm::Aes128Gcm => {
let mut key = vec![0u8; 16];
use aes_gcm::aead::{OsRng, rand_core::RngCore};
RngCore::fill_bytes(&mut OsRng, &mut key);
key
}
_ => {
return Err(AllSourceError::ValidationError(format!(
"Algorithm {algorithm:?} not supported in local KMS"
)));
}
};
let metadata = KeyMetadata {
key_id: key_id.clone(),
alias,
purpose,
algorithm,
created_at: chrono::Utc::now(),
last_rotated: None,
status: KeyStatus::Active,
version: 1,
};
let stored_key = StoredKey {
metadata: metadata.clone(),
key_material,
};
self.keys.insert(key_id, stored_key);
Ok(metadata)
}
async fn get_key(&self, key_id: &str) -> Result<KeyMetadata> {
self.keys
.get(key_id)
.map(|entry| entry.value().metadata.clone())
.ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))
}
async fn list_keys(&self) -> Result<Vec<KeyMetadata>> {
Ok(self
.keys
.iter()
.map(|entry| entry.value().metadata.clone())
.collect())
}
async fn encrypt(&self, key_id: &str, plaintext: &[u8]) -> Result<Vec<u8>> {
use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
let stored_key = self
.keys
.get(key_id)
.ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
if stored_key.metadata.status != KeyStatus::Active {
return Err(AllSourceError::ValidationError(
"Key is not active".to_string(),
));
}
let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
use aes_gcm::aead::{OsRng, rand_core::RngCore};
let nonce_bytes = OsRng.next_u64().to_le_bytes();
let mut nonce_array = [0u8; 12];
nonce_array[..8].copy_from_slice(&nonce_bytes);
let nonce = Nonce::from_slice(&nonce_array);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
let mut result = nonce.to_vec();
result.extend_from_slice(&ciphertext);
Ok(result)
}
async fn decrypt(&self, key_id: &str, ciphertext_with_nonce: &[u8]) -> Result<Vec<u8>> {
use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
if ciphertext_with_nonce.len() < 12 {
return Err(AllSourceError::ValidationError(
"Invalid ciphertext".to_string(),
));
}
let stored_key = self
.keys
.get(key_id)
.ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
let cipher = Aes256Gcm::new_from_slice(&stored_key.key_material)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
let nonce = Nonce::from_slice(&ciphertext_with_nonce[..12]);
let ciphertext = &ciphertext_with_nonce[12..];
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
}
async fn rotate_key(&self, key_id: &str) -> Result<KeyMetadata> {
let mut stored_key = self
.keys
.get_mut(key_id)
.ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
let new_key_material = {
let mut key = vec![0u8; 32];
use aes_gcm::aead::{OsRng, rand_core::RngCore};
RngCore::fill_bytes(&mut OsRng, &mut key);
key
};
stored_key.key_material = new_key_material;
stored_key.metadata.version += 1;
stored_key.metadata.last_rotated = Some(chrono::Utc::now());
Ok(stored_key.metadata.clone())
}
async fn disable_key(&self, key_id: &str) -> Result<()> {
let mut stored_key = self
.keys
.get_mut(key_id)
.ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
stored_key.metadata.status = KeyStatus::Deprecated;
Ok(())
}
async fn enable_key(&self, key_id: &str) -> Result<()> {
let mut stored_key = self
.keys
.get_mut(key_id)
.ok_or_else(|| AllSourceError::ValidationError(format!("Key {key_id} not found")))?;
stored_key.metadata.status = KeyStatus::Active;
Ok(())
}
async fn generate_data_key(&self, key_id: &str) -> Result<(Vec<u8>, Vec<u8>)> {
let mut dek = vec![0u8; 32];
use aes_gcm::aead::{OsRng, rand_core::RngCore};
RngCore::fill_bytes(&mut OsRng, &mut dek);
let encrypted_dek = self.encrypt(key_id, &dek).await?;
Ok((dek, encrypted_dek))
}
}
pub struct KmsManager {
client: Arc<dyn KmsClient>,
config: KmsConfig,
}
impl KmsManager {
pub fn new(config: KmsConfig) -> Result<Self> {
let client: Arc<dyn KmsClient> = match config.provider {
KmsProvider::Local => Arc::new(LocalKms::new(config.clone())),
};
Ok(Self { client, config })
}
pub fn client(&self) -> &Arc<dyn KmsClient> {
&self.client
}
pub async fn envelope_encrypt(
&self,
master_key_id: &str,
plaintext: &[u8],
) -> Result<EnvelopeEncryptedData> {
let (dek, encrypted_dek) = self.client.generate_data_key(master_key_id).await?;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
let cipher = Aes256Gcm::new_from_slice(&dek)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
use aes_gcm::aead::{OsRng, rand_core::RngCore};
let nonce_bytes = OsRng.next_u64().to_le_bytes();
let mut nonce_array = [0u8; 12];
nonce_array[..8].copy_from_slice(&nonce_bytes);
let nonce = Nonce::from_slice(&nonce_array);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| AllSourceError::ValidationError(format!("Encryption failed: {e}")))?;
Ok(EnvelopeEncryptedData {
ciphertext,
nonce: nonce.to_vec(),
encrypted_dek,
master_key_id: master_key_id.to_string(),
})
}
pub async fn envelope_decrypt(&self, encrypted: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
let dek = self
.client
.decrypt(&encrypted.master_key_id, &encrypted.encrypted_dek)
.await?;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce, aead::Aead};
let cipher = Aes256Gcm::new_from_slice(&dek)
.map_err(|e| AllSourceError::ValidationError(format!("Invalid key: {e}")))?;
let nonce = Nonce::from_slice(&encrypted.nonce);
cipher
.decrypt(nonce, encrypted.ciphertext.as_ref())
.map_err(|e| AllSourceError::ValidationError(format!("Decryption failed: {e}")))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvelopeEncryptedData {
pub ciphertext: Vec<u8>,
pub nonce: Vec<u8>,
pub encrypted_dek: Vec<u8>,
pub master_key_id: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_local_kms_create_key() {
let config = KmsConfig::default();
let kms = LocalKms::new(config);
let metadata = kms
.create_key(
"test-key".to_string(),
KeyPurpose::DataEncryption,
KeyAlgorithm::Aes256Gcm,
)
.await
.unwrap();
assert_eq!(metadata.alias, "test-key");
assert_eq!(metadata.status, KeyStatus::Active);
assert_eq!(metadata.version, 1);
}
#[tokio::test]
async fn test_local_kms_encrypt_decrypt() {
let config = KmsConfig::default();
let kms = LocalKms::new(config);
let key = kms
.create_key(
"test-key".to_string(),
KeyPurpose::DataEncryption,
KeyAlgorithm::Aes256Gcm,
)
.await
.unwrap();
let plaintext = b"sensitive data";
let ciphertext = kms.encrypt(&key.key_id, plaintext).await.unwrap();
let decrypted = kms.decrypt(&key.key_id, &ciphertext).await.unwrap();
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_key_rotation() {
let config = KmsConfig::default();
let kms = LocalKms::new(config);
let key = kms
.create_key(
"test-key".to_string(),
KeyPurpose::DataEncryption,
KeyAlgorithm::Aes256Gcm,
)
.await
.unwrap();
let rotated = kms.rotate_key(&key.key_id).await.unwrap();
assert_eq!(rotated.version, 2);
assert!(rotated.last_rotated.is_some());
}
#[tokio::test]
async fn test_envelope_encryption() {
let config = KmsConfig::default();
let manager = KmsManager::new(config).unwrap();
let master_key = manager
.client()
.create_key(
"master-key".to_string(),
KeyPurpose::DataEncryption,
KeyAlgorithm::Aes256Gcm,
)
.await
.unwrap();
let plaintext = b"sensitive data for envelope encryption";
let encrypted = manager
.envelope_encrypt(&master_key.key_id, plaintext)
.await
.unwrap();
let decrypted = manager.envelope_decrypt(&encrypted).await.unwrap();
assert_eq!(decrypted, plaintext);
}
#[tokio::test]
async fn test_disable_enable_key() {
let config = KmsConfig::default();
let kms = LocalKms::new(config);
let key = kms
.create_key(
"test-key".to_string(),
KeyPurpose::DataEncryption,
KeyAlgorithm::Aes256Gcm,
)
.await
.unwrap();
kms.disable_key(&key.key_id).await.unwrap();
let metadata = kms.get_key(&key.key_id).await.unwrap();
assert_eq!(metadata.status, KeyStatus::Deprecated);
let result = kms.encrypt(&key.key_id, b"test").await;
assert!(result.is_err());
kms.enable_key(&key.key_id).await.unwrap();
let metadata = kms.get_key(&key.key_id).await.unwrap();
assert_eq!(metadata.status, KeyStatus::Active);
let result = kms.encrypt(&key.key_id, b"test").await;
assert!(result.is_ok());
}
}