1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use tower::{Layer, Service};
10
11use super::traits::{HasTenantId, TenantResolver, TenantStrategy};
12
13pub fn middleware<S, R>(strategy: S, resolver: R) -> TenantLayer<S, R>
20where
21 S: TenantStrategy,
22 R: TenantResolver,
23{
24 TenantLayer::new(strategy, resolver)
25}
26
27pub struct TenantLayer<S, R> {
33 strategy: Arc<S>,
34 resolver: Arc<R>,
35}
36
37impl<S, R> Clone for TenantLayer<S, R> {
38 fn clone(&self) -> Self {
39 Self {
40 strategy: self.strategy.clone(),
41 resolver: self.resolver.clone(),
42 }
43 }
44}
45
46impl<S, R> TenantLayer<S, R> {
47 pub fn new(strategy: S, resolver: R) -> Self {
49 Self {
50 strategy: Arc::new(strategy),
51 resolver: Arc::new(resolver),
52 }
53 }
54}
55
56impl<Svc, S, R> Layer<Svc> for TenantLayer<S, R>
57where
58 S: TenantStrategy,
59 R: TenantResolver,
60{
61 type Service = TenantMiddleware<Svc, S, R>;
62
63 fn layer(&self, inner: Svc) -> Self::Service {
64 TenantMiddleware {
65 inner,
66 strategy: self.strategy.clone(),
67 resolver: self.resolver.clone(),
68 }
69 }
70}
71
72pub struct TenantMiddleware<Svc, S, R> {
87 inner: Svc,
88 strategy: Arc<S>,
89 resolver: Arc<R>,
90}
91
92impl<Svc: Clone, S, R> Clone for TenantMiddleware<Svc, S, R> {
93 fn clone(&self) -> Self {
94 Self {
95 inner: self.inner.clone(),
96 strategy: self.strategy.clone(),
97 resolver: self.resolver.clone(),
98 }
99 }
100}
101
102impl<Svc, S, R> Service<Request<Body>> for TenantMiddleware<Svc, S, R>
103where
104 Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
105 Svc::Future: Send + 'static,
106 Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
107 S: TenantStrategy,
108 R: TenantResolver,
109{
110 type Response = http::Response<Body>;
111 type Error = Svc::Error;
112 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
113
114 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115 self.inner.poll_ready(cx)
116 }
117
118 fn call(&mut self, request: Request<Body>) -> Self::Future {
119 let strategy = self.strategy.clone();
120 let resolver = self.resolver.clone();
121 let mut inner = self.inner.clone();
122 std::mem::swap(&mut self.inner, &mut inner);
123
124 Box::pin(async move {
125 let (mut parts, body) = request.into_parts();
126
127 let tenant_id = match strategy.extract(&mut parts) {
129 Ok(id) => id,
130 Err(e) => return Ok(e.into_response()),
131 };
132
133 let tenant = match resolver.resolve(&tenant_id).await {
135 Ok(t) => t,
136 Err(e) => return Ok(e.into_response()),
137 };
138
139 tracing::Span::current().record("tenant_id", tenant.tenant_id());
141
142 let tenant = Arc::new(tenant);
144 parts.extensions.insert(tenant);
145
146 let request = Request::from_parts(parts, body);
147 inner.call(request).await
148 })
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use axum::body::Body;
156 use http::{Request, Response, StatusCode};
157 use std::convert::Infallible;
158 use tower::ServiceExt;
159
160 use crate::error::Error;
161 use crate::tenant::TenantId;
162
163 #[derive(Clone, Debug)]
164 struct TestTenant {
165 id: String,
166 }
167
168 impl HasTenantId for TestTenant {
169 fn tenant_id(&self) -> &str {
170 &self.id
171 }
172 }
173
174 struct OkStrategy;
175 impl TenantStrategy for OkStrategy {
176 fn extract(&self, _parts: &mut http::request::Parts) -> crate::Result<TenantId> {
177 Ok(TenantId::Slug("acme".into()))
178 }
179 }
180
181 struct FailStrategy;
182 impl TenantStrategy for FailStrategy {
183 fn extract(&self, _parts: &mut http::request::Parts) -> crate::Result<TenantId> {
184 Err(Error::bad_request("no tenant"))
185 }
186 }
187
188 struct OkResolver;
189 impl TenantResolver for OkResolver {
190 type Tenant = TestTenant;
191 async fn resolve(&self, _id: &TenantId) -> crate::Result<TestTenant> {
192 Ok(TestTenant { id: "t1".into() })
193 }
194 }
195
196 struct NotFoundResolver;
197 impl TenantResolver for NotFoundResolver {
198 type Tenant = TestTenant;
199 async fn resolve(&self, _id: &TenantId) -> crate::Result<TestTenant> {
200 Err(Error::not_found("tenant not found"))
201 }
202 }
203
204 struct InternalErrorResolver;
205 impl TenantResolver for InternalErrorResolver {
206 type Tenant = TestTenant;
207 async fn resolve(&self, _id: &TenantId) -> crate::Result<TestTenant> {
208 Err(Error::internal("db failure"))
209 }
210 }
211
212 async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
214 let has_tenant = req.extensions().get::<Arc<TestTenant>>().is_some();
215 let body = if has_tenant { "ok" } else { "no-tenant" };
216 Ok(Response::new(Body::from(body)))
217 }
218
219 #[tokio::test]
220 async fn strategy_ok_resolver_ok_passes_through() {
221 let layer = TenantLayer::new(OkStrategy, OkResolver);
222 let svc = layer.layer(tower::service_fn(echo_handler));
223
224 let req = Request::builder().body(Body::empty()).unwrap();
225 let resp = svc.oneshot(req).await.unwrap();
226 assert_eq!(resp.status(), StatusCode::OK);
227 }
228
229 #[tokio::test]
230 async fn strategy_fail_returns_400() {
231 let layer = TenantLayer::new(FailStrategy, OkResolver);
232 let svc = layer.layer(tower::service_fn(echo_handler));
233
234 let req = Request::builder().body(Body::empty()).unwrap();
235 let resp = svc.oneshot(req).await.unwrap();
236 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
237 }
238
239 #[tokio::test]
240 async fn resolver_not_found_returns_404() {
241 let layer = TenantLayer::new(OkStrategy, NotFoundResolver);
242 let svc = layer.layer(tower::service_fn(echo_handler));
243
244 let req = Request::builder().body(Body::empty()).unwrap();
245 let resp = svc.oneshot(req).await.unwrap();
246 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
247 }
248
249 #[tokio::test]
250 async fn resolver_internal_error_returns_500() {
251 let layer = TenantLayer::new(OkStrategy, InternalErrorResolver);
252 let svc = layer.layer(tower::service_fn(echo_handler));
253
254 let req = Request::builder().body(Body::empty()).unwrap();
255 let resp = svc.oneshot(req).await.unwrap();
256 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
257 }
258
259 #[tokio::test]
260 async fn strategy_fail_does_not_call_inner() {
261 use std::sync::atomic::{AtomicBool, Ordering};
262
263 let called = Arc::new(AtomicBool::new(false));
264 let called_clone = called.clone();
265
266 let layer = TenantLayer::new(FailStrategy, OkResolver);
267 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
268 let called = called_clone.clone();
269 async move {
270 called.store(true, Ordering::SeqCst);
271 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
272 }
273 }));
274
275 let req = Request::builder().body(Body::empty()).unwrap();
276 let resp = svc.oneshot(req).await.unwrap();
277 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
278 assert!(!called.load(Ordering::SeqCst));
279 }
280
281 #[tokio::test]
282 async fn resolver_fail_does_not_call_inner() {
283 use std::sync::atomic::{AtomicBool, Ordering};
284
285 let called = Arc::new(AtomicBool::new(false));
286 let called_clone = called.clone();
287
288 let layer = TenantLayer::new(OkStrategy, NotFoundResolver);
289 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
290 let called = called_clone.clone();
291 async move {
292 called.store(true, Ordering::SeqCst);
293 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
294 }
295 }));
296
297 let req = Request::builder().body(Body::empty()).unwrap();
298 let resp = svc.oneshot(req).await.unwrap();
299 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
300 assert!(!called.load(Ordering::SeqCst));
301 }
302
303 #[tokio::test]
304 async fn tenant_in_extensions_after_resolve() {
305 let layer = TenantLayer::new(OkStrategy, OkResolver);
306
307 let inner = tower::service_fn(|req: Request<Body>| async move {
309 let tenant = req.extensions().get::<Arc<TestTenant>>().unwrap();
310 assert_eq!(tenant.id, "t1");
311 Ok::<_, Infallible>(Response::new(Body::empty()))
312 });
313
314 let svc = layer.layer(inner);
315 let req = Request::builder().body(Body::empty()).unwrap();
316 let resp = svc.oneshot(req).await.unwrap();
317 assert_eq!(resp.status(), StatusCode::OK);
318 }
319}