use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub enabled: bool,
pub ttl_seconds: u64,
pub max_entries: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
ttl_seconds: 420, max_entries: 1000,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct CacheEntry<T: Clone> {
data: Arc<T>,
created_at: SystemTime,
ttl: Duration,
}
impl<T: Clone> CacheEntry<T> {
fn new(data: T, ttl: Duration) -> Self {
Self {
data: Arc::new(data),
created_at: SystemTime::now(),
ttl,
}
}
fn is_expired(&self) -> bool {
self.created_at.elapsed().unwrap_or(Duration::MAX) > self.ttl
}
fn data_arc(&self) -> &Arc<T> {
&self.data
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct QueryMemoryKey {
pub query: String,
pub domain: String,
pub task_type: Option<String>,
pub limit: usize,
}
impl QueryMemoryKey {
pub fn new(query: String, domain: String, task_type: Option<String>, limit: usize) -> Self {
Self {
query,
domain,
task_type,
limit,
}
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct AnalyzePatternsKey {
pub task_type: String,
pub min_success_rate: u32, pub limit: usize,
}
impl AnalyzePatternsKey {
pub fn new(task_type: String, min_success_rate: f32, limit: usize) -> Self {
Self {
task_type,
min_success_rate: (min_success_rate * 100.0) as u32, limit,
}
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct ExecuteCodeKey {
pub code_hash: u64, pub context_task: String,
pub context_input_hash: u64, }
impl ExecuteCodeKey {
pub fn new(code: &str, context: &super::ExecutionContext) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
code.hash(&mut hasher);
let code_hash = hasher.finish();
let mut hasher = std::collections::hash_map::DefaultHasher::new();
context.input.to_string().hash(&mut hasher);
let context_input_hash = hasher.finish();
Self {
code_hash,
context_task: context.task.clone(),
context_input_hash,
}
}
}
pub struct QueryCache {
config: CacheConfig,
query_memory_cache: RwLock<HashMap<QueryMemoryKey, CacheEntry<serde_json::Value>>>,
analyze_patterns_cache: RwLock<HashMap<AnalyzePatternsKey, CacheEntry<serde_json::Value>>>,
execute_code_cache: RwLock<HashMap<ExecuteCodeKey, CacheEntry<super::ExecutionResult>>>,
hits: RwLock<u64>,
misses: RwLock<u64>,
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
impl QueryCache {
pub fn new() -> Self {
Self::with_config(CacheConfig::default())
}
pub fn with_config(config: CacheConfig) -> Self {
Self {
config,
query_memory_cache: RwLock::new(HashMap::new()),
analyze_patterns_cache: RwLock::new(HashMap::new()),
execute_code_cache: RwLock::new(HashMap::new()),
hits: RwLock::new(0),
misses: RwLock::new(0),
}
}
pub fn get_query_memory(&self, key: &QueryMemoryKey) -> Option<serde_json::Value> {
if !self.config.enabled {
return None;
}
let cache = self.query_memory_cache.read();
if let Some(entry) = cache.get(key) {
if !entry.is_expired() {
*self.hits.write() += 1;
return Some((**entry.data_arc()).clone());
}
}
*self.misses.write() += 1;
None
}
pub fn put_query_memory(&self, key: QueryMemoryKey, result: serde_json::Value) {
if !self.config.enabled {
return;
}
let mut cache = self.query_memory_cache.write();
self.evict_expired_entries(&mut cache);
if cache.len() >= self.config.max_entries {
self.evict_oldest(&mut cache);
}
let ttl = Duration::from_secs(self.config.ttl_seconds);
cache.insert(key, CacheEntry::new(result, ttl));
}
pub fn get_analyze_patterns(&self, key: &AnalyzePatternsKey) -> Option<serde_json::Value> {
if !self.config.enabled {
return None;
}
let cache = self.analyze_patterns_cache.read();
if let Some(entry) = cache.get(key) {
if !entry.is_expired() {
*self.hits.write() += 1;
return Some((**entry.data_arc()).clone());
}
}
*self.misses.write() += 1;
None
}
pub fn put_analyze_patterns(&self, key: AnalyzePatternsKey, result: serde_json::Value) {
if !self.config.enabled {
return;
}
let mut cache = self.analyze_patterns_cache.write();
self.evict_expired_entries(&mut cache);
if cache.len() >= self.config.max_entries {
self.evict_oldest(&mut cache);
}
let ttl = Duration::from_secs(self.config.ttl_seconds);
cache.insert(key, CacheEntry::new(result, ttl));
}
pub fn get_execute_code(&self, key: &ExecuteCodeKey) -> Option<super::ExecutionResult> {
if !self.config.enabled {
return None;
}
let cache = self.execute_code_cache.read();
if let Some(entry) = cache.get(key) {
if !entry.is_expired() {
*self.hits.write() += 1;
return Some((**entry.data_arc()).clone());
}
}
*self.misses.write() += 1;
None
}
pub fn put_execute_code(&self, key: ExecuteCodeKey, result: super::ExecutionResult) {
if !self.config.enabled {
return;
}
let mut cache = self.execute_code_cache.write();
self.evict_expired_entries(&mut cache);
if cache.len() >= self.config.max_entries {
self.evict_oldest(&mut cache);
}
let ttl = Duration::from_secs(self.config.ttl_seconds);
cache.insert(key, CacheEntry::new(result, ttl));
}
pub fn clear(&self) {
self.query_memory_cache.write().clear();
self.analyze_patterns_cache.write().clear();
self.execute_code_cache.write().clear();
}
pub fn stats(&self) -> CacheStats {
let query_memory = self.query_memory_cache.read();
let analyze_patterns = self.analyze_patterns_cache.read();
let execute_code = self.execute_code_cache.read();
let hits = *self.hits.read();
let misses = *self.misses.read();
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
CacheStats {
query_memory_entries: query_memory.len(),
analyze_patterns_entries: analyze_patterns.len(),
execute_code_entries: execute_code.len(),
total_entries: query_memory.len() + analyze_patterns.len() + execute_code.len(),
max_entries: self.config.max_entries,
enabled: self.config.enabled,
ttl_seconds: self.config.ttl_seconds,
hits,
misses,
hit_rate,
}
}
fn evict_expired_entries<T, U>(&self, cache: &mut HashMap<T, CacheEntry<U>>)
where
T: Eq + Hash + Clone,
U: Clone,
{
cache.retain(|_, entry| !entry.is_expired());
}
fn evict_oldest<T, U>(&self, cache: &mut HashMap<T, CacheEntry<U>>)
where
T: Eq + Hash + Clone,
U: Clone,
{
if cache.is_empty() {
return;
}
let mut oldest_key = None;
let mut oldest_time = SystemTime::now();
for (key, entry) in cache.iter() {
if entry.created_at < oldest_time {
oldest_time = entry.created_at;
oldest_key = Some(key.clone());
}
}
if let Some(key) = oldest_key {
cache.remove(&key);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub query_memory_entries: usize,
pub analyze_patterns_entries: usize,
pub execute_code_entries: usize,
pub total_entries: usize,
pub max_entries: usize,
pub enabled: bool,
pub ttl_seconds: u64,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests;