1use chrono::{DateTime, Utc};
2use ohttp::{
3 KeyConfig, Server as OhttpServer, SymmetricSuite,
4 hpke::{Aead, Kdf, Kem},
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tracing::{error, info};
12
13#[derive(Clone, Debug)]
15pub struct KeyInfo {
16 pub id: u8,
17 pub config: KeyConfig,
18 pub server: OhttpServer,
19 pub expires_at: DateTime<Utc>,
20 pub is_active: bool,
21}
22
23#[derive(Clone, Debug, Deserialize, Serialize)]
25pub struct KeyManagerConfig {
26 pub rotation_interval: Duration,
28 pub key_retention_period: Duration,
30 pub auto_rotation_enabled: bool,
32 pub cipher_suites: Vec<CipherSuiteConfig>,
34}
35
36#[derive(Clone, Debug, Deserialize, Serialize)]
37pub struct CipherSuiteConfig {
38 pub kem: String,
39 pub kdf: String,
40 pub aead: String,
41}
42
43impl Default for KeyManagerConfig {
44 fn default() -> Self {
45 Self {
46 rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), auto_rotation_enabled: true,
49 cipher_suites: vec![
50 CipherSuiteConfig {
51 kem: "X25519_SHA256".to_string(),
52 kdf: "HKDF_SHA256".to_string(),
53 aead: "AES_128_GCM".to_string(),
54 },
55 CipherSuiteConfig {
56 kem: "X25519_SHA256".to_string(),
57 kdf: "HKDF_SHA256".to_string(),
58 aead: "CHACHA20_POLY1305".to_string(),
59 },
60 ],
61 }
62 }
63}
64
65pub struct KeyManager {
66 keys: Arc<RwLock<HashMap<u8, KeyInfo>>>,
68 active_key_id: Arc<RwLock<u8>>,
70 config: KeyManagerConfig,
72 next_key_id: Arc<RwLock<u8>>,
74 seed: Option<Vec<u8>>,
76}
77
78impl KeyManager {
79 pub async fn new(config: KeyManagerConfig) -> Result<Self, Box<dyn std::error::Error>> {
80 let manager = Self {
81 keys: Arc::new(RwLock::new(HashMap::new())),
82 active_key_id: Arc::new(RwLock::new(0)),
83 config,
84 next_key_id: Arc::new(RwLock::new(1)),
85 seed: None,
86 };
87
88 let initial_key = manager.generate_new_key().await?;
90 {
91 let mut keys = manager.keys.write().await;
92 let mut active_id = manager.active_key_id.write().await;
93
94 keys.insert(initial_key.id, initial_key.clone());
95 *active_id = initial_key.id;
96 }
97
98 info!("KeyManager initialized with key ID: {}", initial_key.id);
99 Ok(manager)
100 }
101
102 pub async fn new_with_seed(
104 config: KeyManagerConfig,
105 seed: Vec<u8>,
106 ) -> Result<Self, Box<dyn std::error::Error>> {
107 if seed.len() < 32 {
108 return Err("Seed must be at least 32 bytes".into());
109 }
110
111 let manager = Self {
112 keys: Arc::new(RwLock::new(HashMap::new())),
113 active_key_id: Arc::new(RwLock::new(0)),
114 config,
115 next_key_id: Arc::new(RwLock::new(1)),
116 seed: Some(seed),
117 };
118
119 let initial_key = manager.generate_new_key().await?;
121 {
122 let mut keys = manager.keys.write().await;
123 let mut active_id = manager.active_key_id.write().await;
124
125 keys.insert(initial_key.id, initial_key.clone());
126 *active_id = initial_key.id;
127 }
128
129 info!("KeyManager initialized with key ID: {}", initial_key.id);
130 Ok(manager)
131 }
132
133 async fn generate_new_key(&self) -> Result<KeyInfo, Box<dyn std::error::Error>> {
135 let key_id = {
136 let mut next_id = self.next_key_id.write().await;
137 let id = *next_id;
138 *next_id = next_id.wrapping_add(1);
139 id
140 };
141
142 let mut symmetric_suites = Vec::new();
144 for suite in &self.config.cipher_suites {
145 let kdf = match suite.kdf.as_str() {
146 "HKDF_SHA256" => Kdf::HkdfSha256,
147 "HKDF_SHA384" => Kdf::HkdfSha384,
148 "HKDF_SHA512" => Kdf::HkdfSha512,
149 _ => Kdf::HkdfSha256,
150 };
151
152 let aead = match suite.aead.as_str() {
153 "AES_128_GCM" => Aead::Aes128Gcm,
154 "AES_256_GCM" => Aead::Aes256Gcm,
155 "CHACHA20_POLY1305" => Aead::ChaCha20Poly1305,
156 _ => Aead::Aes128Gcm,
157 };
158
159 symmetric_suites.push(SymmetricSuite::new(kdf, aead));
160 }
161
162 if symmetric_suites.is_empty() {
164 return Err("No valid cipher suites configured".into());
165 }
166
167 let kem = Kem::X25519Sha256;
169
170 let key_config = if let Some(seed) = &self.seed {
172 let mut key_seed = seed.clone();
174 key_seed.push(key_id);
175
176 KeyConfig::derive(key_id, kem, symmetric_suites, &key_seed)?
177 } else {
178 KeyConfig::new(key_id, kem, symmetric_suites)?
179 };
180
181 let server = OhttpServer::new(key_config.clone())?;
182 let now = Utc::now();
183
184 Ok(KeyInfo {
185 id: key_id,
186 config: key_config,
187 server,
188 expires_at: now + chrono::Duration::from_std(self.config.rotation_interval)?,
189 is_active: true,
190 })
191 }
192
193 pub async fn get_current_server(&self) -> Result<OhttpServer, Box<dyn std::error::Error>> {
195 let keys = self.keys.read().await;
196 let active_id = self.active_key_id.read().await;
197
198 keys.get(&*active_id)
199 .map(|info| info.server.clone())
200 .ok_or_else(|| "No active key found".into())
201 }
202
203 pub async fn get_server_by_id(&self, key_id: u8) -> Option<OhttpServer> {
205 let keys = self.keys.read().await;
206 keys.get(&key_id).map(|info| info.server.clone())
207 }
208
209 pub async fn get_encoded_config(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
211 let keys = self.keys.read().await;
212 let active_id = self.active_key_id.read().await;
213 let cfg_bytes = keys
214 .get(&*active_id)
215 .ok_or("no active key")?
216 .config
217 .encode()?;
218
219 let mut out = Vec::with_capacity(2 + cfg_bytes.len());
220 out.extend_from_slice(&(cfg_bytes.len() as u16).to_be_bytes());
222 out.extend_from_slice(&cfg_bytes);
223 Ok(out)
224 }
225
226 pub async fn rotate_keys(&self) -> Result<(), Box<dyn std::error::Error>> {
228 info!("Starting key rotation");
229
230 let new_key = self.generate_new_key().await?;
232 let new_key_id = new_key.id;
233
234 {
236 let mut keys = self.keys.write().await;
237 let mut active_id = self.active_key_id.write().await;
238 let now = Utc::now();
239
240 if let Some(current_key) = keys.get_mut(&*active_id) {
242 current_key.is_active = false;
243 current_key.expires_at =
245 now + chrono::Duration::from_std(self.config.key_retention_period)?;
246 }
247
248 keys.insert(new_key_id, new_key);
250
251 *active_id = new_key_id;
253
254 keys.retain(|_, info| info.expires_at > now);
256
257 info!(
258 "Key rotation completed. New active key ID: {}, total keys: {}",
259 new_key_id,
260 keys.len()
261 );
262 }
263
264 Ok(())
265 }
266
267 pub async fn should_rotate(&self) -> bool {
269 let keys = self.keys.read().await;
270 let active_id = self.active_key_id.read().await;
271
272 if let Some(active_key) = keys.get(&*active_id) {
273 let time_until_expiry = active_key.expires_at.signed_duration_since(Utc::now());
274
275 let threshold = chrono::Duration::from_std(self.config.rotation_interval / 10)
277 .unwrap_or_else(|_| chrono::Duration::days(3));
278
279 time_until_expiry < threshold
280 } else {
281 true }
283 }
284
285 pub async fn start_rotation_scheduler(self: Arc<Self>) {
287 if !self.config.auto_rotation_enabled {
288 info!("Automatic key rotation is disabled");
289 return;
290 }
291
292 let manager = self;
293 tokio::spawn(async move {
294 let mut interval = tokio::time::interval(manager.config.rotation_interval);
296
297 loop {
298 interval.tick().await;
299
300 if manager.should_rotate().await
301 && let Err(e) = manager.rotate_keys().await
302 {
303 error!("Key rotation failed: {}", e);
304 }
305
306 manager.cleanup_expired_keys().await;
308 }
309 });
310 }
311
312 async fn cleanup_expired_keys(&self) {
314 let mut keys = self.keys.write().await;
315 let now = Utc::now();
316 let before_count = keys.len();
317
318 keys.retain(|id, info| {
319 if info.expires_at <= now {
320 info!("Removing expired key ID: {}", id);
321 false
322 } else {
323 true
324 }
325 });
326
327 let removed = before_count - keys.len();
328 if removed > 0 {
329 info!("Cleaned up {} expired keys", removed);
330 }
331 }
332
333 pub async fn get_stats(&self) -> KeyManagerStats {
335 let keys = self.keys.read().await;
336 let active_id = self.active_key_id.read().await;
337 let now = Utc::now();
338
339 let active_keys = keys.values().filter(|k| k.is_active).count();
340 let total_keys = keys.len();
341 let expired_keys = keys.values().filter(|k| k.expires_at <= now).count();
342
343 KeyManagerStats {
344 active_key_id: *active_id,
345 total_keys,
346 active_keys,
347 expired_keys,
348 rotation_interval: self.config.rotation_interval,
349 auto_rotation_enabled: self.config.auto_rotation_enabled,
350 }
351 }
352}
353
354#[derive(Debug, Serialize)]
355pub struct KeyManagerStats {
356 pub active_key_id: u8,
357 pub total_keys: usize,
358 pub active_keys: usize,
359 pub expired_keys: usize,
360 pub rotation_interval: Duration,
361 pub auto_rotation_enabled: bool,
362}
363
364unsafe impl Send for KeyManager {}
366unsafe impl Sync for KeyManager {}