Skip to main content

modo/tenant/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use tower::{Layer, Service};
10
11use super::traits::{HasTenantId, TenantResolver, TenantStrategy};
12
13/// Creates a tenant middleware layer from a strategy and resolver.
14///
15/// This is the primary entry point for wiring tenant resolution into a router.
16/// The returned [`TenantLayer`] should be applied with `.layer()` for all
17/// strategies except [`crate::tenant::PathParamStrategy`], which requires
18/// `.route_layer()`.
19pub fn middleware<S, R>(strategy: S, resolver: R) -> TenantLayer<S, R>
20where
21    S: TenantStrategy,
22    R: TenantResolver,
23{
24    TenantLayer::new(strategy, resolver)
25}
26
27/// Tower [`Layer`] that wraps an inner service with tenant resolution.
28///
29/// Produced by [`middleware`]. Apply to a router with `.layer()` for all
30/// strategies except [`crate::tenant::PathParamStrategy`], which requires
31/// `.route_layer()`.
32pub struct TenantLayer<S, R> {
33    strategy: Arc<S>,
34    resolver: Arc<R>,
35}
36
37impl<S, R> Clone for TenantLayer<S, R> {
38    fn clone(&self) -> Self {
39        Self {
40            strategy: self.strategy.clone(),
41            resolver: self.resolver.clone(),
42        }
43    }
44}
45
46impl<S, R> TenantLayer<S, R> {
47    /// Creates a new `TenantLayer` wrapping the given strategy and resolver.
48    pub fn new(strategy: S, resolver: R) -> Self {
49        Self {
50            strategy: Arc::new(strategy),
51            resolver: Arc::new(resolver),
52        }
53    }
54}
55
56impl<Svc, S, R> Layer<Svc> for TenantLayer<S, R>
57where
58    S: TenantStrategy,
59    R: TenantResolver,
60{
61    type Service = TenantMiddleware<Svc, S, R>;
62
63    fn layer(&self, inner: Svc) -> Self::Service {
64        TenantMiddleware {
65            inner,
66            strategy: self.strategy.clone(),
67            resolver: self.resolver.clone(),
68        }
69    }
70}
71
72/// Tower [`Service`] that resolves the tenant on every request.
73///
74/// On each request this service:
75/// 1. Calls the [`TenantStrategy`] to extract a [`crate::tenant::TenantId`].
76/// 2. Calls the [`TenantResolver`] to obtain the concrete tenant value.
77/// 3. Records `tenant_id` in the current tracing span via `Span::current().record()`.
78/// 4. Inserts the resolved tenant as `Arc<T>` into request extensions.
79///
80/// For step 3 to take effect the enclosing tracing span must declare
81/// `tenant_id = tracing::field::Empty` — spans created without that field
82/// silently ignore the `record()` call.
83///
84/// Errors at either step are converted to HTTP responses via [`IntoResponse`]
85/// and returned immediately without calling the inner service.
86pub struct TenantMiddleware<Svc, S, R> {
87    inner: Svc,
88    strategy: Arc<S>,
89    resolver: Arc<R>,
90}
91
92impl<Svc: Clone, S, R> Clone for TenantMiddleware<Svc, S, R> {
93    fn clone(&self) -> Self {
94        Self {
95            inner: self.inner.clone(),
96            strategy: self.strategy.clone(),
97            resolver: self.resolver.clone(),
98        }
99    }
100}
101
102impl<Svc, S, R> Service<Request<Body>> for TenantMiddleware<Svc, S, R>
103where
104    Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
105    Svc::Future: Send + 'static,
106    Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
107    S: TenantStrategy,
108    R: TenantResolver,
109{
110    type Response = http::Response<Body>;
111    type Error = Svc::Error;
112    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
113
114    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        self.inner.poll_ready(cx)
116    }
117
118    fn call(&mut self, request: Request<Body>) -> Self::Future {
119        let strategy = self.strategy.clone();
120        let resolver = self.resolver.clone();
121        let mut inner = self.inner.clone();
122        std::mem::swap(&mut self.inner, &mut inner);
123
124        Box::pin(async move {
125            let (mut parts, body) = request.into_parts();
126
127            // Step 1: Extract tenant identifier
128            let tenant_id = match strategy.extract(&mut parts) {
129                Ok(id) => id,
130                Err(e) => return Ok(e.into_response()),
131            };
132
133            // Step 2: Resolve tenant
134            let tenant = match resolver.resolve(&tenant_id).await {
135                Ok(t) => t,
136                Err(e) => return Ok(e.into_response()),
137            };
138
139            // Step 3: Record tenant_id in tracing span
140            tracing::Span::current().record("tenant_id", tenant.tenant_id());
141
142            // Step 4: Insert into extensions
143            let tenant = Arc::new(tenant);
144            parts.extensions.insert(tenant);
145
146            let request = Request::from_parts(parts, body);
147            inner.call(request).await
148        })
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use axum::body::Body;
156    use http::{Request, Response, StatusCode};
157    use std::convert::Infallible;
158    use tower::ServiceExt;
159
160    use crate::error::Error;
161    use crate::tenant::TenantId;
162
163    #[derive(Clone, Debug)]
164    struct TestTenant {
165        id: String,
166    }
167
168    impl HasTenantId for TestTenant {
169        fn tenant_id(&self) -> &str {
170            &self.id
171        }
172    }
173
174    struct OkStrategy;
175    impl TenantStrategy for OkStrategy {
176        fn extract(&self, _parts: &mut http::request::Parts) -> crate::Result<TenantId> {
177            Ok(TenantId::Slug("acme".into()))
178        }
179    }
180
181    struct FailStrategy;
182    impl TenantStrategy for FailStrategy {
183        fn extract(&self, _parts: &mut http::request::Parts) -> crate::Result<TenantId> {
184            Err(Error::bad_request("no tenant"))
185        }
186    }
187
188    struct OkResolver;
189    impl TenantResolver for OkResolver {
190        type Tenant = TestTenant;
191        async fn resolve(&self, _id: &TenantId) -> crate::Result<TestTenant> {
192            Ok(TestTenant { id: "t1".into() })
193        }
194    }
195
196    struct NotFoundResolver;
197    impl TenantResolver for NotFoundResolver {
198        type Tenant = TestTenant;
199        async fn resolve(&self, _id: &TenantId) -> crate::Result<TestTenant> {
200            Err(Error::not_found("tenant not found"))
201        }
202    }
203
204    struct InternalErrorResolver;
205    impl TenantResolver for InternalErrorResolver {
206        type Tenant = TestTenant;
207        async fn resolve(&self, _id: &TenantId) -> crate::Result<TestTenant> {
208            Err(Error::internal("db failure"))
209        }
210    }
211
212    /// Inner service that checks extensions for resolved tenant.
213    async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
214        let has_tenant = req.extensions().get::<Arc<TestTenant>>().is_some();
215        let body = if has_tenant { "ok" } else { "no-tenant" };
216        Ok(Response::new(Body::from(body)))
217    }
218
219    #[tokio::test]
220    async fn strategy_ok_resolver_ok_passes_through() {
221        let layer = TenantLayer::new(OkStrategy, OkResolver);
222        let svc = layer.layer(tower::service_fn(echo_handler));
223
224        let req = Request::builder().body(Body::empty()).unwrap();
225        let resp = svc.oneshot(req).await.unwrap();
226        assert_eq!(resp.status(), StatusCode::OK);
227    }
228
229    #[tokio::test]
230    async fn strategy_fail_returns_400() {
231        let layer = TenantLayer::new(FailStrategy, OkResolver);
232        let svc = layer.layer(tower::service_fn(echo_handler));
233
234        let req = Request::builder().body(Body::empty()).unwrap();
235        let resp = svc.oneshot(req).await.unwrap();
236        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
237    }
238
239    #[tokio::test]
240    async fn resolver_not_found_returns_404() {
241        let layer = TenantLayer::new(OkStrategy, NotFoundResolver);
242        let svc = layer.layer(tower::service_fn(echo_handler));
243
244        let req = Request::builder().body(Body::empty()).unwrap();
245        let resp = svc.oneshot(req).await.unwrap();
246        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
247    }
248
249    #[tokio::test]
250    async fn resolver_internal_error_returns_500() {
251        let layer = TenantLayer::new(OkStrategy, InternalErrorResolver);
252        let svc = layer.layer(tower::service_fn(echo_handler));
253
254        let req = Request::builder().body(Body::empty()).unwrap();
255        let resp = svc.oneshot(req).await.unwrap();
256        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
257    }
258
259    #[tokio::test]
260    async fn strategy_fail_does_not_call_inner() {
261        use std::sync::atomic::{AtomicBool, Ordering};
262
263        let called = Arc::new(AtomicBool::new(false));
264        let called_clone = called.clone();
265
266        let layer = TenantLayer::new(FailStrategy, OkResolver);
267        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
268            let called = called_clone.clone();
269            async move {
270                called.store(true, Ordering::SeqCst);
271                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
272            }
273        }));
274
275        let req = Request::builder().body(Body::empty()).unwrap();
276        let resp = svc.oneshot(req).await.unwrap();
277        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
278        assert!(!called.load(Ordering::SeqCst));
279    }
280
281    #[tokio::test]
282    async fn resolver_fail_does_not_call_inner() {
283        use std::sync::atomic::{AtomicBool, Ordering};
284
285        let called = Arc::new(AtomicBool::new(false));
286        let called_clone = called.clone();
287
288        let layer = TenantLayer::new(OkStrategy, NotFoundResolver);
289        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
290            let called = called_clone.clone();
291            async move {
292                called.store(true, Ordering::SeqCst);
293                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
294            }
295        }));
296
297        let req = Request::builder().body(Body::empty()).unwrap();
298        let resp = svc.oneshot(req).await.unwrap();
299        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
300        assert!(!called.load(Ordering::SeqCst));
301    }
302
303    #[tokio::test]
304    async fn tenant_in_extensions_after_resolve() {
305        let layer = TenantLayer::new(OkStrategy, OkResolver);
306
307        // Custom inner service that asserts tenant is in extensions
308        let inner = tower::service_fn(|req: Request<Body>| async move {
309            let tenant = req.extensions().get::<Arc<TestTenant>>().unwrap();
310            assert_eq!(tenant.id, "t1");
311            Ok::<_, Infallible>(Response::new(Body::empty()))
312        });
313
314        let svc = layer.layer(inner);
315        let req = Request::builder().body(Body::empty()).unwrap();
316        let resp = svc.oneshot(req).await.unwrap();
317        assert_eq!(resp.status(), StatusCode::OK);
318    }
319}