use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use crate::aead::AeadEncryptor;
use crate::encryption::{EncryptionError, KeyStoreProvider};
use crate::key_unwrap::RsaKeyUnwrapper;
pub struct InMemoryKeyStore {
keys: HashMap<String, RsaKeyUnwrapper>,
}
impl InMemoryKeyStore {
pub fn new() -> Self {
Self {
keys: HashMap::new(),
}
}
pub fn add_key(&mut self, key_path: &str, pem: &str) -> Result<(), EncryptionError> {
let unwrapper = RsaKeyUnwrapper::from_pem(pem)?;
self.keys.insert(key_path.to_string(), unwrapper);
Ok(())
}
pub fn add_key_der(&mut self, key_path: &str, der: &[u8]) -> Result<(), EncryptionError> {
let unwrapper = RsaKeyUnwrapper::from_der(der)?;
self.keys.insert(key_path.to_string(), unwrapper);
Ok(())
}
pub fn has_key(&self, key_path: &str) -> bool {
self.keys.contains_key(key_path)
}
pub fn remove_key(&mut self, key_path: &str) -> bool {
self.keys.remove(key_path).is_some()
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
}
impl Default for InMemoryKeyStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl KeyStoreProvider for InMemoryKeyStore {
fn provider_name(&self) -> &str {
"IN_MEMORY_KEY_STORE"
}
async fn decrypt_cek(
&self,
cmk_path: &str,
_algorithm: &str,
encrypted_cek: &[u8],
) -> Result<Vec<u8>, EncryptionError> {
let unwrapper = self.keys.get(cmk_path).ok_or_else(|| {
EncryptionError::KeyStoreNotFound(format!("Key not found: {cmk_path}"))
})?;
unwrapper.decrypt_cek(encrypted_cek)
}
}
struct CekCacheEntry {
#[allow(dead_code)]
cek: Vec<u8>,
encryptor: Arc<AeadEncryptor>,
created_at: Instant,
}
pub struct CekCache {
entries: RwLock<HashMap<CekCacheKey, CekCacheEntry>>,
ttl: Duration,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CekCacheKey {
pub database_id: u32,
pub cek_id: u32,
pub cek_version: u32,
}
impl CekCacheKey {
pub fn new(database_id: u32, cek_id: u32, cek_version: u32) -> Self {
Self {
database_id,
cek_id,
cek_version,
}
}
}
impl CekCache {
pub fn new() -> Self {
Self::with_ttl(Duration::from_secs(2 * 60 * 60))
}
pub fn with_ttl(ttl: Duration) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
ttl,
}
}
pub fn get(&self, key: &CekCacheKey) -> Option<Arc<AeadEncryptor>> {
let entries = self.entries.read();
if let Some(entry) = entries.get(key) {
if entry.created_at.elapsed() < self.ttl {
return Some(Arc::clone(&entry.encryptor));
}
}
None
}
pub fn insert(
&self,
key: CekCacheKey,
cek: Vec<u8>,
) -> Result<Arc<AeadEncryptor>, EncryptionError> {
let encryptor = Arc::new(AeadEncryptor::new(&cek)?);
let entry = CekCacheEntry {
cek,
encryptor: Arc::clone(&encryptor),
created_at: Instant::now(),
};
let mut entries = self.entries.write();
entries.insert(key, entry);
Ok(encryptor)
}
pub async fn get_or_insert<F, Fut>(
&self,
key: CekCacheKey,
get_cek: F,
) -> Result<Arc<AeadEncryptor>, EncryptionError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<Vec<u8>, EncryptionError>>,
{
if let Some(encryptor) = self.get(&key) {
return Ok(encryptor);
}
let cek = get_cek().await?;
self.insert(key, cek)
}
pub fn remove(&self, key: &CekCacheKey) -> bool {
let mut entries = self.entries.write();
entries.remove(key).is_some()
}
pub fn cleanup_expired(&self) {
let mut entries = self.entries.write();
entries.retain(|_, entry| entry.created_at.elapsed() < self.ttl);
}
pub fn clear(&self) {
let mut entries = self.entries.write();
entries.clear();
}
pub fn len(&self) -> usize {
self.entries.read().len()
}
pub fn is_empty(&self) -> bool {
self.entries.read().is_empty()
}
}
impl Default for CekCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use rsa::{RsaPrivateKey, pkcs8::EncodePrivateKey};
fn generate_test_key_pem() -> String {
let mut rng = rand::thread_rng();
let key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.unwrap()
.to_string()
}
#[test]
fn test_in_memory_key_store_new() {
let store = InMemoryKeyStore::new();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
}
#[test]
fn test_in_memory_key_store_add_key() {
let mut store = InMemoryKeyStore::new();
let pem = generate_test_key_pem();
store.add_key("TestKey", &pem).unwrap();
assert!(store.has_key("TestKey"));
assert!(!store.has_key("OtherKey"));
assert_eq!(store.len(), 1);
}
#[test]
fn test_in_memory_key_store_remove_key() {
let mut store = InMemoryKeyStore::new();
let pem = generate_test_key_pem();
store.add_key("TestKey", &pem).unwrap();
assert!(store.remove_key("TestKey"));
assert!(!store.has_key("TestKey"));
assert!(!store.remove_key("TestKey"));
}
#[test]
fn test_in_memory_key_store_provider_name() {
let store = InMemoryKeyStore::new();
assert_eq!(store.provider_name(), "IN_MEMORY_KEY_STORE");
}
#[test]
fn test_cek_cache_key() {
let key1 = CekCacheKey::new(1, 2, 3);
let key2 = CekCacheKey::new(1, 2, 3);
let key3 = CekCacheKey::new(1, 2, 4);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_cek_cache_insert_and_get() {
let cache = CekCache::new();
let key = CekCacheKey::new(1, 1, 1);
let cek = vec![0x42u8; 32];
let encryptor = cache.insert(key.clone(), cek).unwrap();
assert_eq!(cache.len(), 1);
let retrieved = cache.get(&key);
assert!(retrieved.is_some());
assert!(Arc::ptr_eq(&encryptor, &retrieved.unwrap()));
}
#[test]
fn test_cek_cache_miss() {
let cache = CekCache::new();
let key = CekCacheKey::new(1, 1, 1);
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cek_cache_expiration() {
let cache = CekCache::with_ttl(Duration::from_millis(10));
let key = CekCacheKey::new(1, 1, 1);
let cek = vec![0x42u8; 32];
cache.insert(key.clone(), cek).unwrap();
assert!(cache.get(&key).is_some());
std::thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cek_cache_remove() {
let cache = CekCache::new();
let key = CekCacheKey::new(1, 1, 1);
let cek = vec![0x42u8; 32];
cache.insert(key.clone(), cek).unwrap();
assert!(cache.remove(&key));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cek_cache_clear() {
let cache = CekCache::new();
for i in 0..5 {
let key = CekCacheKey::new(i, 1, 1);
let cek = vec![0x42u8; 32];
cache.insert(key, cek).unwrap();
}
assert_eq!(cache.len(), 5);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_cek_cache_cleanup_expired() {
let cache = CekCache::with_ttl(Duration::from_millis(50));
let key1 = CekCacheKey::new(1, 1, 1);
cache.insert(key1.clone(), vec![0x42u8; 32]).unwrap();
std::thread::sleep(Duration::from_millis(30));
let key2 = CekCacheKey::new(2, 1, 1);
cache.insert(key2.clone(), vec![0x43u8; 32]).unwrap();
assert_eq!(cache.len(), 2);
std::thread::sleep(Duration::from_millis(30));
cache.cleanup_expired();
assert_eq!(cache.len(), 1);
assert!(cache.get(&key1).is_none());
assert!(cache.get(&key2).is_some());
}
}