Skip to main content

ferro_rs/tenant/
mod.rs

1//! Multi-tenant middleware support for Ferro framework.
2//!
3//! Provides task-local tenant context, resolver and lookup trait contracts,
4//! and a default cached database lookup implementation.
5//!
6//! # Overview
7//!
8//! - [`TenantContext`] — holds id, slug, name, and optional plan fields
9//! - [`current_tenant()`] — reads the current tenant from task-local storage
10//! - [`TenantResolver`] — trait for pluggable tenant resolution strategies
11//! - [`TenantLookup`] / [`DbTenantLookup`] — trait + cached implementation for DB queries
12//! - [`TenantFailureMode`] — controls behavior when no tenant is resolved
13
14pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21pub mod scoped;
22#[cfg(feature = "stripe")]
23pub mod subscription;
24pub mod worker;
25
26pub use context::current_tenant;
27pub use lookup::{DbTenantLookup, TenantLookup};
28pub use middleware::TenantMiddleware;
29#[cfg(feature = "stripe")]
30pub use requires_plan::RequiresPlan;
31pub use resolver::{
32    HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
33};
34pub use scope::TenantScope;
35pub use scoped::TenantScoped;
36pub use worker::FrameworkTenantScopeProvider;
37
38use crate::error::FrameworkError;
39use crate::http::{FromRequest, Request};
40use async_trait::async_trait;
41
42/// Core data for the resolved tenant.
43///
44/// Populated by [`TenantResolver`] and stored in task-local scope during a request.
45/// The `plan` field is nullable — tenants may not have a billing plan assigned
46/// until Stripe integration is complete (Phase 96).
47#[derive(Debug, Clone, serde::Serialize)]
48pub struct TenantContext {
49    /// Unique numeric tenant ID (primary key).
50    pub id: i64,
51    /// URL-safe slug used for subdomain or path-based routing.
52    pub slug: String,
53    /// Human-readable tenant name.
54    pub name: String,
55    /// Optional billing plan identifier (legacy — use subscription.plan when stripe feature is enabled).
56    pub plan: Option<String>,
57    /// Full subscription state (available when stripe feature is enabled).
58    #[cfg(feature = "stripe")]
59    pub subscription: Option<subscription::SubscriptionInfo>,
60}
61
62impl TenantContext {
63    /// Construct a tenant context from its core identity fields.
64    ///
65    /// The stripe-gated `subscription` field is initialized to `None`; populate
66    /// it afterward (when the `stripe` feature is enabled) if subscription data
67    /// is available. Using this constructor keeps call sites feature-agnostic —
68    /// they do not need to know whether `subscription` exists in the struct.
69    pub fn new(id: i64, slug: String, name: String, plan: Option<String>) -> Self {
70        Self {
71            id,
72            slug,
73            name,
74            plan,
75            #[cfg(feature = "stripe")]
76            subscription: None,
77        }
78    }
79}
80
81#[cfg(feature = "stripe")]
82impl TenantContext {
83    /// Returns true when the tenant's subscription is in a trial period.
84    ///
85    /// Returns false if there is no active subscription.
86    pub fn on_trial(&self) -> bool {
87        self.subscription.as_ref().is_some_and(|s| s.on_trial())
88    }
89
90    /// Returns true when the tenant has an active or trialing subscription.
91    ///
92    /// Returns false if there is no subscription or status is not active/trialing.
93    pub fn subscribed(&self) -> bool {
94        self.subscription.as_ref().is_some_and(|s| s.subscribed())
95    }
96
97    /// Returns true when the subscription is scheduled to cancel but the billing period is still active.
98    ///
99    /// Returns false if there is no subscription.
100    pub fn on_grace_period(&self) -> bool {
101        self.subscription
102            .as_ref()
103            .is_some_and(|s| s.on_grace_period())
104    }
105
106    /// Returns the current plan identifier from the subscription, falling back to the legacy plan field.
107    ///
108    /// Returns `None` if neither the subscription nor the legacy plan is set.
109    pub fn current_plan(&self) -> Option<&str> {
110        self.subscription
111            .as_ref()
112            .map(|s| s.plan.as_str())
113            .or(self.plan.as_deref())
114    }
115}
116
117/// Extracts the current tenant from task-local context.
118///
119/// Returns `Ok(TenantContext)` when called from a handler behind
120/// `TenantMiddleware`. Returns a 400 error if no tenant context exists.
121///
122/// # Example
123///
124/// ```rust,ignore
125/// #[handler]
126/// pub async fn dashboard(tenant: TenantContext) -> Response {
127///     Ok(json!({"tenant": tenant.name}))
128/// }
129/// ```
130#[async_trait]
131impl FromRequest for TenantContext {
132    async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
133        current_tenant().ok_or_else(|| {
134            FrameworkError::domain(
135                "No tenant context available. Ensure this route is behind TenantMiddleware.",
136                400,
137            )
138        })
139    }
140}
141
142/// Controls framework behavior when no tenant is resolved for a request.
143pub enum TenantFailureMode {
144    /// Return 404 Not Found when the tenant cannot be resolved.
145    NotFound,
146    /// Return 403 Forbidden when the tenant cannot be resolved.
147    Forbidden,
148    /// Pass through — allow the request even without a resolved tenant.
149    Allow,
150    /// Return a custom response when the tenant cannot be resolved.
151    Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
152}
153
154impl std::fmt::Debug for TenantFailureMode {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        match self {
157            Self::NotFound => write!(f, "NotFound"),
158            Self::Forbidden => write!(f, "Forbidden"),
159            Self::Allow => write!(f, "Allow"),
160            Self::Custom(_) => write!(f, "Custom(...)"),
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::tenant::context::{tenant_scope, with_tenant_scope};
169    use hyper_util::rt::TokioIo;
170    use tokio::sync::oneshot;
171
172    fn make_tenant(id: i64, slug: &str) -> TenantContext {
173        TenantContext {
174            id,
175            slug: slug.to_string(),
176            name: format!("Tenant {slug}"),
177            plan: None,
178            #[cfg(feature = "stripe")]
179            subscription: None,
180        }
181    }
182
183    /// Create a minimal Request via TCP loopback.
184    ///
185    /// hyper::body::Incoming has no default constructor, so we use a real
186    /// TCP connection (matching the pattern in middleware tests).
187    async fn make_request() -> Request {
188        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
189        let addr = listener.local_addr().unwrap();
190
191        let (tx, rx) = oneshot::channel::<Request>();
192        let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
193
194        tokio::spawn(async move {
195            if let Ok((stream, _)) = listener.accept().await {
196                let io = TokioIo::new(stream);
197                let tx_holder = tx_holder.clone();
198                hyper::server::conn::http1::Builder::new()
199                    .serve_connection(
200                        io,
201                        hyper::service::service_fn(move |req| {
202                            let tx_holder = tx_holder.clone();
203                            async move {
204                                if let Some(tx) = tx_holder.lock().unwrap().take() {
205                                    let _ = tx.send(Request::new(req));
206                                }
207                                Ok::<_, hyper::Error>(hyper::Response::new(
208                                    http_body_util::Empty::<bytes::Bytes>::new(),
209                                ))
210                            }
211                        }),
212                    )
213                    .await
214                    .ok();
215            }
216        });
217
218        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
219        let io = TokioIo::new(stream);
220        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
221        tokio::spawn(async move { conn.await.ok() });
222
223        let req = hyper::Request::builder()
224            .uri("/test")
225            .body(http_body_util::Empty::<bytes::Bytes>::new())
226            .unwrap();
227        let _ = sender.send_request(req).await;
228        rx.await.unwrap()
229    }
230
231    /// Test 4: TenantContext FromRequest returns Ok(ctx) when current_tenant() is Some.
232    #[tokio::test]
233    async fn from_request_returns_ok_when_tenant_context_is_set() {
234        let ctx = tenant_scope();
235        {
236            let mut guard = ctx.write().await;
237            *guard = Some(make_tenant(99, "acme"));
238        }
239
240        let result = with_tenant_scope(ctx, async {
241            let req = make_request().await;
242            TenantContext::from_request(req).await
243        })
244        .await;
245
246        assert!(
247            result.is_ok(),
248            "Expected Ok(TenantContext), got: {result:?}"
249        );
250        let tenant = result.unwrap();
251        assert_eq!(tenant.id, 99);
252        assert_eq!(tenant.slug, "acme");
253    }
254
255    /// Test 5: TenantContext FromRequest returns Err(FrameworkError) with status 400 when no tenant context.
256    #[tokio::test]
257    async fn from_request_returns_400_error_when_no_tenant_context() {
258        // Call from_request without any TenantMiddleware scope
259        let req = make_request().await;
260        let result = TenantContext::from_request(req).await;
261
262        assert!(result.is_err(), "Expected Err when no tenant context");
263        let err = result.unwrap_err();
264        assert_eq!(
265            err.status_code(),
266            400,
267            "Expected 400 status code, got: {}",
268            err.status_code()
269        );
270    }
271
272    #[cfg(feature = "stripe")]
273    mod stripe_tests {
274        use super::*;
275        use crate::tenant::subscription::{SubscriptionInfo, SubscriptionStatus};
276
277        fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
278            SubscriptionInfo {
279                stripe_subscription_id: "sub_test".to_string(),
280                plan: plan.to_string(),
281                status,
282                trial_ends_at: None,
283                cancel_at_period_end: false,
284                current_period_end: chrono::Utc::now(),
285                stripe_connect_account_id: None,
286            }
287        }
288
289        fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
290            TenantContext {
291                id: 1,
292                slug: "acme".to_string(),
293                name: "ACME Corp".to_string(),
294                plan: Some(plan.to_string()),
295                subscription: Some(make_subscription(plan, status)),
296            }
297        }
298
299        fn make_tenant_no_subscription() -> TenantContext {
300            TenantContext {
301                id: 1,
302                slug: "acme".to_string(),
303                name: "ACME Corp".to_string(),
304                plan: None,
305                subscription: None,
306            }
307        }
308
309        #[test]
310        fn tenant_with_none_subscription_serializes_with_null_subscription() {
311            let tenant = make_tenant_no_subscription();
312            let json = serde_json::to_value(&tenant).unwrap();
313            assert!(json["subscription"].is_null());
314        }
315
316        #[test]
317        fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
318            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
319            let json = serde_json::to_value(&tenant).unwrap();
320            assert!(!json["subscription"].is_null());
321            assert_eq!(json["subscription"]["plan"], "pro");
322            assert_eq!(json["subscription"]["status"], "active");
323        }
324
325        #[test]
326        fn on_trial_returns_true_when_subscription_is_trialing() {
327            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
328            assert!(tenant.on_trial());
329        }
330
331        #[test]
332        fn on_trial_returns_false_when_subscription_is_active() {
333            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
334            assert!(!tenant.on_trial());
335        }
336
337        #[test]
338        fn on_trial_returns_false_when_no_subscription() {
339            let tenant = make_tenant_no_subscription();
340            assert!(!tenant.on_trial());
341        }
342
343        #[test]
344        fn subscribed_returns_true_when_active() {
345            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
346            assert!(tenant.subscribed());
347        }
348
349        #[test]
350        fn subscribed_returns_false_when_canceled() {
351            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
352            assert!(!tenant.subscribed());
353        }
354
355        #[test]
356        fn subscribed_returns_false_when_no_subscription() {
357            let tenant = make_tenant_no_subscription();
358            assert!(!tenant.subscribed());
359        }
360
361        #[test]
362        fn on_grace_period_returns_false_when_no_subscription() {
363            let tenant = make_tenant_no_subscription();
364            assert!(!tenant.on_grace_period());
365        }
366
367        #[test]
368        fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
369            let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
370            if let Some(ref mut sub) = tenant.subscription {
371                sub.cancel_at_period_end = true;
372            }
373            assert!(tenant.on_grace_period());
374        }
375
376        #[test]
377        fn current_plan_returns_subscription_plan_when_present() {
378            let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
379            assert_eq!(tenant.current_plan(), Some("enterprise"));
380        }
381
382        #[test]
383        fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
384            let tenant = TenantContext {
385                id: 1,
386                slug: "acme".to_string(),
387                name: "ACME Corp".to_string(),
388                plan: Some("pro".to_string()),
389                subscription: None,
390            };
391            assert_eq!(tenant.current_plan(), Some("pro"));
392        }
393
394        #[test]
395        fn current_plan_returns_none_when_neither_is_set() {
396            let tenant = make_tenant_no_subscription();
397            assert_eq!(tenant.current_plan(), None);
398        }
399    }
400}