use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use sha2::{Sha256, Digest};
use serde_json::Value;
use tracing::{debug, info, warn};
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub default_ttl: Duration,
pub cleanup_interval: Duration,
pub max_response_size: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
default_ttl: Duration::from_secs(300), cleanup_interval: Duration::from_secs(60), max_response_size: 1024 * 1024, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheKey {
pub tool_name: String,
pub arguments_hash: String,
pub feature_flags_hash: String,
}
impl CacheKey {
pub fn new(tool_name: &str, arguments: &Value) -> Result<Self> {
let arguments_json = serde_json::to_string(arguments)
.map_err(|e| Error::Validation(format!("Failed to serialize arguments: {}", e)))?;
let mut hasher = Sha256::new();
hasher.update(arguments_json.as_bytes());
let arguments_hash = format!("{:x}", hasher.finalize());
let feature_flags = get_active_feature_flags();
let mut feature_hasher = Sha256::new();
feature_hasher.update(feature_flags.as_bytes());
let feature_flags_hash = format!("{:x}", feature_hasher.finalize());
Ok(Self {
tool_name: tool_name.to_string(),
arguments_hash,
feature_flags_hash,
})
}
pub fn to_string_key(&self) -> String {
format!("{}:{}:{}", self.tool_name, self.arguments_hash, self.feature_flags_hash)
}
}
#[derive(Debug, Clone)]
pub struct CachedResult {
pub response: Value,
pub cached_at: Instant,
pub ttl: Duration,
pub size_bytes: usize,
pub hit_count: u64,
pub tags: Vec<String>,
}
impl CachedResult {
pub fn is_expired(&self) -> bool {
self.cached_at.elapsed() > self.ttl
}
pub fn age(&self) -> Duration {
self.cached_at.elapsed()
}
pub fn record_hit(&mut self) {
self.hit_count += 1;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStatistics {
pub total_entries: usize,
pub total_hits: u64,
pub total_misses: u64,
pub total_size_bytes: usize,
pub hit_rate: f64,
pub average_entry_size: f64,
pub oldest_entry_age_seconds: f64,
pub newest_entry_age_seconds: f64,
pub cleanup_runs: u64,
pub evicted_entries: u64,
}
pub struct CommandCache {
cache: Arc<RwLock<HashMap<String, CachedResult>>>,
config: CacheConfig,
stats: Arc<RwLock<CacheStatistics>>,
access_order: Arc<RwLock<Vec<String>>>,
}
impl CommandCache {
pub fn new(config: CacheConfig) -> Self {
let cache = Self {
cache: Arc::new(RwLock::new(HashMap::new())),
config,
stats: Arc::new(RwLock::new(CacheStatistics {
total_entries: 0,
total_hits: 0,
total_misses: 0,
total_size_bytes: 0,
hit_rate: 0.0,
average_entry_size: 0.0,
oldest_entry_age_seconds: 0.0,
newest_entry_age_seconds: 0.0,
cleanup_runs: 0,
evicted_entries: 0,
})),
access_order: Arc::new(RwLock::new(Vec::new())),
};
cache.start_cleanup_task();
cache
}
pub async fn get(&self, key: &CacheKey) -> Option<Value> {
let string_key = key.to_string_key();
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
if let Some(cached_result) = cache.get_mut(&string_key) {
if !cached_result.is_expired() {
cached_result.record_hit();
stats.total_hits += 1;
self.update_access_order(&string_key).await;
debug!("Cache hit for {}", key.tool_name);
return Some(cached_result.response.clone());
} else {
let removed = cache.remove(&string_key);
if let Some(removed) = removed {
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(removed.size_bytes);
stats.total_entries = stats.total_entries.saturating_sub(1);
}
debug!("Cache entry expired for {}", key.tool_name);
}
}
stats.total_misses += 1;
self.update_hit_rate(&mut stats).await;
debug!("Cache miss for {}", key.tool_name);
None
}
pub async fn put(&self, key: &CacheKey, response: Value, tags: Vec<String>) -> Result<()> {
let string_key = key.to_string_key();
let response_json = serde_json::to_string(&response)
.map_err(|e| Error::Validation(format!("Failed to serialize response: {}", e)))?;
let size_bytes = response_json.len();
if size_bytes > self.config.max_response_size {
warn!("Response too large to cache: {} bytes > {} bytes limit",
size_bytes, self.config.max_response_size);
return Ok(());
}
let cached_result = CachedResult {
response,
cached_at: Instant::now(),
ttl: self.config.default_ttl,
size_bytes,
hit_count: 0,
tags,
};
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
if cache.len() >= self.config.max_entries {
self.evict_lru_entries(&mut cache, &mut stats).await;
}
if let Some(old_entry) = cache.insert(string_key.clone(), cached_result) {
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(old_entry.size_bytes);
} else {
stats.total_entries += 1;
}
stats.total_size_bytes += size_bytes;
self.update_access_order(&string_key).await;
debug!("Cached result for {} ({} bytes)", key.tool_name, size_bytes);
Ok(())
}
pub async fn put_with_ttl(&self, key: &CacheKey, response: Value, ttl: Duration, tags: Vec<String>) -> Result<()> {
let string_key = key.to_string_key();
let response_json = serde_json::to_string(&response)
.map_err(|e| Error::Validation(format!("Failed to serialize response: {}", e)))?;
let size_bytes = response_json.len();
if size_bytes > self.config.max_response_size {
warn!("Response too large to cache: {} bytes > {} bytes limit",
size_bytes, self.config.max_response_size);
return Ok(());
}
let cached_result = CachedResult {
response,
cached_at: Instant::now(),
ttl,
size_bytes,
hit_count: 0,
tags,
};
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
if cache.len() >= self.config.max_entries {
self.evict_lru_entries(&mut cache, &mut stats).await;
}
if let Some(old_entry) = cache.insert(string_key.clone(), cached_result) {
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(old_entry.size_bytes);
} else {
stats.total_entries += 1;
}
stats.total_size_bytes += size_bytes;
self.update_access_order(&string_key).await;
debug!("Cached result for {} with custom TTL {:?} ({} bytes)", key.tool_name, ttl, size_bytes);
Ok(())
}
pub async fn invalidate_by_tag(&self, tag: &str) -> usize {
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
let mut access_order = self.access_order.write().await;
let mut to_remove = Vec::new();
for (key, cached_result) in cache.iter() {
if cached_result.tags.contains(&tag.to_string()) {
to_remove.push(key.clone());
}
}
let count = to_remove.len();
for key in to_remove {
if let Some(removed) = cache.remove(&key) {
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(removed.size_bytes);
stats.total_entries = stats.total_entries.saturating_sub(1);
stats.evicted_entries += 1;
}
access_order.retain(|k| k != &key);
}
info!("Invalidated {} cache entries with tag '{}'", count, tag);
count
}
pub async fn clear(&self) {
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
let mut access_order = self.access_order.write().await;
let count = cache.len();
cache.clear();
access_order.clear();
stats.total_entries = 0;
stats.total_size_bytes = 0;
stats.evicted_entries += count as u64;
info!("Cleared {} cache entries", count);
}
pub async fn get_statistics(&self) -> CacheStatistics {
let cache = self.cache.read().await;
let mut stats = self.stats.write().await;
stats.total_entries = cache.len();
self.update_hit_rate(&mut stats).await;
stats.average_entry_size = if stats.total_entries > 0 {
stats.total_size_bytes as f64 / stats.total_entries as f64
} else {
0.0
};
let now = Instant::now();
let mut oldest_age = Duration::ZERO;
let mut newest_age = Duration::MAX;
for cached_result in cache.values() {
let age = now.duration_since(cached_result.cached_at);
if age > oldest_age {
oldest_age = age;
}
if age < newest_age {
newest_age = age;
}
}
stats.oldest_entry_age_seconds = oldest_age.as_secs_f64();
stats.newest_entry_age_seconds = if newest_age == Duration::MAX {
0.0
} else {
newest_age.as_secs_f64()
};
stats.clone()
}
async fn update_access_order(&self, key: &str) {
let mut access_order = self.access_order.write().await;
access_order.retain(|k| k != key);
access_order.insert(0, key.to_string());
}
async fn evict_lru_entries(&self, cache: &mut HashMap<String, CachedResult>, stats: &mut CacheStatistics) {
let access_order = self.access_order.read().await;
let entries_to_evict = (cache.len() + 1).saturating_sub(self.config.max_entries);
let mut evicted = 0;
for key in access_order.iter().rev().take(entries_to_evict) {
if let Some(removed) = cache.remove(key) {
stats.total_size_bytes = stats.total_size_bytes.saturating_sub(removed.size_bytes);
stats.total_entries = stats.total_entries.saturating_sub(1);
stats.evicted_entries += 1;
evicted += 1;
}
}
debug!("Evicted {} LRU cache entries", evicted);
}
async fn update_hit_rate(&self, stats: &mut CacheStatistics) {
let total_requests = stats.total_hits + stats.total_misses;
stats.hit_rate = if total_requests > 0 {
stats.total_hits as f64 / total_requests as f64
} else {
0.0
};
}
fn start_cleanup_task(&self) {
let cache = self.cache.clone();
let stats = self.stats.clone();
let access_order = self.access_order.clone();
let cleanup_interval = self.config.cleanup_interval;
tokio::spawn(async move {
let mut interval = tokio::time::interval(cleanup_interval);
loop {
interval.tick().await;
let mut cache_guard = cache.write().await;
let mut stats_guard = stats.write().await;
let mut access_order_guard = access_order.write().await;
let mut expired_keys = Vec::new();
let now = Instant::now();
for (key, cached_result) in cache_guard.iter() {
if cached_result.is_expired() {
expired_keys.push(key.clone());
}
}
let expired_count = expired_keys.len();
for key in &expired_keys {
if let Some(removed) = cache_guard.remove(key) {
stats_guard.total_size_bytes = stats_guard.total_size_bytes.saturating_sub(removed.size_bytes);
stats_guard.total_entries = stats_guard.total_entries.saturating_sub(1);
}
access_order_guard.retain(|k| k != key);
}
stats_guard.cleanup_runs += 1;
if expired_count > 0 {
debug!("Cache cleanup removed {} expired entries", expired_count);
}
}
});
}
}
fn get_active_feature_flags() -> String {
let mut flags = Vec::new();
#[cfg(feature = "basic-debugging")]
flags.push("basic-debugging");
#[cfg(feature = "entity-inspection")]
flags.push("entity-inspection");
#[cfg(feature = "performance-profiling")]
flags.push("performance-profiling");
#[cfg(feature = "visual-debugging")]
flags.push("visual-debugging");
#[cfg(feature = "session-management")]
flags.push("session-management");
#[cfg(feature = "issue-detection")]
flags.push("issue-detection");
#[cfg(feature = "memory-profiling")]
flags.push("memory-profiling");
#[cfg(feature = "stress-testing")]
flags.push("stress-testing");
#[cfg(feature = "time-travel")]
flags.push("time-travel");
#[cfg(feature = "orchestration")]
flags.push("orchestration");
#[cfg(feature = "optimizations")]
flags.push("optimizations");
#[cfg(feature = "caching")]
flags.push("caching");
#[cfg(feature = "pooling")]
flags.push("pooling");
#[cfg(feature = "lazy-init")]
flags.push("lazy-init");
flags.join(",")
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_cache_basic_operations() {
let config = CacheConfig {
max_entries: 2,
default_ttl: Duration::from_secs(1),
cleanup_interval: Duration::from_secs(10),
max_response_size: 1024,
};
let cache = CommandCache::new(config);
let key1 = CacheKey::new("test_tool", &json!({"param": "value1"})).unwrap();
let key2 = CacheKey::new("test_tool", &json!({"param": "value2"})).unwrap();
let response1 = json!({"result": "data1"});
let response2 = json!({"result": "data2"});
assert!(cache.get(&key1).await.is_none());
cache.put(&key1, response1.clone(), vec!["tag1".to_string()]).await.unwrap();
assert_eq!(cache.get(&key1).await.unwrap(), response1);
cache.put(&key2, response2.clone(), vec!["tag2".to_string()]).await.unwrap();
assert_eq!(cache.get(&key2).await.unwrap(), response2);
}
#[tokio::test]
async fn test_cache_expiration() {
let config = CacheConfig {
max_entries: 10,
default_ttl: Duration::from_millis(50),
cleanup_interval: Duration::from_secs(10),
max_response_size: 1024,
};
let cache = CommandCache::new(config);
let key = CacheKey::new("test_tool", &json!({"param": "value"})).unwrap();
let response = json!({"result": "data"});
cache.put(&key, response.clone(), vec![]).await.unwrap();
assert_eq!(cache.get(&key).await.unwrap(), response);
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(cache.get(&key).await.is_none());
}
#[tokio::test]
async fn test_tag_invalidation() {
let config = CacheConfig::default();
let cache = CommandCache::new(config);
let key1 = CacheKey::new("tool1", &json!({"param": "value1"})).unwrap();
let key2 = CacheKey::new("tool2", &json!({"param": "value2"})).unwrap();
let response1 = json!({"result": "data1"});
let response2 = json!({"result": "data2"});
cache.put(&key1, response1.clone(), vec!["entity_data".to_string()]).await.unwrap();
cache.put(&key2, response2.clone(), vec!["performance_data".to_string()]).await.unwrap();
assert_eq!(cache.get(&key1).await.unwrap(), response1);
assert_eq!(cache.get(&key2).await.unwrap(), response2);
let invalidated = cache.invalidate_by_tag("entity_data").await;
assert_eq!(invalidated, 1);
assert!(cache.get(&key1).await.is_none());
assert_eq!(cache.get(&key2).await.unwrap(), response2);
}
}