1use crate::tenant::TenantContext;
7use crate::Request;
8use async_trait::async_trait;
9use std::sync::Arc;
10
11use super::lookup::TenantLookup;
12
13#[async_trait]
36pub trait TenantResolver: Send + Sync {
37 async fn resolve(&self, req: &Request) -> Option<TenantContext>;
41}
42
43pub struct SubdomainResolver {
60 base_domain_parts: usize,
61 tenant_lookup: Arc<dyn TenantLookup>,
62}
63
64impl SubdomainResolver {
65 pub fn new(base_domain_parts: usize, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
70 Self {
71 base_domain_parts,
72 tenant_lookup,
73 }
74 }
75}
76
77#[async_trait]
78impl TenantResolver for SubdomainResolver {
79 async fn resolve(&self, req: &Request) -> Option<TenantContext> {
80 let host = req.header("host")?;
81 let host_no_port = host.split(':').next()?;
83 let parts: Vec<&str> = host_no_port.split('.').collect();
84 if parts.len() <= self.base_domain_parts {
85 return None;
86 }
87 let slug = parts[0];
88 self.tenant_lookup.find_by_slug(slug).await
89 }
90}
91
92pub struct HeaderResolver {
106 header_name: String,
107 tenant_lookup: Arc<dyn TenantLookup>,
108}
109
110impl HeaderResolver {
111 pub fn new(header_name: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
116 Self {
117 header_name: header_name.into(),
118 tenant_lookup,
119 }
120 }
121}
122
123#[async_trait]
124impl TenantResolver for HeaderResolver {
125 async fn resolve(&self, req: &Request) -> Option<TenantContext> {
126 let value = req.header(&self.header_name)?;
127 self.tenant_lookup.find_by_slug(value).await
128 }
129}
130
131pub struct PathResolver {
146 param_name: String,
147 tenant_lookup: Arc<dyn TenantLookup>,
148}
149
150impl PathResolver {
151 pub fn new(param_name: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
156 Self {
157 param_name: param_name.into(),
158 tenant_lookup,
159 }
160 }
161}
162
163#[async_trait]
164impl TenantResolver for PathResolver {
165 async fn resolve(&self, req: &Request) -> Option<TenantContext> {
166 let value = req.param(&self.param_name).ok()?;
167 self.tenant_lookup.find_by_slug(value).await
168 }
169}
170
171pub struct JwtClaimResolver {
190 claim_field: String,
191 tenant_lookup: Arc<dyn TenantLookup>,
192}
193
194impl JwtClaimResolver {
195 pub fn new(claim_field: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
200 Self {
201 claim_field: claim_field.into(),
202 tenant_lookup,
203 }
204 }
205}
206
207#[async_trait]
208impl TenantResolver for JwtClaimResolver {
209 async fn resolve(&self, req: &Request) -> Option<TenantContext> {
210 let claims = req.get::<serde_json::Value>()?;
211 let id = claims[&self.claim_field].as_i64()?;
212 self.tenant_lookup.find_by_id(id).await
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::tenant::{TenantContext, TenantLookup};
220 use async_trait::async_trait;
221 use bytes::Bytes;
222 use http_body_util::Empty;
223 use hyper_util::rt::TokioIo;
224 use std::sync::Arc;
225 use std::sync::Mutex;
226 use tokio::sync::oneshot;
227
228 fn make_tenant(slug: &str) -> TenantContext {
229 TenantContext {
230 id: 42,
231 slug: slug.to_string(),
232 name: "Test Corp".to_string(),
233 plan: None,
234 #[cfg(feature = "stripe")]
235 subscription: None,
236 }
237 }
238
239 struct MockLookup;
241
242 #[async_trait]
243 impl TenantLookup for MockLookup {
244 async fn find_by_slug(&self, slug: &str) -> Option<TenantContext> {
245 match slug {
246 "acme" | "beta" => Some(make_tenant(slug)),
247 _ => None,
248 }
249 }
250
251 async fn find_by_id(&self, id: i64) -> Option<TenantContext> {
252 if id == 42 {
253 Some(make_tenant("acme"))
254 } else {
255 None
256 }
257 }
258 }
259
260 async fn make_request_opts(
262 host: Option<&str>,
263 extra_header: Option<(&str, &str)>,
264 params: std::collections::HashMap<String, String>,
265 ) -> Request {
266 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
267 let addr = listener.local_addr().unwrap();
268 let (tx, rx) = oneshot::channel();
269 let tx_holder = Arc::new(Mutex::new(Some(tx)));
270 let params_clone = params.clone();
271
272 tokio::spawn(async move {
273 let (stream, _) = listener.accept().await.unwrap();
274 let io = TokioIo::new(stream);
275 let tx_holder = tx_holder.clone();
276 let params_inner = params_clone.clone();
277 let service =
278 hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
279 let tx_holder = tx_holder.clone();
280 let params_for_req = params_inner.clone();
281 async move {
282 if let Some(tx) = tx_holder.lock().unwrap().take() {
283 let _ = tx.send(Request::new(req).with_params(params_for_req));
284 }
285 Ok::<_, hyper::Error>(hyper::Response::new(Empty::<Bytes>::new()))
286 }
287 });
288 hyper::server::conn::http1::Builder::new()
289 .serve_connection(io, service)
290 .await
291 .ok();
292 });
293
294 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
295 let io = TokioIo::new(stream);
296 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
297 tokio::spawn(async move {
298 conn.await.ok();
299 });
300
301 let mut builder = hyper::Request::builder().uri("/test");
302 if let Some(h) = host {
303 builder = builder.header("host", h);
304 }
305 if let Some((name, value)) = extra_header {
306 builder = builder.header(name, value);
307 }
308 let req = builder.body(Empty::<Bytes>::new()).unwrap();
309
310 let _ = sender.send_request(req).await;
311 rx.await.unwrap()
312 }
313
314 #[tokio::test]
316 async fn subdomain_resolver_extracts_slug_from_host() {
317 let lookup = Arc::new(MockLookup);
318 let resolver = SubdomainResolver::new(2, lookup);
319 let req = make_request_opts(Some("acme.yourapp.com"), None, Default::default()).await;
320
321 let result = resolver.resolve(&req).await;
322 assert!(result.is_some());
323 assert_eq!(result.unwrap().slug, "acme");
324 }
325
326 #[tokio::test]
328 async fn subdomain_resolver_returns_none_for_no_subdomain() {
329 let lookup = Arc::new(MockLookup);
330 let resolver = SubdomainResolver::new(2, lookup);
331 let req = make_request_opts(Some("yourapp.com"), None, Default::default()).await;
332
333 let result = resolver.resolve(&req).await;
334 assert!(result.is_none());
335 }
336
337 #[tokio::test]
339 async fn subdomain_resolver_strips_port() {
340 let lookup = Arc::new(MockLookup);
341 let resolver = SubdomainResolver::new(2, lookup);
342 let req = make_request_opts(Some("acme.yourapp.com:8080"), None, Default::default()).await;
343
344 let result = resolver.resolve(&req).await;
345 assert!(result.is_some());
346 assert_eq!(result.unwrap().slug, "acme");
347 }
348
349 #[tokio::test]
351 async fn subdomain_resolver_calls_find_by_slug() {
352 let lookup = Arc::new(MockLookup);
353 let resolver = SubdomainResolver::new(2, lookup);
354 let req = make_request_opts(Some("unknown.yourapp.com"), None, Default::default()).await;
356
357 let result = resolver.resolve(&req).await;
358 assert!(result.is_none()); }
360
361 #[tokio::test]
363 async fn header_resolver_extracts_from_header() {
364 let lookup = Arc::new(MockLookup);
365 let resolver = HeaderResolver::new("x-tenant-id", lookup);
366 let req = make_request_opts(None, Some(("x-tenant-id", "acme")), Default::default()).await;
367
368 let result = resolver.resolve(&req).await;
369 assert!(result.is_some());
370 assert_eq!(result.unwrap().slug, "acme");
371 }
372
373 #[tokio::test]
375 async fn header_resolver_returns_none_when_absent() {
376 let lookup = Arc::new(MockLookup);
377 let resolver = HeaderResolver::new("x-tenant-id", lookup);
378 let req = make_request_opts(None, None, Default::default()).await;
380
381 let result = resolver.resolve(&req).await;
382 assert!(result.is_none());
383 }
384
385 #[tokio::test]
387 async fn path_resolver_extracts_from_param() {
388 let lookup = Arc::new(MockLookup);
389 let resolver = PathResolver::new("tenant_slug", lookup);
390 let mut params = std::collections::HashMap::new();
391 params.insert("tenant_slug".to_string(), "beta".to_string());
392 let req = make_request_opts(None, None, params).await;
393
394 let result = resolver.resolve(&req).await;
395 assert!(result.is_some());
396 assert_eq!(result.unwrap().slug, "beta");
397 }
398
399 #[tokio::test]
401 async fn path_resolver_returns_none_when_param_absent() {
402 let lookup = Arc::new(MockLookup);
403 let resolver = PathResolver::new("tenant_slug", lookup);
404 let req = make_request_opts(None, None, Default::default()).await;
406
407 let result = resolver.resolve(&req).await;
408 assert!(result.is_none());
409 }
410
411 #[tokio::test]
413 async fn jwt_claim_resolver_extracts_from_extensions() {
414 let lookup = Arc::new(MockLookup);
415 let resolver = JwtClaimResolver::new("tenant_id", lookup);
416 let mut req = make_request_opts(None, None, Default::default()).await;
417
418 req.insert::<serde_json::Value>(serde_json::json!({"tenant_id": 42, "sub": "user1"}));
420
421 let result = resolver.resolve(&req).await;
422 assert!(result.is_some());
423 assert_eq!(result.unwrap().slug, "acme"); }
425
426 #[tokio::test]
428 async fn jwt_claim_resolver_returns_none_without_claims() {
429 let lookup = Arc::new(MockLookup);
430 let resolver = JwtClaimResolver::new("tenant_id", lookup);
431 let req = make_request_opts(None, None, Default::default()).await;
432 let result = resolver.resolve(&req).await;
435 assert!(result.is_none());
436 }
437
438 #[test]
440 fn tenant_resolver_is_object_safe() {
441 let _: Box<dyn TenantResolver>;
443 }
444}