use crate::encryption::{AtRestEncryptor, EncryptedData, EncryptionAlgorithm};
use crate::error::{Result, SecurityError};
use serde::{Deserialize, Serialize};
pub trait KekProvider: Send + Sync {
fn encrypt_dek(&self, dek: &[u8]) -> Result<Vec<u8>>;
fn decrypt_dek(&self, encrypted_dek: &[u8]) -> Result<Vec<u8>>;
fn kek_id(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvelopeEncryptedData {
pub encrypted_dek: Vec<u8>,
pub kek_id: String,
pub encrypted_payload: EncryptedData,
pub dek_algorithm: EncryptionAlgorithm,
}
impl EnvelopeEncryptedData {
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
serde_json::to_vec(self).map_err(SecurityError::from)
}
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(SecurityError::from)
}
}
pub struct EnvelopeEncryptor {
kek_provider: Box<dyn KekProvider>,
dek_algorithm: EncryptionAlgorithm,
}
impl EnvelopeEncryptor {
pub fn new(kek_provider: Box<dyn KekProvider>, dek_algorithm: EncryptionAlgorithm) -> Self {
Self {
kek_provider,
dek_algorithm,
}
}
pub fn encrypt(&self, plaintext: &[u8], aad: Option<&[u8]>) -> Result<EnvelopeEncryptedData> {
let dek = AtRestEncryptor::generate_key(self.dek_algorithm);
let dek_id = uuid::Uuid::new_v4().to_string();
let encryptor = AtRestEncryptor::new(self.dek_algorithm, dek.clone(), dek_id)?;
let encrypted_payload = encryptor.encrypt(plaintext, aad)?;
let encrypted_dek = self.kek_provider.encrypt_dek(&dek)?;
Ok(EnvelopeEncryptedData {
encrypted_dek,
kek_id: self.kek_provider.kek_id().to_string(),
encrypted_payload,
dek_algorithm: self.dek_algorithm,
})
}
pub fn decrypt(&self, envelope: &EnvelopeEncryptedData) -> Result<Vec<u8>> {
if envelope.kek_id != self.kek_provider.kek_id() {
return Err(SecurityError::decryption(format!(
"KEK ID mismatch: expected {}, got {}",
self.kek_provider.kek_id(),
envelope.kek_id
)));
}
let dek = self.kek_provider.decrypt_dek(&envelope.encrypted_dek)?;
let dek_id = envelope.encrypted_payload.metadata.key_id.clone();
let encryptor = AtRestEncryptor::new(envelope.dek_algorithm, dek, dek_id)?;
encryptor.decrypt(&envelope.encrypted_payload)
}
pub fn kek_id(&self) -> &str {
self.kek_provider.kek_id()
}
pub fn dek_algorithm(&self) -> EncryptionAlgorithm {
self.dek_algorithm
}
}
pub struct InMemoryKekProvider {
kek_id: String,
encryptor: AtRestEncryptor,
}
impl InMemoryKekProvider {
pub fn new(kek_id: String) -> Result<Self> {
let kek = AtRestEncryptor::generate_key(EncryptionAlgorithm::Aes256Gcm);
let encryptor = AtRestEncryptor::new(EncryptionAlgorithm::Aes256Gcm, kek, kek_id.clone())?;
Ok(Self { kek_id, encryptor })
}
pub fn with_kek(kek_id: String, kek: Vec<u8>) -> Result<Self> {
let encryptor = AtRestEncryptor::new(EncryptionAlgorithm::Aes256Gcm, kek, kek_id.clone())?;
Ok(Self { kek_id, encryptor })
}
}
impl KekProvider for InMemoryKekProvider {
fn encrypt_dek(&self, dek: &[u8]) -> Result<Vec<u8>> {
let encrypted = self.encryptor.encrypt(dek, None)?;
encrypted.to_json_bytes()
}
fn decrypt_dek(&self, encrypted_dek: &[u8]) -> Result<Vec<u8>> {
let encrypted = EncryptedData::from_json_bytes(encrypted_dek)?;
self.encryptor.decrypt(&encrypted)
}
fn kek_id(&self) -> &str {
&self.kek_id
}
}
pub struct MultiRegionKekProvider {
kek_id: String,
primary: Box<dyn KekProvider>,
secondary: Option<Box<dyn KekProvider>>,
}
impl MultiRegionKekProvider {
pub fn new(
kek_id: String,
primary: Box<dyn KekProvider>,
secondary: Option<Box<dyn KekProvider>>,
) -> Self {
Self {
kek_id,
primary,
secondary,
}
}
pub fn encrypt_with_both(&self, dek: &[u8]) -> Result<(Vec<u8>, Option<Vec<u8>>)> {
let primary_encrypted = self.primary.encrypt_dek(dek)?;
let secondary_encrypted = if let Some(ref secondary) = self.secondary {
Some(secondary.encrypt_dek(dek)?)
} else {
None
};
Ok((primary_encrypted, secondary_encrypted))
}
}
impl KekProvider for MultiRegionKekProvider {
fn encrypt_dek(&self, dek: &[u8]) -> Result<Vec<u8>> {
self.primary.encrypt_dek(dek)
}
fn decrypt_dek(&self, encrypted_dek: &[u8]) -> Result<Vec<u8>> {
match self.primary.decrypt_dek(encrypted_dek) {
Ok(dek) => Ok(dek),
Err(e) => {
if let Some(ref secondary) = self.secondary {
secondary.decrypt_dek(encrypted_dek)
} else {
Err(e)
}
}
}
}
fn kek_id(&self) -> &str {
&self.kek_id
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_in_memory_kek_provider() {
let provider =
InMemoryKekProvider::new("test-kek".to_string()).expect("Failed to create provider");
let dek = AtRestEncryptor::generate_key(EncryptionAlgorithm::Aes256Gcm);
let encrypted_dek = provider.encrypt_dek(&dek).expect("Encryption failed");
assert_ne!(encrypted_dek, dek);
let decrypted_dek = provider
.decrypt_dek(&encrypted_dek)
.expect("Decryption failed");
assert_eq!(decrypted_dek, dek);
}
#[test]
fn test_envelope_encryption() {
let kek_provider =
InMemoryKekProvider::new("test-kek".to_string()).expect("Failed to create provider");
let encryptor =
EnvelopeEncryptor::new(Box::new(kek_provider), EncryptionAlgorithm::Aes256Gcm);
let plaintext = b"sensitive data";
let envelope = encryptor
.encrypt(plaintext, None)
.expect("Encryption failed");
assert_ne!(envelope.encrypted_payload.ciphertext, plaintext);
assert!(!envelope.encrypted_dek.is_empty());
let decrypted = encryptor.decrypt(&envelope).expect("Decryption failed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_envelope_with_aad() {
let kek_provider =
InMemoryKekProvider::new("test-kek".to_string()).expect("Failed to create provider");
let encryptor =
EnvelopeEncryptor::new(Box::new(kek_provider), EncryptionAlgorithm::Aes256Gcm);
let plaintext = b"sensitive data";
let aad = b"additional data";
let envelope = encryptor
.encrypt(plaintext, Some(aad))
.expect("Encryption failed");
let decrypted = encryptor.decrypt(&envelope).expect("Decryption failed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_multi_region_kek_provider() {
let primary =
InMemoryKekProvider::new("primary-kek".to_string()).expect("Failed to create primary");
let secondary = InMemoryKekProvider::new("secondary-kek".to_string())
.expect("Failed to create secondary");
let multi = MultiRegionKekProvider::new(
"multi-kek".to_string(),
Box::new(primary),
Some(Box::new(secondary)),
);
let dek = AtRestEncryptor::generate_key(EncryptionAlgorithm::Aes256Gcm);
let encrypted_dek = multi.encrypt_dek(&dek).expect("Encryption failed");
let decrypted_dek = multi
.decrypt_dek(&encrypted_dek)
.expect("Decryption failed");
assert_eq!(decrypted_dek, dek);
}
#[test]
fn test_envelope_serialization() {
let kek_provider =
InMemoryKekProvider::new("test-kek".to_string()).expect("Failed to create provider");
let encryptor =
EnvelopeEncryptor::new(Box::new(kek_provider), EncryptionAlgorithm::Aes256Gcm);
let plaintext = b"sensitive data";
let envelope = encryptor
.encrypt(plaintext, None)
.expect("Encryption failed");
let json = envelope.to_json_bytes().expect("Serialization failed");
let deserialized =
EnvelopeEncryptedData::from_json_bytes(&json).expect("Deserialization failed");
let decrypted = encryptor.decrypt(&deserialized).expect("Decryption failed");
assert_eq!(decrypted, plaintext);
}
}