use llm_shield_core::ScanResult;
use std::collections::{HashMap, hash_map::DefaultHasher};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_size: usize,
pub ttl: Duration,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size: 10_000,
ttl: Duration::from_secs(300), }
}
}
pub struct ResultCache {
inner: Arc<RwLock<CacheInner>>,
}
struct CacheInner {
config: CacheConfig,
entries: HashMap<String, CacheEntry>,
access_order: Vec<String>, stats: CacheStats,
}
struct CacheEntry {
result: ScanResult,
inserted_at: Instant,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
}
impl CacheStats {
pub fn total_requests(&self) -> u64 {
self.hits + self.misses
}
pub fn hit_rate(&self) -> f64 {
let total = self.total_requests();
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
impl ResultCache {
pub fn new(config: CacheConfig) -> Self {
Self {
inner: Arc::new(RwLock::new(CacheInner {
config,
entries: HashMap::new(),
access_order: Vec::new(),
stats: CacheStats::default(),
})),
}
}
pub fn get(&self, key: &str) -> Option<ScanResult> {
let mut inner = self.inner.write().unwrap();
if let Some(entry) = inner.entries.get(key) {
if entry.inserted_at.elapsed() < inner.config.ttl {
let result = entry.result.clone();
inner.stats.hits += 1;
inner.access_order.retain(|k| k != key);
inner.access_order.push(key.to_string());
return Some(result);
} else {
inner.entries.remove(key);
inner.access_order.retain(|k| k != key);
}
}
inner.stats.misses += 1;
None
}
pub fn insert(&self, key: String, result: ScanResult) {
let mut inner = self.inner.write().unwrap();
if inner.config.max_size == 0 {
return;
}
if inner.entries.contains_key(&key) {
inner.access_order.retain(|k| k != &key);
} else if inner.entries.len() >= inner.config.max_size {
if let Some(oldest_key) = inner.access_order.first().cloned() {
inner.entries.remove(&oldest_key);
inner.access_order.remove(0);
}
}
inner.entries.insert(
key.clone(),
CacheEntry {
result,
inserted_at: Instant::now(),
},
);
inner.access_order.push(key);
}
pub fn clear(&self) {
let mut inner = self.inner.write().unwrap();
inner.entries.clear();
inner.access_order.clear();
}
pub fn len(&self) -> usize {
self.inner.read().unwrap().entries.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> CacheStats {
self.inner.read().unwrap().stats.clone()
}
pub fn reset_stats(&self) {
let mut inner = self.inner.write().unwrap();
inner.stats = CacheStats::default();
}
pub fn hash_key(input: &str) -> String {
let mut hasher = DefaultHasher::new();
input.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
}
impl Clone for ResultCache {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_result(text: &str) -> ScanResult {
ScanResult::pass(text.to_string())
}
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.max_size, 10_000);
assert_eq!(config.ttl, Duration::from_secs(300));
}
#[test]
fn test_cache_stats_empty() {
let stats = CacheStats::default();
assert_eq!(stats.total_requests(), 0);
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_cache_stats_calculation() {
let stats = CacheStats {
hits: 7,
misses: 3,
};
assert_eq!(stats.total_requests(), 10);
assert!((stats.hit_rate() - 0.7).abs() < 0.001);
}
#[test]
fn test_basic_insert_get() {
let cache = ResultCache::new(CacheConfig {
max_size: 10,
ttl: Duration::from_secs(60),
});
let result = create_test_result("test");
cache.insert("key1".to_string(), result.clone());
assert_eq!(cache.get("key1"), Some(result));
}
#[test]
fn test_cache_miss() {
let cache = ResultCache::new(CacheConfig {
max_size: 10,
ttl: Duration::from_secs(60),
});
assert_eq!(cache.get("nonexistent"), None);
}
#[test]
fn test_is_empty() {
let cache = ResultCache::new(CacheConfig::default());
assert!(cache.is_empty());
cache.insert("key".to_string(), create_test_result("test"));
assert!(!cache.is_empty());
}
#[test]
fn test_hash_key_deterministic() {
let key1 = ResultCache::hash_key("test input");
let key2 = ResultCache::hash_key("test input");
assert_eq!(key1, key2);
}
#[test]
fn test_hash_key_different_inputs() {
let key1 = ResultCache::hash_key("input1");
let key2 = ResultCache::hash_key("input2");
assert_ne!(key1, key2);
}
}