prax_query/tenant/
prepared.rs

1//! Tenant-aware prepared statement caching.
2//!
3//! This module provides efficient prepared statement management for multi-tenant
4//! applications. It supports:
5//!
6//! - **Global statement cache** for RLS-based isolation (same statements work for all tenants)
7//! - **Per-tenant statement cache** for schema-based isolation
8//! - **Automatic statement invalidation** on schema changes
9//! - **LRU eviction** with configurable limits
10//!
11//! # Performance Benefits
12//!
13//! Prepared statements provide significant performance benefits:
14//! - **Query planning cached** - Database doesn't re-plan the query
15//! - **Parameter binding optimized** - Type checking done once
16//! - **Network efficiency** - Only parameters sent, not full SQL
17//!
18//! With RLS, the same prepared statement works for all tenants because the
19//! tenant filtering happens via session variables, not query changes.
20//!
21//! # Example
22//!
23//! ```rust,ignore
24//! use prax_query::tenant::prepared::{StatementCache, CacheMode};
25//!
26//! // For RLS-based tenancy (shared statements)
27//! let cache = StatementCache::new(CacheMode::Global { max_statements: 1000 });
28//!
29//! // For schema-based tenancy (per-tenant statements)
30//! let cache = StatementCache::new(CacheMode::PerTenant {
31//!     max_tenants: 100,
32//!     statements_per_tenant: 100,
33//! });
34//!
35//! // Get or prepare a statement
36//! let stmt = cache.get_or_prepare("users", "SELECT * FROM users WHERE id = $1", || {
37//!     conn.prepare("SELECT * FROM users WHERE id = $1").await
38//! }).await?;
39//! ```
40
41use parking_lot::RwLock;
42use std::collections::HashMap;
43use std::hash::{Hash, Hasher};
44use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
45use std::time::Instant;
46
47use super::context::TenantId;
48
49/// Cache mode for prepared statements.
50#[derive(Debug, Clone)]
51pub enum CacheMode {
52    /// Single global cache (for RLS-based isolation).
53    /// All tenants share the same prepared statements.
54    Global {
55        /// Maximum number of statements to cache.
56        max_statements: usize,
57    },
58
59    /// Per-tenant statement caches (for schema-based isolation).
60    /// Each tenant has their own statements because schemas differ.
61    PerTenant {
62        /// Maximum number of tenants to track.
63        max_tenants: usize,
64        /// Maximum statements per tenant.
65        statements_per_tenant: usize,
66    },
67
68    /// Disabled - don't cache statements.
69    Disabled,
70}
71
72impl Default for CacheMode {
73    fn default() -> Self {
74        Self::Global {
75            max_statements: 1000,
76        }
77    }
78}
79
80/// A unique key for a prepared statement.
81#[derive(Debug, Clone, PartialEq, Eq, Hash)]
82pub struct StatementKey {
83    /// Logical name for the statement (e.g., "find_user_by_id").
84    pub name: String,
85    /// SQL query text.
86    pub sql: String,
87}
88
89impl StatementKey {
90    /// Create a new statement key.
91    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
92        Self {
93            name: name.into(),
94            sql: sql.into(),
95        }
96    }
97
98    /// Create from SQL only (name derived from hash).
99    pub fn from_sql(sql: impl Into<String>) -> Self {
100        let sql = sql.into();
101        let name = format!("stmt_{:x}", hash_sql(&sql));
102        Self { name, sql }
103    }
104}
105
106/// Hash SQL for statement naming.
107fn hash_sql(sql: &str) -> u64 {
108    use std::collections::hash_map::DefaultHasher;
109    let mut hasher = DefaultHasher::new();
110    sql.hash(&mut hasher);
111    hasher.finish()
112}
113
114/// Metadata about a cached statement.
115#[derive(Debug, Clone)]
116pub struct StatementMeta {
117    /// When the statement was prepared.
118    pub prepared_at: Instant,
119    /// Number of times the statement was executed.
120    pub execution_count: u64,
121    /// Last execution time.
122    pub last_used: Instant,
123    /// Average execution time in microseconds.
124    pub avg_execution_us: f64,
125}
126
127impl StatementMeta {
128    /// Create new metadata.
129    fn new() -> Self {
130        let now = Instant::now();
131        Self {
132            prepared_at: now,
133            execution_count: 0,
134            last_used: now,
135            avg_execution_us: 0.0,
136        }
137    }
138
139    /// Record an execution.
140    fn record_execution(&mut self, duration_us: f64) {
141        self.execution_count += 1;
142        self.last_used = Instant::now();
143
144        // Running average
145        let n = self.execution_count as f64;
146        self.avg_execution_us = self.avg_execution_us * (n - 1.0) / n + duration_us / n;
147    }
148}
149
150/// A cached statement entry.
151struct CacheEntry<S> {
152    /// The prepared statement handle.
153    statement: S,
154    /// Metadata about the statement.
155    meta: StatementMeta,
156}
157
158impl<S> CacheEntry<S> {
159    fn new(statement: S) -> Self {
160        Self {
161            statement,
162            meta: StatementMeta::new(),
163        }
164    }
165}
166
167/// Cache statistics.
168#[derive(Debug, Clone, Default)]
169pub struct CacheStats {
170    /// Total cache hits.
171    pub hits: u64,
172    /// Total cache misses.
173    pub misses: u64,
174    /// Total statements prepared.
175    pub statements_prepared: u64,
176    /// Total statements evicted.
177    pub statements_evicted: u64,
178    /// Current cache size.
179    pub size: usize,
180    /// Total execution time saved (estimated, in ms).
181    pub time_saved_ms: u64,
182}
183
184impl CacheStats {
185    /// Calculate hit rate.
186    pub fn hit_rate(&self) -> f64 {
187        let total = self.hits + self.misses;
188        if total == 0 {
189            0.0
190        } else {
191            self.hits as f64 / total as f64
192        }
193    }
194}
195
196/// Thread-safe cache statistics.
197pub struct AtomicCacheStats {
198    hits: AtomicU64,
199    misses: AtomicU64,
200    statements_prepared: AtomicU64,
201    statements_evicted: AtomicU64,
202    size: AtomicUsize,
203    time_saved_ms: AtomicU64,
204}
205
206impl Default for AtomicCacheStats {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl AtomicCacheStats {
213    /// Create new stats.
214    pub fn new() -> Self {
215        Self {
216            hits: AtomicU64::new(0),
217            misses: AtomicU64::new(0),
218            statements_prepared: AtomicU64::new(0),
219            statements_evicted: AtomicU64::new(0),
220            size: AtomicUsize::new(0),
221            time_saved_ms: AtomicU64::new(0),
222        }
223    }
224
225    #[inline]
226    pub fn record_hit(&self) {
227        self.hits.fetch_add(1, Ordering::Relaxed);
228    }
229
230    #[inline]
231    pub fn record_miss(&self) {
232        self.misses.fetch_add(1, Ordering::Relaxed);
233    }
234
235    #[inline]
236    pub fn record_prepare(&self) {
237        self.statements_prepared.fetch_add(1, Ordering::Relaxed);
238    }
239
240    #[inline]
241    pub fn record_eviction(&self) {
242        self.statements_evicted.fetch_add(1, Ordering::Relaxed);
243    }
244
245    #[inline]
246    pub fn set_size(&self, size: usize) {
247        self.size.store(size, Ordering::Relaxed);
248    }
249
250    #[inline]
251    pub fn add_time_saved(&self, ms: u64) {
252        self.time_saved_ms.fetch_add(ms, Ordering::Relaxed);
253    }
254
255    /// Get a snapshot.
256    pub fn snapshot(&self) -> CacheStats {
257        CacheStats {
258            hits: self.hits.load(Ordering::Relaxed),
259            misses: self.misses.load(Ordering::Relaxed),
260            statements_prepared: self.statements_prepared.load(Ordering::Relaxed),
261            statements_evicted: self.statements_evicted.load(Ordering::Relaxed),
262            size: self.size.load(Ordering::Relaxed),
263            time_saved_ms: self.time_saved_ms.load(Ordering::Relaxed),
264        }
265    }
266}
267
268/// Generic statement cache that works with any statement type.
269pub struct StatementCache<S> {
270    mode: CacheMode,
271    /// Global cache (for CacheMode::Global).
272    global_cache: RwLock<HashMap<StatementKey, CacheEntry<S>>>,
273    /// Per-tenant caches (for CacheMode::PerTenant).
274    tenant_caches: RwLock<HashMap<String, HashMap<StatementKey, CacheEntry<S>>>>,
275    /// Statistics.
276    stats: AtomicCacheStats,
277}
278
279impl<S: Clone> StatementCache<S> {
280    /// Create a new statement cache.
281    pub fn new(mode: CacheMode) -> Self {
282        let capacity = match &mode {
283            CacheMode::Global { max_statements } => *max_statements,
284            CacheMode::PerTenant { max_tenants, .. } => *max_tenants,
285            CacheMode::Disabled => 0,
286        };
287
288        Self {
289            mode,
290            global_cache: RwLock::new(HashMap::with_capacity(capacity)),
291            tenant_caches: RwLock::new(HashMap::with_capacity(capacity)),
292            stats: AtomicCacheStats::new(),
293        }
294    }
295
296    /// Create a global cache with the given max size.
297    pub fn global(max_statements: usize) -> Self {
298        Self::new(CacheMode::Global { max_statements })
299    }
300
301    /// Create a per-tenant cache.
302    pub fn per_tenant(max_tenants: usize, statements_per_tenant: usize) -> Self {
303        Self::new(CacheMode::PerTenant {
304            max_tenants,
305            statements_per_tenant,
306        })
307    }
308
309    /// Get the cache mode.
310    pub fn mode(&self) -> &CacheMode {
311        &self.mode
312    }
313
314    /// Get cache statistics.
315    pub fn stats(&self) -> CacheStats {
316        let size = match &self.mode {
317            CacheMode::Global { .. } => self.global_cache.read().len(),
318            CacheMode::PerTenant { .. } => {
319                self.tenant_caches.read().values().map(|c| c.len()).sum()
320            }
321            CacheMode::Disabled => 0,
322        };
323        self.stats.set_size(size);
324        self.stats.snapshot()
325    }
326
327    /// Get a cached statement (global mode).
328    pub fn get(&self, key: &StatementKey) -> Option<S> {
329        if matches!(self.mode, CacheMode::Disabled) {
330            return None;
331        }
332
333        let cache = self.global_cache.read();
334        if let Some(entry) = cache.get(key) {
335            self.stats.record_hit();
336            // Estimate 1ms saved per cache hit (prepare time avoided)
337            self.stats.add_time_saved(1);
338            Some(entry.statement.clone())
339        } else {
340            self.stats.record_miss();
341            None
342        }
343    }
344
345    /// Get a cached statement for a tenant.
346    pub fn get_for_tenant(&self, tenant_id: &TenantId, key: &StatementKey) -> Option<S> {
347        match &self.mode {
348            CacheMode::Disabled => None,
349            CacheMode::Global { .. } => self.get(key),
350            CacheMode::PerTenant { .. } => {
351                let caches = self.tenant_caches.read();
352                if let Some(cache) = caches.get(tenant_id.as_str()) {
353                    if let Some(entry) = cache.get(key) {
354                        self.stats.record_hit();
355                        self.stats.add_time_saved(1);
356                        return Some(entry.statement.clone());
357                    }
358                }
359                self.stats.record_miss();
360                None
361            }
362        }
363    }
364
365    /// Insert a statement into the global cache.
366    pub fn insert(&self, key: StatementKey, statement: S) {
367        if matches!(self.mode, CacheMode::Disabled) {
368            return;
369        }
370
371        let max = match &self.mode {
372            CacheMode::Global { max_statements } => *max_statements,
373            _ => return self.insert_for_tenant(&TenantId::new("global"), key, statement),
374        };
375
376        let mut cache = self.global_cache.write();
377
378        // Evict if necessary
379        if cache.len() >= max && !cache.contains_key(&key) {
380            self.evict_lru(&mut cache);
381        }
382
383        cache.insert(key, CacheEntry::new(statement));
384        self.stats.record_prepare();
385    }
386
387    /// Insert a statement for a specific tenant.
388    pub fn insert_for_tenant(&self, tenant_id: &TenantId, key: StatementKey, statement: S) {
389        match &self.mode {
390            CacheMode::Disabled => {}
391            CacheMode::Global { .. } => self.insert(key, statement),
392            CacheMode::PerTenant {
393                max_tenants,
394                statements_per_tenant,
395            } => {
396                let mut caches = self.tenant_caches.write();
397
398                // Evict tenant if too many
399                if !caches.contains_key(tenant_id.as_str()) && caches.len() >= *max_tenants {
400                    self.evict_lru_tenant(&mut caches);
401                }
402
403                let cache = caches
404                    .entry(tenant_id.as_str().to_string())
405                    .or_insert_with(|| HashMap::with_capacity(*statements_per_tenant));
406
407                // Evict statement if too many
408                if cache.len() >= *statements_per_tenant && !cache.contains_key(&key) {
409                    self.evict_lru(cache);
410                }
411
412                cache.insert(key, CacheEntry::new(statement));
413                self.stats.record_prepare();
414            }
415        }
416    }
417
418    /// Record an execution for statistics.
419    pub fn record_execution(&self, key: &StatementKey, duration_us: f64) {
420        if matches!(self.mode, CacheMode::Disabled) {
421            return;
422        }
423
424        let mut cache = self.global_cache.write();
425        if let Some(entry) = cache.get_mut(key) {
426            entry.meta.record_execution(duration_us);
427        }
428    }
429
430    /// Record an execution for a tenant.
431    pub fn record_tenant_execution(
432        &self,
433        tenant_id: &TenantId,
434        key: &StatementKey,
435        duration_us: f64,
436    ) {
437        match &self.mode {
438            CacheMode::Disabled => {}
439            CacheMode::Global { .. } => self.record_execution(key, duration_us),
440            CacheMode::PerTenant { .. } => {
441                let mut caches = self.tenant_caches.write();
442                if let Some(cache) = caches.get_mut(tenant_id.as_str()) {
443                    if let Some(entry) = cache.get_mut(key) {
444                        entry.meta.record_execution(duration_us);
445                    }
446                }
447            }
448        }
449    }
450
451    /// Invalidate all statements for a tenant.
452    pub fn invalidate_tenant(&self, tenant_id: &TenantId) {
453        if let CacheMode::PerTenant { .. } = &self.mode {
454            self.tenant_caches.write().remove(tenant_id.as_str());
455        }
456    }
457
458    /// Invalidate a specific statement globally.
459    pub fn invalidate(&self, key: &StatementKey) {
460        self.global_cache.write().remove(key);
461    }
462
463    /// Clear all cached statements.
464    pub fn clear(&self) {
465        self.global_cache.write().clear();
466        self.tenant_caches.write().clear();
467    }
468
469    /// Evict LRU statement from a cache.
470    fn evict_lru(&self, cache: &mut HashMap<StatementKey, CacheEntry<S>>) {
471        let lru_key = cache
472            .iter()
473            .min_by_key(|(_, e)| e.meta.last_used)
474            .map(|(k, _)| k.clone());
475
476        if let Some(key) = lru_key {
477            cache.remove(&key);
478            self.stats.record_eviction();
479        }
480    }
481
482    /// Evict LRU tenant cache.
483    fn evict_lru_tenant(&self, caches: &mut HashMap<String, HashMap<StatementKey, CacheEntry<S>>>) {
484        let lru_tenant = caches
485            .iter()
486            .filter_map(|(tenant, cache)| {
487                cache
488                    .values()
489                    .map(|e| e.meta.last_used)
490                    .max()
491                    .map(|last| (tenant.clone(), last))
492            })
493            .min_by_key(|(_, last)| *last)
494            .map(|(tenant, _)| tenant);
495
496        if let Some(tenant) = lru_tenant {
497            caches.remove(&tenant);
498        }
499    }
500}
501
502/// A prepared statement registry that tracks statements by name.
503///
504/// This is useful for debugging and monitoring which statements are cached.
505#[derive(Default)]
506pub struct StatementRegistry {
507    statements: RwLock<HashMap<String, StatementInfo>>,
508}
509
510/// Information about a registered statement.
511#[derive(Debug, Clone)]
512pub struct StatementInfo {
513    /// Statement name.
514    pub name: String,
515    /// SQL query.
516    pub sql: String,
517    /// Description.
518    pub description: Option<String>,
519    /// Expected parameter count.
520    pub param_count: usize,
521    /// Whether this is tenant-scoped.
522    pub tenant_scoped: bool,
523}
524
525impl StatementRegistry {
526    /// Create a new registry.
527    pub fn new() -> Self {
528        Self::default()
529    }
530
531    /// Register a statement.
532    pub fn register(&self, info: StatementInfo) {
533        self.statements.write().insert(info.name.clone(), info);
534    }
535
536    /// Get a statement by name.
537    pub fn get(&self, name: &str) -> Option<StatementInfo> {
538        self.statements.read().get(name).cloned()
539    }
540
541    /// List all registered statements.
542    pub fn list(&self) -> Vec<StatementInfo> {
543        self.statements.read().values().cloned().collect()
544    }
545
546    /// Check if a statement is registered.
547    pub fn contains(&self, name: &str) -> bool {
548        self.statements.read().contains_key(name)
549    }
550}
551
552/// Builder for statement registration.
553pub struct StatementBuilder {
554    name: String,
555    sql: String,
556    description: Option<String>,
557    param_count: usize,
558    tenant_scoped: bool,
559}
560
561impl StatementBuilder {
562    /// Create a new builder.
563    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
564        Self {
565            name: name.into(),
566            sql: sql.into(),
567            description: None,
568            param_count: 0,
569            tenant_scoped: false,
570        }
571    }
572
573    /// Set description.
574    pub fn description(mut self, desc: impl Into<String>) -> Self {
575        self.description = Some(desc.into());
576        self
577    }
578
579    /// Set parameter count.
580    pub fn params(mut self, count: usize) -> Self {
581        self.param_count = count;
582        self
583    }
584
585    /// Mark as tenant-scoped.
586    pub fn tenant_scoped(mut self) -> Self {
587        self.tenant_scoped = true;
588        self
589    }
590
591    /// Build the statement info.
592    pub fn build(self) -> StatementInfo {
593        StatementInfo {
594            name: self.name,
595            sql: self.sql,
596            description: self.description,
597            param_count: self.param_count,
598            tenant_scoped: self.tenant_scoped,
599        }
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_statement_key() {
609        let key1 = StatementKey::new("find_user", "SELECT * FROM users WHERE id = $1");
610        let key2 = StatementKey::from_sql("SELECT * FROM users WHERE id = $1");
611
612        assert_eq!(key1.sql, key2.sql);
613        assert!(key2.name.starts_with("stmt_"));
614    }
615
616    #[test]
617    fn test_global_cache() {
618        let cache: StatementCache<String> = StatementCache::global(100);
619
620        let key = StatementKey::new("test", "SELECT 1");
621        assert!(cache.get(&key).is_none());
622
623        cache.insert(key.clone(), "prepared_handle".to_string());
624        assert_eq!(cache.get(&key), Some("prepared_handle".to_string()));
625    }
626
627    #[test]
628    fn test_per_tenant_cache() {
629        let cache: StatementCache<String> = StatementCache::per_tenant(10, 50);
630
631        let tenant1 = TenantId::new("tenant-1");
632        let tenant2 = TenantId::new("tenant-2");
633        let key = StatementKey::new("test", "SELECT 1");
634
635        cache.insert_for_tenant(&tenant1, key.clone(), "handle_1".to_string());
636        cache.insert_for_tenant(&tenant2, key.clone(), "handle_2".to_string());
637
638        assert_eq!(
639            cache.get_for_tenant(&tenant1, &key),
640            Some("handle_1".to_string())
641        );
642        assert_eq!(
643            cache.get_for_tenant(&tenant2, &key),
644            Some("handle_2".to_string())
645        );
646    }
647
648    #[test]
649    fn test_cache_eviction() {
650        let cache: StatementCache<i32> = StatementCache::global(2);
651
652        for i in 0..3 {
653            let key = StatementKey::new(format!("stmt_{}", i), format!("SELECT {}", i));
654            cache.insert(key, i);
655        }
656
657        // Should have evicted one
658        let stats = cache.stats();
659        assert_eq!(stats.statements_evicted, 1);
660    }
661
662    #[test]
663    fn test_cache_stats() {
664        let cache: StatementCache<String> = StatementCache::global(100);
665
666        let key = StatementKey::new("test", "SELECT 1");
667
668        // Miss
669        cache.get(&key);
670        assert_eq!(cache.stats().misses, 1);
671
672        // Insert
673        cache.insert(key.clone(), "handle".to_string());
674
675        // Hit
676        cache.get(&key);
677        assert_eq!(cache.stats().hits, 1);
678    }
679
680    #[test]
681    fn test_statement_registry() {
682        let registry = StatementRegistry::new();
683
684        registry.register(
685            StatementBuilder::new("find_user", "SELECT * FROM users WHERE id = $1")
686                .description("Find user by ID")
687                .params(1)
688                .build(),
689        );
690
691        assert!(registry.contains("find_user"));
692        let info = registry.get("find_user").unwrap();
693        assert_eq!(info.param_count, 1);
694    }
695}
696