1pub mod config;
53pub mod hints;
54pub mod invalidation;
55pub mod l1_hot;
56pub mod l2_warm;
57pub mod l3_semantic;
58pub mod metrics;
59pub mod normalizer;
60pub mod result;
61
62pub use config::{CacheConfig, L1Config, L2Config, L3Config, StorageBackend};
64pub use hints::{parse_cache_hints, CacheHint};
65pub use invalidation::{InvalidationManager, InvalidationMode};
66pub use l1_hot::L1HotCache;
67pub use l2_warm::L2WarmCache;
68pub use l3_semantic::L3SemanticCache;
69pub use metrics::{CacheMetrics, CacheStatsLevelSnapshot, CacheStatsSnapshot};
70pub use normalizer::{NormalizedQuery, QueryNormalizer};
71pub use result::{CacheKey, CachedResult};
72
73use bytes::Bytes;
74use dashmap::DashMap;
75use std::sync::Arc;
76use std::time::{Duration, Instant};
77
78#[derive(Debug, Clone, Hash, Eq, PartialEq)]
80pub struct CacheContext {
81 pub database: String,
83 pub user: Option<String>,
85 pub branch: Option<String>,
87 pub connection_id: Option<u64>,
89}
90
91impl Default for CacheContext {
92 fn default() -> Self {
93 Self {
94 database: "default".to_string(),
95 user: None,
96 branch: None,
97 connection_id: None,
98 }
99 }
100}
101
102#[derive(Debug)]
104pub enum CacheLookup {
105 Hit {
107 result: CachedResult,
108 level: CacheLevel,
109 },
110 Miss,
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum CacheLevel {
117 L1Hot,
118 L2Warm,
119 L3Semantic,
120}
121
122impl std::fmt::Display for CacheLevel {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 CacheLevel::L1Hot => write!(f, "L1"),
126 CacheLevel::L2Warm => write!(f, "L2"),
127 CacheLevel::L3Semantic => write!(f, "L3"),
128 }
129 }
130}
131
132pub struct QueryCache {
134 config: CacheConfig,
136
137 l1_caches: DashMap<u64, Arc<L1HotCache>>,
139
140 l2_cache: Option<Arc<L2WarmCache>>,
142
143 l3_cache: Option<Arc<L3SemanticCache>>,
145
146 normalizer: Arc<QueryNormalizer>,
148
149 invalidator: Arc<InvalidationManager>,
151
152 metrics: Arc<CacheMetrics>,
154
155 #[allow(dead_code)]
157 pending_requests: DashMap<CacheKey, Arc<tokio::sync::Notify>>,
158}
159
160impl QueryCache {
161 pub fn new(config: CacheConfig) -> Self {
163 let l2_cache = if config.l2.enabled {
164 Some(Arc::new(L2WarmCache::new(config.l2.clone())))
165 } else {
166 None
167 };
168
169 let l3_cache = if config.l3.enabled {
170 Some(Arc::new(L3SemanticCache::new(config.l3.clone())))
171 } else {
172 None
173 };
174
175 let invalidator = Arc::new(InvalidationManager::new(config.invalidation.clone()));
176
177 Self {
178 config: config.clone(),
179 l1_caches: DashMap::new(),
180 l2_cache,
181 l3_cache,
182 normalizer: Arc::new(QueryNormalizer::new()),
183 invalidator,
184 metrics: Arc::new(CacheMetrics::new()),
185 pending_requests: DashMap::new(),
186 }
187 }
188
189 pub fn get_l1_cache(&self, connection_id: u64) -> Arc<L1HotCache> {
191 self.l1_caches
192 .entry(connection_id)
193 .or_insert_with(|| Arc::new(L1HotCache::new(self.config.l1.clone())))
194 .clone()
195 }
196
197 pub fn remove_l1_cache(&self, connection_id: u64) {
199 self.l1_caches.remove(&connection_id);
200 }
201
202 pub async fn get(&self, query: &str, context: &CacheContext) -> CacheLookup {
204 let hints = parse_cache_hints(query);
206
207 if hints.skip {
209 self.metrics.record_skip();
210 return CacheLookup::Miss;
211 }
212
213 let start = Instant::now();
214
215 if self.config.l1.enabled {
217 if let Some(conn_id) = context.connection_id {
218 let l1 = self.get_l1_cache(conn_id);
219 if let Some(result) = l1.get(query) {
220 self.metrics.record_hit(CacheLevel::L1Hot, start.elapsed());
221 return CacheLookup::Hit {
222 result,
223 level: CacheLevel::L1Hot,
224 };
225 }
226 }
227 }
228
229 let normalized = self.normalizer.normalize(query);
231 let cache_key = CacheKey::new(&normalized, context);
232
233 if let Some(ref l2) = self.l2_cache {
235 if let Some(result) = l2.get(&cache_key).await {
236 self.metrics.record_hit(CacheLevel::L2Warm, start.elapsed());
237
238 if self.config.l1.enabled {
240 if let Some(conn_id) = context.connection_id {
241 let l1 = self.get_l1_cache(conn_id);
242 l1.put(query.to_string(), result.clone());
243 }
244 }
245
246 return CacheLookup::Hit {
247 result,
248 level: CacheLevel::L2Warm,
249 };
250 }
251 }
252
253 if hints.semantic_cache {
255 if let Some(ref l3) = self.l3_cache {
256 if let Some(result) = l3.get(query, context).await {
257 self.metrics
258 .record_hit(CacheLevel::L3Semantic, start.elapsed());
259 return CacheLookup::Hit {
260 result,
261 level: CacheLevel::L3Semantic,
262 };
263 }
264 }
265 }
266
267 self.metrics.record_miss(start.elapsed());
268 CacheLookup::Miss
269 }
270
271 pub async fn put(
273 &self,
274 query: &str,
275 context: &CacheContext,
276 data: Bytes,
277 row_count: usize,
278 execution_time: Duration,
279 ) {
280 let hints = parse_cache_hints(query);
282
283 if hints.skip {
285 return;
286 }
287
288 let normalized = self.normalizer.normalize(query);
290
291 let ttl = hints
293 .ttl
294 .unwrap_or_else(|| self.get_table_ttl(&normalized.tables));
295
296 if data.len() > self.config.max_result_size {
298 self.metrics.record_size_exceeded();
299 return;
300 }
301
302 let result = CachedResult {
304 data,
305 row_count,
306 cached_at: Instant::now(),
307 ttl,
308 tables: normalized.tables.clone(),
309 execution_time,
310 };
311
312 if self.config.l1.enabled {
314 if let Some(conn_id) = context.connection_id {
315 let l1 = self.get_l1_cache(conn_id);
316 l1.put(query.to_string(), result.clone());
317 }
318 }
319
320 if let Some(ref l2) = self.l2_cache {
322 let cache_key = CacheKey::new(&normalized, context);
323 l2.put(cache_key.clone(), result.clone()).await;
324
325 for table in &normalized.tables {
327 self.invalidator.register(&cache_key, table);
328 }
329 }
330
331 if hints.semantic_cache {
333 if let Some(ref l3) = self.l3_cache {
334 l3.put(query, context, result).await;
335 }
336 }
337
338 self.metrics.record_put();
339 }
340
341 pub async fn invalidate_query(&self, sql: &str) {
345 let normalized = self.normalizer.normalize(sql);
346 if !normalized.tables.is_empty() {
347 self.invalidate_tables(&normalized.tables).await;
348 }
349 }
350
351 pub async fn invalidate_tables(&self, tables: &[String]) {
353 for table in tables {
354 let keys = self.invalidator.get_keys_for_table(table);
355
356 if let Some(ref l2) = self.l2_cache {
358 for key in &keys {
359 l2.remove(key).await;
360 }
361 }
362
363 self.invalidator.invalidate_table(table);
364 }
365
366 self.metrics.record_invalidation(tables.len());
370 }
371
372 pub async fn clear(&self, levels: &[CacheLevel]) {
374 for level in levels {
375 match level {
376 CacheLevel::L1Hot => {
377 self.l1_caches.clear();
378 }
379 CacheLevel::L2Warm => {
380 if let Some(ref l2) = self.l2_cache {
381 l2.clear().await;
382 }
383 }
384 CacheLevel::L3Semantic => {
385 if let Some(ref l3) = self.l3_cache {
386 l3.clear().await;
387 }
388 }
389 }
390 }
391
392 self.metrics.record_clear();
393 }
394
395 pub fn stats(&self) -> CacheStatsSnapshot {
397 self.metrics.snapshot()
398 }
399
400 pub fn config(&self) -> &CacheConfig {
402 &self.config
403 }
404
405 pub fn invalidator(&self) -> Arc<InvalidationManager> {
407 self.invalidator.clone()
408 }
409
410 fn get_table_ttl(&self, tables: &[String]) -> Duration {
412 let mut min_ttl = self.config.default_ttl;
414
415 for table in tables {
416 if let Some(table_config) = self.config.table_configs.get(table) {
417 if table_config.ttl < min_ttl {
418 min_ttl = table_config.ttl;
419 }
420 }
421 }
422
423 min_ttl
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_cache_context_default() {
433 let ctx = CacheContext::default();
434 assert_eq!(ctx.database, "default");
435 assert!(ctx.user.is_none());
436 assert!(ctx.branch.is_none());
437 assert!(ctx.connection_id.is_none());
438 }
439
440 #[test]
441 fn test_cache_level_display() {
442 assert_eq!(format!("{}", CacheLevel::L1Hot), "L1");
443 assert_eq!(format!("{}", CacheLevel::L2Warm), "L2");
444 assert_eq!(format!("{}", CacheLevel::L3Semantic), "L3");
445 }
446
447 #[tokio::test]
448 async fn test_query_cache_creation() {
449 let config = CacheConfig::default();
450 let cache = QueryCache::new(config);
451
452 assert!(cache.config.l1.enabled);
453 assert!(cache.config.l2.enabled);
454 }
455
456 #[tokio::test]
457 async fn test_l1_cache_per_connection() {
458 let config = CacheConfig::default();
459 let cache = QueryCache::new(config);
460
461 let l1_a = cache.get_l1_cache(1);
462 let l1_b = cache.get_l1_cache(2);
463 let l1_a2 = cache.get_l1_cache(1);
464
465 assert!(Arc::ptr_eq(&l1_a, &l1_a2));
467 assert!(!Arc::ptr_eq(&l1_a, &l1_b));
469 }
470
471 #[tokio::test]
472 async fn test_cache_miss() {
473 let config = CacheConfig::default();
474 let cache = QueryCache::new(config);
475 let context = CacheContext::default();
476
477 let result = cache.get("SELECT * FROM users", &context).await;
478 assert!(matches!(result, CacheLookup::Miss));
479 }
480}