Skip to main content

ferro_rs/tenant/
mod.rs

1//! Multi-tenant middleware support for Ferro framework.
2//!
3//! Provides task-local tenant context, resolver and lookup trait contracts,
4//! and a default cached database lookup implementation.
5//!
6//! # Overview
7//!
8//! - [`TenantContext`] — holds id, slug, name, and optional plan fields
9//! - [`current_tenant()`] — reads the current tenant from task-local storage
10//! - [`TenantResolver`] — trait for pluggable tenant resolution strategies
11//! - [`TenantLookup`] / [`DbTenantLookup`] — trait + cached implementation for DB queries
12//! - [`TenantFailureMode`] — controls behavior when no tenant is resolved
13
14pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21#[cfg(feature = "stripe")]
22pub mod subscription;
23pub mod worker;
24
25pub use context::current_tenant;
26pub use lookup::{DbTenantLookup, TenantLookup};
27pub use middleware::TenantMiddleware;
28#[cfg(feature = "stripe")]
29pub use requires_plan::RequiresPlan;
30pub use resolver::{
31    HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
32};
33pub use scope::TenantScope;
34pub use worker::FrameworkTenantScopeProvider;
35
36use crate::error::FrameworkError;
37use crate::http::{FromRequest, Request};
38use async_trait::async_trait;
39
40/// Core data for the resolved tenant.
41///
42/// Populated by [`TenantResolver`] and stored in task-local scope during a request.
43/// The `plan` field is nullable — tenants may not have a billing plan assigned
44/// until Stripe integration is complete (Phase 96).
45#[derive(Debug, Clone, serde::Serialize)]
46pub struct TenantContext {
47    /// Unique numeric tenant ID (primary key).
48    pub id: i64,
49    /// URL-safe slug used for subdomain or path-based routing.
50    pub slug: String,
51    /// Human-readable tenant name.
52    pub name: String,
53    /// Optional billing plan identifier (legacy — use subscription.plan when stripe feature is enabled).
54    pub plan: Option<String>,
55    /// Full subscription state (available when stripe feature is enabled).
56    #[cfg(feature = "stripe")]
57    pub subscription: Option<subscription::SubscriptionInfo>,
58}
59
60impl TenantContext {
61    /// Construct a tenant context from its core identity fields.
62    ///
63    /// The stripe-gated `subscription` field is initialized to `None`; populate
64    /// it afterward (when the `stripe` feature is enabled) if subscription data
65    /// is available. Using this constructor keeps call sites feature-agnostic —
66    /// they do not need to know whether `subscription` exists in the struct.
67    pub fn new(id: i64, slug: String, name: String, plan: Option<String>) -> Self {
68        Self {
69            id,
70            slug,
71            name,
72            plan,
73            #[cfg(feature = "stripe")]
74            subscription: None,
75        }
76    }
77}
78
79#[cfg(feature = "stripe")]
80impl TenantContext {
81    /// Returns true when the tenant's subscription is in a trial period.
82    ///
83    /// Returns false if there is no active subscription.
84    pub fn on_trial(&self) -> bool {
85        self.subscription.as_ref().is_some_and(|s| s.on_trial())
86    }
87
88    /// Returns true when the tenant has an active or trialing subscription.
89    ///
90    /// Returns false if there is no subscription or status is not active/trialing.
91    pub fn subscribed(&self) -> bool {
92        self.subscription.as_ref().is_some_and(|s| s.subscribed())
93    }
94
95    /// Returns true when the subscription is scheduled to cancel but the billing period is still active.
96    ///
97    /// Returns false if there is no subscription.
98    pub fn on_grace_period(&self) -> bool {
99        self.subscription
100            .as_ref()
101            .is_some_and(|s| s.on_grace_period())
102    }
103
104    /// Returns the current plan identifier from the subscription, falling back to the legacy plan field.
105    ///
106    /// Returns `None` if neither the subscription nor the legacy plan is set.
107    pub fn current_plan(&self) -> Option<&str> {
108        self.subscription
109            .as_ref()
110            .map(|s| s.plan.as_str())
111            .or(self.plan.as_deref())
112    }
113}
114
115/// Extracts the current tenant from task-local context.
116///
117/// Returns `Ok(TenantContext)` when called from a handler behind
118/// `TenantMiddleware`. Returns a 400 error if no tenant context exists.
119///
120/// # Example
121///
122/// ```rust,ignore
123/// #[handler]
124/// pub async fn dashboard(tenant: TenantContext) -> Response {
125///     Ok(json!({"tenant": tenant.name}))
126/// }
127/// ```
128#[async_trait]
129impl FromRequest for TenantContext {
130    async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
131        current_tenant().ok_or_else(|| {
132            FrameworkError::domain(
133                "No tenant context available. Ensure this route is behind TenantMiddleware.",
134                400,
135            )
136        })
137    }
138}
139
140/// Controls framework behavior when no tenant is resolved for a request.
141pub enum TenantFailureMode {
142    /// Return 404 Not Found when the tenant cannot be resolved.
143    NotFound,
144    /// Return 403 Forbidden when the tenant cannot be resolved.
145    Forbidden,
146    /// Pass through — allow the request even without a resolved tenant.
147    Allow,
148    /// Return a custom response when the tenant cannot be resolved.
149    Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
150}
151
152impl std::fmt::Debug for TenantFailureMode {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        match self {
155            Self::NotFound => write!(f, "NotFound"),
156            Self::Forbidden => write!(f, "Forbidden"),
157            Self::Allow => write!(f, "Allow"),
158            Self::Custom(_) => write!(f, "Custom(...)"),
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::tenant::context::{tenant_scope, with_tenant_scope};
167    use hyper_util::rt::TokioIo;
168    use tokio::sync::oneshot;
169
170    fn make_tenant(id: i64, slug: &str) -> TenantContext {
171        TenantContext {
172            id,
173            slug: slug.to_string(),
174            name: format!("Tenant {slug}"),
175            plan: None,
176            #[cfg(feature = "stripe")]
177            subscription: None,
178        }
179    }
180
181    /// Create a minimal Request via TCP loopback.
182    ///
183    /// hyper::body::Incoming has no default constructor, so we use a real
184    /// TCP connection (matching the pattern in middleware tests).
185    async fn make_request() -> Request {
186        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
187        let addr = listener.local_addr().unwrap();
188
189        let (tx, rx) = oneshot::channel::<Request>();
190        let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
191
192        tokio::spawn(async move {
193            if let Ok((stream, _)) = listener.accept().await {
194                let io = TokioIo::new(stream);
195                let tx_holder = tx_holder.clone();
196                hyper::server::conn::http1::Builder::new()
197                    .serve_connection(
198                        io,
199                        hyper::service::service_fn(move |req| {
200                            let tx_holder = tx_holder.clone();
201                            async move {
202                                if let Some(tx) = tx_holder.lock().unwrap().take() {
203                                    let _ = tx.send(Request::new(req));
204                                }
205                                Ok::<_, hyper::Error>(hyper::Response::new(
206                                    http_body_util::Empty::<bytes::Bytes>::new(),
207                                ))
208                            }
209                        }),
210                    )
211                    .await
212                    .ok();
213            }
214        });
215
216        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
217        let io = TokioIo::new(stream);
218        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
219        tokio::spawn(async move { conn.await.ok() });
220
221        let req = hyper::Request::builder()
222            .uri("/test")
223            .body(http_body_util::Empty::<bytes::Bytes>::new())
224            .unwrap();
225        let _ = sender.send_request(req).await;
226        rx.await.unwrap()
227    }
228
229    /// Test 4: TenantContext FromRequest returns Ok(ctx) when current_tenant() is Some.
230    #[tokio::test]
231    async fn from_request_returns_ok_when_tenant_context_is_set() {
232        let ctx = tenant_scope();
233        {
234            let mut guard = ctx.write().await;
235            *guard = Some(make_tenant(99, "acme"));
236        }
237
238        let result = with_tenant_scope(ctx, async {
239            let req = make_request().await;
240            TenantContext::from_request(req).await
241        })
242        .await;
243
244        assert!(
245            result.is_ok(),
246            "Expected Ok(TenantContext), got: {result:?}"
247        );
248        let tenant = result.unwrap();
249        assert_eq!(tenant.id, 99);
250        assert_eq!(tenant.slug, "acme");
251    }
252
253    /// Test 5: TenantContext FromRequest returns Err(FrameworkError) with status 400 when no tenant context.
254    #[tokio::test]
255    async fn from_request_returns_400_error_when_no_tenant_context() {
256        // Call from_request without any TenantMiddleware scope
257        let req = make_request().await;
258        let result = TenantContext::from_request(req).await;
259
260        assert!(result.is_err(), "Expected Err when no tenant context");
261        let err = result.unwrap_err();
262        assert_eq!(
263            err.status_code(),
264            400,
265            "Expected 400 status code, got: {}",
266            err.status_code()
267        );
268    }
269
270    #[cfg(feature = "stripe")]
271    mod stripe_tests {
272        use super::*;
273        use crate::tenant::subscription::{SubscriptionInfo, SubscriptionStatus};
274
275        fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
276            SubscriptionInfo {
277                stripe_subscription_id: "sub_test".to_string(),
278                plan: plan.to_string(),
279                status,
280                trial_ends_at: None,
281                cancel_at_period_end: false,
282                current_period_end: chrono::Utc::now(),
283                stripe_connect_account_id: None,
284            }
285        }
286
287        fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
288            TenantContext {
289                id: 1,
290                slug: "acme".to_string(),
291                name: "ACME Corp".to_string(),
292                plan: Some(plan.to_string()),
293                subscription: Some(make_subscription(plan, status)),
294            }
295        }
296
297        fn make_tenant_no_subscription() -> TenantContext {
298            TenantContext {
299                id: 1,
300                slug: "acme".to_string(),
301                name: "ACME Corp".to_string(),
302                plan: None,
303                subscription: None,
304            }
305        }
306
307        #[test]
308        fn tenant_with_none_subscription_serializes_with_null_subscription() {
309            let tenant = make_tenant_no_subscription();
310            let json = serde_json::to_value(&tenant).unwrap();
311            assert!(json["subscription"].is_null());
312        }
313
314        #[test]
315        fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
316            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
317            let json = serde_json::to_value(&tenant).unwrap();
318            assert!(!json["subscription"].is_null());
319            assert_eq!(json["subscription"]["plan"], "pro");
320            assert_eq!(json["subscription"]["status"], "active");
321        }
322
323        #[test]
324        fn on_trial_returns_true_when_subscription_is_trialing() {
325            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
326            assert!(tenant.on_trial());
327        }
328
329        #[test]
330        fn on_trial_returns_false_when_subscription_is_active() {
331            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
332            assert!(!tenant.on_trial());
333        }
334
335        #[test]
336        fn on_trial_returns_false_when_no_subscription() {
337            let tenant = make_tenant_no_subscription();
338            assert!(!tenant.on_trial());
339        }
340
341        #[test]
342        fn subscribed_returns_true_when_active() {
343            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
344            assert!(tenant.subscribed());
345        }
346
347        #[test]
348        fn subscribed_returns_false_when_canceled() {
349            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
350            assert!(!tenant.subscribed());
351        }
352
353        #[test]
354        fn subscribed_returns_false_when_no_subscription() {
355            let tenant = make_tenant_no_subscription();
356            assert!(!tenant.subscribed());
357        }
358
359        #[test]
360        fn on_grace_period_returns_false_when_no_subscription() {
361            let tenant = make_tenant_no_subscription();
362            assert!(!tenant.on_grace_period());
363        }
364
365        #[test]
366        fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
367            let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
368            if let Some(ref mut sub) = tenant.subscription {
369                sub.cancel_at_period_end = true;
370            }
371            assert!(tenant.on_grace_period());
372        }
373
374        #[test]
375        fn current_plan_returns_subscription_plan_when_present() {
376            let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
377            assert_eq!(tenant.current_plan(), Some("enterprise"));
378        }
379
380        #[test]
381        fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
382            let tenant = TenantContext {
383                id: 1,
384                slug: "acme".to_string(),
385                name: "ACME Corp".to_string(),
386                plan: Some("pro".to_string()),
387                subscription: None,
388            };
389            assert_eq!(tenant.current_plan(), Some("pro"));
390        }
391
392        #[test]
393        fn current_plan_returns_none_when_neither_is_set() {
394            let tenant = make_tenant_no_subscription();
395            assert_eq!(tenant.current_plan(), None);
396        }
397    }
398}