use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use chrono::Utc;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::time::Instant;
use super::config::KeyRotationConfig;
use super::key_metadata::{KeyFormat, KeyStatus, SigningKeyMetadata};
use super::storage::KeyRotationStorage;
use crate::error::Error;
#[derive(Debug, Clone)]
pub struct CachedKey {
pub kid: String,
pub format: KeyFormat,
pub key_material: Vec<u8>,
pub status: KeyStatus,
}
struct KeyCache {
active_key: Option<CachedKey>,
verification_keys: HashMap<String, CachedKey>,
last_refresh: Instant,
}
impl KeyCache {
fn empty() -> Self {
Self {
active_key: None,
verification_keys: HashMap::new(),
last_refresh: Instant::now(),
}
}
}
#[derive(Clone)]
pub struct KeyManager {
storage: Arc<dyn KeyRotationStorage>,
service_name: String,
config: KeyRotationConfig,
cache: Arc<RwLock<KeyCache>>,
}
impl KeyManager {
pub fn new(
storage: Arc<dyn KeyRotationStorage>,
service_name: impl Into<String>,
config: KeyRotationConfig,
) -> Self {
Self {
storage,
service_name: service_name.into(),
config,
cache: Arc::new(RwLock::new(KeyCache::empty())),
}
}
pub async fn refresh_cache(&self) -> Result<(), Error> {
let active_meta = self.storage.get_active_key(&self.service_name).await?;
let verification_metas = self
.storage
.get_verification_keys(&self.service_name)
.await?;
let active_cached = active_meta.as_ref().map(decode_key_metadata).transpose()?;
let mut verification_keys = HashMap::with_capacity(verification_metas.len());
for meta in &verification_metas {
let cached = decode_key_metadata(meta)?;
verification_keys.insert(cached.kid.clone(), cached);
}
let mut cache = self.cache.write().await;
cache.active_key = active_cached;
cache.verification_keys = verification_keys;
cache.last_refresh = Instant::now();
Ok(())
}
pub async fn get_signing_key(&self) -> Result<Option<CachedKey>, Error> {
if self.is_cache_stale().await {
self.refresh_cache().await?;
}
let cache = self.cache.read().await;
Ok(cache.active_key.clone())
}
pub async fn get_verification_key(&self, kid: &str) -> Result<Option<CachedKey>, Error> {
{
let cache = self.cache.read().await;
if let Some(cached) = cache.verification_keys.get(kid) {
return Ok(Some(cached.clone()));
}
}
let meta = self.storage.get_key_by_kid(kid).await?;
match meta {
Some(ref m) if m.status == KeyStatus::Active || m.status == KeyStatus::Draining => {
let cached = decode_key_metadata(m)?;
let mut cache = self.cache.write().await;
cache
.verification_keys
.insert(cached.kid.clone(), cached.clone());
Ok(Some(cached))
}
_ => Ok(None),
}
}
pub async fn get_all_verification_keys(&self) -> Result<Vec<CachedKey>, Error> {
if self.is_cache_stale().await {
self.refresh_cache().await?;
}
let cache = self.cache.read().await;
Ok(cache.verification_keys.values().cloned().collect())
}
pub async fn rotate(&self) -> Result<SigningKeyMetadata, Error> {
let now = Utc::now();
let mut key_bytes = [0u8; 32];
rand::fill(&mut key_bytes);
let key_material_b64 = BASE64.encode(key_bytes);
let key_hash = blake3::hash(&key_bytes).to_hex().to_string();
let kid = uuid::Uuid::now_v7().to_string();
let new_key = SigningKeyMetadata {
kid: kid.clone(),
format: KeyFormat::PasetoV4Local,
key_material: key_material_b64,
status: KeyStatus::Active,
created_at: now,
activated_at: Some(now),
draining_since: None,
retired_at: None,
drain_expires_at: None,
service_name: self.service_name.clone(),
key_hash,
};
let old_active = self.storage.get_active_key(&self.service_name).await?;
if let Some(ref old) = old_active {
let drain_duration_secs =
self.config.rotation_period_secs + self.config.drain_grace_period_secs;
let drain_expires_at = now + chrono::Duration::seconds(drain_duration_secs as i64);
self.storage
.update_key_status(&old.kid, KeyStatus::Draining, now)
.await?;
tracing::info!(
kid = %old.kid,
drain_expires_at = %drain_expires_at,
"transitioned previous active key to draining"
);
}
self.storage.store_key(&new_key).await?;
self.refresh_cache().await?;
tracing::info!(
kid = %kid,
service = %self.service_name,
"rotated signing key"
);
Ok(new_key)
}
pub async fn retire_expired(&self) -> Result<u64, Error> {
let now = Utc::now();
let retired_count = self.storage.retire_expired_draining_keys(now).await?;
if retired_count > 0 {
self.refresh_cache().await?;
tracing::info!(
count = retired_count,
service = %self.service_name,
"retired expired draining keys"
);
}
Ok(retired_count)
}
pub fn service_name(&self) -> &str {
&self.service_name
}
pub fn storage(&self) -> &Arc<dyn KeyRotationStorage> {
&self.storage
}
async fn is_cache_stale(&self) -> bool {
let cache = self.cache.read().await;
let elapsed = cache.last_refresh.elapsed();
elapsed.as_secs() >= self.config.check_interval_secs
}
}
fn decode_key_metadata(meta: &SigningKeyMetadata) -> Result<CachedKey, Error> {
let key_material = BASE64.decode(&meta.key_material).map_err(|e| {
Error::Internal(format!(
"failed to decode base64 key material for kid '{}': {}",
meta.kid, e
))
})?;
Ok(CachedKey {
kid: meta.kid.clone(),
format: meta.format,
key_material,
status: meta.status,
})
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::sync::atomic::{AtomicU64, Ordering};
struct MockStorage {
keys: RwLock<Vec<SigningKeyMetadata>>,
retire_call_count: AtomicU64,
}
impl MockStorage {
fn new() -> Self {
Self {
keys: RwLock::new(Vec::new()),
retire_call_count: AtomicU64::new(0),
}
}
fn with_keys(keys: Vec<SigningKeyMetadata>) -> Self {
Self {
keys: RwLock::new(keys),
retire_call_count: AtomicU64::new(0),
}
}
}
#[async_trait]
impl KeyRotationStorage for MockStorage {
async fn store_key(&self, key: &SigningKeyMetadata) -> Result<(), Error> {
let mut keys = self.keys.write().await;
if keys.iter().any(|k| k.kid == key.kid) {
return Err(Error::Conflict(format!(
"key with kid '{}' already exists",
key.kid
)));
}
keys.push(key.clone());
Ok(())
}
async fn get_active_key(
&self,
service_name: &str,
) -> Result<Option<SigningKeyMetadata>, Error> {
let keys = self.keys.read().await;
Ok(keys
.iter()
.find(|k| k.service_name == service_name && k.status == KeyStatus::Active)
.cloned())
}
async fn get_key_by_kid(&self, kid: &str) -> Result<Option<SigningKeyMetadata>, Error> {
let keys = self.keys.read().await;
Ok(keys.iter().find(|k| k.kid == kid).cloned())
}
async fn get_verification_keys(
&self,
service_name: &str,
) -> Result<Vec<SigningKeyMetadata>, Error> {
let keys = self.keys.read().await;
Ok(keys
.iter()
.filter(|k| {
k.service_name == service_name
&& (k.status == KeyStatus::Active || k.status == KeyStatus::Draining)
})
.cloned()
.collect())
}
async fn update_key_status(
&self,
kid: &str,
new_status: KeyStatus,
timestamp: DateTime<Utc>,
) -> Result<(), Error> {
let mut keys = self.keys.write().await;
let key = keys
.iter_mut()
.find(|k| k.kid == kid)
.ok_or_else(|| Error::NotFound(format!("key '{}' not found", kid)))?;
key.status = new_status;
match new_status {
KeyStatus::Active => key.activated_at = Some(timestamp),
KeyStatus::Draining => key.draining_since = Some(timestamp),
KeyStatus::Retired => key.retired_at = Some(timestamp),
}
Ok(())
}
async fn retire_expired_draining_keys(&self, now: DateTime<Utc>) -> Result<u64, Error> {
self.retire_call_count.fetch_add(1, Ordering::Relaxed);
let mut keys = self.keys.write().await;
let mut count = 0u64;
for key in keys.iter_mut() {
if key.status == KeyStatus::Draining {
if let Some(expires) = key.drain_expires_at {
if expires <= now {
key.status = KeyStatus::Retired;
key.retired_at = Some(now);
count += 1;
}
}
}
}
Ok(count)
}
async fn initialize(&self) -> Result<(), Error> {
Ok(())
}
}
fn test_config() -> KeyRotationConfig {
KeyRotationConfig {
enabled: true,
rotation_period_secs: 3600,
drain_grace_period_secs: 300,
check_interval_secs: 60,
retention_days: 90,
bootstrap_key_path: None,
}
}
fn sample_metadata(kid: &str, status: KeyStatus, service: &str) -> SigningKeyMetadata {
let key_bytes = b"test-key-material-32-bytes!!!!!!";
let key_b64 = BASE64.encode(key_bytes);
let key_hash = blake3::hash(key_bytes).to_hex().to_string();
SigningKeyMetadata {
kid: kid.to_string(),
format: KeyFormat::PasetoV4Local,
key_material: key_b64,
status,
created_at: Utc::now(),
activated_at: if status == KeyStatus::Active {
Some(Utc::now())
} else {
None
},
draining_since: if status == KeyStatus::Draining {
Some(Utc::now())
} else {
None
},
retired_at: None,
drain_expires_at: if status == KeyStatus::Draining {
Some(Utc::now() + chrono::Duration::seconds(3600))
} else {
None
},
service_name: service.to_string(),
key_hash,
}
}
#[test]
fn test_decode_key_metadata_valid_base64() {
let meta = sample_metadata("kid-1", KeyStatus::Active, "svc");
let cached = decode_key_metadata(&meta).expect("should decode");
assert_eq!(cached.kid, "kid-1");
assert_eq!(cached.format, KeyFormat::PasetoV4Local);
assert_eq!(cached.status, KeyStatus::Active);
assert_eq!(cached.key_material.len(), 32);
}
#[test]
fn test_decode_key_metadata_invalid_base64() {
let mut meta = sample_metadata("kid-bad", KeyStatus::Active, "svc");
meta.key_material = "not-valid-base64!!!".to_string();
let result = decode_key_metadata(&meta);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("base64"));
}
#[test]
fn test_cached_key_clone() {
let meta = sample_metadata("kid-clone", KeyStatus::Draining, "svc");
let cached = decode_key_metadata(&meta).expect("should decode");
let cloned = cached.clone();
assert_eq!(cached.kid, cloned.kid);
assert_eq!(cached.key_material, cloned.key_material);
}
#[test]
fn test_key_manager_new() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let config = test_config();
let mgr = KeyManager::new(storage, "my-service", config.clone());
assert_eq!(mgr.service_name(), "my-service");
assert_eq!(mgr.config.rotation_period_secs, 3600);
}
#[test]
fn test_key_manager_is_clone() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mgr = KeyManager::new(storage, "svc", test_config());
let cloned = mgr.clone();
assert_eq!(cloned.service_name(), "svc");
}
#[test]
fn test_key_manager_storage_accessor() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mgr = KeyManager::new(Arc::clone(&storage), "svc", test_config());
let _storage_ref = mgr.storage();
}
#[tokio::test]
async fn test_refresh_cache_empty_storage() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mgr = KeyManager::new(storage, "svc", test_config());
mgr.refresh_cache().await.expect("refresh should succeed");
let cache = mgr.cache.read().await;
assert!(cache.active_key.is_none());
assert!(cache.verification_keys.is_empty());
}
#[tokio::test]
async fn test_refresh_cache_with_active_key() {
let key = sample_metadata("kid-active", KeyStatus::Active, "svc");
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::with_keys(vec![key]));
let mgr = KeyManager::new(storage, "svc", test_config());
mgr.refresh_cache().await.expect("refresh should succeed");
let cache = mgr.cache.read().await;
assert!(cache.active_key.is_some());
assert_eq!(cache.active_key.as_ref().unwrap().kid, "kid-active");
assert_eq!(cache.verification_keys.len(), 1);
assert!(cache.verification_keys.contains_key("kid-active"));
}
#[tokio::test]
async fn test_refresh_cache_with_active_and_draining() {
let active = sample_metadata("kid-active", KeyStatus::Active, "svc");
let draining = sample_metadata("kid-drain", KeyStatus::Draining, "svc");
let storage: Arc<dyn KeyRotationStorage> =
Arc::new(MockStorage::with_keys(vec![active, draining]));
let mgr = KeyManager::new(storage, "svc", test_config());
mgr.refresh_cache().await.expect("refresh should succeed");
let cache = mgr.cache.read().await;
assert!(cache.active_key.is_some());
assert_eq!(cache.verification_keys.len(), 2);
assert!(cache.verification_keys.contains_key("kid-active"));
assert!(cache.verification_keys.contains_key("kid-drain"));
}
#[tokio::test]
async fn test_get_signing_key_returns_active() {
let key = sample_metadata("kid-sign", KeyStatus::Active, "svc");
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::with_keys(vec![key]));
let mut config = test_config();
config.check_interval_secs = 9999;
let mgr = KeyManager::new(storage, "svc", config);
mgr.refresh_cache().await.unwrap();
let signing_key = mgr.get_signing_key().await.unwrap();
assert!(signing_key.is_some());
assert_eq!(signing_key.unwrap().kid, "kid-sign");
}
#[tokio::test]
async fn test_get_signing_key_none_when_no_active() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mut config = test_config();
config.check_interval_secs = 9999;
let mgr = KeyManager::new(storage, "svc", config);
mgr.refresh_cache().await.unwrap();
let signing_key = mgr.get_signing_key().await.unwrap();
assert!(signing_key.is_none());
}
#[tokio::test]
async fn test_get_verification_key_from_cache() {
let key = sample_metadata("kid-verify", KeyStatus::Active, "svc");
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::with_keys(vec![key]));
let mgr = KeyManager::new(storage, "svc", test_config());
mgr.refresh_cache().await.unwrap();
let result = mgr.get_verification_key("kid-verify").await.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().kid, "kid-verify");
}
#[tokio::test]
async fn test_get_verification_key_cache_miss_falls_back_to_storage() {
let key = sample_metadata("kid-miss", KeyStatus::Active, "svc");
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::with_keys(vec![key]));
let mgr = KeyManager::new(storage, "svc", test_config());
let result = mgr.get_verification_key("kid-miss").await.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().kid, "kid-miss");
let cache = mgr.cache.read().await;
assert!(cache.verification_keys.contains_key("kid-miss"));
}
#[tokio::test]
async fn test_get_verification_key_retired_returns_none() {
let key = sample_metadata("kid-retired", KeyStatus::Retired, "svc");
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::with_keys(vec![key]));
let mgr = KeyManager::new(storage, "svc", test_config());
let result = mgr.get_verification_key("kid-retired").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_get_verification_key_nonexistent_returns_none() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mgr = KeyManager::new(storage, "svc", test_config());
let result = mgr
.get_verification_key("kid-does-not-exist")
.await
.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_get_all_verification_keys() {
let active = sample_metadata("kid-a", KeyStatus::Active, "svc");
let draining = sample_metadata("kid-d", KeyStatus::Draining, "svc");
let storage: Arc<dyn KeyRotationStorage> =
Arc::new(MockStorage::with_keys(vec![active, draining]));
let mut config = test_config();
config.check_interval_secs = 9999;
let mgr = KeyManager::new(storage, "svc", config);
mgr.refresh_cache().await.unwrap();
let keys = mgr.get_all_verification_keys().await.unwrap();
assert_eq!(keys.len(), 2);
let kids: Vec<&str> = keys.iter().map(|k| k.kid.as_str()).collect();
assert!(kids.contains(&"kid-a"));
assert!(kids.contains(&"kid-d"));
}
#[tokio::test]
async fn test_rotate_creates_new_key() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mgr = KeyManager::new(storage, "svc", test_config());
let new_key = mgr.rotate().await.unwrap();
assert_eq!(new_key.status, KeyStatus::Active);
assert_eq!(new_key.service_name, "svc");
assert_eq!(new_key.format, KeyFormat::PasetoV4Local);
assert!(new_key.activated_at.is_some());
assert!(!new_key.kid.is_empty());
assert!(!new_key.key_hash.is_empty());
let parsed = uuid::Uuid::parse_str(&new_key.kid);
assert!(parsed.is_ok());
let decoded = BASE64.decode(&new_key.key_material).unwrap();
assert_eq!(decoded.len(), 32);
}
#[tokio::test]
async fn test_rotate_transitions_old_active_to_draining() {
let old = sample_metadata("kid-old", KeyStatus::Active, "svc");
let storage = Arc::new(MockStorage::with_keys(vec![old]));
let mgr = KeyManager::new(
Arc::clone(&storage) as Arc<dyn KeyRotationStorage>,
"svc",
test_config(),
);
let new_key = mgr.rotate().await.unwrap();
assert_ne!(new_key.kid, "kid-old");
let old_key = storage
.get_key_by_kid("kid-old")
.await
.unwrap()
.expect("old key should exist");
assert_eq!(old_key.status, KeyStatus::Draining);
assert!(old_key.draining_since.is_some());
}
#[tokio::test]
async fn test_rotate_refreshes_cache() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mut config = test_config();
config.check_interval_secs = 9999;
let mgr = KeyManager::new(storage, "svc", config);
let new_key = mgr.rotate().await.unwrap();
let cache = mgr.cache.read().await;
assert!(cache.active_key.is_some());
assert_eq!(cache.active_key.as_ref().unwrap().kid, new_key.kid);
}
#[tokio::test]
async fn test_retire_expired_delegates_to_storage() {
let storage = Arc::new(MockStorage::new());
let mgr = KeyManager::new(
Arc::clone(&storage) as Arc<dyn KeyRotationStorage>,
"svc",
test_config(),
);
let count = mgr.retire_expired().await.unwrap();
assert_eq!(count, 0);
assert_eq!(storage.retire_call_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_retire_expired_refreshes_cache_when_keys_retired() {
let mut draining = sample_metadata("kid-expired", KeyStatus::Draining, "svc");
draining.drain_expires_at = Some(Utc::now() - chrono::Duration::seconds(100));
let storage = Arc::new(MockStorage::with_keys(vec![draining]));
let mut config = test_config();
config.check_interval_secs = 9999;
let mgr = KeyManager::new(
Arc::clone(&storage) as Arc<dyn KeyRotationStorage>,
"svc",
config,
);
mgr.refresh_cache().await.unwrap();
assert_eq!(mgr.cache.read().await.verification_keys.len(), 1);
let count = mgr.retire_expired().await.unwrap();
assert_eq!(count, 1);
let cache = mgr.cache.read().await;
assert!(cache.verification_keys.is_empty());
}
#[tokio::test]
async fn test_rotate_key_hash_matches_blake3() {
let storage: Arc<dyn KeyRotationStorage> = Arc::new(MockStorage::new());
let mgr = KeyManager::new(storage, "svc", test_config());
let new_key = mgr.rotate().await.unwrap();
let decoded = BASE64.decode(&new_key.key_material).unwrap();
let expected_hash = blake3::hash(&decoded).to_hex().to_string();
assert_eq!(new_key.key_hash, expected_hash);
}
#[tokio::test]
async fn test_different_service_names_isolated() {
let key_a = sample_metadata("kid-a", KeyStatus::Active, "svc-a");
let key_b = sample_metadata("kid-b", KeyStatus::Active, "svc-b");
let storage: Arc<dyn KeyRotationStorage> =
Arc::new(MockStorage::with_keys(vec![key_a, key_b]));
let mgr_a = KeyManager::new(Arc::clone(&storage), "svc-a", test_config());
let mgr_b = KeyManager::new(Arc::clone(&storage), "svc-b", test_config());
mgr_a.refresh_cache().await.unwrap();
mgr_b.refresh_cache().await.unwrap();
let key_a_cached = mgr_a.cache.read().await;
let key_b_cached = mgr_b.cache.read().await;
assert_eq!(key_a_cached.active_key.as_ref().unwrap().kid, "kid-a");
assert_eq!(key_b_cached.active_key.as_ref().unwrap().kid, "kid-b");
}
}