prax_query/tenant/
task_local.rs1use std::cell::Cell;
36use std::future::Future;
37
38use super::context::{TenantContext, TenantId};
39
40tokio::task_local! {
41 static TENANT_CONTEXT: TenantContext;
43}
44
45thread_local! {
46 static SYNC_TENANT_ID: Cell<Option<TenantId>> = const { Cell::new(None) };
49}
50
51pub async fn with_tenant<F, T>(tenant_id: impl Into<TenantId>, f: F) -> T
68where
69 F: Future<Output = T>,
70{
71 let ctx = TenantContext::new(tenant_id);
72 TENANT_CONTEXT.scope(ctx, f).await
73}
74
75pub async fn with_context<F, T>(ctx: TenantContext, f: F) -> T
77where
78 F: Future<Output = T>,
79{
80 TENANT_CONTEXT.scope(ctx, f).await
81}
82
83#[inline]
97pub fn current_tenant() -> Option<TenantContext> {
98 TENANT_CONTEXT.try_with(|ctx| ctx.clone()).ok()
99}
100
101#[inline]
105pub fn current_tenant_id() -> Option<TenantId> {
106 TENANT_CONTEXT.try_with(|ctx| ctx.id.clone()).ok()
107}
108
109#[inline]
113pub fn current_tenant_id_str() -> &'static str {
114 ""
117}
118
119#[inline]
121pub fn has_tenant() -> bool {
122 TENANT_CONTEXT.try_with(|_| ()).is_ok()
123}
124
125#[inline]
129pub fn with_current_tenant<F, T>(f: F) -> Option<T>
130where
131 F: FnOnce(&TenantContext) -> T,
132{
133 TENANT_CONTEXT.try_with(f).ok()
134}
135
136#[inline]
138pub fn require_tenant() -> Result<TenantContext, TenantNotSetError> {
139 current_tenant().ok_or(TenantNotSetError)
140}
141
142#[derive(Debug, Clone, Copy)]
144pub struct TenantNotSetError;
145
146impl std::fmt::Display for TenantNotSetError {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 write!(f, "tenant context not set")
149 }
150}
151
152impl std::error::Error for TenantNotSetError {}
153
154pub fn set_sync_tenant(tenant_id: impl Into<TenantId>) -> SyncTenantGuard {
172 let id = tenant_id.into();
173 let previous = SYNC_TENANT_ID.with(|cell| cell.replace(Some(id)));
174 SyncTenantGuard { previous }
175}
176
177#[inline]
179pub fn sync_tenant_id() -> Option<TenantId> {
180 SYNC_TENANT_ID.with(|cell| {
181 unsafe { &*cell.as_ptr() }.clone()
183 })
184}
185
186pub struct SyncTenantGuard {
188 previous: Option<TenantId>,
189}
190
191impl Drop for SyncTenantGuard {
192 fn drop(&mut self) {
193 SYNC_TENANT_ID.with(|cell| cell.set(self.previous.take()));
194 }
195}
196
197#[derive(Debug, Clone)]
220pub struct TenantScope {
221 context: TenantContext,
222}
223
224impl TenantScope {
225 pub fn new(tenant_id: impl Into<TenantId>) -> Self {
227 Self {
228 context: TenantContext::new(tenant_id),
229 }
230 }
231
232 pub fn from_context(context: TenantContext) -> Self {
234 Self { context }
235 }
236
237 pub fn tenant_id(&self) -> &TenantId {
239 &self.context.id
240 }
241
242 pub fn context(&self) -> &TenantContext {
244 &self.context
245 }
246
247 pub async fn run<F, T>(&self, f: F) -> T
249 where
250 F: Future<Output = T>,
251 {
252 TENANT_CONTEXT.scope(self.context.clone(), f).await
253 }
254
255 pub fn run_sync<F, T>(&self, f: F) -> T
257 where
258 F: FnOnce() -> T,
259 {
260 let _guard = set_sync_tenant(self.context.id.clone());
261 f()
262 }
263}
264
265pub trait TenantExtractor: Send + Sync {
271 fn extract(&self, headers: &[(String, String)]) -> Option<TenantId>;
273}
274
275#[derive(Debug, Clone)]
277pub struct HeaderExtractor {
278 header_name: String,
279}
280
281impl HeaderExtractor {
282 pub fn new(header_name: impl Into<String>) -> Self {
284 Self {
285 header_name: header_name.into(),
286 }
287 }
288
289 pub fn default_header() -> Self {
291 Self::new("X-Tenant-ID")
292 }
293}
294
295impl TenantExtractor for HeaderExtractor {
296 fn extract(&self, headers: &[(String, String)]) -> Option<TenantId> {
297 headers
298 .iter()
299 .find(|(k, _)| k.eq_ignore_ascii_case(&self.header_name))
300 .map(|(_, v)| TenantId::new(v.clone()))
301 }
302}
303
304#[derive(Debug, Clone)]
306pub struct JwtClaimExtractor {
307 claim_name: String,
308}
309
310impl JwtClaimExtractor {
311 pub fn new(claim_name: impl Into<String>) -> Self {
313 Self {
314 claim_name: claim_name.into(),
315 }
316 }
317
318 pub fn default_claim() -> Self {
320 Self::new("tenant_id")
321 }
322
323 pub fn claim_name(&self) -> &str {
325 &self.claim_name
326 }
327}
328
329impl TenantExtractor for JwtClaimExtractor {
330 fn extract(&self, _headers: &[(String, String)]) -> Option<TenantId> {
331 None
334 }
335}
336
337pub struct CompositeExtractor {
339 extractors: Vec<Box<dyn TenantExtractor>>,
340}
341
342impl CompositeExtractor {
343 pub fn new() -> Self {
345 Self {
346 extractors: Vec::new(),
347 }
348 }
349
350 pub fn add<E: TenantExtractor + 'static>(mut self, extractor: E) -> Self {
352 self.extractors.push(Box::new(extractor));
353 self
354 }
355}
356
357impl Default for CompositeExtractor {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363impl TenantExtractor for CompositeExtractor {
364 fn extract(&self, headers: &[(String, String)]) -> Option<TenantId> {
365 for extractor in &self.extractors {
366 if let Some(id) = extractor.extract(headers) {
367 return Some(id);
368 }
369 }
370 None
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[tokio::test]
379 async fn test_with_tenant() {
380 let result = with_tenant("test-tenant", async {
381 current_tenant_id()
382 })
383 .await;
384
385 assert_eq!(result.unwrap().as_str(), "test-tenant");
386 }
387
388 #[tokio::test]
389 async fn test_no_tenant() {
390 assert!(current_tenant().is_none());
391 assert!(!has_tenant());
392 }
393
394 #[tokio::test]
395 async fn test_nested_tenant() {
396 with_tenant("outer", async {
397 assert_eq!(current_tenant_id().unwrap().as_str(), "outer");
398
399 with_tenant("inner", async {
400 assert_eq!(current_tenant_id().unwrap().as_str(), "inner");
401 })
402 .await;
403
404 assert_eq!(current_tenant_id().unwrap().as_str(), "outer");
406 })
407 .await;
408 }
409
410 #[tokio::test]
411 async fn test_tenant_scope() {
412 let scope = TenantScope::new("scoped-tenant");
413
414 let result = scope
415 .run(async { current_tenant_id().map(|id| id.as_str().to_string()) })
416 .await;
417
418 assert_eq!(result, Some("scoped-tenant".to_string()));
419 }
420
421 #[test]
422 fn test_sync_tenant() {
423 {
424 let _guard = set_sync_tenant("sync-tenant");
425 assert_eq!(sync_tenant_id().unwrap().as_str(), "sync-tenant");
426 }
427
428 assert!(sync_tenant_id().is_none());
430 }
431
432 #[test]
433 fn test_header_extractor() {
434 let extractor = HeaderExtractor::new("X-Tenant-ID");
435
436 let headers = vec![
437 ("Content-Type".to_string(), "application/json".to_string()),
438 ("X-Tenant-ID".to_string(), "tenant-from-header".to_string()),
439 ];
440
441 let id = extractor.extract(&headers);
442 assert_eq!(id.unwrap().as_str(), "tenant-from-header");
443 }
444
445 #[test]
446 fn test_composite_extractor() {
447 let extractor = CompositeExtractor::new()
448 .add(HeaderExtractor::new("X-Organization-ID"))
449 .add(HeaderExtractor::new("X-Tenant-ID"));
450
451 let headers = vec![("X-Tenant-ID".to_string(), "fallback-tenant".to_string())];
452
453 let id = extractor.extract(&headers);
454 assert_eq!(id.unwrap().as_str(), "fallback-tenant");
455 }
456}
457
458