use chrono::{DateTime, Utc};
use ohttp::{
KeyConfig, Server as OhttpServer, SymmetricSuite,
hpke::{Aead, Kdf, Kem},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{error, info};
#[derive(Clone, Debug)]
pub struct KeyInfo {
pub id: u8,
pub config: KeyConfig,
pub server: OhttpServer,
pub expires_at: DateTime<Utc>,
pub is_active: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct KeyManagerConfig {
pub rotation_interval: Duration,
pub key_retention_period: Duration,
pub auto_rotation_enabled: bool,
pub cipher_suites: Vec<CipherSuiteConfig>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CipherSuiteConfig {
pub kem: String,
pub kdf: String,
pub aead: String,
}
impl Default for KeyManagerConfig {
fn default() -> Self {
Self {
rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), auto_rotation_enabled: true,
cipher_suites: vec![
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "AES_128_GCM".to_string(),
},
CipherSuiteConfig {
kem: "X25519_SHA256".to_string(),
kdf: "HKDF_SHA256".to_string(),
aead: "CHACHA20_POLY1305".to_string(),
},
],
}
}
}
pub struct KeyManager {
keys: Arc<RwLock<HashMap<u8, KeyInfo>>>,
active_key_id: Arc<RwLock<u8>>,
config: KeyManagerConfig,
next_key_id: Arc<RwLock<u8>>,
seed: Option<Vec<u8>>,
}
impl KeyManager {
pub async fn new(config: KeyManagerConfig) -> Result<Self, Box<dyn std::error::Error>> {
let manager = Self {
keys: Arc::new(RwLock::new(HashMap::new())),
active_key_id: Arc::new(RwLock::new(0)),
config,
next_key_id: Arc::new(RwLock::new(1)),
seed: None,
};
let initial_key = manager.generate_new_key().await?;
{
let mut keys = manager.keys.write().await;
let mut active_id = manager.active_key_id.write().await;
keys.insert(initial_key.id, initial_key.clone());
*active_id = initial_key.id;
}
info!("KeyManager initialized with key ID: {}", initial_key.id);
Ok(manager)
}
pub async fn new_with_seed(
config: KeyManagerConfig,
seed: Vec<u8>,
) -> Result<Self, Box<dyn std::error::Error>> {
if seed.len() < 32 {
return Err("Seed must be at least 32 bytes".into());
}
let manager = Self {
keys: Arc::new(RwLock::new(HashMap::new())),
active_key_id: Arc::new(RwLock::new(0)),
config,
next_key_id: Arc::new(RwLock::new(1)),
seed: Some(seed),
};
let initial_key = manager.generate_new_key().await?;
{
let mut keys = manager.keys.write().await;
let mut active_id = manager.active_key_id.write().await;
keys.insert(initial_key.id, initial_key.clone());
*active_id = initial_key.id;
}
info!("KeyManager initialized with key ID: {}", initial_key.id);
Ok(manager)
}
async fn generate_new_key(&self) -> Result<KeyInfo, Box<dyn std::error::Error>> {
let key_id = {
let mut next_id = self.next_key_id.write().await;
let id = *next_id;
*next_id = next_id.wrapping_add(1);
id
};
let mut symmetric_suites = Vec::new();
for suite in &self.config.cipher_suites {
let kdf = match suite.kdf.as_str() {
"HKDF_SHA256" => Kdf::HkdfSha256,
"HKDF_SHA384" => Kdf::HkdfSha384,
"HKDF_SHA512" => Kdf::HkdfSha512,
_ => Kdf::HkdfSha256,
};
let aead = match suite.aead.as_str() {
"AES_128_GCM" => Aead::Aes128Gcm,
"AES_256_GCM" => Aead::Aes256Gcm,
"CHACHA20_POLY1305" => Aead::ChaCha20Poly1305,
_ => Aead::Aes128Gcm,
};
symmetric_suites.push(SymmetricSuite::new(kdf, aead));
}
if symmetric_suites.is_empty() {
return Err("No valid cipher suites configured".into());
}
let kem = Kem::X25519Sha256;
let key_config = if let Some(seed) = &self.seed {
let mut key_seed = seed.clone();
key_seed.push(key_id);
KeyConfig::derive(key_id, kem, symmetric_suites, &key_seed)?
} else {
KeyConfig::new(key_id, kem, symmetric_suites)?
};
let server = OhttpServer::new(key_config.clone())?;
let now = Utc::now();
Ok(KeyInfo {
id: key_id,
config: key_config,
server,
expires_at: now + chrono::Duration::from_std(self.config.rotation_interval)?,
is_active: true,
})
}
pub async fn get_current_server(&self) -> Result<OhttpServer, Box<dyn std::error::Error>> {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
keys.get(&*active_id)
.map(|info| info.server.clone())
.ok_or_else(|| "No active key found".into())
}
pub async fn get_server_by_id(&self, key_id: u8) -> Option<OhttpServer> {
let keys = self.keys.read().await;
keys.get(&key_id).map(|info| info.server.clone())
}
pub async fn get_encoded_config(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
let cfg_bytes = keys
.get(&*active_id)
.ok_or("no active key")?
.config
.encode()?;
let mut out = Vec::with_capacity(cfg_bytes.len() + 2);
out.extend_from_slice(&(cfg_bytes.len() as u16).to_be_bytes()); out.extend_from_slice(&cfg_bytes);
Ok(out)
}
pub async fn rotate_keys(&self) -> Result<(), Box<dyn std::error::Error>> {
info!("Starting key rotation");
let new_key = self.generate_new_key().await?;
let new_key_id = new_key.id;
{
let mut keys = self.keys.write().await;
let mut active_id = self.active_key_id.write().await;
let now = Utc::now();
if let Some(current_key) = keys.get_mut(&*active_id) {
current_key.is_active = false;
current_key.expires_at =
now + chrono::Duration::from_std(self.config.key_retention_period)?;
}
keys.insert(new_key_id, new_key);
*active_id = new_key_id;
keys.retain(|_, info| info.expires_at > now);
info!(
"Key rotation completed. New active key ID: {}, total keys: {}",
new_key_id,
keys.len()
);
}
Ok(())
}
pub async fn should_rotate(&self) -> bool {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
if let Some(active_key) = keys.get(&*active_id) {
let time_until_expiry = active_key.expires_at.signed_duration_since(Utc::now());
let threshold = chrono::Duration::from_std(self.config.rotation_interval / 10)
.unwrap_or_else(|_| chrono::Duration::days(3));
time_until_expiry < threshold
} else {
true }
}
pub async fn start_rotation_scheduler(self: Arc<Self>) {
if !self.config.auto_rotation_enabled {
info!("Automatic key rotation is disabled");
return;
}
let manager = self;
tokio::spawn(async move {
let mut interval = tokio::time::interval(manager.config.rotation_interval);
loop {
interval.tick().await;
if manager.should_rotate().await {
if let Err(e) = manager.rotate_keys().await {
error!("Key rotation failed: {}", e);
}
}
manager.cleanup_expired_keys().await;
}
});
}
async fn cleanup_expired_keys(&self) {
let mut keys = self.keys.write().await;
let now = Utc::now();
let before_count = keys.len();
keys.retain(|id, info| {
if info.expires_at <= now {
info!("Removing expired key ID: {}", id);
false
} else {
true
}
});
let removed = before_count - keys.len();
if removed > 0 {
info!("Cleaned up {} expired keys", removed);
}
}
pub async fn get_stats(&self) -> KeyManagerStats {
let keys = self.keys.read().await;
let active_id = self.active_key_id.read().await;
let now = Utc::now();
let active_keys = keys.values().filter(|k| k.is_active).count();
let total_keys = keys.len();
let expired_keys = keys.values().filter(|k| k.expires_at <= now).count();
KeyManagerStats {
active_key_id: *active_id,
total_keys,
active_keys,
expired_keys,
rotation_interval: self.config.rotation_interval,
auto_rotation_enabled: self.config.auto_rotation_enabled,
}
}
}
#[derive(Debug, Serialize)]
pub struct KeyManagerStats {
pub active_key_id: u8,
pub total_keys: usize,
pub active_keys: usize,
pub expired_keys: usize,
pub rotation_interval: Duration,
pub auto_rotation_enabled: bool,
}
unsafe impl Send for KeyManager {}
unsafe impl Sync for KeyManager {}