prax_query/tenant/
cache.rs

1//! High-performance tenant caching with TTL and background refresh.
2//!
3//! This module provides an efficient caching layer for tenant lookups with:
4//!
5//! - **TTL-based expiration** with configurable durations
6//! - **LRU eviction** when cache is full
7//! - **Background refresh** to avoid cache stampedes
8//! - **Negative caching** to prevent repeated lookups of invalid tenants
9//! - **Metrics** for monitoring cache performance
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use prax_query::tenant::cache::{TenantCache, CacheConfig};
15//!
16//! let cache = TenantCache::new(CacheConfig {
17//!     max_entries: 10_000,
18//!     ttl: Duration::from_secs(300),
19//!     negative_ttl: Duration::from_secs(60),
20//!     ..Default::default()
21//! });
22//!
23//! // Get or fetch tenant
24//! let ctx = cache.get_or_fetch("tenant-123", || async {
25//!     // Fetch from database
26//!     db.query("SELECT * FROM tenants WHERE id = $1", &[&"tenant-123"]).await
27//! }).await?;
28//! ```
29
30use parking_lot::RwLock;
31use std::collections::HashMap;
32use std::future::Future;
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::time::{Duration, Instant};
35
36use super::context::{TenantContext, TenantId};
37
38/// Configuration for the tenant cache.
39#[derive(Debug, Clone)]
40pub struct CacheConfig {
41    /// Maximum number of entries in the cache.
42    pub max_entries: usize,
43    /// Time-to-live for cached entries.
44    pub ttl: Duration,
45    /// Time-to-live for negative cache entries (tenant not found).
46    pub negative_ttl: Duration,
47    /// Enable background refresh before TTL expires.
48    pub background_refresh: bool,
49    /// How long before TTL to start background refresh (e.g., 0.8 = refresh at 80% of TTL).
50    pub refresh_threshold: f64,
51    /// Enable metrics collection.
52    pub enable_metrics: bool,
53}
54
55impl Default for CacheConfig {
56    fn default() -> Self {
57        Self {
58            max_entries: 10_000,
59            ttl: Duration::from_secs(300),           // 5 minutes
60            negative_ttl: Duration::from_secs(60),   // 1 minute
61            background_refresh: true,
62            refresh_threshold: 0.8,
63            enable_metrics: true,
64        }
65    }
66}
67
68impl CacheConfig {
69    /// Create a new config with the given max entries.
70    pub fn new(max_entries: usize) -> Self {
71        Self {
72            max_entries,
73            ..Default::default()
74        }
75    }
76
77    /// Set the TTL.
78    pub fn with_ttl(mut self, ttl: Duration) -> Self {
79        self.ttl = ttl;
80        self
81    }
82
83    /// Set the negative TTL.
84    pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
85        self.negative_ttl = ttl;
86        self
87    }
88
89    /// Disable background refresh.
90    pub fn without_background_refresh(mut self) -> Self {
91        self.background_refresh = false;
92        self
93    }
94
95    /// Set the refresh threshold.
96    pub fn with_refresh_threshold(mut self, threshold: f64) -> Self {
97        self.refresh_threshold = threshold.clamp(0.1, 0.99);
98        self
99    }
100
101    /// Disable metrics.
102    pub fn without_metrics(mut self) -> Self {
103        self.enable_metrics = false;
104        self
105    }
106}
107
108/// A cached tenant entry.
109#[derive(Debug, Clone)]
110struct CacheEntry {
111    /// The cached context (None = negative cache).
112    context: Option<TenantContext>,
113    /// When this entry was created.
114    created_at: Instant,
115    /// When this entry expires.
116    expires_at: Instant,
117    /// Whether a background refresh is in progress.
118    refreshing: bool,
119    /// Access count for LRU tracking.
120    access_count: u64,
121}
122
123impl CacheEntry {
124    /// Create a positive cache entry.
125    fn positive(context: TenantContext, ttl: Duration) -> Self {
126        let now = Instant::now();
127        Self {
128            context: Some(context),
129            created_at: now,
130            expires_at: now + ttl,
131            refreshing: false,
132            access_count: 1,
133        }
134    }
135
136    /// Create a negative cache entry.
137    fn negative(ttl: Duration) -> Self {
138        let now = Instant::now();
139        Self {
140            context: None,
141            created_at: now,
142            expires_at: now + ttl,
143            refreshing: false,
144            access_count: 1,
145        }
146    }
147
148    /// Check if the entry is expired.
149    fn is_expired(&self) -> bool {
150        Instant::now() >= self.expires_at
151    }
152
153    /// Check if the entry should be refreshed.
154    fn should_refresh(&self, threshold: f64) -> bool {
155        if self.refreshing {
156            return false;
157        }
158
159        let ttl = self.expires_at.duration_since(self.created_at);
160        let elapsed = self.created_at.elapsed();
161        let threshold_duration = ttl.mul_f64(threshold);
162
163        elapsed >= threshold_duration
164    }
165
166    /// Get remaining TTL.
167    fn remaining_ttl(&self) -> Duration {
168        self.expires_at.saturating_duration_since(Instant::now())
169    }
170}
171
172/// Cache metrics.
173#[derive(Debug, Clone, Default)]
174pub struct CacheMetrics {
175    /// Total cache hits.
176    pub hits: u64,
177    /// Total cache misses.
178    pub misses: u64,
179    /// Negative cache hits.
180    pub negative_hits: u64,
181    /// Evictions due to capacity.
182    pub evictions: u64,
183    /// Evictions due to TTL expiration.
184    pub expirations: u64,
185    /// Background refreshes triggered.
186    pub background_refreshes: u64,
187    /// Current cache size.
188    pub size: usize,
189}
190
191impl CacheMetrics {
192    /// Calculate hit rate.
193    pub fn hit_rate(&self) -> f64 {
194        let total = self.hits + self.misses;
195        if total == 0 {
196            0.0
197        } else {
198            self.hits as f64 / total as f64
199        }
200    }
201}
202
203/// Thread-safe atomic metrics.
204pub struct AtomicCacheMetrics {
205    hits: AtomicU64,
206    misses: AtomicU64,
207    negative_hits: AtomicU64,
208    evictions: AtomicU64,
209    expirations: AtomicU64,
210    background_refreshes: AtomicU64,
211}
212
213impl Default for AtomicCacheMetrics {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219impl AtomicCacheMetrics {
220    /// Create new atomic metrics.
221    pub fn new() -> Self {
222        Self {
223            hits: AtomicU64::new(0),
224            misses: AtomicU64::new(0),
225            negative_hits: AtomicU64::new(0),
226            evictions: AtomicU64::new(0),
227            expirations: AtomicU64::new(0),
228            background_refreshes: AtomicU64::new(0),
229        }
230    }
231
232    /// Record a hit.
233    #[inline]
234    pub fn record_hit(&self) {
235        self.hits.fetch_add(1, Ordering::Relaxed);
236    }
237
238    /// Record a miss.
239    #[inline]
240    pub fn record_miss(&self) {
241        self.misses.fetch_add(1, Ordering::Relaxed);
242    }
243
244    /// Record a negative hit.
245    #[inline]
246    pub fn record_negative_hit(&self) {
247        self.negative_hits.fetch_add(1, Ordering::Relaxed);
248    }
249
250    /// Record an eviction.
251    #[inline]
252    pub fn record_eviction(&self) {
253        self.evictions.fetch_add(1, Ordering::Relaxed);
254    }
255
256    /// Record an expiration.
257    #[inline]
258    pub fn record_expiration(&self) {
259        self.expirations.fetch_add(1, Ordering::Relaxed);
260    }
261
262    /// Record a background refresh.
263    #[inline]
264    pub fn record_background_refresh(&self) {
265        self.background_refreshes.fetch_add(1, Ordering::Relaxed);
266    }
267
268    /// Get a snapshot of the metrics.
269    pub fn snapshot(&self, size: usize) -> CacheMetrics {
270        CacheMetrics {
271            hits: self.hits.load(Ordering::Relaxed),
272            misses: self.misses.load(Ordering::Relaxed),
273            negative_hits: self.negative_hits.load(Ordering::Relaxed),
274            evictions: self.evictions.load(Ordering::Relaxed),
275            expirations: self.expirations.load(Ordering::Relaxed),
276            background_refreshes: self.background_refreshes.load(Ordering::Relaxed),
277            size,
278        }
279    }
280
281    /// Reset all metrics.
282    pub fn reset(&self) {
283        self.hits.store(0, Ordering::Relaxed);
284        self.misses.store(0, Ordering::Relaxed);
285        self.negative_hits.store(0, Ordering::Relaxed);
286        self.evictions.store(0, Ordering::Relaxed);
287        self.expirations.store(0, Ordering::Relaxed);
288        self.background_refreshes.store(0, Ordering::Relaxed);
289    }
290}
291
292/// Result of a cache lookup.
293#[derive(Debug, Clone)]
294pub enum CacheLookup {
295    /// Found valid entry.
296    Hit(TenantContext),
297    /// Found negative entry (tenant doesn't exist).
298    NegativeHit,
299    /// Entry not found or expired.
300    Miss,
301    /// Entry found but should be refreshed.
302    Stale(TenantContext),
303}
304
305/// High-performance tenant cache.
306pub struct TenantCache {
307    config: CacheConfig,
308    entries: RwLock<HashMap<String, CacheEntry>>,
309    metrics: AtomicCacheMetrics,
310}
311
312impl TenantCache {
313    /// Create a new tenant cache with the given config.
314    pub fn new(config: CacheConfig) -> Self {
315        Self {
316            entries: RwLock::new(HashMap::with_capacity(config.max_entries)),
317            config,
318            metrics: AtomicCacheMetrics::new(),
319        }
320    }
321
322    /// Create with default config.
323    pub fn default_config() -> Self {
324        Self::new(CacheConfig::default())
325    }
326
327    /// Get the cache config.
328    pub fn config(&self) -> &CacheConfig {
329        &self.config
330    }
331
332    /// Look up a tenant in the cache.
333    pub fn lookup(&self, tenant_id: &TenantId) -> CacheLookup {
334        let key = tenant_id.as_str();
335
336        let entries = self.entries.read();
337        match entries.get(key) {
338            Some(entry) => {
339                if entry.is_expired() {
340                    self.metrics.record_expiration();
341                    CacheLookup::Miss
342                } else if entry.context.is_none() {
343                    self.metrics.record_negative_hit();
344                    CacheLookup::NegativeHit
345                } else if self.config.background_refresh
346                    && entry.should_refresh(self.config.refresh_threshold)
347                {
348                    self.metrics.record_hit();
349                    CacheLookup::Stale(entry.context.clone().unwrap())
350                } else {
351                    self.metrics.record_hit();
352                    CacheLookup::Hit(entry.context.clone().unwrap())
353                }
354            }
355            None => {
356                self.metrics.record_miss();
357                CacheLookup::Miss
358            }
359        }
360    }
361
362    /// Insert a tenant into the cache.
363    pub fn insert(&self, tenant_id: TenantId, context: TenantContext) {
364        let key = tenant_id.as_str().to_string();
365        let entry = CacheEntry::positive(context, self.config.ttl);
366
367        let mut entries = self.entries.write();
368
369        // Check capacity and evict if necessary
370        if entries.len() >= self.config.max_entries && !entries.contains_key(&key) {
371            self.evict_one(&mut entries);
372        }
373
374        entries.insert(key, entry);
375    }
376
377    /// Insert a negative entry (tenant not found).
378    pub fn insert_negative(&self, tenant_id: TenantId) {
379        let key = tenant_id.as_str().to_string();
380        let entry = CacheEntry::negative(self.config.negative_ttl);
381
382        let mut entries = self.entries.write();
383
384        if entries.len() >= self.config.max_entries && !entries.contains_key(&key) {
385            self.evict_one(&mut entries);
386        }
387
388        entries.insert(key, entry);
389    }
390
391    /// Invalidate a specific tenant.
392    pub fn invalidate(&self, tenant_id: &TenantId) {
393        self.entries.write().remove(tenant_id.as_str());
394    }
395
396    /// Invalidate all tenants matching a predicate.
397    pub fn invalidate_if<F>(&self, predicate: F)
398    where
399        F: Fn(&str, &TenantContext) -> bool,
400    {
401        let mut entries = self.entries.write();
402        entries.retain(|k, v| {
403            if let Some(ref ctx) = v.context {
404                !predicate(k, ctx)
405            } else {
406                true
407            }
408        });
409    }
410
411    /// Clear the entire cache.
412    pub fn clear(&self) {
413        self.entries.write().clear();
414    }
415
416    /// Get the current cache size.
417    pub fn len(&self) -> usize {
418        self.entries.read().len()
419    }
420
421    /// Check if the cache is empty.
422    pub fn is_empty(&self) -> bool {
423        self.len() == 0
424    }
425
426    /// Get cache metrics.
427    pub fn metrics(&self) -> CacheMetrics {
428        self.metrics.snapshot(self.len())
429    }
430
431    /// Reset metrics.
432    pub fn reset_metrics(&self) {
433        self.metrics.reset();
434    }
435
436    /// Evict expired entries.
437    pub fn evict_expired(&self) -> usize {
438        let mut entries = self.entries.write();
439        let before = entries.len();
440
441        entries.retain(|_, entry| !entry.is_expired());
442
443        let evicted = before - entries.len();
444        for _ in 0..evicted {
445            self.metrics.record_expiration();
446        }
447        evicted
448    }
449
450    /// Mark an entry as refreshing (to prevent thundering herd).
451    pub fn mark_refreshing(&self, tenant_id: &TenantId) -> bool {
452        let mut entries = self.entries.write();
453        if let Some(entry) = entries.get_mut(tenant_id.as_str()) {
454            if !entry.refreshing {
455                entry.refreshing = true;
456                self.metrics.record_background_refresh();
457                return true;
458            }
459        }
460        false
461    }
462
463    /// Complete a refresh with a new context.
464    pub fn complete_refresh(&self, tenant_id: TenantId, context: TenantContext) {
465        let key = tenant_id.as_str().to_string();
466        let entry = CacheEntry::positive(context, self.config.ttl);
467
468        self.entries.write().insert(key, entry);
469    }
470
471    /// Get or fetch a tenant, using the cache.
472    pub async fn get_or_fetch<F, Fut>(
473        &self,
474        tenant_id: &TenantId,
475        fetch: F,
476    ) -> Option<TenantContext>
477    where
478        F: FnOnce() -> Fut,
479        Fut: Future<Output = Option<TenantContext>>,
480    {
481        match self.lookup(tenant_id) {
482            CacheLookup::Hit(ctx) => Some(ctx),
483            CacheLookup::NegativeHit => None,
484            CacheLookup::Stale(ctx) => {
485                // Return stale data, background refresh could be triggered separately
486                Some(ctx)
487            }
488            CacheLookup::Miss => {
489                // Fetch from source
490                match fetch().await {
491                    Some(ctx) => {
492                        self.insert(tenant_id.clone(), ctx.clone());
493                        Some(ctx)
494                    }
495                    None => {
496                        self.insert_negative(tenant_id.clone());
497                        None
498                    }
499                }
500            }
501        }
502    }
503
504    /// Evict one entry (LRU).
505    fn evict_one(&self, entries: &mut HashMap<String, CacheEntry>) {
506        // First try to evict expired entries
507        let expired_key = entries
508            .iter()
509            .find(|(_, e)| e.is_expired())
510            .map(|(k, _)| k.clone());
511
512        if let Some(key) = expired_key {
513            entries.remove(&key);
514            self.metrics.record_expiration();
515            return;
516        }
517
518        // Otherwise evict least recently used (lowest access count)
519        let lru_key = entries
520            .iter()
521            .min_by_key(|(_, e)| e.access_count)
522            .map(|(k, _)| k.clone());
523
524        if let Some(key) = lru_key {
525            entries.remove(&key);
526            self.metrics.record_eviction();
527        }
528    }
529}
530
531/// A sharded cache for high-concurrency scenarios.
532///
533/// Uses multiple shards to reduce lock contention under heavy load.
534pub struct ShardedTenantCache {
535    shards: Vec<TenantCache>,
536    shard_count: usize,
537}
538
539impl ShardedTenantCache {
540    /// Create a new sharded cache.
541    pub fn new(shard_count: usize, config: CacheConfig) -> Self {
542        let per_shard_max = config.max_entries / shard_count;
543        let shard_config = CacheConfig {
544            max_entries: per_shard_max.max(100),
545            ..config
546        };
547
548        let shards = (0..shard_count)
549            .map(|_| TenantCache::new(shard_config.clone()))
550            .collect();
551
552        Self {
553            shards,
554            shard_count,
555        }
556    }
557
558    /// Create with reasonable defaults for high-concurrency.
559    pub fn high_concurrency(max_entries: usize) -> Self {
560        // Use number of CPUs for shard count
561        let shard_count = num_cpus::get().max(4);
562        Self::new(shard_count, CacheConfig::new(max_entries))
563    }
564
565    /// Get the shard for a tenant ID.
566    fn shard(&self, tenant_id: &TenantId) -> &TenantCache {
567        let hash = tenant_id.as_str().bytes().fold(0u64, |acc, b| {
568            acc.wrapping_mul(31).wrapping_add(b as u64)
569        });
570        &self.shards[(hash as usize) % self.shard_count]
571    }
572
573    /// Look up a tenant.
574    pub fn lookup(&self, tenant_id: &TenantId) -> CacheLookup {
575        self.shard(tenant_id).lookup(tenant_id)
576    }
577
578    /// Insert a tenant.
579    pub fn insert(&self, tenant_id: TenantId, context: TenantContext) {
580        self.shard(&tenant_id).insert(tenant_id, context);
581    }
582
583    /// Insert a negative entry.
584    pub fn insert_negative(&self, tenant_id: TenantId) {
585        self.shard(&tenant_id).insert_negative(tenant_id);
586    }
587
588    /// Invalidate a tenant.
589    pub fn invalidate(&self, tenant_id: &TenantId) {
590        self.shard(tenant_id).invalidate(tenant_id);
591    }
592
593    /// Clear all shards.
594    pub fn clear(&self) {
595        for shard in &self.shards {
596            shard.clear();
597        }
598    }
599
600    /// Get total size.
601    pub fn len(&self) -> usize {
602        self.shards.iter().map(|s| s.len()).sum()
603    }
604
605    /// Check if empty.
606    pub fn is_empty(&self) -> bool {
607        self.shards.iter().all(|s| s.is_empty())
608    }
609
610    /// Get aggregated metrics.
611    pub fn metrics(&self) -> CacheMetrics {
612        let mut total = CacheMetrics::default();
613        for shard in &self.shards {
614            let m = shard.metrics();
615            total.hits += m.hits;
616            total.misses += m.misses;
617            total.negative_hits += m.negative_hits;
618            total.evictions += m.evictions;
619            total.expirations += m.expirations;
620            total.background_refreshes += m.background_refreshes;
621            total.size += m.size;
622        }
623        total
624    }
625
626    /// Evict expired entries from all shards.
627    pub fn evict_expired(&self) -> usize {
628        self.shards.iter().map(|s| s.evict_expired()).sum()
629    }
630
631    /// Get or fetch a tenant.
632    pub async fn get_or_fetch<F, Fut>(
633        &self,
634        tenant_id: &TenantId,
635        fetch: F,
636    ) -> Option<TenantContext>
637    where
638        F: FnOnce() -> Fut,
639        Fut: Future<Output = Option<TenantContext>>,
640    {
641        self.shard(tenant_id).get_or_fetch(tenant_id, fetch).await
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[test]
650    fn test_cache_hit() {
651        let cache = TenantCache::new(CacheConfig::new(100));
652        let tenant_id = TenantId::new("test-tenant");
653        let context = TenantContext::new(tenant_id.clone());
654
655        cache.insert(tenant_id.clone(), context);
656
657        match cache.lookup(&tenant_id) {
658            CacheLookup::Hit(ctx) => assert_eq!(ctx.id.as_str(), "test-tenant"),
659            _ => panic!("Expected hit"),
660        }
661    }
662
663    #[test]
664    fn test_cache_miss() {
665        let cache = TenantCache::new(CacheConfig::new(100));
666        let tenant_id = TenantId::new("unknown");
667
668        match cache.lookup(&tenant_id) {
669            CacheLookup::Miss => {}
670            _ => panic!("Expected miss"),
671        }
672    }
673
674    #[test]
675    fn test_negative_cache() {
676        let cache = TenantCache::new(CacheConfig::new(100));
677        let tenant_id = TenantId::new("deleted-tenant");
678
679        cache.insert_negative(tenant_id.clone());
680
681        match cache.lookup(&tenant_id) {
682            CacheLookup::NegativeHit => {}
683            _ => panic!("Expected negative hit"),
684        }
685    }
686
687    #[test]
688    fn test_cache_eviction() {
689        let cache = TenantCache::new(CacheConfig::new(2));
690
691        for i in 0..3 {
692            let id = TenantId::new(format!("tenant-{}", i));
693            cache.insert(id.clone(), TenantContext::new(id));
694        }
695
696        // Should have evicted one
697        assert_eq!(cache.len(), 2);
698    }
699
700    #[test]
701    fn test_cache_metrics() {
702        let cache = TenantCache::new(CacheConfig::new(100));
703        let id = TenantId::new("test");
704
705        // Miss
706        cache.lookup(&id);
707        assert_eq!(cache.metrics().misses, 1);
708
709        // Insert and hit
710        cache.insert(id.clone(), TenantContext::new(id.clone()));
711        cache.lookup(&id);
712        assert_eq!(cache.metrics().hits, 1);
713    }
714
715    #[test]
716    fn test_sharded_cache() {
717        let cache = ShardedTenantCache::new(4, CacheConfig::new(100));
718
719        for i in 0..10 {
720            let id = TenantId::new(format!("tenant-{}", i));
721            cache.insert(id.clone(), TenantContext::new(id));
722        }
723
724        assert_eq!(cache.len(), 10);
725
726        for i in 0..10 {
727            let id = TenantId::new(format!("tenant-{}", i));
728            match cache.lookup(&id) {
729                CacheLookup::Hit(_) => {}
730                _ => panic!("Expected hit for tenant-{}", i),
731            }
732        }
733    }
734
735    #[tokio::test]
736    async fn test_get_or_fetch() {
737        let cache = TenantCache::new(CacheConfig::new(100));
738        let id = TenantId::new("fetch-tenant");
739
740        // First call should fetch
741        let result = cache
742            .get_or_fetch(&id, || async { Some(TenantContext::new("fetch-tenant")) })
743            .await;
744
745        assert!(result.is_some());
746
747        // Second call should hit cache
748        let result2 = cache
749            .get_or_fetch(&id, || async {
750                panic!("Should not fetch again");
751            })
752            .await;
753
754        assert!(result2.is_some());
755    }
756}
757