use std::sync::Arc;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use parking_lot::RwLock;
use std::time::{Duration, Instant};
use ring::digest;
pub struct KeyCache {
cache: RwLock<lru::LruCache<KeyHash, CachedKey>>,
stats: CacheStats,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KeyHash([u8; 32]);
impl KeyHash {
pub fn from_bytes(data: &[u8]) -> Self {
let digest = digest::digest(&digest::SHA256, data);
let mut hash = [0u8; 32];
hash.copy_from_slice(digest.as_ref());
KeyHash(hash)
}
}
#[derive(Clone)]
pub struct CachedKey {
pub data: Vec<u8>,
pub cached_at: Instant,
pub access_count: usize,
pub key_type: KeyType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeyType {
MlKemPublic,
MlKemSecret,
Transport,
Symmetric,
}
#[derive(Default)]
struct CacheStats {
hits: std::sync::atomic::AtomicUsize,
misses: std::sync::atomic::AtomicUsize,
evictions: std::sync::atomic::AtomicUsize,
}
impl KeyCache {
pub fn new(capacity: usize) -> Self {
Self {
cache: RwLock::new(lru::LruCache::new(capacity.try_into().unwrap())),
stats: CacheStats::default(),
}
}
pub fn insert(&self, key_data: &[u8], key_type: KeyType) -> KeyHash {
let key_hash = KeyHash::from_bytes(key_data);
let cached_key = CachedKey {
data: key_data.to_vec(),
cached_at: Instant::now(),
access_count: 0,
key_type,
};
let mut cache = self.cache.write();
if cache.put(key_hash, cached_key).is_some() {
self.stats.evictions.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
key_hash
}
pub fn get(&self, key_hash: &KeyHash) -> Option<CachedKey> {
let mut cache = self.cache.write();
if let Some(mut cached_key) = cache.get_mut(key_hash) {
cached_key.access_count += 1;
self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Some(cached_key.clone())
} else {
self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
None
}
}
pub fn contains(&self, key_hash: &KeyHash) -> bool {
let cache = self.cache.read();
cache.contains(key_hash)
}
pub fn cleanup_expired(&self, max_age: Duration) {
let mut cache = self.cache.write();
let now = Instant::now();
let expired_keys: Vec<KeyHash> = cache
.iter()
.filter_map(|(hash, key)| {
if now.duration_since(key.cached_at) > max_age {
Some(*hash)
} else {
None
}
})
.collect();
for key_hash in expired_keys {
cache.pop(&key_hash);
self.stats.evictions.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
pub fn stats(&self) -> CacheStatistics {
let hits = self.stats.hits.load(std::sync::atomic::Ordering::Relaxed);
let misses = self.stats.misses.load(std::sync::atomic::Ordering::Relaxed);
let evictions = self.stats.evictions.load(std::sync::atomic::Ordering::Relaxed);
let total_requests = hits + misses;
let hit_rate = if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
};
let cache_size = self.cache.read().len();
CacheStatistics {
hits,
misses,
evictions,
hit_rate,
current_size: cache_size,
}
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
}
pub fn capacity(&self) -> usize {
self.cache.read().cap().into()
}
pub fn len(&self) -> usize {
self.cache.read().len()
}
pub fn is_empty(&self) -> bool {
self.cache.read().is_empty()
}
}
#[derive(Debug, Clone)]
pub struct CacheStatistics {
pub hits: usize,
pub misses: usize,
pub evictions: usize,
pub hit_rate: f64,
pub current_size: usize,
}
impl std::fmt::Display for CacheStatistics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cache Stats: {} hits, {} misses, {:.2}% hit rate, {} evictions, {} items",
self.hits, self.misses, self.hit_rate * 100.0, self.evictions, self.current_size
)
}
}
pub struct PrecomputedKeyContext {
pub key_hash: KeyHash,
pub precomputed_values: Vec<u8>,
pub derivation_context: Option<Vec<u8>>,
}
impl PrecomputedKeyContext {
pub fn for_ml_kem(public_key: &[u8]) -> Self {
let key_hash = KeyHash::from_bytes(public_key);
let mut precomputed_values = Vec::with_capacity(1024);
for i in 0..256 {
let val = ((i * 31) % 3329) as u8; precomputed_values.push(val);
}
Self {
key_hash,
precomputed_values,
derivation_context: None,
}
}
pub fn for_transport(key_data: &[u8], derivation_info: &[u8]) -> Self {
let key_hash = KeyHash::from_bytes(key_data);
Self {
key_hash,
precomputed_values: Vec::new(),
derivation_context: Some(derivation_info.to_vec()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_cache_basic() {
let cache = KeyCache::new(10);
let key_data = b"test key data";
let key_hash = cache.insert(key_data, KeyType::MlKemPublic);
let cached_key = cache.get(&key_hash).unwrap();
assert_eq!(cached_key.data, key_data);
assert_eq!(cached_key.key_type, KeyType::MlKemPublic);
assert_eq!(cached_key.access_count, 1);
let cached_key2 = cache.get(&key_hash).unwrap();
assert_eq!(cached_key2.access_count, 2);
}
#[test]
fn test_key_hash() {
let data1 = b"test data";
let data2 = b"test data";
let data3 = b"different data";
let hash1 = KeyHash::from_bytes(data1);
let hash2 = KeyHash::from_bytes(data2);
let hash3 = KeyHash::from_bytes(data3);
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
#[test]
fn test_cache_stats() {
let cache = KeyCache::new(10);
let key_data = b"test key";
let key_hash = KeyHash::from_bytes(key_data);
assert!(cache.get(&key_hash).is_none());
let key_hash = cache.insert(key_data, KeyType::Symmetric);
let _cached = cache.get(&key_hash).unwrap();
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate, 0.5);
}
#[test]
fn test_cache_cleanup() {
let cache = KeyCache::new(10);
let key_data = b"test key";
let key_hash = cache.insert(key_data, KeyType::Transport);
assert!(cache.contains(&key_hash));
cache.cleanup_expired(Duration::from_nanos(1));
std::thread::sleep(Duration::from_millis(1));
cache.cleanup_expired(Duration::from_nanos(1));
assert!(!cache.contains(&key_hash));
}
#[test]
fn test_precomputed_context() {
let public_key = vec![1u8; 1184]; let context = PrecomputedKeyContext::for_ml_kem(&public_key);
assert_eq!(context.precomputed_values.len(), 256);
assert!(context.derivation_context.is_none());
let transport_key = vec![2u8; 32];
let derivation_info = b"transport key derivation";
let transport_context = PrecomputedKeyContext::for_transport(&transport_key, derivation_info);
assert!(transport_context.derivation_context.is_some());
}
}