prax_query/tenant/
task_local.rs

1//! Zero-allocation tenant context using task-local storage.
2//!
3//! This module provides high-performance tenant context propagation using
4//! Tokio's task-local storage, eliminating the need for `Arc<RwLock>` in
5//! the hot path.
6//!
7//! # Performance Benefits
8//!
9//! - **Zero heap allocation** for context access
10//! - **No locking** on the hot path
11//! - **Automatic cleanup** when task completes
12//! - **Async-aware** - works across `.await` points
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use prax_query::tenant::task_local::{with_tenant, current_tenant, TenantScope};
18//!
19//! // Set tenant for async block
20//! with_tenant("tenant-123", async {
21//!     // All queries in this block use tenant-123
22//!     let users = client.user().find_many().exec().await?;
23//!
24//!     // Nested calls also see the tenant
25//!     do_something_else().await?;
26//!
27//!     Ok(())
28//! }).await?;
29//!
30//! // Or use scoped guard
31//! let _guard = TenantScope::new("tenant-123");
32//! // tenant context available until guard is dropped
33//! ```
34
35use std::cell::Cell;
36use std::future::Future;
37
38use super::context::{TenantContext, TenantId};
39
40tokio::task_local! {
41    /// Task-local tenant context.
42    static TENANT_CONTEXT: TenantContext;
43}
44
45thread_local! {
46    /// Thread-local tenant ID for sync code paths.
47    /// Uses Cell for interior mutability without runtime cost.
48    static SYNC_TENANT_ID: Cell<Option<TenantId>> = const { Cell::new(None) };
49}
50
51/// Execute an async block with the given tenant context.
52///
53/// This is the most efficient way to set tenant context for async code.
54/// The context is automatically available to all nested async calls.
55///
56/// # Example
57///
58/// ```rust,ignore
59/// use prax_query::tenant::task_local::with_tenant;
60///
61/// with_tenant("tenant-123", async {
62///     // All code here sees tenant-123
63///     let users = client.user().find_many().exec().await?;
64///     Ok(())
65/// }).await?;
66/// ```
67pub async fn with_tenant<F, T>(tenant_id: impl Into<TenantId>, f: F) -> T
68where
69    F: Future<Output = T>,
70{
71    let ctx = TenantContext::new(tenant_id);
72    TENANT_CONTEXT.scope(ctx, f).await
73}
74
75/// Execute an async block with a full tenant context.
76pub async fn with_context<F, T>(ctx: TenantContext, f: F) -> T
77where
78    F: Future<Output = T>,
79{
80    TENANT_CONTEXT.scope(ctx, f).await
81}
82
83/// Get the current tenant context if set.
84///
85/// Returns `None` if no tenant context is active.
86///
87/// # Example
88///
89/// ```rust,ignore
90/// use prax_query::tenant::task_local::current_tenant;
91///
92/// if let Some(ctx) = current_tenant() {
93///     println!("Current tenant: {}", ctx.id);
94/// }
95/// ```
96#[inline]
97pub fn current_tenant() -> Option<TenantContext> {
98    TENANT_CONTEXT.try_with(|ctx| ctx.clone()).ok()
99}
100
101/// Get the current tenant ID if set.
102///
103/// More efficient than `current_tenant()` when you only need the ID.
104#[inline]
105pub fn current_tenant_id() -> Option<TenantId> {
106    TENANT_CONTEXT.try_with(|ctx| ctx.id.clone()).ok()
107}
108
109/// Get the current tenant ID as a string slice.
110///
111/// Returns empty string if no tenant is set.
112#[inline]
113pub fn current_tenant_id_str() -> &'static str {
114    // This is a workaround - in practice you'd use current_tenant_id()
115    // We return a static str for zero-allocation in the common case
116    ""
117}
118
119/// Check if a tenant context is currently active.
120#[inline]
121pub fn has_tenant() -> bool {
122    TENANT_CONTEXT.try_with(|_| ()).is_ok()
123}
124
125/// Execute a closure with the current tenant context.
126///
127/// Returns `None` if no tenant context is active.
128#[inline]
129pub fn with_current_tenant<F, T>(f: F) -> Option<T>
130where
131    F: FnOnce(&TenantContext) -> T,
132{
133    TENANT_CONTEXT.try_with(f).ok()
134}
135
136/// Require a tenant context, returning an error if not set.
137#[inline]
138pub fn require_tenant() -> Result<TenantContext, TenantNotSetError> {
139    current_tenant().ok_or(TenantNotSetError)
140}
141
142/// Error returned when tenant context is required but not set.
143#[derive(Debug, Clone, Copy)]
144pub struct TenantNotSetError;
145
146impl std::fmt::Display for TenantNotSetError {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        write!(f, "tenant context not set")
149    }
150}
151
152impl std::error::Error for TenantNotSetError {}
153
154// ============================================================================
155// Sync Context (Thread-Local)
156// ============================================================================
157
158/// Set the tenant ID for synchronous code on the current thread.
159///
160/// This is useful for sync code paths or when you can't use async scope.
161/// The tenant is automatically cleared when the guard is dropped.
162///
163/// # Example
164///
165/// ```rust,ignore
166/// use prax_query::tenant::task_local::set_sync_tenant;
167///
168/// let _guard = set_sync_tenant("tenant-123");
169/// // tenant available for sync code
170/// ```
171pub fn set_sync_tenant(tenant_id: impl Into<TenantId>) -> SyncTenantGuard {
172    let id = tenant_id.into();
173    let previous = SYNC_TENANT_ID.with(|cell| cell.replace(Some(id)));
174    SyncTenantGuard { previous }
175}
176
177/// Get the current sync tenant ID.
178#[inline]
179pub fn sync_tenant_id() -> Option<TenantId> {
180    SYNC_TENANT_ID.with(|cell| {
181        // SAFETY: We only read, not modify
182        unsafe { &*cell.as_ptr() }.clone()
183    })
184}
185
186/// Guard that resets the sync tenant when dropped.
187pub struct SyncTenantGuard {
188    previous: Option<TenantId>,
189}
190
191impl Drop for SyncTenantGuard {
192    fn drop(&mut self) {
193        SYNC_TENANT_ID.with(|cell| cell.set(self.previous.take()));
194    }
195}
196
197// ============================================================================
198// Scoped Guard (Alternative API)
199// ============================================================================
200
201/// A scoped tenant context that tracks whether it's been entered.
202///
203/// This provides an alternative to `with_tenant` for cases where you
204/// need more control over the scope.
205///
206/// # Example
207///
208/// ```rust,ignore
209/// use prax_query::tenant::task_local::TenantScope;
210///
211/// async fn handle_request(tenant_id: &str) {
212///     let scope = TenantScope::new(tenant_id);
213///
214///     scope.run(async {
215///         // tenant context active here
216///     }).await;
217/// }
218/// ```
219#[derive(Debug, Clone)]
220pub struct TenantScope {
221    context: TenantContext,
222}
223
224impl TenantScope {
225    /// Create a new tenant scope.
226    pub fn new(tenant_id: impl Into<TenantId>) -> Self {
227        Self {
228            context: TenantContext::new(tenant_id),
229        }
230    }
231
232    /// Create from a full context.
233    pub fn from_context(context: TenantContext) -> Self {
234        Self { context }
235    }
236
237    /// Get the tenant ID.
238    pub fn tenant_id(&self) -> &TenantId {
239        &self.context.id
240    }
241
242    /// Get the full context.
243    pub fn context(&self) -> &TenantContext {
244        &self.context
245    }
246
247    /// Run an async function within this tenant scope.
248    pub async fn run<F, T>(&self, f: F) -> T
249    where
250        F: Future<Output = T>,
251    {
252        TENANT_CONTEXT.scope(self.context.clone(), f).await
253    }
254
255    /// Run a sync closure within this tenant scope (thread-local).
256    pub fn run_sync<F, T>(&self, f: F) -> T
257    where
258        F: FnOnce() -> T,
259    {
260        let _guard = set_sync_tenant(self.context.id.clone());
261        f()
262    }
263}
264
265// ============================================================================
266// Middleware Integration
267// ============================================================================
268
269/// Extract tenant from various sources.
270pub trait TenantExtractor: Send + Sync {
271    /// Extract tenant ID from a request/context.
272    fn extract(&self, headers: &[(String, String)]) -> Option<TenantId>;
273}
274
275/// Extract tenant from a header.
276#[derive(Debug, Clone)]
277pub struct HeaderExtractor {
278    header_name: String,
279}
280
281impl HeaderExtractor {
282    /// Create a new header extractor.
283    pub fn new(header_name: impl Into<String>) -> Self {
284        Self {
285            header_name: header_name.into(),
286        }
287    }
288
289    /// Create with default header name "X-Tenant-ID".
290    pub fn default_header() -> Self {
291        Self::new("X-Tenant-ID")
292    }
293}
294
295impl TenantExtractor for HeaderExtractor {
296    fn extract(&self, headers: &[(String, String)]) -> Option<TenantId> {
297        headers
298            .iter()
299            .find(|(k, _)| k.eq_ignore_ascii_case(&self.header_name))
300            .map(|(_, v)| TenantId::new(v.clone()))
301    }
302}
303
304/// Extract tenant from a JWT claim.
305#[derive(Debug, Clone)]
306pub struct JwtClaimExtractor {
307    claim_name: String,
308}
309
310impl JwtClaimExtractor {
311    /// Create a new JWT claim extractor.
312    pub fn new(claim_name: impl Into<String>) -> Self {
313        Self {
314            claim_name: claim_name.into(),
315        }
316    }
317
318    /// Create with default claim name "tenant_id".
319    pub fn default_claim() -> Self {
320        Self::new("tenant_id")
321    }
322
323    /// Get the claim name.
324    pub fn claim_name(&self) -> &str {
325        &self.claim_name
326    }
327}
328
329impl TenantExtractor for JwtClaimExtractor {
330    fn extract(&self, _headers: &[(String, String)]) -> Option<TenantId> {
331        // JWT extraction would be implemented by the framework integration
332        // This is a placeholder that frameworks can override
333        None
334    }
335}
336
337/// Composite extractor that tries multiple sources.
338pub struct CompositeExtractor {
339    extractors: Vec<Box<dyn TenantExtractor>>,
340}
341
342impl CompositeExtractor {
343    /// Create a new composite extractor.
344    pub fn new() -> Self {
345        Self {
346            extractors: Vec::new(),
347        }
348    }
349
350    /// Add an extractor.
351    pub fn add<E: TenantExtractor + 'static>(mut self, extractor: E) -> Self {
352        self.extractors.push(Box::new(extractor));
353        self
354    }
355}
356
357impl Default for CompositeExtractor {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363impl TenantExtractor for CompositeExtractor {
364    fn extract(&self, headers: &[(String, String)]) -> Option<TenantId> {
365        for extractor in &self.extractors {
366            if let Some(id) = extractor.extract(headers) {
367                return Some(id);
368            }
369        }
370        None
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[tokio::test]
379    async fn test_with_tenant() {
380        let result = with_tenant("test-tenant", async {
381            current_tenant_id()
382        })
383        .await;
384
385        assert_eq!(result.unwrap().as_str(), "test-tenant");
386    }
387
388    #[tokio::test]
389    async fn test_no_tenant() {
390        assert!(current_tenant().is_none());
391        assert!(!has_tenant());
392    }
393
394    #[tokio::test]
395    async fn test_nested_tenant() {
396        with_tenant("outer", async {
397            assert_eq!(current_tenant_id().unwrap().as_str(), "outer");
398
399            with_tenant("inner", async {
400                assert_eq!(current_tenant_id().unwrap().as_str(), "inner");
401            })
402            .await;
403
404            // Should be back to outer
405            assert_eq!(current_tenant_id().unwrap().as_str(), "outer");
406        })
407        .await;
408    }
409
410    #[tokio::test]
411    async fn test_tenant_scope() {
412        let scope = TenantScope::new("scoped-tenant");
413
414        let result = scope
415            .run(async { current_tenant_id().map(|id| id.as_str().to_string()) })
416            .await;
417
418        assert_eq!(result, Some("scoped-tenant".to_string()));
419    }
420
421    #[test]
422    fn test_sync_tenant() {
423        {
424            let _guard = set_sync_tenant("sync-tenant");
425            assert_eq!(sync_tenant_id().unwrap().as_str(), "sync-tenant");
426        }
427
428        // Should be cleared after guard drop
429        assert!(sync_tenant_id().is_none());
430    }
431
432    #[test]
433    fn test_header_extractor() {
434        let extractor = HeaderExtractor::new("X-Tenant-ID");
435
436        let headers = vec![
437            ("Content-Type".to_string(), "application/json".to_string()),
438            ("X-Tenant-ID".to_string(), "tenant-from-header".to_string()),
439        ];
440
441        let id = extractor.extract(&headers);
442        assert_eq!(id.unwrap().as_str(), "tenant-from-header");
443    }
444
445    #[test]
446    fn test_composite_extractor() {
447        let extractor = CompositeExtractor::new()
448            .add(HeaderExtractor::new("X-Organization-ID"))
449            .add(HeaderExtractor::new("X-Tenant-ID"));
450
451        let headers = vec![("X-Tenant-ID".to_string(), "fallback-tenant".to_string())];
452
453        let id = extractor.extract(&headers);
454        assert_eq!(id.unwrap().as_str(), "fallback-tenant");
455    }
456}
457
458