1pub mod config;
53pub mod l1_hot;
54pub mod l2_warm;
55pub mod l3_semantic;
56pub mod normalizer;
57pub mod invalidation;
58pub mod metrics;
59pub mod hints;
60pub mod result;
61
62pub use config::{CacheConfig, L1Config, L2Config, L3Config, StorageBackend};
64pub use l1_hot::L1HotCache;
65pub use l2_warm::L2WarmCache;
66pub use l3_semantic::L3SemanticCache;
67pub use normalizer::{QueryNormalizer, NormalizedQuery};
68pub use invalidation::{InvalidationManager, InvalidationMode};
69pub use metrics::{CacheMetrics, CacheStatsSnapshot, CacheStatsLevelSnapshot};
70pub use hints::{CacheHint, parse_cache_hints};
71pub use result::{CachedResult, CacheKey};
72
73use bytes::Bytes;
74use dashmap::DashMap;
75use std::sync::Arc;
76use std::time::{Duration, Instant};
77
78#[derive(Debug, Clone, Hash, Eq, PartialEq)]
80pub struct CacheContext {
81 pub database: String,
83 pub user: Option<String>,
85 pub branch: Option<String>,
87 pub connection_id: Option<u64>,
89}
90
91impl Default for CacheContext {
92 fn default() -> Self {
93 Self {
94 database: "default".to_string(),
95 user: None,
96 branch: None,
97 connection_id: None,
98 }
99 }
100}
101
102#[derive(Debug)]
104pub enum CacheLookup {
105 Hit {
107 result: CachedResult,
108 level: CacheLevel,
109 },
110 Miss,
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum CacheLevel {
117 L1Hot,
118 L2Warm,
119 L3Semantic,
120}
121
122impl std::fmt::Display for CacheLevel {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 CacheLevel::L1Hot => write!(f, "L1"),
126 CacheLevel::L2Warm => write!(f, "L2"),
127 CacheLevel::L3Semantic => write!(f, "L3"),
128 }
129 }
130}
131
132pub struct QueryCache {
134 config: CacheConfig,
136
137 l1_caches: DashMap<u64, Arc<L1HotCache>>,
139
140 l2_cache: Option<Arc<L2WarmCache>>,
142
143 l3_cache: Option<Arc<L3SemanticCache>>,
145
146 normalizer: Arc<QueryNormalizer>,
148
149 invalidator: Arc<InvalidationManager>,
151
152 metrics: Arc<CacheMetrics>,
154
155 pending_requests: DashMap<CacheKey, Arc<tokio::sync::Notify>>,
157}
158
159impl QueryCache {
160 pub fn new(config: CacheConfig) -> Self {
162 let l2_cache = if config.l2.enabled {
163 Some(Arc::new(L2WarmCache::new(config.l2.clone())))
164 } else {
165 None
166 };
167
168 let l3_cache = if config.l3.enabled {
169 Some(Arc::new(L3SemanticCache::new(config.l3.clone())))
170 } else {
171 None
172 };
173
174 let invalidator = Arc::new(InvalidationManager::new(config.invalidation.clone()));
175
176 Self {
177 config: config.clone(),
178 l1_caches: DashMap::new(),
179 l2_cache,
180 l3_cache,
181 normalizer: Arc::new(QueryNormalizer::new()),
182 invalidator,
183 metrics: Arc::new(CacheMetrics::new()),
184 pending_requests: DashMap::new(),
185 }
186 }
187
188 pub fn get_l1_cache(&self, connection_id: u64) -> Arc<L1HotCache> {
190 self.l1_caches
191 .entry(connection_id)
192 .or_insert_with(|| Arc::new(L1HotCache::new(self.config.l1.clone())))
193 .clone()
194 }
195
196 pub fn remove_l1_cache(&self, connection_id: u64) {
198 self.l1_caches.remove(&connection_id);
199 }
200
201 pub async fn get(&self, query: &str, context: &CacheContext) -> CacheLookup {
203 let hints = parse_cache_hints(query);
205
206 if hints.skip {
208 self.metrics.record_skip();
209 return CacheLookup::Miss;
210 }
211
212 let start = Instant::now();
213
214 if self.config.l1.enabled {
216 if let Some(conn_id) = context.connection_id {
217 let l1 = self.get_l1_cache(conn_id);
218 if let Some(result) = l1.get(query) {
219 self.metrics.record_hit(CacheLevel::L1Hot, start.elapsed());
220 return CacheLookup::Hit {
221 result,
222 level: CacheLevel::L1Hot,
223 };
224 }
225 }
226 }
227
228 let normalized = self.normalizer.normalize(query);
230 let cache_key = CacheKey::new(&normalized, context);
231
232 if let Some(ref l2) = self.l2_cache {
234 if let Some(result) = l2.get(&cache_key).await {
235 self.metrics.record_hit(CacheLevel::L2Warm, start.elapsed());
236
237 if self.config.l1.enabled {
239 if let Some(conn_id) = context.connection_id {
240 let l1 = self.get_l1_cache(conn_id);
241 l1.put(query.to_string(), result.clone());
242 }
243 }
244
245 return CacheLookup::Hit {
246 result,
247 level: CacheLevel::L2Warm,
248 };
249 }
250 }
251
252 if hints.semantic_cache {
254 if let Some(ref l3) = self.l3_cache {
255 if let Some(result) = l3.get(query, context).await {
256 self.metrics.record_hit(CacheLevel::L3Semantic, start.elapsed());
257 return CacheLookup::Hit {
258 result,
259 level: CacheLevel::L3Semantic,
260 };
261 }
262 }
263 }
264
265 self.metrics.record_miss(start.elapsed());
266 CacheLookup::Miss
267 }
268
269 pub async fn put(
271 &self,
272 query: &str,
273 context: &CacheContext,
274 data: Bytes,
275 row_count: usize,
276 execution_time: Duration,
277 ) {
278 let hints = parse_cache_hints(query);
280
281 if hints.skip {
283 return;
284 }
285
286 let normalized = self.normalizer.normalize(query);
288
289 let ttl = hints.ttl.unwrap_or_else(|| {
291 self.get_table_ttl(&normalized.tables)
292 });
293
294 if data.len() > self.config.max_result_size {
296 self.metrics.record_size_exceeded();
297 return;
298 }
299
300 let result = CachedResult {
302 data,
303 row_count,
304 cached_at: Instant::now(),
305 ttl,
306 tables: normalized.tables.clone(),
307 execution_time,
308 };
309
310 if self.config.l1.enabled {
312 if let Some(conn_id) = context.connection_id {
313 let l1 = self.get_l1_cache(conn_id);
314 l1.put(query.to_string(), result.clone());
315 }
316 }
317
318 if let Some(ref l2) = self.l2_cache {
320 let cache_key = CacheKey::new(&normalized, context);
321 l2.put(cache_key.clone(), result.clone()).await;
322
323 for table in &normalized.tables {
325 self.invalidator.register(&cache_key, table);
326 }
327 }
328
329 if hints.semantic_cache {
331 if let Some(ref l3) = self.l3_cache {
332 l3.put(query, context, result).await;
333 }
334 }
335
336 self.metrics.record_put();
337 }
338
339 pub async fn invalidate_tables(&self, tables: &[String]) {
341 for table in tables {
342 let keys = self.invalidator.get_keys_for_table(table);
343
344 if let Some(ref l2) = self.l2_cache {
346 for key in &keys {
347 l2.remove(key).await;
348 }
349 }
350
351 self.invalidator.invalidate_table(table);
352 }
353
354 self.metrics.record_invalidation(tables.len());
358 }
359
360 pub async fn clear(&self, levels: &[CacheLevel]) {
362 for level in levels {
363 match level {
364 CacheLevel::L1Hot => {
365 self.l1_caches.clear();
366 }
367 CacheLevel::L2Warm => {
368 if let Some(ref l2) = self.l2_cache {
369 l2.clear().await;
370 }
371 }
372 CacheLevel::L3Semantic => {
373 if let Some(ref l3) = self.l3_cache {
374 l3.clear().await;
375 }
376 }
377 }
378 }
379
380 self.metrics.record_clear();
381 }
382
383 pub fn stats(&self) -> CacheStatsSnapshot {
385 self.metrics.snapshot()
386 }
387
388 pub fn config(&self) -> &CacheConfig {
390 &self.config
391 }
392
393 pub fn invalidator(&self) -> Arc<InvalidationManager> {
395 self.invalidator.clone()
396 }
397
398 fn get_table_ttl(&self, tables: &[String]) -> Duration {
400 let mut min_ttl = self.config.default_ttl;
402
403 for table in tables {
404 if let Some(table_config) = self.config.table_configs.get(table) {
405 if table_config.ttl < min_ttl {
406 min_ttl = table_config.ttl;
407 }
408 }
409 }
410
411 min_ttl
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_cache_context_default() {
421 let ctx = CacheContext::default();
422 assert_eq!(ctx.database, "default");
423 assert!(ctx.user.is_none());
424 assert!(ctx.branch.is_none());
425 assert!(ctx.connection_id.is_none());
426 }
427
428 #[test]
429 fn test_cache_level_display() {
430 assert_eq!(format!("{}", CacheLevel::L1Hot), "L1");
431 assert_eq!(format!("{}", CacheLevel::L2Warm), "L2");
432 assert_eq!(format!("{}", CacheLevel::L3Semantic), "L3");
433 }
434
435 #[tokio::test]
436 async fn test_query_cache_creation() {
437 let config = CacheConfig::default();
438 let cache = QueryCache::new(config);
439
440 assert!(cache.config.l1.enabled);
441 assert!(cache.config.l2.enabled);
442 }
443
444 #[tokio::test]
445 async fn test_l1_cache_per_connection() {
446 let config = CacheConfig::default();
447 let cache = QueryCache::new(config);
448
449 let l1_a = cache.get_l1_cache(1);
450 let l1_b = cache.get_l1_cache(2);
451 let l1_a2 = cache.get_l1_cache(1);
452
453 assert!(Arc::ptr_eq(&l1_a, &l1_a2));
455 assert!(!Arc::ptr_eq(&l1_a, &l1_b));
457 }
458
459 #[tokio::test]
460 async fn test_cache_miss() {
461 let config = CacheConfig::default();
462 let cache = QueryCache::new(config);
463 let context = CacheContext::default();
464
465 let result = cache.get("SELECT * FROM users", &context).await;
466 assert!(matches!(result, CacheLookup::Miss));
467 }
468}