use super::algorithms::{EncryptedData, EncryptionAlgorithm, EncryptionKey};
use super::errors::{EncryptionError, EncryptionResult};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct KeyVersion(pub u64);
impl KeyVersion {
pub fn new(version: u64) -> Self {
Self(version)
}
pub fn next(&self) -> Self {
Self(self.0 + 1)
}
}
impl std::fmt::Display for KeyVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "v{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct VersionedKey {
pub version: KeyVersion,
pub key: EncryptionKey,
pub created_at: DateTime<Utc>,
pub rotate_at: Option<DateTime<Utc>>,
pub is_active: bool,
}
impl VersionedKey {
pub fn new(version: KeyVersion, key: EncryptionKey) -> Self {
Self {
version,
key,
created_at: Utc::now(),
rotate_at: None,
is_active: false,
}
}
pub fn with_rotation(mut self, rotate_at: DateTime<Utc>) -> Self {
self.rotate_at = Some(rotate_at);
self
}
pub fn activate(mut self) -> Self {
self.is_active = true;
self
}
pub fn should_rotate(&self) -> bool {
if let Some(rotate_at) = self.rotate_at {
Utc::now() >= rotate_at
} else {
false
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyRotationConfig {
pub rotation_interval_days: i64,
pub max_previous_keys: usize,
pub algorithm: EncryptionAlgorithm,
pub auto_rotate: bool,
}
impl Default for KeyRotationConfig {
fn default() -> Self {
Self {
rotation_interval_days: 30, max_previous_keys: 5, algorithm: EncryptionAlgorithm::Aes256Gcm,
auto_rotate: true,
}
}
}
pub struct KeyManager {
current_key: VersionedKey,
previous_keys: HashMap<KeyVersion, VersionedKey>,
config: KeyRotationConfig,
}
impl KeyManager {
pub fn new(config: KeyRotationConfig) -> EncryptionResult<Self> {
let key = EncryptionKey::generate(config.algorithm.clone())?;
let rotate_at = if config.rotation_interval_days > 0 {
Some(Utc::now() + Duration::days(config.rotation_interval_days))
} else {
None
};
let current_key = VersionedKey::new(KeyVersion::new(1), key)
.with_rotation(rotate_at.unwrap_or_else(|| Utc::now() + Duration::days(365)))
.activate();
Ok(Self {
current_key,
previous_keys: HashMap::new(),
config,
})
}
pub fn with_key(config: KeyRotationConfig, key: EncryptionKey) -> EncryptionResult<Self> {
let rotate_at = if config.rotation_interval_days > 0 {
Some(Utc::now() + Duration::days(config.rotation_interval_days))
} else {
None
};
let current_key = VersionedKey::new(KeyVersion::new(1), key)
.with_rotation(rotate_at.unwrap_or_else(|| Utc::now() + Duration::days(365)))
.activate();
Ok(Self {
current_key,
previous_keys: HashMap::new(),
config,
})
}
pub fn current_key(&self) -> &EncryptionKey {
&self.current_key.key
}
pub fn current_version(&self) -> KeyVersion {
self.current_key.version
}
pub fn is_rotation_due(&self) -> bool {
self.config.auto_rotate && self.current_key.should_rotate()
}
pub fn rotate_key(&mut self) -> EncryptionResult<KeyVersion> {
let new_key = EncryptionKey::generate(self.config.algorithm.clone())?;
let new_version = self.current_key.version.next();
let rotate_at = if self.config.rotation_interval_days > 0 {
Some(Utc::now() + Duration::days(self.config.rotation_interval_days))
} else {
None
};
let new_versioned_key = VersionedKey::new(new_version, new_key)
.with_rotation(rotate_at.unwrap_or_else(|| Utc::now() + Duration::days(365)))
.activate();
let mut old_key = self.current_key.clone();
old_key.is_active = false;
self.previous_keys.insert(old_key.version, old_key);
if self.previous_keys.len() > self.config.max_previous_keys {
let mut versions: Vec<KeyVersion> = self.previous_keys.keys().copied().collect();
versions.sort_by_key(|v| v.0);
let to_remove = versions.len() - self.config.max_previous_keys;
for version in versions.iter().take(to_remove) {
self.previous_keys.remove(version);
}
}
self.current_key = new_versioned_key;
Ok(new_version)
}
pub fn encrypt(
&self,
plaintext: &[u8],
aad: Option<&[u8]>,
) -> EncryptionResult<VersionedEncryptedData> {
let encrypted_data =
super::algorithms::EncryptionEngine::encrypt(&self.current_key.key, plaintext, aad)?;
Ok(VersionedEncryptedData {
version: self.current_key.version,
data: encrypted_data,
})
}
pub fn decrypt(&self, encrypted_data: &VersionedEncryptedData) -> EncryptionResult<Vec<u8>> {
if encrypted_data.version == self.current_key.version {
return super::algorithms::EncryptionEngine::decrypt(
&self.current_key.key,
&encrypted_data.data,
);
}
if let Some(versioned_key) = self.previous_keys.get(&encrypted_data.version) {
return super::algorithms::EncryptionEngine::decrypt(
&versioned_key.key,
&encrypted_data.data,
);
}
Err(EncryptionError::invalid_key(format!(
"No key found for version {}",
encrypted_data.version
)))
}
pub fn key_versions(&self) -> Vec<KeyVersion> {
let mut versions: Vec<KeyVersion> = self.previous_keys.keys().copied().collect();
versions.push(self.current_key.version);
versions.sort_by_key(|v| v.0);
versions
}
pub fn key_metadata(&self, version: KeyVersion) -> Option<KeyMetadata> {
if version == self.current_key.version {
Some(KeyMetadata {
version,
created_at: self.current_key.created_at,
rotate_at: self.current_key.rotate_at,
is_active: self.current_key.is_active,
})
} else {
self.previous_keys.get(&version).map(|key| KeyMetadata {
version,
created_at: key.created_at,
rotate_at: key.rotate_at,
is_active: key.is_active,
})
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyMetadata {
pub version: KeyVersion,
pub created_at: DateTime<Utc>,
pub rotate_at: Option<DateTime<Utc>>,
pub is_active: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionedEncryptedData {
pub version: KeyVersion,
pub data: EncryptedData,
}
impl VersionedEncryptedData {
pub fn new(version: KeyVersion, data: EncryptedData) -> Self {
Self { version, data }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_version() {
let v1 = KeyVersion::new(1);
let v2 = v1.next();
assert_eq!(v2.0, 2);
assert_eq!(v1.to_string(), "v1");
}
#[test]
fn test_key_manager_creation() {
let config = KeyRotationConfig::default();
let manager = KeyManager::new(config).unwrap();
assert_eq!(manager.current_version(), KeyVersion::new(1));
assert_eq!(manager.key_versions(), vec![KeyVersion::new(1)]);
}
#[test]
fn test_key_rotation() {
let config = KeyRotationConfig::default();
let mut manager = KeyManager::new(config).unwrap();
let new_version = manager.rotate_key().unwrap();
assert_eq!(new_version, KeyVersion::new(2));
assert_eq!(manager.current_version(), KeyVersion::new(2));
let versions = manager.key_versions();
assert_eq!(versions.len(), 2);
assert!(versions.contains(&KeyVersion::new(1)));
assert!(versions.contains(&KeyVersion::new(2)));
}
#[test]
fn test_encrypt_decrypt_with_rotation() {
let config = KeyRotationConfig::default();
let mut manager = KeyManager::new(config).unwrap();
let plaintext = b"secret data";
let encrypted_v1 = manager.encrypt(plaintext, None).unwrap();
assert_eq!(encrypted_v1.version, KeyVersion::new(1));
manager.rotate_key().unwrap();
let encrypted_v2 = manager.encrypt(plaintext, None).unwrap();
assert_eq!(encrypted_v2.version, KeyVersion::new(2));
let decrypted_v1 = manager.decrypt(&encrypted_v1).unwrap();
let decrypted_v2 = manager.decrypt(&encrypted_v2).unwrap();
assert_eq!(decrypted_v1, plaintext);
assert_eq!(decrypted_v2, plaintext);
}
#[test]
fn test_max_previous_keys() {
let config = KeyRotationConfig {
rotation_interval_days: 30,
max_previous_keys: 2,
algorithm: EncryptionAlgorithm::Aes256Gcm,
auto_rotate: true,
};
let mut manager = KeyManager::new(config).unwrap();
for _ in 0..5 {
manager.rotate_key().unwrap();
}
let versions = manager.key_versions();
assert_eq!(versions.len(), 3);
}
#[test]
fn test_versioned_key_should_rotate() {
let key = EncryptionKey::generate(EncryptionAlgorithm::Aes256Gcm).unwrap();
let past_rotation = Utc::now() - Duration::days(1);
let versioned_key =
VersionedKey::new(KeyVersion::new(1), key.clone()).with_rotation(past_rotation);
assert!(versioned_key.should_rotate());
let future_rotation = Utc::now() + Duration::days(1);
let versioned_key =
VersionedKey::new(KeyVersion::new(1), key).with_rotation(future_rotation);
assert!(!versioned_key.should_rotate());
}
#[test]
fn test_key_metadata() {
let config = KeyRotationConfig::default();
let manager = KeyManager::new(config).unwrap();
let metadata = manager.key_metadata(KeyVersion::new(1)).unwrap();
assert_eq!(metadata.version, KeyVersion::new(1));
assert!(metadata.is_active);
assert!(metadata.rotate_at.is_some());
}
#[test]
fn test_decrypt_with_unknown_version() {
let config = KeyRotationConfig::default();
let manager = KeyManager::new(config).unwrap();
let plaintext = b"test";
let key = EncryptionKey::generate(EncryptionAlgorithm::Aes256Gcm).unwrap();
let encrypted =
super::super::algorithms::EncryptionEngine::encrypt(&key, plaintext, None).unwrap();
let versioned = VersionedEncryptedData {
version: KeyVersion::new(999), data: encrypted,
};
let result = manager.decrypt(&versioned);
assert!(result.is_err());
}
}