1pub fn current_tenant_id() -> Option<i64> {
6 #[cfg(feature = "tenant")]
7 return inner::CURRENT_TENANT_ID.try_with(|id| *id).ok();
8 #[cfg(not(feature = "tenant"))]
9 None
10}
11
12#[cfg(feature = "tenant")]
15pub(crate) mod inner {
16 tokio::task_local! {
17 pub(crate) static CURRENT_TENANT_ID: i64;
18 }
19}
20
21#[cfg(feature = "tenant")]
22pub use middleware::{TenantLayer, TenantSource};
23
24#[cfg(feature = "tenant")]
25mod middleware {
26 use super::inner::CURRENT_TENANT_ID;
27 use std::{
28 future::Future,
29 pin::Pin,
30 task::{Context, Poll},
31 };
32 use tower::{Layer, Service};
33
34 #[derive(Debug, Clone)]
36 pub enum TenantSource {
37 Header(String),
39 QueryParam(String),
41 Subdomain,
43 }
44
45 impl TenantSource {
46 fn extract<B>(&self, req: &http::Request<B>) -> Option<i64> {
47 match self {
48 TenantSource::Header(name) => req
49 .headers()
50 .get(name.as_str())
51 .and_then(|v| v.to_str().ok())
52 .and_then(|s| s.parse().ok()),
53
54 TenantSource::QueryParam(param) => req.uri().query().and_then(|q| {
55 q.split('&').find_map(|pair| {
56 let mut kv = pair.splitn(2, '=');
57 (kv.next() == Some(param.as_str()))
58 .then(|| kv.next().and_then(|v| v.parse().ok()))
59 .flatten()
60 })
61 }),
62
63 TenantSource::Subdomain => req
64 .headers()
65 .get("host")
66 .and_then(|v| v.to_str().ok())
67 .and_then(|host| host.split('.').next())
68 .and_then(|sub| sub.parse().ok()),
69 }
70 }
71 }
72
73 #[derive(Clone)]
76 pub struct TenantLayer {
77 source: TenantSource,
78 default: Option<i64>,
79 }
80
81 impl TenantLayer {
82 pub fn from_header(name: impl Into<String>) -> Self {
84 Self {
85 source: TenantSource::Header(name.into()),
86 default: None,
87 }
88 }
89
90 pub fn from_query_param(name: impl Into<String>) -> Self {
92 Self {
93 source: TenantSource::QueryParam(name.into()),
94 default: None,
95 }
96 }
97
98 pub fn from_subdomain() -> Self {
100 Self {
101 source: TenantSource::Subdomain,
102 default: None,
103 }
104 }
105
106 pub fn with_default(mut self, id: i64) -> Self {
108 self.default = Some(id);
109 self
110 }
111 }
112
113 impl<S: Clone> Layer<S> for TenantLayer {
114 type Service = TenantService<S>;
115
116 fn layer(&self, inner: S) -> Self::Service {
117 TenantService {
118 inner,
119 source: self.source.clone(),
120 default: self.default,
121 }
122 }
123 }
124
125 #[derive(Clone)]
126 pub struct TenantService<S> {
127 inner: S,
128 source: TenantSource,
129 default: Option<i64>,
130 }
131
132 impl<S, ReqBody> Service<http::Request<ReqBody>> for TenantService<S>
133 where
134 S: Service<http::Request<ReqBody>> + Clone + Send + 'static,
135 S::Future: Send + 'static,
136 ReqBody: Send + 'static,
137 {
138 type Response = S::Response;
139 type Error = S::Error;
140 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
141
142 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143 self.inner.poll_ready(cx)
144 }
145
146 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
147 let tenant_id = self.source.extract(&req).or(self.default);
148 let fut = self.inner.call(req);
149 Box::pin(async move {
150 match tenant_id {
151 Some(id) => CURRENT_TENANT_ID.scope(id, fut).await,
152 None => fut.await,
153 }
154 })
155 }
156 }
157}