1use parking_lot::RwLock;
42use std::collections::HashMap;
43use std::hash::{Hash, Hasher};
44use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
45use std::time::Instant;
46
47use super::context::TenantId;
48
49#[derive(Debug, Clone)]
51pub enum CacheMode {
52 Global {
55 max_statements: usize,
57 },
58
59 PerTenant {
62 max_tenants: usize,
64 statements_per_tenant: usize,
66 },
67
68 Disabled,
70}
71
72impl Default for CacheMode {
73 fn default() -> Self {
74 Self::Global {
75 max_statements: 1000,
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Hash)]
82pub struct StatementKey {
83 pub name: String,
85 pub sql: String,
87}
88
89impl StatementKey {
90 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
92 Self {
93 name: name.into(),
94 sql: sql.into(),
95 }
96 }
97
98 pub fn from_sql(sql: impl Into<String>) -> Self {
100 let sql = sql.into();
101 let name = format!("stmt_{:x}", hash_sql(&sql));
102 Self { name, sql }
103 }
104}
105
106fn hash_sql(sql: &str) -> u64 {
108 use std::collections::hash_map::DefaultHasher;
109 let mut hasher = DefaultHasher::new();
110 sql.hash(&mut hasher);
111 hasher.finish()
112}
113
114#[derive(Debug, Clone)]
116pub struct StatementMeta {
117 pub prepared_at: Instant,
119 pub execution_count: u64,
121 pub last_used: Instant,
123 pub avg_execution_us: f64,
125}
126
127impl StatementMeta {
128 fn new() -> Self {
130 let now = Instant::now();
131 Self {
132 prepared_at: now,
133 execution_count: 0,
134 last_used: now,
135 avg_execution_us: 0.0,
136 }
137 }
138
139 fn record_execution(&mut self, duration_us: f64) {
141 self.execution_count += 1;
142 self.last_used = Instant::now();
143
144 let n = self.execution_count as f64;
146 self.avg_execution_us = self.avg_execution_us * (n - 1.0) / n + duration_us / n;
147 }
148}
149
150struct CacheEntry<S> {
152 statement: S,
154 meta: StatementMeta,
156}
157
158impl<S> CacheEntry<S> {
159 fn new(statement: S) -> Self {
160 Self {
161 statement,
162 meta: StatementMeta::new(),
163 }
164 }
165}
166
167#[derive(Debug, Clone, Default)]
169pub struct CacheStats {
170 pub hits: u64,
172 pub misses: u64,
174 pub statements_prepared: u64,
176 pub statements_evicted: u64,
178 pub size: usize,
180 pub time_saved_ms: u64,
182}
183
184impl CacheStats {
185 pub fn hit_rate(&self) -> f64 {
187 let total = self.hits + self.misses;
188 if total == 0 {
189 0.0
190 } else {
191 self.hits as f64 / total as f64
192 }
193 }
194}
195
196pub struct AtomicCacheStats {
198 hits: AtomicU64,
199 misses: AtomicU64,
200 statements_prepared: AtomicU64,
201 statements_evicted: AtomicU64,
202 size: AtomicUsize,
203 time_saved_ms: AtomicU64,
204}
205
206impl Default for AtomicCacheStats {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl AtomicCacheStats {
213 pub fn new() -> Self {
215 Self {
216 hits: AtomicU64::new(0),
217 misses: AtomicU64::new(0),
218 statements_prepared: AtomicU64::new(0),
219 statements_evicted: AtomicU64::new(0),
220 size: AtomicUsize::new(0),
221 time_saved_ms: AtomicU64::new(0),
222 }
223 }
224
225 #[inline]
226 pub fn record_hit(&self) {
227 self.hits.fetch_add(1, Ordering::Relaxed);
228 }
229
230 #[inline]
231 pub fn record_miss(&self) {
232 self.misses.fetch_add(1, Ordering::Relaxed);
233 }
234
235 #[inline]
236 pub fn record_prepare(&self) {
237 self.statements_prepared.fetch_add(1, Ordering::Relaxed);
238 }
239
240 #[inline]
241 pub fn record_eviction(&self) {
242 self.statements_evicted.fetch_add(1, Ordering::Relaxed);
243 }
244
245 #[inline]
246 pub fn set_size(&self, size: usize) {
247 self.size.store(size, Ordering::Relaxed);
248 }
249
250 #[inline]
251 pub fn add_time_saved(&self, ms: u64) {
252 self.time_saved_ms.fetch_add(ms, Ordering::Relaxed);
253 }
254
255 pub fn snapshot(&self) -> CacheStats {
257 CacheStats {
258 hits: self.hits.load(Ordering::Relaxed),
259 misses: self.misses.load(Ordering::Relaxed),
260 statements_prepared: self.statements_prepared.load(Ordering::Relaxed),
261 statements_evicted: self.statements_evicted.load(Ordering::Relaxed),
262 size: self.size.load(Ordering::Relaxed),
263 time_saved_ms: self.time_saved_ms.load(Ordering::Relaxed),
264 }
265 }
266}
267
268pub struct StatementCache<S> {
270 mode: CacheMode,
271 global_cache: RwLock<HashMap<StatementKey, CacheEntry<S>>>,
273 tenant_caches: RwLock<HashMap<String, HashMap<StatementKey, CacheEntry<S>>>>,
275 stats: AtomicCacheStats,
277}
278
279impl<S: Clone> StatementCache<S> {
280 pub fn new(mode: CacheMode) -> Self {
282 let capacity = match &mode {
283 CacheMode::Global { max_statements } => *max_statements,
284 CacheMode::PerTenant { max_tenants, .. } => *max_tenants,
285 CacheMode::Disabled => 0,
286 };
287
288 Self {
289 mode,
290 global_cache: RwLock::new(HashMap::with_capacity(capacity)),
291 tenant_caches: RwLock::new(HashMap::with_capacity(capacity)),
292 stats: AtomicCacheStats::new(),
293 }
294 }
295
296 pub fn global(max_statements: usize) -> Self {
298 Self::new(CacheMode::Global { max_statements })
299 }
300
301 pub fn per_tenant(max_tenants: usize, statements_per_tenant: usize) -> Self {
303 Self::new(CacheMode::PerTenant {
304 max_tenants,
305 statements_per_tenant,
306 })
307 }
308
309 pub fn mode(&self) -> &CacheMode {
311 &self.mode
312 }
313
314 pub fn stats(&self) -> CacheStats {
316 let size = match &self.mode {
317 CacheMode::Global { .. } => self.global_cache.read().len(),
318 CacheMode::PerTenant { .. } => {
319 self.tenant_caches.read().values().map(|c| c.len()).sum()
320 }
321 CacheMode::Disabled => 0,
322 };
323 self.stats.set_size(size);
324 self.stats.snapshot()
325 }
326
327 pub fn get(&self, key: &StatementKey) -> Option<S> {
329 if matches!(self.mode, CacheMode::Disabled) {
330 return None;
331 }
332
333 let cache = self.global_cache.read();
334 if let Some(entry) = cache.get(key) {
335 self.stats.record_hit();
336 self.stats.add_time_saved(1);
338 Some(entry.statement.clone())
339 } else {
340 self.stats.record_miss();
341 None
342 }
343 }
344
345 pub fn get_for_tenant(&self, tenant_id: &TenantId, key: &StatementKey) -> Option<S> {
347 match &self.mode {
348 CacheMode::Disabled => None,
349 CacheMode::Global { .. } => self.get(key),
350 CacheMode::PerTenant { .. } => {
351 let caches = self.tenant_caches.read();
352 if let Some(cache) = caches.get(tenant_id.as_str()) {
353 if let Some(entry) = cache.get(key) {
354 self.stats.record_hit();
355 self.stats.add_time_saved(1);
356 return Some(entry.statement.clone());
357 }
358 }
359 self.stats.record_miss();
360 None
361 }
362 }
363 }
364
365 pub fn insert(&self, key: StatementKey, statement: S) {
367 if matches!(self.mode, CacheMode::Disabled) {
368 return;
369 }
370
371 let max = match &self.mode {
372 CacheMode::Global { max_statements } => *max_statements,
373 _ => return self.insert_for_tenant(&TenantId::new("global"), key, statement),
374 };
375
376 let mut cache = self.global_cache.write();
377
378 if cache.len() >= max && !cache.contains_key(&key) {
380 self.evict_lru(&mut cache);
381 }
382
383 cache.insert(key, CacheEntry::new(statement));
384 self.stats.record_prepare();
385 }
386
387 pub fn insert_for_tenant(&self, tenant_id: &TenantId, key: StatementKey, statement: S) {
389 match &self.mode {
390 CacheMode::Disabled => {}
391 CacheMode::Global { .. } => self.insert(key, statement),
392 CacheMode::PerTenant {
393 max_tenants,
394 statements_per_tenant,
395 } => {
396 let mut caches = self.tenant_caches.write();
397
398 if !caches.contains_key(tenant_id.as_str()) && caches.len() >= *max_tenants {
400 self.evict_lru_tenant(&mut caches);
401 }
402
403 let cache = caches
404 .entry(tenant_id.as_str().to_string())
405 .or_insert_with(|| HashMap::with_capacity(*statements_per_tenant));
406
407 if cache.len() >= *statements_per_tenant && !cache.contains_key(&key) {
409 self.evict_lru(cache);
410 }
411
412 cache.insert(key, CacheEntry::new(statement));
413 self.stats.record_prepare();
414 }
415 }
416 }
417
418 pub fn record_execution(&self, key: &StatementKey, duration_us: f64) {
420 if matches!(self.mode, CacheMode::Disabled) {
421 return;
422 }
423
424 let mut cache = self.global_cache.write();
425 if let Some(entry) = cache.get_mut(key) {
426 entry.meta.record_execution(duration_us);
427 }
428 }
429
430 pub fn record_tenant_execution(
432 &self,
433 tenant_id: &TenantId,
434 key: &StatementKey,
435 duration_us: f64,
436 ) {
437 match &self.mode {
438 CacheMode::Disabled => {}
439 CacheMode::Global { .. } => self.record_execution(key, duration_us),
440 CacheMode::PerTenant { .. } => {
441 let mut caches = self.tenant_caches.write();
442 if let Some(cache) = caches.get_mut(tenant_id.as_str()) {
443 if let Some(entry) = cache.get_mut(key) {
444 entry.meta.record_execution(duration_us);
445 }
446 }
447 }
448 }
449 }
450
451 pub fn invalidate_tenant(&self, tenant_id: &TenantId) {
453 if let CacheMode::PerTenant { .. } = &self.mode {
454 self.tenant_caches.write().remove(tenant_id.as_str());
455 }
456 }
457
458 pub fn invalidate(&self, key: &StatementKey) {
460 self.global_cache.write().remove(key);
461 }
462
463 pub fn clear(&self) {
465 self.global_cache.write().clear();
466 self.tenant_caches.write().clear();
467 }
468
469 fn evict_lru(&self, cache: &mut HashMap<StatementKey, CacheEntry<S>>) {
471 let lru_key = cache
472 .iter()
473 .min_by_key(|(_, e)| e.meta.last_used)
474 .map(|(k, _)| k.clone());
475
476 if let Some(key) = lru_key {
477 cache.remove(&key);
478 self.stats.record_eviction();
479 }
480 }
481
482 fn evict_lru_tenant(&self, caches: &mut HashMap<String, HashMap<StatementKey, CacheEntry<S>>>) {
484 let lru_tenant = caches
485 .iter()
486 .filter_map(|(tenant, cache)| {
487 cache
488 .values()
489 .map(|e| e.meta.last_used)
490 .max()
491 .map(|last| (tenant.clone(), last))
492 })
493 .min_by_key(|(_, last)| *last)
494 .map(|(tenant, _)| tenant);
495
496 if let Some(tenant) = lru_tenant {
497 caches.remove(&tenant);
498 }
499 }
500}
501
502#[derive(Default)]
506pub struct StatementRegistry {
507 statements: RwLock<HashMap<String, StatementInfo>>,
508}
509
510#[derive(Debug, Clone)]
512pub struct StatementInfo {
513 pub name: String,
515 pub sql: String,
517 pub description: Option<String>,
519 pub param_count: usize,
521 pub tenant_scoped: bool,
523}
524
525impl StatementRegistry {
526 pub fn new() -> Self {
528 Self::default()
529 }
530
531 pub fn register(&self, info: StatementInfo) {
533 self.statements.write().insert(info.name.clone(), info);
534 }
535
536 pub fn get(&self, name: &str) -> Option<StatementInfo> {
538 self.statements.read().get(name).cloned()
539 }
540
541 pub fn list(&self) -> Vec<StatementInfo> {
543 self.statements.read().values().cloned().collect()
544 }
545
546 pub fn contains(&self, name: &str) -> bool {
548 self.statements.read().contains_key(name)
549 }
550}
551
552pub struct StatementBuilder {
554 name: String,
555 sql: String,
556 description: Option<String>,
557 param_count: usize,
558 tenant_scoped: bool,
559}
560
561impl StatementBuilder {
562 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
564 Self {
565 name: name.into(),
566 sql: sql.into(),
567 description: None,
568 param_count: 0,
569 tenant_scoped: false,
570 }
571 }
572
573 pub fn description(mut self, desc: impl Into<String>) -> Self {
575 self.description = Some(desc.into());
576 self
577 }
578
579 pub fn params(mut self, count: usize) -> Self {
581 self.param_count = count;
582 self
583 }
584
585 pub fn tenant_scoped(mut self) -> Self {
587 self.tenant_scoped = true;
588 self
589 }
590
591 pub fn build(self) -> StatementInfo {
593 StatementInfo {
594 name: self.name,
595 sql: self.sql,
596 description: self.description,
597 param_count: self.param_count,
598 tenant_scoped: self.tenant_scoped,
599 }
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_statement_key() {
609 let key1 = StatementKey::new("find_user", "SELECT * FROM users WHERE id = $1");
610 let key2 = StatementKey::from_sql("SELECT * FROM users WHERE id = $1");
611
612 assert_eq!(key1.sql, key2.sql);
613 assert!(key2.name.starts_with("stmt_"));
614 }
615
616 #[test]
617 fn test_global_cache() {
618 let cache: StatementCache<String> = StatementCache::global(100);
619
620 let key = StatementKey::new("test", "SELECT 1");
621 assert!(cache.get(&key).is_none());
622
623 cache.insert(key.clone(), "prepared_handle".to_string());
624 assert_eq!(cache.get(&key), Some("prepared_handle".to_string()));
625 }
626
627 #[test]
628 fn test_per_tenant_cache() {
629 let cache: StatementCache<String> = StatementCache::per_tenant(10, 50);
630
631 let tenant1 = TenantId::new("tenant-1");
632 let tenant2 = TenantId::new("tenant-2");
633 let key = StatementKey::new("test", "SELECT 1");
634
635 cache.insert_for_tenant(&tenant1, key.clone(), "handle_1".to_string());
636 cache.insert_for_tenant(&tenant2, key.clone(), "handle_2".to_string());
637
638 assert_eq!(
639 cache.get_for_tenant(&tenant1, &key),
640 Some("handle_1".to_string())
641 );
642 assert_eq!(
643 cache.get_for_tenant(&tenant2, &key),
644 Some("handle_2".to_string())
645 );
646 }
647
648 #[test]
649 fn test_cache_eviction() {
650 let cache: StatementCache<i32> = StatementCache::global(2);
651
652 for i in 0..3 {
653 let key = StatementKey::new(format!("stmt_{}", i), format!("SELECT {}", i));
654 cache.insert(key, i);
655 }
656
657 let stats = cache.stats();
659 assert_eq!(stats.statements_evicted, 1);
660 }
661
662 #[test]
663 fn test_cache_stats() {
664 let cache: StatementCache<String> = StatementCache::global(100);
665
666 let key = StatementKey::new("test", "SELECT 1");
667
668 cache.get(&key);
670 assert_eq!(cache.stats().misses, 1);
671
672 cache.insert(key.clone(), "handle".to_string());
674
675 cache.get(&key);
677 assert_eq!(cache.stats().hits, 1);
678 }
679
680 #[test]
681 fn test_statement_registry() {
682 let registry = StatementRegistry::new();
683
684 registry.register(
685 StatementBuilder::new("find_user", "SELECT * FROM users WHERE id = $1")
686 .description("Find user by ID")
687 .params(1)
688 .build(),
689 );
690
691 assert!(registry.contains("find_user"));
692 let info = registry.get("find_user").unwrap();
693 assert_eq!(info.param_count, 1);
694 }
695}
696