1use std::collections::HashMap;
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31
32use parking_lot::RwLock;
33
34use crate::aead::AeadEncryptor;
35use crate::encryption::{EncryptionError, KeyStoreProvider};
36use crate::key_unwrap::RsaKeyUnwrapper;
37
38pub struct InMemoryKeyStore {
45 keys: HashMap<String, RsaKeyUnwrapper>,
47}
48
49impl InMemoryKeyStore {
50 pub fn new() -> Self {
52 Self {
53 keys: HashMap::new(),
54 }
55 }
56
57 pub fn add_key(&mut self, key_path: &str, pem: &str) -> Result<(), EncryptionError> {
68 let unwrapper = RsaKeyUnwrapper::from_pem(pem)?;
69 self.keys.insert(key_path.to_string(), unwrapper);
70 Ok(())
71 }
72
73 pub fn add_key_der(&mut self, key_path: &str, der: &[u8]) -> Result<(), EncryptionError> {
84 let unwrapper = RsaKeyUnwrapper::from_der(der)?;
85 self.keys.insert(key_path.to_string(), unwrapper);
86 Ok(())
87 }
88
89 pub fn has_key(&self, key_path: &str) -> bool {
91 self.keys.contains_key(key_path)
92 }
93
94 pub fn remove_key(&mut self, key_path: &str) -> bool {
96 self.keys.remove(key_path).is_some()
97 }
98
99 pub fn len(&self) -> usize {
101 self.keys.len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.keys.is_empty()
107 }
108}
109
110impl Default for InMemoryKeyStore {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116#[async_trait::async_trait]
117impl KeyStoreProvider for InMemoryKeyStore {
118 fn provider_name(&self) -> &str {
119 "IN_MEMORY_KEY_STORE"
120 }
121
122 async fn decrypt_cek(
123 &self,
124 cmk_path: &str,
125 _algorithm: &str,
126 encrypted_cek: &[u8],
127 ) -> Result<Vec<u8>, EncryptionError> {
128 let unwrapper = self.keys.get(cmk_path).ok_or_else(|| {
129 EncryptionError::KeyStoreNotFound(format!("Key not found: {cmk_path}"))
130 })?;
131
132 unwrapper.decrypt_cek(encrypted_cek)
133 }
134}
135
136struct CekCacheEntry {
138 #[allow(dead_code)]
140 cek: Vec<u8>,
141 encryptor: Arc<AeadEncryptor>,
143 created_at: Instant,
145}
146
147pub struct CekCache {
161 entries: RwLock<HashMap<CekCacheKey, CekCacheEntry>>,
163 ttl: Duration,
165}
166
167#[derive(Debug, Clone, PartialEq, Eq, Hash)]
169pub struct CekCacheKey {
170 pub database_id: u32,
172 pub cek_id: u32,
174 pub cek_version: u32,
176}
177
178impl CekCacheKey {
179 pub fn new(database_id: u32, cek_id: u32, cek_version: u32) -> Self {
181 Self {
182 database_id,
183 cek_id,
184 cek_version,
185 }
186 }
187}
188
189impl CekCache {
190 pub fn new() -> Self {
192 Self::with_ttl(Duration::from_secs(2 * 60 * 60))
193 }
194
195 pub fn with_ttl(ttl: Duration) -> Self {
197 Self {
198 entries: RwLock::new(HashMap::new()),
199 ttl,
200 }
201 }
202
203 pub fn get(&self, key: &CekCacheKey) -> Option<Arc<AeadEncryptor>> {
207 let entries = self.entries.read();
208 if let Some(entry) = entries.get(key) {
209 if entry.created_at.elapsed() < self.ttl {
210 return Some(Arc::clone(&entry.encryptor));
211 }
212 }
213 None
214 }
215
216 pub fn insert(
229 &self,
230 key: CekCacheKey,
231 cek: Vec<u8>,
232 ) -> Result<Arc<AeadEncryptor>, EncryptionError> {
233 let encryptor = Arc::new(AeadEncryptor::new(&cek)?);
234
235 let entry = CekCacheEntry {
236 cek,
237 encryptor: Arc::clone(&encryptor),
238 created_at: Instant::now(),
239 };
240
241 let mut entries = self.entries.write();
242 entries.insert(key, entry);
243
244 Ok(encryptor)
245 }
246
247 pub async fn get_or_insert<F, Fut>(
258 &self,
259 key: CekCacheKey,
260 get_cek: F,
261 ) -> Result<Arc<AeadEncryptor>, EncryptionError>
262 where
263 F: FnOnce() -> Fut,
264 Fut: std::future::Future<Output = Result<Vec<u8>, EncryptionError>>,
265 {
266 if let Some(encryptor) = self.get(&key) {
268 return Ok(encryptor);
269 }
270
271 let cek = get_cek().await?;
273 self.insert(key, cek)
274 }
275
276 pub fn remove(&self, key: &CekCacheKey) -> bool {
280 let mut entries = self.entries.write();
281 entries.remove(key).is_some()
282 }
283
284 pub fn cleanup_expired(&self) {
286 let mut entries = self.entries.write();
287 entries.retain(|_, entry| entry.created_at.elapsed() < self.ttl);
288 }
289
290 pub fn clear(&self) {
292 let mut entries = self.entries.write();
293 entries.clear();
294 }
295
296 pub fn len(&self) -> usize {
298 self.entries.read().len()
299 }
300
301 pub fn is_empty(&self) -> bool {
303 self.entries.read().is_empty()
304 }
305}
306
307impl Default for CekCache {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
314#[allow(clippy::unwrap_used, clippy::expect_used)]
315mod tests {
316 use super::*;
317 use rsa::{RsaPrivateKey, pkcs8::EncodePrivateKey};
318
319 fn generate_test_key_pem() -> String {
320 let mut rng = rand::thread_rng();
321 let key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
322 key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
323 .unwrap()
324 .to_string()
325 }
326
327 #[test]
328 fn test_in_memory_key_store_new() {
329 let store = InMemoryKeyStore::new();
330 assert!(store.is_empty());
331 assert_eq!(store.len(), 0);
332 }
333
334 #[test]
335 fn test_in_memory_key_store_add_key() {
336 let mut store = InMemoryKeyStore::new();
337 let pem = generate_test_key_pem();
338
339 store.add_key("TestKey", &pem).unwrap();
340 assert!(store.has_key("TestKey"));
341 assert!(!store.has_key("OtherKey"));
342 assert_eq!(store.len(), 1);
343 }
344
345 #[test]
346 fn test_in_memory_key_store_remove_key() {
347 let mut store = InMemoryKeyStore::new();
348 let pem = generate_test_key_pem();
349
350 store.add_key("TestKey", &pem).unwrap();
351 assert!(store.remove_key("TestKey"));
352 assert!(!store.has_key("TestKey"));
353 assert!(!store.remove_key("TestKey"));
354 }
355
356 #[test]
357 fn test_in_memory_key_store_provider_name() {
358 let store = InMemoryKeyStore::new();
359 assert_eq!(store.provider_name(), "IN_MEMORY_KEY_STORE");
360 }
361
362 #[test]
363 fn test_cek_cache_key() {
364 let key1 = CekCacheKey::new(1, 2, 3);
365 let key2 = CekCacheKey::new(1, 2, 3);
366 let key3 = CekCacheKey::new(1, 2, 4);
367
368 assert_eq!(key1, key2);
369 assert_ne!(key1, key3);
370 }
371
372 #[test]
373 fn test_cek_cache_insert_and_get() {
374 let cache = CekCache::new();
375 let key = CekCacheKey::new(1, 1, 1);
376 let cek = vec![0x42u8; 32];
377
378 let encryptor = cache.insert(key.clone(), cek).unwrap();
380 assert_eq!(cache.len(), 1);
381
382 let retrieved = cache.get(&key);
384 assert!(retrieved.is_some());
385 assert!(Arc::ptr_eq(&encryptor, &retrieved.unwrap()));
386 }
387
388 #[test]
389 fn test_cek_cache_miss() {
390 let cache = CekCache::new();
391 let key = CekCacheKey::new(1, 1, 1);
392
393 assert!(cache.get(&key).is_none());
394 }
395
396 #[test]
397 fn test_cek_cache_expiration() {
398 let cache = CekCache::with_ttl(Duration::from_millis(10));
399 let key = CekCacheKey::new(1, 1, 1);
400 let cek = vec![0x42u8; 32];
401
402 cache.insert(key.clone(), cek).unwrap();
403 assert!(cache.get(&key).is_some());
404
405 std::thread::sleep(Duration::from_millis(20));
407 assert!(cache.get(&key).is_none());
408 }
409
410 #[test]
411 fn test_cek_cache_remove() {
412 let cache = CekCache::new();
413 let key = CekCacheKey::new(1, 1, 1);
414 let cek = vec![0x42u8; 32];
415
416 cache.insert(key.clone(), cek).unwrap();
417 assert!(cache.remove(&key));
418 assert!(cache.get(&key).is_none());
419 }
420
421 #[test]
422 fn test_cek_cache_clear() {
423 let cache = CekCache::new();
424
425 for i in 0..5 {
426 let key = CekCacheKey::new(i, 1, 1);
427 let cek = vec![0x42u8; 32];
428 cache.insert(key, cek).unwrap();
429 }
430
431 assert_eq!(cache.len(), 5);
432 cache.clear();
433 assert!(cache.is_empty());
434 }
435
436 #[test]
437 fn test_cek_cache_cleanup_removes_expired() {
438 let cache = CekCache::with_ttl(Duration::from_millis(20));
442 cache
443 .insert(CekCacheKey::new(1, 1, 1), vec![0x42u8; 32])
444 .unwrap();
445 assert_eq!(cache.len(), 1);
446
447 std::thread::sleep(Duration::from_millis(150));
448 cache.cleanup_expired();
449
450 assert!(cache.is_empty());
451 }
452
453 #[test]
454 fn test_cek_cache_cleanup_keeps_fresh() {
455 let cache = CekCache::with_ttl(Duration::from_secs(3600));
457 let key = CekCacheKey::new(1, 1, 1);
458 cache.insert(key.clone(), vec![0x42u8; 32]).unwrap();
459
460 cache.cleanup_expired();
461
462 assert_eq!(cache.len(), 1);
463 assert!(cache.get(&key).is_some());
464 }
465}