prax_query/tenant/
resolver.rs

1//! Tenant resolvers for dynamic tenant lookup.
2
3use super::context::{TenantContext, TenantId, TenantInfo};
4use crate::error::QueryResult;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10
11/// A resolver that can look up tenant information.
12#[async_trait]
13pub trait TenantResolver: Send + Sync {
14    /// Resolve a tenant ID to a full tenant context.
15    async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext>;
16
17    /// Validate that a tenant exists and is active.
18    async fn validate(&self, tenant_id: &TenantId) -> QueryResult<bool> {
19        Ok(self.resolve(tenant_id).await.is_ok())
20    }
21
22    /// Get the schema name for a tenant (schema-based isolation).
23    async fn schema_for(&self, tenant_id: &TenantId) -> QueryResult<Option<String>> {
24        let ctx = self.resolve(tenant_id).await?;
25        Ok(ctx.info.schema)
26    }
27
28    /// Get the database name for a tenant (database-based isolation).
29    async fn database_for(&self, tenant_id: &TenantId) -> QueryResult<Option<String>> {
30        let ctx = self.resolve(tenant_id).await?;
31        Ok(ctx.info.database)
32    }
33}
34
35/// A static resolver that maps tenant IDs to contexts.
36#[derive(Debug, Clone, Default)]
37pub struct StaticResolver {
38    tenants: Arc<RwLock<HashMap<String, TenantContext>>>,
39}
40
41impl StaticResolver {
42    /// Create a new static resolver.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Register a tenant.
48    pub fn register(&self, tenant_id: impl Into<String>, context: TenantContext) -> &Self {
49        self.tenants
50            .write()
51            .expect("lock poisoned")
52            .insert(tenant_id.into(), context);
53        self
54    }
55
56    /// Register a simple tenant with just an ID.
57    pub fn register_simple(&self, tenant_id: impl Into<String>) -> &Self {
58        let id: String = tenant_id.into();
59        let context = TenantContext::new(id.clone());
60        self.register(id, context)
61    }
62
63    /// Unregister a tenant.
64    pub fn unregister(&self, tenant_id: &str) -> Option<TenantContext> {
65        self.tenants
66            .write()
67            .expect("lock poisoned")
68            .remove(tenant_id)
69    }
70
71    /// Check if a tenant is registered.
72    pub fn contains(&self, tenant_id: &str) -> bool {
73        self.tenants
74            .read()
75            .expect("lock poisoned")
76            .contains_key(tenant_id)
77    }
78
79    /// Get the number of registered tenants.
80    pub fn len(&self) -> usize {
81        self.tenants.read().expect("lock poisoned").len()
82    }
83
84    /// Check if empty.
85    pub fn is_empty(&self) -> bool {
86        self.len() == 0
87    }
88}
89
90#[async_trait]
91impl TenantResolver for StaticResolver {
92    async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
93        self.tenants
94            .read()
95            .expect("lock poisoned")
96            .get(tenant_id.as_str())
97            .cloned()
98            .ok_or_else(|| crate::error::QueryError::not_found(format!("Tenant {}", tenant_id)))
99    }
100
101    async fn validate(&self, tenant_id: &TenantId) -> QueryResult<bool> {
102        Ok(self.contains(tenant_id.as_str()))
103    }
104}
105
106/// Type alias for async resolver functions.
107pub type ResolverFn = Arc<
108    dyn Fn(TenantId) -> Pin<Box<dyn Future<Output = QueryResult<TenantContext>> + Send>>
109        + Send
110        + Sync,
111>;
112
113/// A dynamic resolver using a callback function.
114pub struct DynamicResolver {
115    resolve_fn: ResolverFn,
116}
117
118impl DynamicResolver {
119    /// Create a new dynamic resolver with a callback.
120    pub fn new<F, Fut>(f: F) -> Self
121    where
122        F: Fn(TenantId) -> Fut + Send + Sync + 'static,
123        Fut: Future<Output = QueryResult<TenantContext>> + Send + 'static,
124    {
125        Self {
126            resolve_fn: Arc::new(move |id| Box::pin(f(id))),
127        }
128    }
129}
130
131impl std::fmt::Debug for DynamicResolver {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        f.debug_struct("DynamicResolver").finish()
134    }
135}
136
137#[async_trait]
138impl TenantResolver for DynamicResolver {
139    async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
140        (self.resolve_fn)(tenant_id.clone()).await
141    }
142}
143
144/// A resolver that looks up tenants from the database.
145pub struct DatabaseResolver<F>
146where
147    F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
148        + Send
149        + Sync,
150{
151    query_fn: F,
152    cache: Arc<RwLock<HashMap<String, TenantContext>>>,
153    cache_ttl: std::time::Duration,
154}
155
156impl<F> DatabaseResolver<F>
157where
158    F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
159        + Send
160        + Sync,
161{
162    /// Create a new database resolver.
163    pub fn new(query_fn: F) -> Self {
164        Self {
165            query_fn,
166            cache: Arc::new(RwLock::new(HashMap::new())),
167            cache_ttl: std::time::Duration::from_secs(300), // 5 minutes
168        }
169    }
170
171    /// Set the cache TTL.
172    pub fn with_cache_ttl(mut self, ttl: std::time::Duration) -> Self {
173        self.cache_ttl = ttl;
174        self
175    }
176
177    /// Clear the cache.
178    pub fn clear_cache(&self) {
179        self.cache.write().expect("lock poisoned").clear();
180    }
181
182    /// Invalidate a specific tenant in the cache.
183    pub fn invalidate(&self, tenant_id: &str) {
184        self.cache.write().expect("lock poisoned").remove(tenant_id);
185    }
186}
187
188impl<F> std::fmt::Debug for DatabaseResolver<F>
189where
190    F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
191        + Send
192        + Sync,
193{
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("DatabaseResolver")
196            .field("cache_ttl", &self.cache_ttl)
197            .field("cache_size", &self.cache.read().expect("lock").len())
198            .finish()
199    }
200}
201
202#[async_trait]
203impl<F> TenantResolver for DatabaseResolver<F>
204where
205    F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
206        + Send
207        + Sync,
208{
209    async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
210        // Check cache first
211        if let Some(ctx) = self
212            .cache
213            .read()
214            .expect("lock poisoned")
215            .get(tenant_id.as_str())
216        {
217            return Ok(ctx.clone());
218        }
219
220        // Query database
221        let info = (self.query_fn)(tenant_id.as_str().to_string())
222            .await?
223            .ok_or_else(|| crate::error::QueryError::not_found(format!("Tenant {}", tenant_id)))?;
224
225        let ctx = TenantContext::with_info(tenant_id.clone(), info);
226
227        // Cache the result
228        self.cache
229            .write()
230            .expect("lock poisoned")
231            .insert(tenant_id.as_str().to_string(), ctx.clone());
232
233        Ok(ctx)
234    }
235}
236
237/// A composite resolver that tries multiple resolvers in order.
238pub struct CompositeResolver {
239    resolvers: Vec<Arc<dyn TenantResolver>>,
240}
241
242impl CompositeResolver {
243    /// Create a new composite resolver.
244    pub fn new() -> Self {
245        Self {
246            resolvers: Vec::new(),
247        }
248    }
249
250    /// Add a resolver to the chain.
251    pub fn add<R: TenantResolver + 'static>(mut self, resolver: R) -> Self {
252        self.resolvers.push(Arc::new(resolver));
253        self
254    }
255}
256
257impl Default for CompositeResolver {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263impl std::fmt::Debug for CompositeResolver {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        f.debug_struct("CompositeResolver")
266            .field("resolver_count", &self.resolvers.len())
267            .finish()
268    }
269}
270
271#[async_trait]
272impl TenantResolver for CompositeResolver {
273    async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
274        for resolver in &self.resolvers {
275            if let Ok(ctx) = resolver.resolve(tenant_id).await {
276                return Ok(ctx);
277            }
278        }
279        Err(crate::error::QueryError::not_found(format!(
280            "Tenant {} not found in any resolver",
281            tenant_id
282        )))
283    }
284
285    async fn validate(&self, tenant_id: &TenantId) -> QueryResult<bool> {
286        for resolver in &self.resolvers {
287            if resolver.validate(tenant_id).await? {
288                return Ok(true);
289            }
290        }
291        Ok(false)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[tokio::test]
300    async fn test_static_resolver() {
301        let resolver = StaticResolver::new();
302        resolver.register_simple("tenant-1");
303        resolver.register(
304            "tenant-2",
305            TenantContext::with_info(
306                "tenant-2",
307                TenantInfo::new()
308                    .with_name("Acme Corp")
309                    .with_schema("tenant_acme"),
310            ),
311        );
312
313        let ctx1 = resolver.resolve(&TenantId::new("tenant-1")).await.unwrap();
314        assert_eq!(ctx1.id.as_str(), "tenant-1");
315
316        let ctx2 = resolver.resolve(&TenantId::new("tenant-2")).await.unwrap();
317        assert_eq!(ctx2.info.name, Some("Acme Corp".to_string()));
318        assert_eq!(ctx2.info.schema, Some("tenant_acme".to_string()));
319
320        assert!(resolver.validate(&TenantId::new("tenant-1")).await.unwrap());
321        assert!(!resolver.validate(&TenantId::new("unknown")).await.unwrap());
322    }
323
324    #[tokio::test]
325    async fn test_dynamic_resolver() {
326        let resolver = DynamicResolver::new(|id| async move {
327            if id.as_str() == "valid" {
328                Ok(TenantContext::new(id))
329            } else {
330                Err(crate::error::QueryError::not_found("Tenant"))
331            }
332        });
333
334        assert!(resolver.resolve(&TenantId::new("valid")).await.is_ok());
335        assert!(resolver.resolve(&TenantId::new("invalid")).await.is_err());
336    }
337
338    #[tokio::test]
339    async fn test_composite_resolver() {
340        let static1 = StaticResolver::new();
341        static1.register_simple("tenant-a");
342
343        let static2 = StaticResolver::new();
344        static2.register_simple("tenant-b");
345
346        let resolver = CompositeResolver::new().add(static1).add(static2);
347
348        assert!(resolver.resolve(&TenantId::new("tenant-a")).await.is_ok());
349        assert!(resolver.resolve(&TenantId::new("tenant-b")).await.is_ok());
350        assert!(resolver.resolve(&TenantId::new("tenant-c")).await.is_err());
351    }
352}