prax_query/tenant/
resolver.rs1use super::context::{TenantContext, TenantId, TenantInfo};
4use crate::error::QueryResult;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10
11#[async_trait]
13pub trait TenantResolver: Send + Sync {
14 async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext>;
16
17 async fn validate(&self, tenant_id: &TenantId) -> QueryResult<bool> {
19 Ok(self.resolve(tenant_id).await.is_ok())
20 }
21
22 async fn schema_for(&self, tenant_id: &TenantId) -> QueryResult<Option<String>> {
24 let ctx = self.resolve(tenant_id).await?;
25 Ok(ctx.info.schema)
26 }
27
28 async fn database_for(&self, tenant_id: &TenantId) -> QueryResult<Option<String>> {
30 let ctx = self.resolve(tenant_id).await?;
31 Ok(ctx.info.database)
32 }
33}
34
35#[derive(Debug, Clone, Default)]
37pub struct StaticResolver {
38 tenants: Arc<RwLock<HashMap<String, TenantContext>>>,
39}
40
41impl StaticResolver {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn register(&self, tenant_id: impl Into<String>, context: TenantContext) -> &Self {
49 self.tenants
50 .write()
51 .expect("lock poisoned")
52 .insert(tenant_id.into(), context);
53 self
54 }
55
56 pub fn register_simple(&self, tenant_id: impl Into<String>) -> &Self {
58 let id: String = tenant_id.into();
59 let context = TenantContext::new(id.clone());
60 self.register(id, context)
61 }
62
63 pub fn unregister(&self, tenant_id: &str) -> Option<TenantContext> {
65 self.tenants
66 .write()
67 .expect("lock poisoned")
68 .remove(tenant_id)
69 }
70
71 pub fn contains(&self, tenant_id: &str) -> bool {
73 self.tenants
74 .read()
75 .expect("lock poisoned")
76 .contains_key(tenant_id)
77 }
78
79 pub fn len(&self) -> usize {
81 self.tenants.read().expect("lock poisoned").len()
82 }
83
84 pub fn is_empty(&self) -> bool {
86 self.len() == 0
87 }
88}
89
90#[async_trait]
91impl TenantResolver for StaticResolver {
92 async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
93 self.tenants
94 .read()
95 .expect("lock poisoned")
96 .get(tenant_id.as_str())
97 .cloned()
98 .ok_or_else(|| crate::error::QueryError::not_found(format!("Tenant {}", tenant_id)))
99 }
100
101 async fn validate(&self, tenant_id: &TenantId) -> QueryResult<bool> {
102 Ok(self.contains(tenant_id.as_str()))
103 }
104}
105
106pub type ResolverFn = Arc<
108 dyn Fn(TenantId) -> Pin<Box<dyn Future<Output = QueryResult<TenantContext>> + Send>>
109 + Send
110 + Sync,
111>;
112
113pub struct DynamicResolver {
115 resolve_fn: ResolverFn,
116}
117
118impl DynamicResolver {
119 pub fn new<F, Fut>(f: F) -> Self
121 where
122 F: Fn(TenantId) -> Fut + Send + Sync + 'static,
123 Fut: Future<Output = QueryResult<TenantContext>> + Send + 'static,
124 {
125 Self {
126 resolve_fn: Arc::new(move |id| Box::pin(f(id))),
127 }
128 }
129}
130
131impl std::fmt::Debug for DynamicResolver {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_struct("DynamicResolver").finish()
134 }
135}
136
137#[async_trait]
138impl TenantResolver for DynamicResolver {
139 async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
140 (self.resolve_fn)(tenant_id.clone()).await
141 }
142}
143
144pub struct DatabaseResolver<F>
146where
147 F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
148 + Send
149 + Sync,
150{
151 query_fn: F,
152 cache: Arc<RwLock<HashMap<String, TenantContext>>>,
153 cache_ttl: std::time::Duration,
154}
155
156impl<F> DatabaseResolver<F>
157where
158 F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
159 + Send
160 + Sync,
161{
162 pub fn new(query_fn: F) -> Self {
164 Self {
165 query_fn,
166 cache: Arc::new(RwLock::new(HashMap::new())),
167 cache_ttl: std::time::Duration::from_secs(300), }
169 }
170
171 pub fn with_cache_ttl(mut self, ttl: std::time::Duration) -> Self {
173 self.cache_ttl = ttl;
174 self
175 }
176
177 pub fn clear_cache(&self) {
179 self.cache.write().expect("lock poisoned").clear();
180 }
181
182 pub fn invalidate(&self, tenant_id: &str) {
184 self.cache.write().expect("lock poisoned").remove(tenant_id);
185 }
186}
187
188impl<F> std::fmt::Debug for DatabaseResolver<F>
189where
190 F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
191 + Send
192 + Sync,
193{
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("DatabaseResolver")
196 .field("cache_ttl", &self.cache_ttl)
197 .field("cache_size", &self.cache.read().expect("lock").len())
198 .finish()
199 }
200}
201
202#[async_trait]
203impl<F> TenantResolver for DatabaseResolver<F>
204where
205 F: Fn(String) -> Pin<Box<dyn Future<Output = QueryResult<Option<TenantInfo>>> + Send>>
206 + Send
207 + Sync,
208{
209 async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
210 if let Some(ctx) = self
212 .cache
213 .read()
214 .expect("lock poisoned")
215 .get(tenant_id.as_str())
216 {
217 return Ok(ctx.clone());
218 }
219
220 let info = (self.query_fn)(tenant_id.as_str().to_string())
222 .await?
223 .ok_or_else(|| crate::error::QueryError::not_found(format!("Tenant {}", tenant_id)))?;
224
225 let ctx = TenantContext::with_info(tenant_id.clone(), info);
226
227 self.cache
229 .write()
230 .expect("lock poisoned")
231 .insert(tenant_id.as_str().to_string(), ctx.clone());
232
233 Ok(ctx)
234 }
235}
236
237pub struct CompositeResolver {
239 resolvers: Vec<Arc<dyn TenantResolver>>,
240}
241
242impl CompositeResolver {
243 pub fn new() -> Self {
245 Self {
246 resolvers: Vec::new(),
247 }
248 }
249
250 pub fn add<R: TenantResolver + 'static>(mut self, resolver: R) -> Self {
252 self.resolvers.push(Arc::new(resolver));
253 self
254 }
255}
256
257impl Default for CompositeResolver {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263impl std::fmt::Debug for CompositeResolver {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 f.debug_struct("CompositeResolver")
266 .field("resolver_count", &self.resolvers.len())
267 .finish()
268 }
269}
270
271#[async_trait]
272impl TenantResolver for CompositeResolver {
273 async fn resolve(&self, tenant_id: &TenantId) -> QueryResult<TenantContext> {
274 for resolver in &self.resolvers {
275 if let Ok(ctx) = resolver.resolve(tenant_id).await {
276 return Ok(ctx);
277 }
278 }
279 Err(crate::error::QueryError::not_found(format!(
280 "Tenant {} not found in any resolver",
281 tenant_id
282 )))
283 }
284
285 async fn validate(&self, tenant_id: &TenantId) -> QueryResult<bool> {
286 for resolver in &self.resolvers {
287 if resolver.validate(tenant_id).await? {
288 return Ok(true);
289 }
290 }
291 Ok(false)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[tokio::test]
300 async fn test_static_resolver() {
301 let resolver = StaticResolver::new();
302 resolver.register_simple("tenant-1");
303 resolver.register(
304 "tenant-2",
305 TenantContext::with_info(
306 "tenant-2",
307 TenantInfo::new()
308 .with_name("Acme Corp")
309 .with_schema("tenant_acme"),
310 ),
311 );
312
313 let ctx1 = resolver.resolve(&TenantId::new("tenant-1")).await.unwrap();
314 assert_eq!(ctx1.id.as_str(), "tenant-1");
315
316 let ctx2 = resolver.resolve(&TenantId::new("tenant-2")).await.unwrap();
317 assert_eq!(ctx2.info.name, Some("Acme Corp".to_string()));
318 assert_eq!(ctx2.info.schema, Some("tenant_acme".to_string()));
319
320 assert!(resolver.validate(&TenantId::new("tenant-1")).await.unwrap());
321 assert!(!resolver.validate(&TenantId::new("unknown")).await.unwrap());
322 }
323
324 #[tokio::test]
325 async fn test_dynamic_resolver() {
326 let resolver = DynamicResolver::new(|id| async move {
327 if id.as_str() == "valid" {
328 Ok(TenantContext::new(id))
329 } else {
330 Err(crate::error::QueryError::not_found("Tenant"))
331 }
332 });
333
334 assert!(resolver.resolve(&TenantId::new("valid")).await.is_ok());
335 assert!(resolver.resolve(&TenantId::new("invalid")).await.is_err());
336 }
337
338 #[tokio::test]
339 async fn test_composite_resolver() {
340 let static1 = StaticResolver::new();
341 static1.register_simple("tenant-a");
342
343 let static2 = StaticResolver::new();
344 static2.register_simple("tenant-b");
345
346 let resolver = CompositeResolver::new().add(static1).add(static2);
347
348 assert!(resolver.resolve(&TenantId::new("tenant-a")).await.is_ok());
349 assert!(resolver.resolve(&TenantId::new("tenant-b")).await.is_ok());
350 assert!(resolver.resolve(&TenantId::new("tenant-c")).await.is_err());
351 }
352}