Skip to main content

ferro_rs/tenant/
resolver.rs

1//! Tenant resolver trait and concrete implementations for Ferro framework.
2//!
3//! Defines the contract for pluggable tenant resolution strategies and provides
4//! four built-in implementations: subdomain, header, path parameter, and JWT claim.
5
6use crate::tenant::TenantContext;
7use crate::Request;
8use async_trait::async_trait;
9use std::sync::Arc;
10
11use super::lookup::TenantLookup;
12
13/// Resolves the current tenant from an incoming request.
14///
15/// Implement this trait to provide custom resolution strategies such as
16/// subdomain parsing, header inspection, or JWT claim extraction.
17///
18/// # Example
19///
20/// ```rust,ignore
21/// use ferro_rs::tenant::{TenantContext, TenantResolver};
22/// use ferro_rs::Request;
23/// use async_trait::async_trait;
24///
25/// struct SubdomainResolver;
26///
27/// #[async_trait]
28/// impl TenantResolver for SubdomainResolver {
29///     async fn resolve(&self, req: &Request) -> Option<TenantContext> {
30///         // Extract subdomain from Host header and resolve tenant
31///         None
32///     }
33/// }
34/// ```
35#[async_trait]
36pub trait TenantResolver: Send + Sync {
37    /// Resolve the tenant from the given request.
38    ///
39    /// Returns `None` if no tenant could be determined.
40    async fn resolve(&self, req: &Request) -> Option<TenantContext>;
41}
42
43/// Resolves tenant from the Host header subdomain.
44///
45/// Extracts the leftmost subdomain segment by stripping the rightmost
46/// `base_domain_parts` segments. For example, with `base_domain_parts = 2`:
47/// - `acme.yourapp.com` → slug `"acme"`
48/// - `yourapp.com` → `None` (no subdomain)
49/// - `acme.yourapp.com:8080` → slug `"acme"` (port stripped)
50///
51/// # Example
52///
53/// ```rust,ignore
54/// use ferro_rs::tenant::{SubdomainResolver, DbTenantLookup};
55/// use std::sync::Arc;
56///
57/// let resolver = SubdomainResolver::new(2, Arc::new(lookup));
58/// ```
59pub struct SubdomainResolver {
60    base_domain_parts: usize,
61    tenant_lookup: Arc<dyn TenantLookup>,
62}
63
64impl SubdomainResolver {
65    /// Create a new `SubdomainResolver`.
66    ///
67    /// - `base_domain_parts` — number of base domain segments to strip (e.g. 2 for `yourapp.com`)
68    /// - `tenant_lookup` — lookup implementation for DB verification
69    pub fn new(base_domain_parts: usize, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
70        Self {
71            base_domain_parts,
72            tenant_lookup,
73        }
74    }
75}
76
77#[async_trait]
78impl TenantResolver for SubdomainResolver {
79    async fn resolve(&self, req: &Request) -> Option<TenantContext> {
80        let host = req.header("host")?;
81        // Strip port: "acme.yourapp.com:8080" -> "acme.yourapp.com"
82        let host_no_port = host.split(':').next()?;
83        let parts: Vec<&str> = host_no_port.split('.').collect();
84        if parts.len() <= self.base_domain_parts {
85            return None;
86        }
87        let slug = parts[0];
88        self.tenant_lookup.find_by_slug(slug).await
89    }
90}
91
92/// Resolves tenant from a configurable HTTP header.
93///
94/// Reads the header named `header_name` and passes its value to the lookup
95/// as a slug.
96///
97/// # Example
98///
99/// ```rust,ignore
100/// use ferro_rs::tenant::{HeaderResolver, DbTenantLookup};
101/// use std::sync::Arc;
102///
103/// let resolver = HeaderResolver::new("X-Tenant-ID", Arc::new(lookup));
104/// ```
105pub struct HeaderResolver {
106    header_name: String,
107    tenant_lookup: Arc<dyn TenantLookup>,
108}
109
110impl HeaderResolver {
111    /// Create a new `HeaderResolver`.
112    ///
113    /// - `header_name` — the HTTP header to read the tenant slug from
114    /// - `tenant_lookup` — lookup implementation for DB verification
115    pub fn new(header_name: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
116        Self {
117            header_name: header_name.into(),
118            tenant_lookup,
119        }
120    }
121}
122
123#[async_trait]
124impl TenantResolver for HeaderResolver {
125    async fn resolve(&self, req: &Request) -> Option<TenantContext> {
126        let value = req.header(&self.header_name)?;
127        self.tenant_lookup.find_by_slug(value).await
128    }
129}
130
131/// Resolves tenant from a route path parameter.
132///
133/// Reads the route parameter named `param_name` and passes it to the lookup
134/// as a slug.
135///
136/// # Example
137///
138/// ```rust,ignore
139/// use ferro_rs::tenant::{PathResolver, DbTenantLookup};
140/// use std::sync::Arc;
141///
142/// // For routes like /{tenant_slug}/dashboard
143/// let resolver = PathResolver::new("tenant_slug", Arc::new(lookup));
144/// ```
145pub struct PathResolver {
146    param_name: String,
147    tenant_lookup: Arc<dyn TenantLookup>,
148}
149
150impl PathResolver {
151    /// Create a new `PathResolver`.
152    ///
153    /// - `param_name` — the route parameter name containing the tenant slug
154    /// - `tenant_lookup` — lookup implementation for DB verification
155    pub fn new(param_name: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
156        Self {
157            param_name: param_name.into(),
158            tenant_lookup,
159        }
160    }
161}
162
163#[async_trait]
164impl TenantResolver for PathResolver {
165    async fn resolve(&self, req: &Request) -> Option<TenantContext> {
166        let value = req.param(&self.param_name).ok()?;
167        self.tenant_lookup.find_by_slug(value).await
168    }
169}
170
171/// Resolves tenant from a JWT claim stored in request extensions.
172///
173/// Reads a `serde_json::Value` inserted by an upstream JWT middleware, extracts
174/// the claim field named `claim_field` as an `i64`, and passes it to the lookup
175/// as a tenant ID.
176///
177/// The upstream JWT/auth middleware must insert the parsed claims as
178/// `serde_json::Value` via `req.insert::<serde_json::Value>(claims)`.
179///
180/// # Example
181///
182/// ```rust,ignore
183/// use ferro_rs::tenant::{JwtClaimResolver, DbTenantLookup};
184/// use std::sync::Arc;
185///
186/// // Expects upstream JWT middleware to insert serde_json::Value with a "tenant_id" field
187/// let resolver = JwtClaimResolver::new("tenant_id", Arc::new(lookup));
188/// ```
189pub struct JwtClaimResolver {
190    claim_field: String,
191    tenant_lookup: Arc<dyn TenantLookup>,
192}
193
194impl JwtClaimResolver {
195    /// Create a new `JwtClaimResolver`.
196    ///
197    /// - `claim_field` — the claim key whose value is the tenant ID (i64)
198    /// - `tenant_lookup` — lookup implementation for DB verification
199    pub fn new(claim_field: impl Into<String>, tenant_lookup: Arc<dyn TenantLookup>) -> Self {
200        Self {
201            claim_field: claim_field.into(),
202            tenant_lookup,
203        }
204    }
205}
206
207#[async_trait]
208impl TenantResolver for JwtClaimResolver {
209    async fn resolve(&self, req: &Request) -> Option<TenantContext> {
210        let claims = req.get::<serde_json::Value>()?;
211        let id = claims[&self.claim_field].as_i64()?;
212        self.tenant_lookup.find_by_id(id).await
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::tenant::{TenantContext, TenantLookup};
220    use async_trait::async_trait;
221    use bytes::Bytes;
222    use http_body_util::Empty;
223    use hyper_util::rt::TokioIo;
224    use std::sync::Arc;
225    use std::sync::Mutex;
226    use tokio::sync::oneshot;
227
228    fn make_tenant(slug: &str) -> TenantContext {
229        TenantContext {
230            id: 42,
231            slug: slug.to_string(),
232            name: "Test Corp".to_string(),
233            plan: None,
234            #[cfg(feature = "stripe")]
235            subscription: None,
236        }
237    }
238
239    /// Mock lookup that returns a tenant for known slugs/ids and None otherwise.
240    struct MockLookup;
241
242    #[async_trait]
243    impl TenantLookup for MockLookup {
244        async fn find_by_slug(&self, slug: &str) -> Option<TenantContext> {
245            match slug {
246                "acme" | "beta" => Some(make_tenant(slug)),
247                _ => None,
248            }
249        }
250
251        async fn find_by_id(&self, id: i64) -> Option<TenantContext> {
252            if id == 42 {
253                Some(make_tenant("acme"))
254            } else {
255                None
256            }
257        }
258    }
259
260    /// Create a test Request via TCP loopback with optional host, header, and path param.
261    async fn make_request_opts(
262        host: Option<&str>,
263        extra_header: Option<(&str, &str)>,
264        params: std::collections::HashMap<String, String>,
265    ) -> Request {
266        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
267        let addr = listener.local_addr().unwrap();
268        let (tx, rx) = oneshot::channel();
269        let tx_holder = Arc::new(Mutex::new(Some(tx)));
270        let params_clone = params.clone();
271
272        tokio::spawn(async move {
273            let (stream, _) = listener.accept().await.unwrap();
274            let io = TokioIo::new(stream);
275            let tx_holder = tx_holder.clone();
276            let params_inner = params_clone.clone();
277            let service =
278                hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
279                    let tx_holder = tx_holder.clone();
280                    let params_for_req = params_inner.clone();
281                    async move {
282                        if let Some(tx) = tx_holder.lock().unwrap().take() {
283                            let _ = tx.send(Request::new(req).with_params(params_for_req));
284                        }
285                        Ok::<_, hyper::Error>(hyper::Response::new(Empty::<Bytes>::new()))
286                    }
287                });
288            hyper::server::conn::http1::Builder::new()
289                .serve_connection(io, service)
290                .await
291                .ok();
292        });
293
294        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
295        let io = TokioIo::new(stream);
296        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
297        tokio::spawn(async move {
298            conn.await.ok();
299        });
300
301        let mut builder = hyper::Request::builder().uri("/test");
302        if let Some(h) = host {
303            builder = builder.header("host", h);
304        }
305        if let Some((name, value)) = extra_header {
306            builder = builder.header(name, value);
307        }
308        let req = builder.body(Empty::<Bytes>::new()).unwrap();
309
310        let _ = sender.send_request(req).await;
311        rx.await.unwrap()
312    }
313
314    // Test 1: SubdomainResolver extracts "acme" from Host "acme.yourapp.com" with base_domain_parts=2
315    #[tokio::test]
316    async fn subdomain_resolver_extracts_slug_from_host() {
317        let lookup = Arc::new(MockLookup);
318        let resolver = SubdomainResolver::new(2, lookup);
319        let req = make_request_opts(Some("acme.yourapp.com"), None, Default::default()).await;
320
321        let result = resolver.resolve(&req).await;
322        assert!(result.is_some());
323        assert_eq!(result.unwrap().slug, "acme");
324    }
325
326    // Test 2: SubdomainResolver returns None for "yourapp.com" (no subdomain, only base parts)
327    #[tokio::test]
328    async fn subdomain_resolver_returns_none_for_no_subdomain() {
329        let lookup = Arc::new(MockLookup);
330        let resolver = SubdomainResolver::new(2, lookup);
331        let req = make_request_opts(Some("yourapp.com"), None, Default::default()).await;
332
333        let result = resolver.resolve(&req).await;
334        assert!(result.is_none());
335    }
336
337    // Test 3: SubdomainResolver strips port from Host header
338    #[tokio::test]
339    async fn subdomain_resolver_strips_port() {
340        let lookup = Arc::new(MockLookup);
341        let resolver = SubdomainResolver::new(2, lookup);
342        let req = make_request_opts(Some("acme.yourapp.com:8080"), None, Default::default()).await;
343
344        let result = resolver.resolve(&req).await;
345        assert!(result.is_some());
346        assert_eq!(result.unwrap().slug, "acme");
347    }
348
349    // Test 4: SubdomainResolver calls tenant_lookup.find_by_slug() with extracted slug
350    #[tokio::test]
351    async fn subdomain_resolver_calls_find_by_slug() {
352        let lookup = Arc::new(MockLookup);
353        let resolver = SubdomainResolver::new(2, lookup);
354        // "unknown" is not in MockLookup so returns None — proving find_by_slug was called
355        let req = make_request_opts(Some("unknown.yourapp.com"), None, Default::default()).await;
356
357        let result = resolver.resolve(&req).await;
358        assert!(result.is_none()); // lookup returned None for unknown slug
359    }
360
361    // Test 5: HeaderResolver extracts from X-Tenant-ID header and calls find_by_slug
362    #[tokio::test]
363    async fn header_resolver_extracts_from_header() {
364        let lookup = Arc::new(MockLookup);
365        let resolver = HeaderResolver::new("x-tenant-id", lookup);
366        let req = make_request_opts(None, Some(("x-tenant-id", "acme")), Default::default()).await;
367
368        let result = resolver.resolve(&req).await;
369        assert!(result.is_some());
370        assert_eq!(result.unwrap().slug, "acme");
371    }
372
373    // Test 6: HeaderResolver returns None when header is absent
374    #[tokio::test]
375    async fn header_resolver_returns_none_when_absent() {
376        let lookup = Arc::new(MockLookup);
377        let resolver = HeaderResolver::new("x-tenant-id", lookup);
378        // No x-tenant-id header in request
379        let req = make_request_opts(None, None, Default::default()).await;
380
381        let result = resolver.resolve(&req).await;
382        assert!(result.is_none());
383    }
384
385    // Test 7: PathResolver extracts from request.param("tenant_slug") and calls find_by_slug
386    #[tokio::test]
387    async fn path_resolver_extracts_from_param() {
388        let lookup = Arc::new(MockLookup);
389        let resolver = PathResolver::new("tenant_slug", lookup);
390        let mut params = std::collections::HashMap::new();
391        params.insert("tenant_slug".to_string(), "beta".to_string());
392        let req = make_request_opts(None, None, params).await;
393
394        let result = resolver.resolve(&req).await;
395        assert!(result.is_some());
396        assert_eq!(result.unwrap().slug, "beta");
397    }
398
399    // Test 8: PathResolver returns None when path parameter is absent
400    #[tokio::test]
401    async fn path_resolver_returns_none_when_param_absent() {
402        let lookup = Arc::new(MockLookup);
403        let resolver = PathResolver::new("tenant_slug", lookup);
404        // No tenant_slug param
405        let req = make_request_opts(None, None, Default::default()).await;
406
407        let result = resolver.resolve(&req).await;
408        assert!(result.is_none());
409    }
410
411    // Test 9: JwtClaimResolver extracts tenant_id from request extension (pre-parsed JWT claims)
412    #[tokio::test]
413    async fn jwt_claim_resolver_extracts_from_extensions() {
414        let lookup = Arc::new(MockLookup);
415        let resolver = JwtClaimResolver::new("tenant_id", lookup);
416        let mut req = make_request_opts(None, None, Default::default()).await;
417
418        // Upstream JWT middleware inserts claims as serde_json::Value
419        req.insert::<serde_json::Value>(serde_json::json!({"tenant_id": 42, "sub": "user1"}));
420
421        let result = resolver.resolve(&req).await;
422        assert!(result.is_some());
423        assert_eq!(result.unwrap().slug, "acme"); // MockLookup returns "acme" for id=42
424    }
425
426    // Test 10: JwtClaimResolver returns None when no JWT claims in request
427    #[tokio::test]
428    async fn jwt_claim_resolver_returns_none_without_claims() {
429        let lookup = Arc::new(MockLookup);
430        let resolver = JwtClaimResolver::new("tenant_id", lookup);
431        let req = make_request_opts(None, None, Default::default()).await;
432        // No claims inserted into extensions
433
434        let result = resolver.resolve(&req).await;
435        assert!(result.is_none());
436    }
437
438    // Object safety test from original resolver.rs
439    #[test]
440    fn tenant_resolver_is_object_safe() {
441        // If TenantResolver were not object-safe, this would not compile.
442        let _: Box<dyn TenantResolver>;
443    }
444}