pub mod config;
pub mod l1_hot;
pub mod l2_warm;
pub mod l3_semantic;
pub mod normalizer;
pub mod invalidation;
pub mod metrics;
pub mod hints;
pub mod result;
pub use config::{CacheConfig, L1Config, L2Config, L3Config, StorageBackend};
pub use l1_hot::L1HotCache;
pub use l2_warm::L2WarmCache;
pub use l3_semantic::L3SemanticCache;
pub use normalizer::{QueryNormalizer, NormalizedQuery};
pub use invalidation::{InvalidationManager, InvalidationMode};
pub use metrics::{CacheMetrics, CacheStatsSnapshot, CacheStatsLevelSnapshot};
pub use hints::{CacheHint, parse_cache_hints};
pub use result::{CachedResult, CacheKey};
use bytes::Bytes;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct CacheContext {
pub database: String,
pub user: Option<String>,
pub branch: Option<String>,
pub connection_id: Option<u64>,
}
impl Default for CacheContext {
fn default() -> Self {
Self {
database: "default".to_string(),
user: None,
branch: None,
connection_id: None,
}
}
}
#[derive(Debug)]
pub enum CacheLookup {
Hit {
result: CachedResult,
level: CacheLevel,
},
Miss,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheLevel {
L1Hot,
L2Warm,
L3Semantic,
}
impl std::fmt::Display for CacheLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CacheLevel::L1Hot => write!(f, "L1"),
CacheLevel::L2Warm => write!(f, "L2"),
CacheLevel::L3Semantic => write!(f, "L3"),
}
}
}
pub struct QueryCache {
config: CacheConfig,
l1_caches: DashMap<u64, Arc<L1HotCache>>,
l2_cache: Option<Arc<L2WarmCache>>,
l3_cache: Option<Arc<L3SemanticCache>>,
normalizer: Arc<QueryNormalizer>,
invalidator: Arc<InvalidationManager>,
metrics: Arc<CacheMetrics>,
pending_requests: DashMap<CacheKey, Arc<tokio::sync::Notify>>,
}
impl QueryCache {
pub fn new(config: CacheConfig) -> Self {
let l2_cache = if config.l2.enabled {
Some(Arc::new(L2WarmCache::new(config.l2.clone())))
} else {
None
};
let l3_cache = if config.l3.enabled {
Some(Arc::new(L3SemanticCache::new(config.l3.clone())))
} else {
None
};
let invalidator = Arc::new(InvalidationManager::new(config.invalidation.clone()));
Self {
config: config.clone(),
l1_caches: DashMap::new(),
l2_cache,
l3_cache,
normalizer: Arc::new(QueryNormalizer::new()),
invalidator,
metrics: Arc::new(CacheMetrics::new()),
pending_requests: DashMap::new(),
}
}
pub fn get_l1_cache(&self, connection_id: u64) -> Arc<L1HotCache> {
self.l1_caches
.entry(connection_id)
.or_insert_with(|| Arc::new(L1HotCache::new(self.config.l1.clone())))
.clone()
}
pub fn remove_l1_cache(&self, connection_id: u64) {
self.l1_caches.remove(&connection_id);
}
pub async fn get(&self, query: &str, context: &CacheContext) -> CacheLookup {
let hints = parse_cache_hints(query);
if hints.skip {
self.metrics.record_skip();
return CacheLookup::Miss;
}
let start = Instant::now();
if self.config.l1.enabled {
if let Some(conn_id) = context.connection_id {
let l1 = self.get_l1_cache(conn_id);
if let Some(result) = l1.get(query) {
self.metrics.record_hit(CacheLevel::L1Hot, start.elapsed());
return CacheLookup::Hit {
result,
level: CacheLevel::L1Hot,
};
}
}
}
let normalized = self.normalizer.normalize(query);
let cache_key = CacheKey::new(&normalized, context);
if let Some(ref l2) = self.l2_cache {
if let Some(result) = l2.get(&cache_key).await {
self.metrics.record_hit(CacheLevel::L2Warm, start.elapsed());
if self.config.l1.enabled {
if let Some(conn_id) = context.connection_id {
let l1 = self.get_l1_cache(conn_id);
l1.put(query.to_string(), result.clone());
}
}
return CacheLookup::Hit {
result,
level: CacheLevel::L2Warm,
};
}
}
if hints.semantic_cache {
if let Some(ref l3) = self.l3_cache {
if let Some(result) = l3.get(query, context).await {
self.metrics.record_hit(CacheLevel::L3Semantic, start.elapsed());
return CacheLookup::Hit {
result,
level: CacheLevel::L3Semantic,
};
}
}
}
self.metrics.record_miss(start.elapsed());
CacheLookup::Miss
}
pub async fn put(
&self,
query: &str,
context: &CacheContext,
data: Bytes,
row_count: usize,
execution_time: Duration,
) {
let hints = parse_cache_hints(query);
if hints.skip {
return;
}
let normalized = self.normalizer.normalize(query);
let ttl = hints.ttl.unwrap_or_else(|| {
self.get_table_ttl(&normalized.tables)
});
if data.len() > self.config.max_result_size {
self.metrics.record_size_exceeded();
return;
}
let result = CachedResult {
data,
row_count,
cached_at: Instant::now(),
ttl,
tables: normalized.tables.clone(),
execution_time,
};
if self.config.l1.enabled {
if let Some(conn_id) = context.connection_id {
let l1 = self.get_l1_cache(conn_id);
l1.put(query.to_string(), result.clone());
}
}
if let Some(ref l2) = self.l2_cache {
let cache_key = CacheKey::new(&normalized, context);
l2.put(cache_key.clone(), result.clone()).await;
for table in &normalized.tables {
self.invalidator.register(&cache_key, table);
}
}
if hints.semantic_cache {
if let Some(ref l3) = self.l3_cache {
l3.put(query, context, result).await;
}
}
self.metrics.record_put();
}
pub async fn invalidate_tables(&self, tables: &[String]) {
for table in tables {
let keys = self.invalidator.get_keys_for_table(table);
if let Some(ref l2) = self.l2_cache {
for key in &keys {
l2.remove(key).await;
}
}
self.invalidator.invalidate_table(table);
}
self.metrics.record_invalidation(tables.len());
}
pub async fn clear(&self, levels: &[CacheLevel]) {
for level in levels {
match level {
CacheLevel::L1Hot => {
self.l1_caches.clear();
}
CacheLevel::L2Warm => {
if let Some(ref l2) = self.l2_cache {
l2.clear().await;
}
}
CacheLevel::L3Semantic => {
if let Some(ref l3) = self.l3_cache {
l3.clear().await;
}
}
}
}
self.metrics.record_clear();
}
pub fn stats(&self) -> CacheStatsSnapshot {
self.metrics.snapshot()
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub fn invalidator(&self) -> Arc<InvalidationManager> {
self.invalidator.clone()
}
fn get_table_ttl(&self, tables: &[String]) -> Duration {
let mut min_ttl = self.config.default_ttl;
for table in tables {
if let Some(table_config) = self.config.table_configs.get(table) {
if table_config.ttl < min_ttl {
min_ttl = table_config.ttl;
}
}
}
min_ttl
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_context_default() {
let ctx = CacheContext::default();
assert_eq!(ctx.database, "default");
assert!(ctx.user.is_none());
assert!(ctx.branch.is_none());
assert!(ctx.connection_id.is_none());
}
#[test]
fn test_cache_level_display() {
assert_eq!(format!("{}", CacheLevel::L1Hot), "L1");
assert_eq!(format!("{}", CacheLevel::L2Warm), "L2");
assert_eq!(format!("{}", CacheLevel::L3Semantic), "L3");
}
#[tokio::test]
async fn test_query_cache_creation() {
let config = CacheConfig::default();
let cache = QueryCache::new(config);
assert!(cache.config.l1.enabled);
assert!(cache.config.l2.enabled);
}
#[tokio::test]
async fn test_l1_cache_per_connection() {
let config = CacheConfig::default();
let cache = QueryCache::new(config);
let l1_a = cache.get_l1_cache(1);
let l1_b = cache.get_l1_cache(2);
let l1_a2 = cache.get_l1_cache(1);
assert!(Arc::ptr_eq(&l1_a, &l1_a2));
assert!(!Arc::ptr_eq(&l1_a, &l1_b));
}
#[tokio::test]
async fn test_cache_miss() {
let config = CacheConfig::default();
let cache = QueryCache::new(config);
let context = CacheContext::default();
let result = cache.get("SELECT * FROM users", &context).await;
assert!(matches!(result, CacheLookup::Miss));
}
}