docbox_core/tenant/
tenant_cache.rs

1//! # Tenant Cache
2//!
3//! Provides caching for tenants to ensure we don't have to fetch the tenant
4//! from the database for every request
5
6use docbox_database::{
7    DbPool, DbResult,
8    models::tenant::{Tenant, TenantId},
9};
10use moka::{future::Cache, policy::EvictionPolicy};
11use std::time::Duration;
12
13/// Duration to maintain tenant caches (15 minutes)
14const TENANT_CACHE_DURATION: Duration = Duration::from_secs(60 * 15);
15
16/// Maximum tenants to keep in cache
17const TENANT_CACHE_CAPACITY: u64 = 50;
18
19/// Cache for recently used tenants
20#[derive(Clone)]
21pub struct TenantCache {
22    cache: Cache<TenantCacheKey, Tenant>,
23}
24
25/// Cache key to identify a tenant
26#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
27struct TenantCacheKey {
28    env: String,
29    tenant_id: TenantId,
30}
31
32impl Default for TenantCache {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl TenantCache {
39    /// Create a new tenant cache
40    pub fn new() -> Self {
41        let cache = Cache::builder()
42            .time_to_idle(TENANT_CACHE_DURATION)
43            .max_capacity(TENANT_CACHE_CAPACITY)
44            .eviction_policy(EvictionPolicy::tiny_lfu())
45            .build();
46
47        Self { cache }
48    }
49
50    /// Get a tenant by ID
51    pub async fn get_tenant(
52        &self,
53        db: &DbPool,
54        env: String,
55        tenant_id: TenantId,
56    ) -> DbResult<Option<Tenant>> {
57        let cache_key = TenantCacheKey { env, tenant_id };
58
59        if let Some(tenant) = self.cache.get(&cache_key).await {
60            return Ok(Some(tenant.clone()));
61        }
62
63        let tenant = Tenant::find_by_id(db, tenant_id, &cache_key.env).await?;
64
65        if let Some(tenant) = tenant.as_ref() {
66            self.cache.insert(cache_key, tenant.clone()).await;
67        }
68
69        Ok(tenant)
70    }
71
72    /// Clear the cache
73    pub async fn flush(&self) {
74        self.cache.invalidate_all();
75    }
76}