Skip to main content

rok_orm_core/
tenant.rs

1//! Row-level multi-tenancy: task-local tenant ID + Tower middleware.
2
3/// Returns the current request's tenant ID from the task-local scope, if any.
4/// Always returns `None` when the `tenant` feature is disabled.
5pub fn current_tenant_id() -> Option<i64> {
6    #[cfg(feature = "tenant")]
7    return inner::CURRENT_TENANT_ID.try_with(|id| *id).ok();
8    #[cfg(not(feature = "tenant"))]
9    None
10}
11
12// ── Feature-gated implementation ──────────────────────────────────────────────
13
14#[cfg(feature = "tenant")]
15pub(crate) mod inner {
16    tokio::task_local! {
17        pub(crate) static CURRENT_TENANT_ID: i64;
18    }
19}
20
21#[cfg(feature = "tenant")]
22pub use middleware::{TenantLayer, TenantSource};
23
24#[cfg(feature = "tenant")]
25mod middleware {
26    use super::inner::CURRENT_TENANT_ID;
27    use std::{
28        future::Future,
29        pin::Pin,
30        task::{Context, Poll},
31    };
32    use tower::{Layer, Service};
33
34    /// How the `TenantLayer` extracts the tenant ID from each request.
35    #[derive(Debug, Clone)]
36    pub enum TenantSource {
37        /// Parse tenant ID from a request header value (e.g. `X-Tenant-ID: 42`).
38        Header(String),
39        /// Parse tenant ID from a query parameter (e.g. `?tenant=42`).
40        QueryParam(String),
41        /// Parse tenant ID from the first subdomain label (e.g. `42.app.example.com`).
42        Subdomain,
43    }
44
45    impl TenantSource {
46        fn extract<B>(&self, req: &http::Request<B>) -> Option<i64> {
47            match self {
48                TenantSource::Header(name) => req
49                    .headers()
50                    .get(name.as_str())
51                    .and_then(|v| v.to_str().ok())
52                    .and_then(|s| s.parse().ok()),
53
54                TenantSource::QueryParam(param) => req.uri().query().and_then(|q| {
55                    q.split('&').find_map(|pair| {
56                        let mut kv = pair.splitn(2, '=');
57                        (kv.next() == Some(param.as_str()))
58                            .then(|| kv.next().and_then(|v| v.parse().ok()))
59                            .flatten()
60                    })
61                }),
62
63                TenantSource::Subdomain => req
64                    .headers()
65                    .get("host")
66                    .and_then(|v| v.to_str().ok())
67                    .and_then(|host| host.split('.').next())
68                    .and_then(|sub| sub.parse().ok()),
69            }
70        }
71    }
72
73    /// Tower layer that extracts a tenant ID per-request and scopes it in
74    /// task-local storage so that `Model::query()` can auto-apply tenant filters.
75    #[derive(Clone)]
76    pub struct TenantLayer {
77        source: TenantSource,
78        default: Option<i64>,
79    }
80
81    impl TenantLayer {
82        /// Extract tenant ID from the named request header (e.g. `"X-Tenant-ID"`).
83        pub fn from_header(name: impl Into<String>) -> Self {
84            Self {
85                source: TenantSource::Header(name.into()),
86                default: None,
87            }
88        }
89
90        /// Extract tenant ID from the named query parameter (e.g. `"tenant"`).
91        pub fn from_query_param(name: impl Into<String>) -> Self {
92            Self {
93                source: TenantSource::QueryParam(name.into()),
94                default: None,
95            }
96        }
97
98        /// Extract tenant ID from the first subdomain label of the `Host` header.
99        pub fn from_subdomain() -> Self {
100            Self {
101                source: TenantSource::Subdomain,
102                default: None,
103            }
104        }
105
106        /// Fall back to `id` when the source produces no tenant.
107        pub fn with_default(mut self, id: i64) -> Self {
108            self.default = Some(id);
109            self
110        }
111    }
112
113    impl<S: Clone> Layer<S> for TenantLayer {
114        type Service = TenantService<S>;
115
116        fn layer(&self, inner: S) -> Self::Service {
117            TenantService {
118                inner,
119                source: self.source.clone(),
120                default: self.default,
121            }
122        }
123    }
124
125    #[derive(Clone)]
126    pub struct TenantService<S> {
127        inner: S,
128        source: TenantSource,
129        default: Option<i64>,
130    }
131
132    impl<S, ReqBody> Service<http::Request<ReqBody>> for TenantService<S>
133    where
134        S: Service<http::Request<ReqBody>> + Clone + Send + 'static,
135        S::Future: Send + 'static,
136        ReqBody: Send + 'static,
137    {
138        type Response = S::Response;
139        type Error = S::Error;
140        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
141
142        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143            self.inner.poll_ready(cx)
144        }
145
146        fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
147            let tenant_id = self.source.extract(&req).or(self.default);
148            let fut = self.inner.call(req);
149            Box::pin(async move {
150                match tenant_id {
151                    Some(id) => CURRENT_TENANT_ID.scope(id, fut).await,
152                    None => fut.await,
153                }
154            })
155        }
156    }
157}