ferro_rs/tenant/
middleware.rs1use crate::http::{HttpResponse, Response};
19use crate::middleware::{Middleware, Next};
20use crate::tenant::context::{tenant_scope, with_tenant_scope};
21use crate::tenant::{TenantContext, TenantFailureMode};
22use crate::Request;
23use async_trait::async_trait;
24use serde_json::json;
25
26use super::resolver::TenantResolver;
27
28pub struct TenantMiddleware {
33 resolvers: Vec<Box<dyn TenantResolver>>,
34 on_failure: TenantFailureMode,
35}
36
37impl TenantMiddleware {
38 pub fn new() -> Self {
40 Self {
41 resolvers: Vec::new(),
42 on_failure: TenantFailureMode::NotFound,
43 }
44 }
45
46 pub fn resolver(mut self, resolver: impl TenantResolver + 'static) -> Self {
50 self.resolvers.push(Box::new(resolver));
51 self
52 }
53
54 pub fn on_failure(mut self, mode: TenantFailureMode) -> Self {
56 self.on_failure = mode;
57 self
58 }
59}
60
61impl Default for TenantMiddleware {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67#[async_trait]
68impl Middleware for TenantMiddleware {
69 async fn handle(&self, request: Request, next: Next) -> Response {
70 let mut resolved: Option<TenantContext> = None;
72 for resolver in &self.resolvers {
73 if let Some(ctx) = resolver.resolve(&request).await {
74 resolved = Some(ctx);
75 break;
76 }
77 }
78
79 match resolved {
80 Some(ctx) => {
81 let scope = tenant_scope();
83 {
84 let mut guard = scope.write().await;
85 *guard = Some(ctx);
86 }
87 with_tenant_scope(scope, next(request)).await
88 }
89 None => match &self.on_failure {
90 TenantFailureMode::NotFound => {
91 Err(HttpResponse::json(json!({"error": "Tenant not found"})).status(404))
92 }
93 TenantFailureMode::Forbidden => {
94 Err(HttpResponse::json(json!({"error": "Access denied"})).status(403))
95 }
96 TenantFailureMode::Allow => {
97 let scope = tenant_scope();
99 with_tenant_scope(scope, next(request)).await
100 }
101 TenantFailureMode::Custom(handler) => handler(),
102 },
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::http::HttpResponse;
111 use crate::tenant::context::current_tenant;
112 use crate::tenant::{TenantContext, TenantFailureMode};
113 use async_trait::async_trait;
114 use hyper_util::rt::TokioIo;
115 use std::sync::Arc;
116 use std::sync::Mutex;
117 use tokio::sync::oneshot;
118
119 fn make_tenant(slug: &str) -> TenantContext {
120 TenantContext {
121 id: 1,
122 slug: slug.to_string(),
123 name: "Test Corp".to_string(),
124 plan: None,
125 #[cfg(feature = "stripe")]
126 subscription: None,
127 }
128 }
129
130 async fn make_request_with_header(header_name: &str, header_value: &str) -> Request {
132 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
133 let addr = listener.local_addr().unwrap();
134 let (tx, rx) = oneshot::channel();
135 let tx_holder = Arc::new(Mutex::new(Some(tx)));
136
137 tokio::spawn(async move {
138 let (stream, _) = listener.accept().await.unwrap();
139 let io = TokioIo::new(stream);
140 let tx_holder = tx_holder.clone();
141 let service =
142 hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
143 let tx_holder = tx_holder.clone();
144 async move {
145 if let Some(tx) = tx_holder.lock().unwrap().take() {
146 let _ = tx.send(Request::new(req));
147 }
148 Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Empty::<
149 bytes::Bytes,
150 >::new(
151 )))
152 }
153 });
154 hyper::server::conn::http1::Builder::new()
155 .serve_connection(io, service)
156 .await
157 .ok();
158 });
159
160 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
161 let io = TokioIo::new(stream);
162 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
163 tokio::spawn(async move {
164 conn.await.ok();
165 });
166
167 let req = hyper::Request::builder()
168 .uri("/test")
169 .header(header_name, header_value)
170 .body(http_body_util::Empty::<bytes::Bytes>::new())
171 .unwrap();
172
173 let _ = sender.send_request(req).await;
174 rx.await.unwrap()
175 }
176
177 async fn make_request() -> Request {
178 make_request_with_header("x-test", "1").await
179 }
180
181 struct AlwaysResolver {
183 tenant: TenantContext,
184 }
185
186 #[async_trait]
187 impl TenantResolver for AlwaysResolver {
188 async fn resolve(&self, _req: &Request) -> Option<TenantContext> {
189 Some(self.tenant.clone())
190 }
191 }
192
193 struct NeverResolver;
195
196 #[async_trait]
197 impl TenantResolver for NeverResolver {
198 async fn resolve(&self, _req: &Request) -> Option<TenantContext> {
199 None
200 }
201 }
202
203 fn ok_next() -> Next {
204 Arc::new(|_req| {
205 Box::pin(async { Ok(HttpResponse::text("ok")) }) as crate::middleware::MiddlewareFuture
206 })
207 }
208
209 fn tenant_capture_next() -> Next {
211 Arc::new(|_req| {
212 Box::pin(async move {
213 let tenant = current_tenant();
214 match tenant {
215 Some(t) => Ok(HttpResponse::json(serde_json::json!({"slug": t.slug}))),
216 None => Ok(HttpResponse::json(serde_json::json!({"slug": null}))),
217 }
218 }) as crate::middleware::MiddlewareFuture
219 })
220 }
221
222 #[test]
224 fn new_creates_empty_instance_with_not_found_default() {
225 let mw = TenantMiddleware::new();
226 assert!(mw.resolvers.is_empty());
227 assert!(matches!(mw.on_failure, TenantFailureMode::NotFound));
228 }
229
230 #[test]
232 fn resolver_adds_to_chain() {
233 let mw = TenantMiddleware::new().resolver(NeverResolver);
234 assert_eq!(mw.resolvers.len(), 1);
235 }
236
237 #[test]
239 fn on_failure_sets_mode() {
240 let mw = TenantMiddleware::new().on_failure(TenantFailureMode::Allow);
241 assert!(matches!(mw.on_failure, TenantFailureMode::Allow));
242 }
243
244 #[tokio::test]
246 async fn resolves_tenant_and_stores_in_task_local() {
247 let mw = TenantMiddleware::new()
248 .resolver(AlwaysResolver {
249 tenant: make_tenant("acme"),
250 })
251 .on_failure(TenantFailureMode::NotFound);
252
253 let req = make_request().await;
254 let next = tenant_capture_next();
255 let resp = mw.handle(req, next).await.unwrap();
256
257 let json: serde_json::Value = serde_json::from_str(resp.body()).unwrap();
258 assert_eq!(json["slug"], "acme");
259 }
260
261 #[tokio::test]
263 async fn tries_resolvers_in_order_first_some_wins() {
264 let mw = TenantMiddleware::new()
265 .resolver(NeverResolver)
266 .resolver(AlwaysResolver {
267 tenant: make_tenant("beta"),
268 })
269 .resolver(AlwaysResolver {
270 tenant: make_tenant("gamma"),
271 });
272
273 let req = make_request().await;
274 let next = tenant_capture_next();
275 let response = mw.handle(req, next).await.unwrap();
276
277 let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
278 assert_eq!(json["slug"], "beta");
279 }
280
281 #[tokio::test]
283 async fn no_match_not_found_returns_404() {
284 let mw = TenantMiddleware::new()
285 .resolver(NeverResolver)
286 .on_failure(TenantFailureMode::NotFound);
287
288 let req = make_request().await;
289 let err = mw.handle(req, ok_next()).await.unwrap_err();
290
291 assert_eq!(err.status_code(), 404);
292 let json: serde_json::Value = serde_json::from_str(err.body()).unwrap();
293 assert_eq!(json["error"], "Tenant not found");
294 }
295
296 #[tokio::test]
298 async fn no_match_forbidden_returns_403() {
299 let mw = TenantMiddleware::new()
300 .resolver(NeverResolver)
301 .on_failure(TenantFailureMode::Forbidden);
302
303 let req = make_request().await;
304 let err = mw.handle(req, ok_next()).await.unwrap_err();
305
306 assert_eq!(err.status_code(), 403);
307 let json: serde_json::Value = serde_json::from_str(err.body()).unwrap();
308 assert_eq!(json["error"], "Access denied");
309 }
310
311 #[tokio::test]
313 async fn no_match_allow_continues_with_none() {
314 let mw = TenantMiddleware::new()
315 .resolver(NeverResolver)
316 .on_failure(TenantFailureMode::Allow);
317
318 let req = make_request().await;
319 let next = tenant_capture_next();
320 let response = mw.handle(req, next).await.unwrap();
321
322 let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
323 assert!(json["slug"].is_null());
324 }
325
326 #[tokio::test]
328 async fn current_tenant_available_in_downstream_handler() {
329 let mw = TenantMiddleware::new().resolver(AlwaysResolver {
330 tenant: make_tenant("downstream-test"),
331 });
332
333 let req = make_request().await;
334 let next = tenant_capture_next();
335 let response = mw.handle(req, next).await.unwrap();
336
337 let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
338 assert_eq!(json["slug"], "downstream-test");
339 }
340}