Skip to main content

ferro_rs/tenant/
middleware.rs

1//! TenantMiddleware for Ferro framework.
2//!
3//! Resolves the current tenant from a request using a configurable chain of
4//! [`TenantResolver`] strategies, stores the result in task-local context, and
5//! either continues the request or returns a 404/403 error based on
6//! [`TenantFailureMode`].
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use ferro_rs::tenant::{TenantMiddleware, TenantFailureMode, SubdomainResolver};
12//!
13//! let middleware = TenantMiddleware::new()
14//!     .resolver(SubdomainResolver::new(2, lookup.clone()))
15//!     .on_failure(TenantFailureMode::NotFound);
16//! ```
17
18use crate::http::{HttpResponse, Response};
19use crate::middleware::{Middleware, Next};
20use crate::tenant::context::{tenant_scope, with_tenant_scope};
21use crate::tenant::{TenantContext, TenantFailureMode};
22use crate::Request;
23use async_trait::async_trait;
24use serde_json::json;
25
26use super::resolver::TenantResolver;
27
28/// Middleware that resolves the current tenant and stores it in task-local context.
29///
30/// Resolvers are tried in order; the first `Some` result wins. If no resolver
31/// matches, the failure mode determines the response.
32pub struct TenantMiddleware {
33    resolvers: Vec<Box<dyn TenantResolver>>,
34    on_failure: TenantFailureMode,
35}
36
37impl TenantMiddleware {
38    /// Create a new `TenantMiddleware` with no resolvers and `NotFound` failure mode.
39    pub fn new() -> Self {
40        Self {
41            resolvers: Vec::new(),
42            on_failure: TenantFailureMode::NotFound,
43        }
44    }
45
46    /// Add a resolver to the chain (consuming builder).
47    ///
48    /// Resolvers are tried in the order they were added. The first `Some` result wins.
49    pub fn resolver(mut self, resolver: impl TenantResolver + 'static) -> Self {
50        self.resolvers.push(Box::new(resolver));
51        self
52    }
53
54    /// Set the failure mode when no resolver matches (consuming builder).
55    pub fn on_failure(mut self, mode: TenantFailureMode) -> Self {
56        self.on_failure = mode;
57        self
58    }
59}
60
61impl Default for TenantMiddleware {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67#[async_trait]
68impl Middleware for TenantMiddleware {
69    async fn handle(&self, request: Request, next: Next) -> Response {
70        // Try resolvers in order, first Some wins.
71        let mut resolved: Option<TenantContext> = None;
72        for resolver in &self.resolvers {
73            if let Some(ctx) = resolver.resolve(&request).await {
74                resolved = Some(ctx);
75                break;
76            }
77        }
78
79        match resolved {
80            Some(ctx) => {
81                // Store tenant in task-local context for the downstream chain.
82                let scope = tenant_scope();
83                {
84                    let mut guard = scope.write().await;
85                    *guard = Some(ctx);
86                }
87                with_tenant_scope(scope, next(request)).await
88            }
89            None => match &self.on_failure {
90                TenantFailureMode::NotFound => {
91                    Err(HttpResponse::json(json!({"error": "Tenant not found"})).status(404))
92                }
93                TenantFailureMode::Forbidden => {
94                    Err(HttpResponse::json(json!({"error": "Access denied"})).status(403))
95                }
96                TenantFailureMode::Allow => {
97                    // No tenant — continue with None in context.
98                    let scope = tenant_scope();
99                    with_tenant_scope(scope, next(request)).await
100                }
101                TenantFailureMode::Custom(handler) => handler(),
102            },
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::http::HttpResponse;
111    use crate::tenant::context::current_tenant;
112    use crate::tenant::{TenantContext, TenantFailureMode};
113    use async_trait::async_trait;
114    use hyper_util::rt::TokioIo;
115    use std::sync::Arc;
116    use std::sync::Mutex;
117    use tokio::sync::oneshot;
118
119    fn make_tenant(slug: &str) -> TenantContext {
120        TenantContext {
121            id: 1,
122            slug: slug.to_string(),
123            name: "Test Corp".to_string(),
124            plan: None,
125            #[cfg(feature = "stripe")]
126            subscription: None,
127        }
128    }
129
130    /// Create a test Request via TCP loopback with optional headers.
131    async fn make_request_with_header(header_name: &str, header_value: &str) -> Request {
132        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
133        let addr = listener.local_addr().unwrap();
134        let (tx, rx) = oneshot::channel();
135        let tx_holder = Arc::new(Mutex::new(Some(tx)));
136
137        tokio::spawn(async move {
138            let (stream, _) = listener.accept().await.unwrap();
139            let io = TokioIo::new(stream);
140            let tx_holder = tx_holder.clone();
141            let service =
142                hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
143                    let tx_holder = tx_holder.clone();
144                    async move {
145                        if let Some(tx) = tx_holder.lock().unwrap().take() {
146                            let _ = tx.send(Request::new(req));
147                        }
148                        Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Empty::<
149                            bytes::Bytes,
150                        >::new(
151                        )))
152                    }
153                });
154            hyper::server::conn::http1::Builder::new()
155                .serve_connection(io, service)
156                .await
157                .ok();
158        });
159
160        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
161        let io = TokioIo::new(stream);
162        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
163        tokio::spawn(async move {
164            conn.await.ok();
165        });
166
167        let req = hyper::Request::builder()
168            .uri("/test")
169            .header(header_name, header_value)
170            .body(http_body_util::Empty::<bytes::Bytes>::new())
171            .unwrap();
172
173        let _ = sender.send_request(req).await;
174        rx.await.unwrap()
175    }
176
177    async fn make_request() -> Request {
178        make_request_with_header("x-test", "1").await
179    }
180
181    /// Mock resolver that always returns a fixed TenantContext.
182    struct AlwaysResolver {
183        tenant: TenantContext,
184    }
185
186    #[async_trait]
187    impl TenantResolver for AlwaysResolver {
188        async fn resolve(&self, _req: &Request) -> Option<TenantContext> {
189            Some(self.tenant.clone())
190        }
191    }
192
193    /// Mock resolver that always returns None.
194    struct NeverResolver;
195
196    #[async_trait]
197    impl TenantResolver for NeverResolver {
198        async fn resolve(&self, _req: &Request) -> Option<TenantContext> {
199            None
200        }
201    }
202
203    fn ok_next() -> Next {
204        Arc::new(|_req| {
205            Box::pin(async { Ok(HttpResponse::text("ok")) }) as crate::middleware::MiddlewareFuture
206        })
207    }
208
209    /// Next that captures current_tenant() and returns it as JSON.
210    fn tenant_capture_next() -> Next {
211        Arc::new(|_req| {
212            Box::pin(async move {
213                let tenant = current_tenant();
214                match tenant {
215                    Some(t) => Ok(HttpResponse::json(serde_json::json!({"slug": t.slug}))),
216                    None => Ok(HttpResponse::json(serde_json::json!({"slug": null}))),
217                }
218            }) as crate::middleware::MiddlewareFuture
219        })
220    }
221
222    // Test 1: new() creates instance with empty resolver vec and NotFound default
223    #[test]
224    fn new_creates_empty_instance_with_not_found_default() {
225        let mw = TenantMiddleware::new();
226        assert!(mw.resolvers.is_empty());
227        assert!(matches!(mw.on_failure, TenantFailureMode::NotFound));
228    }
229
230    // Test 2: .resolver(r) adds a resolver to the chain (consuming builder)
231    #[test]
232    fn resolver_adds_to_chain() {
233        let mw = TenantMiddleware::new().resolver(NeverResolver);
234        assert_eq!(mw.resolvers.len(), 1);
235    }
236
237    // Test 3: .on_failure(mode) sets the failure mode (consuming builder)
238    #[test]
239    fn on_failure_sets_mode() {
240        let mw = TenantMiddleware::new().on_failure(TenantFailureMode::Allow);
241        assert!(matches!(mw.on_failure, TenantFailureMode::Allow));
242    }
243
244    // Test 4: Middleware resolves tenant from first matching resolver and stores in task-local
245    #[tokio::test]
246    async fn resolves_tenant_and_stores_in_task_local() {
247        let mw = TenantMiddleware::new()
248            .resolver(AlwaysResolver {
249                tenant: make_tenant("acme"),
250            })
251            .on_failure(TenantFailureMode::NotFound);
252
253        let req = make_request().await;
254        let next = tenant_capture_next();
255        let resp = mw.handle(req, next).await.unwrap();
256
257        let json: serde_json::Value = serde_json::from_str(resp.body()).unwrap();
258        assert_eq!(json["slug"], "acme");
259    }
260
261    // Test 5: Middleware tries resolvers in order, first Some wins
262    #[tokio::test]
263    async fn tries_resolvers_in_order_first_some_wins() {
264        let mw = TenantMiddleware::new()
265            .resolver(NeverResolver)
266            .resolver(AlwaysResolver {
267                tenant: make_tenant("beta"),
268            })
269            .resolver(AlwaysResolver {
270                tenant: make_tenant("gamma"),
271            });
272
273        let req = make_request().await;
274        let next = tenant_capture_next();
275        let response = mw.handle(req, next).await.unwrap();
276
277        let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
278        assert_eq!(json["slug"], "beta");
279    }
280
281    // Test 6: When no resolver matches and on_failure=NotFound, returns 404 JSON
282    #[tokio::test]
283    async fn no_match_not_found_returns_404() {
284        let mw = TenantMiddleware::new()
285            .resolver(NeverResolver)
286            .on_failure(TenantFailureMode::NotFound);
287
288        let req = make_request().await;
289        let err = mw.handle(req, ok_next()).await.unwrap_err();
290
291        assert_eq!(err.status_code(), 404);
292        let json: serde_json::Value = serde_json::from_str(err.body()).unwrap();
293        assert_eq!(json["error"], "Tenant not found");
294    }
295
296    // Test 7: When no resolver matches and on_failure=Forbidden, returns 403 JSON
297    #[tokio::test]
298    async fn no_match_forbidden_returns_403() {
299        let mw = TenantMiddleware::new()
300            .resolver(NeverResolver)
301            .on_failure(TenantFailureMode::Forbidden);
302
303        let req = make_request().await;
304        let err = mw.handle(req, ok_next()).await.unwrap_err();
305
306        assert_eq!(err.status_code(), 403);
307        let json: serde_json::Value = serde_json::from_str(err.body()).unwrap();
308        assert_eq!(json["error"], "Access denied");
309    }
310
311    // Test 8: When no resolver matches and on_failure=Allow, request continues with current_tenant()=None
312    #[tokio::test]
313    async fn no_match_allow_continues_with_none() {
314        let mw = TenantMiddleware::new()
315            .resolver(NeverResolver)
316            .on_failure(TenantFailureMode::Allow);
317
318        let req = make_request().await;
319        let next = tenant_capture_next();
320        let response = mw.handle(req, next).await.unwrap();
321
322        let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
323        assert!(json["slug"].is_null());
324    }
325
326    // Test 9: current_tenant() returns the resolved TenantContext during downstream handler execution
327    #[tokio::test]
328    async fn current_tenant_available_in_downstream_handler() {
329        let mw = TenantMiddleware::new().resolver(AlwaysResolver {
330            tenant: make_tenant("downstream-test"),
331        });
332
333        let req = make_request().await;
334        let next = tenant_capture_next();
335        let response = mw.handle(req, next).await.unwrap();
336
337        let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
338        assert_eq!(json["slug"], "downstream-test");
339    }
340}