Skip to main content

ohttp_gateway/
key_manager.rs

1use chrono::{DateTime, Utc};
2use ohttp::{
3    KeyConfig, Server as OhttpServer, SymmetricSuite,
4    hpke::{Aead, Kdf, Kem},
5};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::RwLock;
11use tracing::{error, info};
12
13/// Represents a key with its metadata
14#[derive(Clone, Debug)]
15pub struct KeyInfo {
16    pub id: u8,
17    pub config: KeyConfig,
18    pub server: OhttpServer,
19    pub expires_at: DateTime<Utc>,
20    pub is_active: bool,
21}
22
23/// Configuration for key management
24#[derive(Clone, Debug, Deserialize, Serialize)]
25pub struct KeyManagerConfig {
26    /// How often to rotate keys (default: 30 days)
27    pub rotation_interval: Duration,
28    /// How long to keep old keys for decryption (default: 7 days)
29    pub key_retention_period: Duration,
30    /// Whether to enable automatic rotation
31    pub auto_rotation_enabled: bool,
32    /// Supported cipher suites
33    pub cipher_suites: Vec<CipherSuiteConfig>,
34}
35
36#[derive(Clone, Debug, Deserialize, Serialize)]
37pub struct CipherSuiteConfig {
38    pub kem: String,
39    pub kdf: String,
40    pub aead: String,
41}
42
43impl Default for KeyManagerConfig {
44    fn default() -> Self {
45        Self {
46            rotation_interval: Duration::from_secs(30 * 24 * 60 * 60), // 30 days
47            key_retention_period: Duration::from_secs(7 * 24 * 60 * 60), // 7 days
48            auto_rotation_enabled: true,
49            cipher_suites: vec![
50                CipherSuiteConfig {
51                    kem: "X25519_SHA256".to_string(),
52                    kdf: "HKDF_SHA256".to_string(),
53                    aead: "AES_128_GCM".to_string(),
54                },
55                CipherSuiteConfig {
56                    kem: "X25519_SHA256".to_string(),
57                    kdf: "HKDF_SHA256".to_string(),
58                    aead: "CHACHA20_POLY1305".to_string(),
59                },
60            ],
61        }
62    }
63}
64
65pub struct KeyManager {
66    /// All keys indexed by ID
67    keys: Arc<RwLock<HashMap<u8, KeyInfo>>>,
68    /// Current active key ID
69    active_key_id: Arc<RwLock<u8>>,
70    /// Configuration
71    config: KeyManagerConfig,
72    /// Key ID counter (wraps around after 255)
73    next_key_id: Arc<RwLock<u8>>,
74    /// Seed for deterministic key generation (optional)
75    seed: Option<Vec<u8>>,
76}
77
78impl KeyManager {
79    pub async fn new(config: KeyManagerConfig) -> Result<Self, Box<dyn std::error::Error>> {
80        let manager = Self {
81            keys: Arc::new(RwLock::new(HashMap::new())),
82            active_key_id: Arc::new(RwLock::new(0)),
83            config,
84            next_key_id: Arc::new(RwLock::new(1)),
85            seed: None,
86        };
87
88        // Generate initial key
89        let initial_key = manager.generate_new_key().await?;
90        {
91            let mut keys = manager.keys.write().await;
92            let mut active_id = manager.active_key_id.write().await;
93
94            keys.insert(initial_key.id, initial_key.clone());
95            *active_id = initial_key.id;
96        }
97
98        info!("KeyManager initialized with key ID: {}", initial_key.id);
99        Ok(manager)
100    }
101
102    /// Create a key manager with a seed for deterministic key generation
103    pub async fn new_with_seed(
104        config: KeyManagerConfig,
105        seed: Vec<u8>,
106    ) -> Result<Self, Box<dyn std::error::Error>> {
107        if seed.len() < 32 {
108            return Err("Seed must be at least 32 bytes".into());
109        }
110
111        let manager = Self {
112            keys: Arc::new(RwLock::new(HashMap::new())),
113            active_key_id: Arc::new(RwLock::new(0)),
114            config,
115            next_key_id: Arc::new(RwLock::new(1)),
116            seed: Some(seed),
117        };
118
119        // Generate initial key (will now use the seed)
120        let initial_key = manager.generate_new_key().await?;
121        {
122            let mut keys = manager.keys.write().await;
123            let mut active_id = manager.active_key_id.write().await;
124
125            keys.insert(initial_key.id, initial_key.clone());
126            *active_id = initial_key.id;
127        }
128
129        info!("KeyManager initialized with key ID: {}", initial_key.id);
130        Ok(manager)
131    }
132
133    /// Generate a new key configuration
134    async fn generate_new_key(&self) -> Result<KeyInfo, Box<dyn std::error::Error>> {
135        let key_id = {
136            let mut next_id = self.next_key_id.write().await;
137            let id = *next_id;
138            *next_id = next_id.wrapping_add(1);
139            id
140        };
141
142        // Parse cipher suites from config
143        let mut symmetric_suites = Vec::new();
144        for suite in &self.config.cipher_suites {
145            let kdf = match suite.kdf.as_str() {
146                "HKDF_SHA256" => Kdf::HkdfSha256,
147                "HKDF_SHA384" => Kdf::HkdfSha384,
148                "HKDF_SHA512" => Kdf::HkdfSha512,
149                _ => Kdf::HkdfSha256,
150            };
151
152            let aead = match suite.aead.as_str() {
153                "AES_128_GCM" => Aead::Aes128Gcm,
154                "AES_256_GCM" => Aead::Aes256Gcm,
155                "CHACHA20_POLY1305" => Aead::ChaCha20Poly1305,
156                _ => Aead::Aes128Gcm,
157            };
158
159            symmetric_suites.push(SymmetricSuite::new(kdf, aead));
160        }
161
162        // Validate that we have at least one cipher suite
163        if symmetric_suites.is_empty() {
164            return Err("No valid cipher suites configured".into());
165        }
166
167        // Determine KEM based on config - only X25519 is supported by ohttp crate
168        let kem = Kem::X25519Sha256;
169
170        // Generate key config
171        let key_config = if let Some(seed) = &self.seed {
172            // Deterministic generation using seed + key_id
173            let mut key_seed = seed.clone();
174            key_seed.push(key_id);
175
176            KeyConfig::derive(key_id, kem, symmetric_suites, &key_seed)?
177        } else {
178            KeyConfig::new(key_id, kem, symmetric_suites)?
179        };
180
181        let server = OhttpServer::new(key_config.clone())?;
182        let now = Utc::now();
183
184        Ok(KeyInfo {
185            id: key_id,
186            config: key_config,
187            server,
188            expires_at: now + chrono::Duration::from_std(self.config.rotation_interval)?,
189            is_active: true,
190        })
191    }
192
193    /// Get the current active server for decryption
194    pub async fn get_current_server(&self) -> Result<OhttpServer, Box<dyn std::error::Error>> {
195        let keys = self.keys.read().await;
196        let active_id = self.active_key_id.read().await;
197
198        keys.get(&*active_id)
199            .map(|info| info.server.clone())
200            .ok_or_else(|| "No active key found".into())
201    }
202
203    /// Get a server by key ID (for handling requests with specific key IDs)
204    pub async fn get_server_by_id(&self, key_id: u8) -> Option<OhttpServer> {
205        let keys = self.keys.read().await;
206        keys.get(&key_id).map(|info| info.server.clone())
207    }
208
209    /// Get encoded config with length prefix per RFC 9458 Section 3.2
210    pub async fn get_encoded_config(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
211        let keys = self.keys.read().await;
212        let active_id = self.active_key_id.read().await;
213        let cfg_bytes = keys
214            .get(&*active_id)
215            .ok_or("no active key")?
216            .config
217            .encode()?;
218
219        let mut out = Vec::with_capacity(2 + cfg_bytes.len());
220        // Add 2-byte length prefix in network byte order per RFC 9458
221        out.extend_from_slice(&(cfg_bytes.len() as u16).to_be_bytes());
222        out.extend_from_slice(&cfg_bytes);
223        Ok(out)
224    }
225
226    /// Rotate keys by generating a new key and marking old ones for expiration
227    pub async fn rotate_keys(&self) -> Result<(), Box<dyn std::error::Error>> {
228        info!("Starting key rotation");
229
230        // Generate new key
231        let new_key = self.generate_new_key().await?;
232        let new_key_id = new_key.id;
233
234        // Update key store
235        {
236            let mut keys = self.keys.write().await;
237            let mut active_id = self.active_key_id.write().await;
238            let now = Utc::now();
239
240            // Mark current active key for future expiration
241            if let Some(current_key) = keys.get_mut(&*active_id) {
242                current_key.is_active = false;
243                // Keep it around for the retention period
244                current_key.expires_at =
245                    now + chrono::Duration::from_std(self.config.key_retention_period)?;
246            }
247
248            // Add new key
249            keys.insert(new_key_id, new_key);
250
251            // Update active key ID
252            *active_id = new_key_id;
253
254            // Clean up expired keys
255            keys.retain(|_, info| info.expires_at > now);
256
257            info!(
258                "Key rotation completed. New active key ID: {}, total keys: {}",
259                new_key_id,
260                keys.len()
261            );
262        }
263
264        Ok(())
265    }
266
267    /// Check if rotation is needed
268    pub async fn should_rotate(&self) -> bool {
269        let keys = self.keys.read().await;
270        let active_id = self.active_key_id.read().await;
271
272        if let Some(active_key) = keys.get(&*active_id) {
273            let time_until_expiry = active_key.expires_at.signed_duration_since(Utc::now());
274
275            // Rotate if less than 10% of the rotation interval remains
276            let threshold = chrono::Duration::from_std(self.config.rotation_interval / 10)
277                .unwrap_or_else(|_| chrono::Duration::days(3));
278
279            time_until_expiry < threshold
280        } else {
281            true // No active key, definitely need to rotate
282        }
283    }
284
285    /// Start automatic key rotation scheduler
286    pub async fn start_rotation_scheduler(self: Arc<Self>) {
287        if !self.config.auto_rotation_enabled {
288            info!("Automatic key rotation is disabled");
289            return;
290        }
291
292        let manager = self;
293        tokio::spawn(async move {
294            // Use the configured rotation interval for the scheduler
295            let mut interval = tokio::time::interval(manager.config.rotation_interval);
296
297            loop {
298                interval.tick().await;
299
300                if manager.should_rotate().await
301                    && let Err(e) = manager.rotate_keys().await
302                {
303                    error!("Key rotation failed: {}", e);
304                }
305
306                // Also clean up expired keys
307                manager.cleanup_expired_keys().await;
308            }
309        });
310    }
311
312    /// Clean up expired keys
313    async fn cleanup_expired_keys(&self) {
314        let mut keys = self.keys.write().await;
315        let now = Utc::now();
316        let before_count = keys.len();
317
318        keys.retain(|id, info| {
319            if info.expires_at <= now {
320                info!("Removing expired key ID: {}", id);
321                false
322            } else {
323                true
324            }
325        });
326
327        let removed = before_count - keys.len();
328        if removed > 0 {
329            info!("Cleaned up {} expired keys", removed);
330        }
331    }
332
333    /// Get key manager statistics
334    pub async fn get_stats(&self) -> KeyManagerStats {
335        let keys = self.keys.read().await;
336        let active_id = self.active_key_id.read().await;
337        let now = Utc::now();
338
339        let active_keys = keys.values().filter(|k| k.is_active).count();
340        let total_keys = keys.len();
341        let expired_keys = keys.values().filter(|k| k.expires_at <= now).count();
342
343        KeyManagerStats {
344            active_key_id: *active_id,
345            total_keys,
346            active_keys,
347            expired_keys,
348            rotation_interval: self.config.rotation_interval,
349            auto_rotation_enabled: self.config.auto_rotation_enabled,
350        }
351    }
352}
353
354#[derive(Debug, Serialize)]
355pub struct KeyManagerStats {
356    pub active_key_id: u8,
357    pub total_keys: usize,
358    pub active_keys: usize,
359    pub expired_keys: usize,
360    pub rotation_interval: Duration,
361    pub auto_rotation_enabled: bool,
362}
363
364// Ensure thread safety
365unsafe impl Send for KeyManager {}
366unsafe impl Sync for KeyManager {}