mod entry;
mod pattern;
mod persistence;
mod stats;
pub use pattern::QueryPattern;
pub use stats::CacheStats;
use bytes::Bytes;
use entry::CacheEntry;
use metrics::{counter, gauge};
use persistence::PersistentCacheState;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
pub struct SelectResultCache {
cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
stats: Arc<RwLock<CacheStats>>,
query_patterns: Arc<RwLock<HashMap<String, QueryPattern>>>,
max_entries: usize,
max_memory_bytes: u64,
}
impl SelectResultCache {
pub fn new(max_entries: usize, max_memory_bytes: u64) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(CacheStats {
max_entries,
max_memory_bytes,
..Default::default()
})),
query_patterns: Arc::new(RwLock::new(HashMap::new())),
max_entries,
max_memory_bytes,
}
}
fn cache_key(sql: &str, etag: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
sql.hash(&mut hasher);
etag.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
pub async fn get(&self, sql: &str, etag: &str) -> Option<Bytes> {
let key = Self::cache_key(sql, etag);
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
stats.gets += 1;
if let Some(entry) = cache.get_mut(&key) {
if entry.is_expired() {
cache.remove(&key);
stats.expirations += 1;
stats.misses += 1;
stats.current_entries = cache.len();
counter!("select_cache_expirations").increment(1);
counter!("select_cache_misses").increment(1);
gauge!("select_cache_entries").set(cache.len() as f64);
crate::metrics::record_cache_operation("select_cache", false);
return None;
}
if entry.etag != etag {
cache.remove(&key);
stats.misses += 1;
stats.current_entries = cache.len();
counter!("select_cache_misses").increment(1);
gauge!("select_cache_entries").set(cache.len() as f64);
crate::metrics::record_cache_operation("select_cache", false);
return None;
}
entry.touch();
stats.hits += 1;
counter!("select_cache_hits").increment(1);
crate::metrics::record_cache_operation("select_cache", true);
Some(entry.result.clone())
} else {
stats.misses += 1;
counter!("select_cache_misses").increment(1);
crate::metrics::record_cache_operation("select_cache", false);
None
}
}
pub async fn put(&self, sql: &str, etag: &str, result: Bytes, ttl_seconds: u64) {
let key = Self::cache_key(sql, etag);
let size = result.len();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let entry = CacheEntry {
result,
etag: etag.to_string(),
created_at: now,
ttl_seconds,
last_accessed: now,
size_bytes: size,
};
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
while cache.len() >= self.max_entries
|| stats.memory_bytes + size as u64 > self.max_memory_bytes
{
if let Some(lru_key) = self.find_lru_entry(&cache) {
if let Some(evicted) = cache.remove(&lru_key) {
stats.memory_bytes =
stats.memory_bytes.saturating_sub(evicted.size_bytes as u64);
stats.evictions += 1;
counter!("select_cache_evictions").increment(1);
}
} else {
break; }
}
cache.insert(key, entry);
stats.memory_bytes += size as u64;
stats.current_entries = cache.len();
gauge!("select_cache_entries").set(cache.len() as f64);
gauge!("select_cache_memory_bytes").set(stats.memory_bytes as f64);
}
fn find_lru_entry(&self, cache: &HashMap<String, CacheEntry>) -> Option<String> {
cache
.iter()
.min_by_key(|(_, entry)| entry.last_accessed)
.map(|(key, _)| key.clone())
}
pub async fn invalidate(&self, sql: &str, etag: &str) {
let key = Self::cache_key(sql, etag);
let mut cache = self.cache.write().await;
if let Some(entry) = cache.remove(&key) {
let mut stats = self.stats.write().await;
stats.memory_bytes = stats.memory_bytes.saturating_sub(entry.size_bytes as u64);
stats.current_entries = cache.len();
gauge!("select_cache_entries").set(cache.len() as f64);
gauge!("select_cache_memory_bytes").set(stats.memory_bytes as f64);
}
}
pub async fn invalidate_object(&self, etag: &str) {
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
let keys_to_remove: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.etag == etag)
.map(|(key, _)| key.clone())
.collect();
for key in keys_to_remove {
if let Some(entry) = cache.remove(&key) {
stats.memory_bytes = stats.memory_bytes.saturating_sub(entry.size_bytes as u64);
}
}
stats.current_entries = cache.len();
gauge!("select_cache_entries").set(cache.len() as f64);
gauge!("select_cache_memory_bytes").set(stats.memory_bytes as f64);
}
pub async fn clear(&self) {
let mut cache = self.cache.write().await;
cache.clear();
let mut stats = self.stats.write().await;
stats.memory_bytes = 0;
stats.current_entries = 0;
gauge!("select_cache_entries").set(0.0);
gauge!("select_cache_memory_bytes").set(0.0);
}
pub async fn stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
pub async fn cleanup_expired(&self) {
let mut cache = self.cache.write().await;
let mut stats = self.stats.write().await;
let expired_keys: Vec<String> = cache
.iter()
.filter(|(_, entry)| entry.is_expired())
.map(|(key, _)| key.clone())
.collect();
let mut expired_count = 0;
for key in expired_keys {
if let Some(entry) = cache.remove(&key) {
stats.memory_bytes = stats.memory_bytes.saturating_sub(entry.size_bytes as u64);
stats.expirations += 1;
expired_count += 1;
}
}
stats.current_entries = cache.len();
if expired_count > 0 {
counter!("select_cache_expirations").increment(expired_count);
gauge!("select_cache_entries").set(cache.len() as f64);
gauge!("select_cache_memory_bytes").set(stats.memory_bytes as f64);
}
}
fn pattern_key(bucket: &str, key: &str, sql: &str) -> String {
format!("{}:{}:{}", bucket, key, sql)
}
pub async fn record_query_pattern(
&self,
bucket: &str,
key: &str,
sql: &str,
execution_time_ms: u64,
) {
let pattern_key = Self::pattern_key(bucket, key, sql);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let mut patterns = self.query_patterns.write().await;
patterns
.entry(pattern_key)
.and_modify(|pattern| {
pattern.execution_count += 1;
pattern.last_executed = now;
pattern.avg_execution_ms =
(pattern.avg_execution_ms * (pattern.execution_count - 1) + execution_time_ms)
/ pattern.execution_count;
})
.or_insert_with(|| QueryPattern {
sql: sql.to_string(),
bucket: bucket.to_string(),
key: key.to_string(),
execution_count: 1,
last_executed: now,
avg_execution_ms: execution_time_ms,
});
}
pub async fn get_top_queries(&self, limit: usize) -> Vec<QueryPattern> {
let patterns = self.query_patterns.read().await;
let mut pattern_list: Vec<QueryPattern> = patterns.values().cloned().collect();
pattern_list.sort_by_key(|b| std::cmp::Reverse(b.execution_count));
pattern_list.truncate(limit);
pattern_list
}
pub async fn get_recent_queries(&self, within_seconds: i64) -> Vec<QueryPattern> {
let patterns = self.query_patterns.read().await;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let cutoff = now - within_seconds;
patterns
.values()
.filter(|p| p.last_executed >= cutoff)
.cloned()
.collect()
}
pub async fn warm(
&self,
bucket: &str,
key: &str,
sql: &str,
etag: &str,
result: Bytes,
ttl_seconds: u64,
) {
self.put(sql, etag, result, ttl_seconds).await;
counter!("select_cache_warmings").increment(1);
tracing::info!(
bucket = %bucket,
key = %key,
sql = %sql,
"Cache warming completed"
);
}
pub async fn pattern_stats(&self) -> HashMap<String, serde_json::Value> {
let patterns = self.query_patterns.read().await;
let total_patterns = patterns.len();
let total_executions: u64 = patterns.values().map(|p| p.execution_count).sum();
let avg_executions = if total_patterns > 0 {
total_executions as f64 / total_patterns as f64
} else {
0.0
};
let mut stats = HashMap::new();
stats.insert(
"total_patterns".to_string(),
serde_json::json!(total_patterns),
);
stats.insert(
"total_executions".to_string(),
serde_json::json!(total_executions),
);
stats.insert(
"avg_executions_per_pattern".to_string(),
serde_json::json!(avg_executions),
);
stats
}
pub async fn clear_patterns(&self) {
let mut patterns = self.query_patterns.write().await;
patterns.clear();
}
pub async fn save_to_file(
&self,
path: &std::path::Path,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let cache = self.cache.read().await;
let patterns = self.query_patterns.read().await;
let stats = self.stats.read().await;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let state = PersistentCacheState {
cache: cache.clone(),
query_patterns: patterns.clone(),
stats: stats.clone(),
version: 1,
saved_at: now,
};
let json = serde_json::to_string_pretty(&state)?;
let temp_path = path.with_extension("tmp");
tokio::fs::write(&temp_path, json).await?;
tokio::fs::rename(&temp_path, path).await?;
tracing::info!(
path = %path.display(),
entries = cache.len(),
patterns = patterns.len(),
"Cache state saved to disk"
);
counter!("select_cache_saves").increment(1);
Ok(())
}
pub async fn load_from_file(
&self,
path: &std::path::Path,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let json = tokio::fs::read_to_string(path).await?;
let state: PersistentCacheState = serde_json::from_str(&json)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let mut cache = self.cache.write().await;
let mut patterns = self.query_patterns.write().await;
let mut stats = self.stats.write().await;
cache.clear();
patterns.clear();
let mut loaded_entries = 0;
let mut expired_entries = 0;
let mut total_size = 0u64;
for (key, entry) in state.cache {
if !entry.is_expired() {
total_size += entry.size_bytes as u64;
cache.insert(key, entry);
loaded_entries += 1;
} else {
expired_entries += 1;
}
}
*patterns = state.query_patterns;
stats.gets = state.stats.gets;
stats.hits = state.stats.hits;
stats.misses = state.stats.misses;
stats.evictions = state.stats.evictions;
stats.expirations = state.stats.expirations + expired_entries;
stats.current_entries = cache.len();
stats.memory_bytes = total_size;
gauge!("select_cache_entries").set(cache.len() as f64);
gauge!("select_cache_memory_bytes").set(total_size as f64);
counter!("select_cache_loads").increment(1);
tracing::info!(
path = %path.display(),
loaded_entries,
expired_entries,
loaded_patterns = patterns.len(),
age_seconds = now - state.saved_at,
"Cache state loaded from disk"
);
Ok(())
}
pub fn start_background_save(
self: Arc<Self>,
path: std::path::PathBuf,
interval_seconds: u64,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_secs(interval_seconds));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
match self.save_to_file(&path).await {
Ok(()) => {
tracing::debug!("Background cache save successful");
}
Err(e) => {
tracing::error!(
error = %e,
"Background cache save failed"
);
}
}
}
})
}
pub async fn calculate_adaptive_ttl(
&self,
bucket: &str,
key: &str,
sql: &str,
execution_time_ms: u64,
result_size_bytes: usize,
) -> u64 {
const MIN_TTL: u64 = 60; const MAX_TTL: u64 = 7200; const BASE_TTL: u64 = 600;
let pattern_key = Self::pattern_key(bucket, key, sql);
let patterns = self.query_patterns.read().await;
if let Some(pattern) = patterns.get(&pattern_key) {
let frequency_factor = if pattern.execution_count > 100 {
2.0
} else if pattern.execution_count > 50 {
1.5
} else if pattern.execution_count > 10 {
1.2
} else {
1.0
};
let time_factor = if execution_time_ms > 10000 {
2.0
} else if execution_time_ms > 5000 {
1.5
} else if execution_time_ms > 1000 {
1.2
} else {
1.0
};
let size_factor = if result_size_bytes < 10_000 {
1.3
} else if result_size_bytes < 100_000 {
1.1
} else if result_size_bytes < 1_000_000 {
1.0
} else {
0.8 };
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let seconds_since_last = now - pattern.last_executed;
let recency_factor = if seconds_since_last < 300 {
1.5
} else if seconds_since_last < 3600 {
1.2
} else {
1.0
};
let calculated_ttl =
(BASE_TTL as f64 * frequency_factor * time_factor * size_factor * recency_factor)
as u64;
calculated_ttl.clamp(MIN_TTL, MAX_TTL)
} else {
let time_factor = if execution_time_ms > 5000 {
1.5
} else if execution_time_ms > 1000 {
1.2
} else {
1.0
};
let calculated_ttl = (BASE_TTL as f64 * time_factor) as u64;
calculated_ttl.clamp(MIN_TTL, MAX_TTL)
}
}
pub async fn get_recommended_ttl(&self, bucket: &str, key: &str, sql: &str) -> Option<u64> {
let pattern_key = Self::pattern_key(bucket, key, sql);
let patterns = self.query_patterns.read().await;
patterns.get(&pattern_key).map(|pattern| {
self.calculate_ttl_from_pattern(pattern)
})
}
fn calculate_ttl_from_pattern(&self, pattern: &QueryPattern) -> u64 {
const MIN_TTL: u64 = 60;
const MAX_TTL: u64 = 7200;
const BASE_TTL: u64 = 600;
let frequency_factor = if pattern.execution_count > 100 {
2.0
} else if pattern.execution_count > 50 {
1.5
} else if pattern.execution_count > 10 {
1.2
} else {
1.0
};
let time_factor = if pattern.avg_execution_ms > 10000 {
2.0
} else if pattern.avg_execution_ms > 5000 {
1.5
} else if pattern.avg_execution_ms > 1000 {
1.2
} else {
1.0
};
let calculated_ttl = (BASE_TTL as f64 * frequency_factor * time_factor) as u64;
calculated_ttl.clamp(MIN_TTL, MAX_TTL)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_cache_creation() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let stats = cache.stats().await;
assert_eq!(stats.max_entries, 100);
assert_eq!(stats.max_memory_bytes, 1024 * 1024);
assert_eq!(stats.current_entries, 0);
assert_eq!(stats.memory_bytes, 0);
}
#[tokio::test]
async fn test_cache_put_and_get() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let sql = "SELECT * FROM s3object WHERE age > 25";
let etag = "etag123";
let result = Bytes::from("test result data");
cache.put(sql, etag, result.clone(), 0).await;
let cached = cache.get(sql, etag).await;
assert!(cached.is_some());
if let Some(cached_result) = cached {
assert_eq!(cached_result, result);
}
let stats = cache.stats().await;
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 0);
assert_eq!(stats.current_entries, 1);
}
#[tokio::test]
async fn test_cache_miss() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let result = cache.get("SELECT * FROM s3object", "etag123").await;
assert!(result.is_none());
let stats = cache.stats().await;
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
}
#[tokio::test]
async fn test_etag_invalidation() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let sql = "SELECT * FROM s3object";
let result = Bytes::from("test result");
cache.put(sql, "etag1", result.clone(), 0).await;
let cached = cache.get(sql, "etag2").await;
assert!(cached.is_none());
let cached = cache.get(sql, "etag1").await;
assert!(cached.is_some());
}
#[tokio::test]
async fn test_lru_eviction() {
let cache = SelectResultCache::new(2, 1024 * 1024);
let result = Bytes::from("test result");
cache.put("query1", "etag1", result.clone(), 0).await;
cache.put("query2", "etag2", result.clone(), 0).await;
let stats = cache.stats().await;
assert_eq!(stats.current_entries, 2);
tokio::time::sleep(Duration::from_millis(100)).await;
cache.put("query3", "etag3", result.clone(), 0).await;
let stats = cache.stats().await;
assert_eq!(stats.evictions, 1);
assert_eq!(stats.current_entries, 2);
let cached1 = cache.get("query1", "etag1").await;
let cached2 = cache.get("query2", "etag2").await;
assert!(
cached1.is_none() || cached2.is_none(),
"One entry should be evicted"
);
let cached3 = cache.get("query3", "etag3").await;
assert!(cached3.is_some());
}
#[tokio::test]
async fn test_ttl_expiration() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let sql = "SELECT * FROM s3object";
let etag = "etag123";
let result = Bytes::from("test result");
cache.put(sql, etag, result, 1).await;
let cached = cache.get(sql, etag).await;
assert!(cached.is_some());
tokio::time::sleep(Duration::from_secs(2)).await;
let cached = cache.get(sql, etag).await;
assert!(cached.is_none());
let stats = cache.stats().await;
assert_eq!(stats.expirations, 1);
}
#[tokio::test]
async fn test_invalidate_object() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let etag = "etag123";
let result = Bytes::from("test result");
cache.put("query1", etag, result.clone(), 0).await;
cache.put("query2", etag, result.clone(), 0).await;
cache.put("query3", "other_etag", result.clone(), 0).await;
cache.invalidate_object(etag).await;
assert!(cache.get("query1", etag).await.is_none());
assert!(cache.get("query2", etag).await.is_none());
assert!(cache.get("query3", "other_etag").await.is_some());
let stats = cache.stats().await;
assert_eq!(stats.current_entries, 1);
}
#[tokio::test]
async fn test_clear_cache() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let result = Bytes::from("test result");
cache.put("query1", "etag1", result.clone(), 0).await;
cache.put("query2", "etag2", result.clone(), 0).await;
let stats = cache.stats().await;
assert_eq!(stats.current_entries, 2);
cache.clear().await;
let stats = cache.stats().await;
assert_eq!(stats.current_entries, 0);
assert_eq!(stats.memory_bytes, 0);
}
#[tokio::test]
async fn test_hit_rate_calculation() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let result = Bytes::from("test");
cache.put("query", "etag", result, 0).await;
cache.get("query", "etag").await;
cache.get("query", "etag").await;
cache.get("other", "etag").await;
let stats = cache.stats().await;
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 66.666).abs() < 0.1);
}
#[tokio::test]
async fn test_query_pattern_recording() {
let cache = SelectResultCache::new(100, 1024 * 1024);
cache
.record_query_pattern("bucket1", "key1", "SELECT * FROM s3object", 100)
.await;
cache
.record_query_pattern("bucket1", "key1", "SELECT * FROM s3object", 150)
.await;
cache
.record_query_pattern("bucket1", "key1", "SELECT * FROM s3object", 120)
.await;
let top_queries = cache.get_top_queries(10).await;
assert_eq!(top_queries.len(), 1);
let pattern = &top_queries[0];
assert_eq!(pattern.execution_count, 3);
assert_eq!(pattern.bucket, "bucket1");
assert_eq!(pattern.key, "key1");
assert_eq!(pattern.sql, "SELECT * FROM s3object");
assert_eq!(pattern.avg_execution_ms, 123);
}
#[tokio::test]
async fn test_get_top_queries() {
let cache = SelectResultCache::new(100, 1024 * 1024);
for _ in 0..5 {
cache
.record_query_pattern("b1", "k1", "SELECT * FROM s3object", 100)
.await;
}
for _ in 0..3 {
cache
.record_query_pattern("b1", "k2", "SELECT id FROM s3object", 100)
.await;
}
for _ in 0..7 {
cache
.record_query_pattern("b1", "k3", "SELECT name FROM s3object", 100)
.await;
}
let top_queries = cache.get_top_queries(2).await;
assert_eq!(top_queries.len(), 2);
assert_eq!(top_queries[0].execution_count, 7);
assert_eq!(top_queries[0].sql, "SELECT name FROM s3object");
assert_eq!(top_queries[1].execution_count, 5);
assert_eq!(top_queries[1].sql, "SELECT * FROM s3object");
}
#[tokio::test]
async fn test_get_recent_queries() {
let cache = SelectResultCache::new(100, 1024 * 1024);
cache
.record_query_pattern("bucket1", "key1", "SELECT * FROM s3object", 100)
.await;
let recent = cache.get_recent_queries(3600).await;
assert_eq!(recent.len(), 1);
let recent_future = cache.get_recent_queries(-1).await;
assert_eq!(recent_future.len(), 0);
}
#[tokio::test]
async fn test_cache_warming() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let result = Bytes::from("warmed result");
cache
.warm(
"bucket1",
"key1",
"SELECT * FROM s3object",
"etag123",
result.clone(),
3600,
)
.await;
let cached = cache.get("SELECT * FROM s3object", "etag123").await;
assert!(cached.is_some());
if let Some(cached_result) = cached {
assert_eq!(cached_result, result);
}
let stats = cache.stats().await;
assert_eq!(stats.current_entries, 1);
}
#[tokio::test]
async fn test_pattern_stats() {
let cache = SelectResultCache::new(100, 1024 * 1024);
cache.record_query_pattern("b1", "k1", "query1", 100).await;
cache.record_query_pattern("b1", "k1", "query1", 100).await;
cache.record_query_pattern("b1", "k2", "query2", 100).await;
let stats = cache.pattern_stats().await;
assert_eq!(
stats.get("total_patterns").and_then(|v| v.as_u64()),
Some(2)
);
assert_eq!(
stats.get("total_executions").and_then(|v| v.as_u64()),
Some(3)
);
}
#[tokio::test]
async fn test_clear_patterns() {
let cache = SelectResultCache::new(100, 1024 * 1024);
cache.record_query_pattern("b1", "k1", "query1", 100).await;
cache.record_query_pattern("b1", "k2", "query2", 100).await;
let stats = cache.pattern_stats().await;
assert_eq!(
stats.get("total_patterns").and_then(|v| v.as_u64()),
Some(2)
);
cache.clear_patterns().await;
let stats = cache.pattern_stats().await;
assert_eq!(
stats.get("total_patterns").and_then(|v| v.as_u64()),
Some(0)
);
}
#[tokio::test]
async fn test_cache_persistence_save_and_load() {
let temp_dir = std::env::temp_dir();
let cache_file = temp_dir.join(format!("test_cache_{}.json", uuid::Uuid::new_v4()));
let cache = SelectResultCache::new(100, 1024 * 1024);
let result1 = Bytes::from("test result 1");
let result2 = Bytes::from("test result 2");
cache.put("query1", "etag1", result1.clone(), 0).await;
cache.put("query2", "etag2", result2.clone(), 0).await;
cache
.record_query_pattern("bucket1", "key1", "query1", 100)
.await;
let save_result = cache.save_to_file(&cache_file).await;
assert!(save_result.is_ok(), "Save failed: {:?}", save_result.err());
assert!(cache_file.exists(), "Cache file was not created");
let cache2 = SelectResultCache::new(100, 1024 * 1024);
let load_result = cache2.load_from_file(&cache_file).await;
assert!(load_result.is_ok(), "Load failed: {:?}", load_result.err());
let loaded1 = cache2.get("query1", "etag1").await;
assert!(loaded1.is_some());
assert_eq!(loaded1, Some(result1));
let loaded2 = cache2.get("query2", "etag2").await;
assert!(loaded2.is_some());
assert_eq!(loaded2, Some(result2));
let patterns = cache2.get_top_queries(10).await;
assert_eq!(patterns.len(), 1);
assert_eq!(patterns[0].sql, "query1");
let _ = tokio::fs::remove_file(&cache_file).await;
}
#[tokio::test]
async fn test_cache_persistence_expired_entries_filtered() {
let temp_dir = std::env::temp_dir();
let cache_file = temp_dir.join(format!("test_cache_expired_{}.json", uuid::Uuid::new_v4()));
let cache = SelectResultCache::new(100, 1024 * 1024);
let result1 = Bytes::from("result with no expiry");
let result2 = Bytes::from("result with short ttl");
cache.put("query1", "etag1", result1.clone(), 0).await; cache.put("query2", "etag2", result2.clone(), 1).await;
tokio::time::sleep(Duration::from_secs(2)).await;
cache.save_to_file(&cache_file).await.expect("Save failed");
let cache2 = SelectResultCache::new(100, 1024 * 1024);
cache2
.load_from_file(&cache_file)
.await
.expect("Load failed");
let loaded1 = cache2.get("query1", "etag1").await;
assert!(loaded1.is_some(), "Non-expired entry should be loaded");
let loaded2 = cache2.get("query2", "etag2").await;
assert!(loaded2.is_none(), "Expired entry should not be loaded");
let stats = cache2.stats().await;
assert_eq!(stats.current_entries, 1, "Should only have 1 entry");
assert!(stats.expirations > 0, "Should count expired entries");
let _ = tokio::fs::remove_file(&cache_file).await;
}
#[tokio::test]
async fn test_cache_persistence_atomic_write() {
let temp_dir = std::env::temp_dir();
let cache_file = temp_dir.join(format!("test_cache_atomic_{}.json", uuid::Uuid::new_v4()));
let cache = SelectResultCache::new(100, 1024 * 1024);
let result = Bytes::from("test data");
cache.put("query", "etag", result, 0).await;
cache.save_to_file(&cache_file).await.expect("Save failed");
let temp_file = cache_file.with_extension("tmp");
assert!(
!temp_file.exists(),
"Temporary file should be removed after save"
);
assert!(cache_file.exists(), "Final cache file should exist");
let _ = tokio::fs::remove_file(&cache_file).await;
}
#[tokio::test]
async fn test_cache_persistence_stats_preserved() {
let temp_dir = std::env::temp_dir();
let cache_file = temp_dir.join(format!("test_cache_stats_{}.json", uuid::Uuid::new_v4()));
let cache = SelectResultCache::new(100, 1024 * 1024);
let result = Bytes::from("test");
cache.put("query", "etag", result, 0).await;
cache.get("query", "etag").await; cache.get("other", "etag").await;
let original_stats = cache.stats().await;
assert_eq!(original_stats.hits, 1);
assert_eq!(original_stats.misses, 1);
cache.save_to_file(&cache_file).await.expect("Save failed");
let cache2 = SelectResultCache::new(100, 1024 * 1024);
cache2
.load_from_file(&cache_file)
.await
.expect("Load failed");
let loaded_stats = cache2.stats().await;
assert_eq!(loaded_stats.hits, original_stats.hits);
assert_eq!(loaded_stats.misses, original_stats.misses);
assert_eq!(loaded_stats.current_entries, original_stats.current_entries);
let _ = tokio::fs::remove_file(&cache_file).await;
}
#[tokio::test]
async fn test_background_save_starts() {
let temp_dir = std::env::temp_dir();
let cache_file = temp_dir.join(format!("test_cache_bg_{}.json", uuid::Uuid::new_v4()));
let cache = Arc::new(SelectResultCache::new(100, 1024 * 1024));
cache.put("query", "etag", Bytes::from("test"), 0).await;
let handle = cache.clone().start_background_save(cache_file.clone(), 1);
tokio::time::sleep(Duration::from_secs(2)).await;
assert!(cache_file.exists(), "Background save should create file");
handle.abort();
let _ = tokio::fs::remove_file(&cache_file).await;
}
#[tokio::test]
async fn test_adaptive_ttl_no_pattern() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let ttl = cache
.calculate_adaptive_ttl("bucket", "key", "SELECT * FROM s3object", 500, 5000)
.await;
assert!((60..=7200).contains(&ttl), "TTL should be within bounds");
assert!(
(600..=720).contains(&ttl),
"TTL should be around base value for unknown pattern"
);
}
#[tokio::test]
async fn test_adaptive_ttl_expensive_query() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let ttl = cache
.calculate_adaptive_ttl("bucket", "key", "SELECT * FROM s3object", 15000, 5000)
.await;
assert!(
ttl > 600,
"Expensive queries should get longer TTL (got {})",
ttl
);
assert!(ttl <= 7200, "TTL should not exceed maximum");
}
#[tokio::test]
async fn test_adaptive_ttl_small_result() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let ttl = cache
.calculate_adaptive_ttl("bucket", "key", "SELECT * FROM s3object", 500, 8000)
.await;
assert!(ttl >= 600, "Small results should get longer TTL");
}
#[tokio::test]
async fn test_adaptive_ttl_with_pattern_history() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let bucket = "test-bucket";
let key = "test-key";
let sql = "SELECT * FROM s3object WHERE age > 25";
for _ in 0..15 {
cache.record_query_pattern(bucket, key, sql, 1200).await;
}
let ttl = cache
.calculate_adaptive_ttl(bucket, key, sql, 1200, 10000)
.await;
assert!(
ttl > 600,
"Frequently executed queries should get longer TTL (got {})",
ttl
);
assert!(
ttl >= 720,
"With 15 executions, TTL should be at least 1.2x base"
);
}
#[tokio::test]
async fn test_adaptive_ttl_high_frequency() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let bucket = "test-bucket";
let key = "test-key";
let sql = "SELECT * FROM s3object WHERE status = 'active'";
for _ in 0..105 {
cache.record_query_pattern(bucket, key, sql, 800).await;
}
let ttl = cache
.calculate_adaptive_ttl(bucket, key, sql, 800, 5000)
.await;
assert!(
ttl >= 1200,
"Very frequent queries should get 2x+ TTL (got {})",
ttl
);
}
#[tokio::test]
async fn test_adaptive_ttl_combined_factors() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let bucket = "test-bucket";
let key = "test-key";
let sql = "SELECT COUNT(*) FROM s3object GROUP BY category";
for _ in 0..55 {
cache.record_query_pattern(bucket, key, sql, 6000).await;
}
let ttl = cache
.calculate_adaptive_ttl(bucket, key, sql, 6000, 3000)
.await;
assert!(
ttl >= 1400,
"Combined factors should significantly increase TTL (got {})",
ttl
);
assert!(ttl <= 7200, "TTL should not exceed maximum");
}
#[tokio::test]
async fn test_adaptive_ttl_large_result() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let ttl = cache
.calculate_adaptive_ttl("bucket", "key", "SELECT * FROM s3object", 500, 2_000_000)
.await;
assert!(ttl >= 60, "TTL should be at least minimum");
assert!(
ttl <= 600,
"Large results should get shorter TTL (got {})",
ttl
);
}
#[tokio::test]
async fn test_recommended_ttl() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let bucket = "test-bucket";
let key = "test-key";
let sql = "SELECT * FROM s3object LIMIT 100";
let ttl = cache.get_recommended_ttl(bucket, key, sql).await;
assert!(ttl.is_none(), "Should return None when no pattern exists");
for _ in 0..20 {
cache.record_query_pattern(bucket, key, sql, 1500).await;
}
let ttl = cache.get_recommended_ttl(bucket, key, sql).await;
assert!(ttl.is_some(), "Should return Some when pattern exists");
let ttl_value = ttl.expect("TTL should be present");
assert!(
(60..=7200).contains(&ttl_value),
"Recommended TTL should be within bounds (got {})",
ttl_value
);
}
#[tokio::test]
async fn test_adaptive_ttl_bounds() {
let cache = SelectResultCache::new(100, 1024 * 1024);
let bucket = "test-bucket";
let key = "test-key";
let sql = "SELECT * FROM s3object";
for _ in 0..150 {
cache.record_query_pattern(bucket, key, sql, 20000).await;
}
let ttl = cache
.calculate_adaptive_ttl(bucket, key, sql, 20000, 1000)
.await;
assert!(
ttl <= 7200,
"TTL should be clamped to maximum (got {})",
ttl
);
assert!(ttl >= 60, "TTL should be at least minimum (got {})", ttl);
}
}