use crate::query::QueryResult;
use lru::LruCache;
use std::hash::Hash;
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct QueryCacheKey {
query: String,
}
impl QueryCacheKey {
pub fn new(query: impl Into<String>) -> Self {
let query = query.into();
let normalized = query.trim().to_lowercase();
Self {
query: normalized,
}
}
}
pub struct QueryCache {
cache: Arc<Mutex<LruCache<QueryCacheKey, QueryResult>>>,
hits: Arc<Mutex<usize>>,
misses: Arc<Mutex<usize>>,
}
impl QueryCache {
pub fn new(capacity: usize) -> Self {
let capacity = NonZeroUsize::new(capacity).unwrap_or(NonZeroUsize::new(100).unwrap());
Self {
cache: Arc::new(Mutex::new(LruCache::new(capacity))),
hits: Arc::new(Mutex::new(0)),
misses: Arc::new(Mutex::new(0)),
}
}
pub fn get(&self, key: &QueryCacheKey) -> Option<QueryResult> {
let mut cache = self.cache.lock().unwrap();
if let Some(result) = cache.get(key) {
*self.hits.lock().unwrap() += 1;
Some(result.clone())
} else {
*self.misses.lock().unwrap() += 1;
None
}
}
pub fn put(&self, key: QueryCacheKey, result: QueryResult) {
let mut cache = self.cache.lock().unwrap();
cache.put(key, result);
}
pub fn clear(&self) {
let mut cache = self.cache.lock().unwrap();
cache.clear();
*self.hits.lock().unwrap() = 0;
*self.misses.lock().unwrap() = 0;
}
pub fn len(&self) -> usize {
let cache = self.cache.lock().unwrap();
cache.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> CacheStats {
let hits = *self.hits.lock().unwrap();
let misses = *self.misses.lock().unwrap();
let total = hits + misses;
let hit_rate = if total > 0 {
(hits as f64 / total as f64) * 100.0
} else {
0.0
};
CacheStats {
hits,
misses,
total_requests: total,
hit_rate,
entries: self.len(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub total_requests: usize,
pub hit_rate: f64,
pub entries: usize,
}
impl std::fmt::Display for CacheStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cache Stats: {} hits, {} misses, {:.2}% hit rate, {} entries",
self.hits, self.misses, self.hit_rate, self.entries
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_dummy_result() -> QueryResult {
QueryResult::new_for_testing(Vec::new(), 0)
}
#[test]
fn test_cache_key_normalization() {
let key1 = QueryCacheKey::new("SELECT * FROM cube");
let key2 = QueryCacheKey::new(" select * from cube ");
assert_eq!(key1, key2);
}
#[test]
fn test_cache_put_get() {
let cache = QueryCache::new(10);
let key = QueryCacheKey::new("SELECT * FROM cube");
let result = create_dummy_result();
cache.put(key.clone(), result.clone());
let cached = cache.get(&key);
assert!(cached.is_some());
assert_eq!(cached.unwrap().row_count(), result.row_count());
}
#[test]
fn test_cache_miss() {
let cache = QueryCache::new(10);
let key = QueryCacheKey::new("SELECT * FROM cube");
let cached = cache.get(&key);
assert!(cached.is_none());
}
#[test]
fn test_cache_eviction() {
let cache = QueryCache::new(2);
cache.put(QueryCacheKey::new("query1"), create_dummy_result());
cache.put(QueryCacheKey::new("query2"), create_dummy_result());
cache.put(QueryCacheKey::new("query3"), create_dummy_result());
assert_eq!(cache.len(), 2);
}
#[test]
fn test_cache_clear() {
let cache = QueryCache::new(10);
cache.put(QueryCacheKey::new("query1"), create_dummy_result());
cache.put(QueryCacheKey::new("query2"), create_dummy_result());
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_cache_stats() {
let cache = QueryCache::new(10);
let key = QueryCacheKey::new("SELECT * FROM cube");
cache.put(key.clone(), create_dummy_result());
cache.get(&key); cache.get(&QueryCacheKey::new("nonexistent"));
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.total_requests, 2);
assert_eq!(stats.hit_rate, 50.0);
}
}