Skip to main content

ferro_rs/tenant/
mod.rs

1//! Multi-tenant middleware support for Ferro framework.
2//!
3//! Provides task-local tenant context, resolver and lookup trait contracts,
4//! and a default cached database lookup implementation.
5//!
6//! # Overview
7//!
8//! - [`TenantContext`] — holds id, slug, name, and optional plan fields
9//! - [`current_tenant()`] — reads the current tenant from task-local storage
10//! - [`TenantResolver`] — trait for pluggable tenant resolution strategies
11//! - [`TenantLookup`] / [`DbTenantLookup`] — trait + cached implementation for DB queries
12//! - [`TenantFailureMode`] — controls behavior when no tenant is resolved
13
14pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21pub mod worker;
22
23pub use context::current_tenant;
24pub use lookup::{DbTenantLookup, TenantLookup};
25pub use middleware::TenantMiddleware;
26#[cfg(feature = "stripe")]
27pub use requires_plan::RequiresPlan;
28pub use resolver::{
29    HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
30};
31pub use scope::TenantScope;
32pub use worker::FrameworkTenantScopeProvider;
33
34use crate::error::FrameworkError;
35use crate::http::{FromRequest, Request};
36use async_trait::async_trait;
37
38/// Core data for the resolved tenant.
39///
40/// Populated by [`TenantResolver`] and stored in task-local scope during a request.
41/// The `plan` field is nullable — tenants may not have a billing plan assigned
42/// until Stripe integration is complete (Phase 96).
43#[derive(Debug, Clone, serde::Serialize)]
44pub struct TenantContext {
45    /// Unique numeric tenant ID (primary key).
46    pub id: i64,
47    /// URL-safe slug used for subdomain or path-based routing.
48    pub slug: String,
49    /// Human-readable tenant name.
50    pub name: String,
51    /// Optional billing plan identifier (legacy — use subscription.plan when stripe feature is enabled).
52    pub plan: Option<String>,
53    /// Full subscription state (available when stripe feature is enabled).
54    #[cfg(feature = "stripe")]
55    pub subscription: Option<ferro_stripe::SubscriptionInfo>,
56}
57
58#[cfg(feature = "stripe")]
59impl TenantContext {
60    /// Returns true when the tenant's subscription is in a trial period.
61    ///
62    /// Returns false if there is no active subscription.
63    pub fn on_trial(&self) -> bool {
64        self.subscription.as_ref().is_some_and(|s| s.on_trial())
65    }
66
67    /// Returns true when the tenant has an active or trialing subscription.
68    ///
69    /// Returns false if there is no subscription or status is not active/trialing.
70    pub fn subscribed(&self) -> bool {
71        self.subscription.as_ref().is_some_and(|s| s.subscribed())
72    }
73
74    /// Returns true when the subscription is scheduled to cancel but the billing period is still active.
75    ///
76    /// Returns false if there is no subscription.
77    pub fn on_grace_period(&self) -> bool {
78        self.subscription
79            .as_ref()
80            .is_some_and(|s| s.on_grace_period())
81    }
82
83    /// Returns the current plan identifier from the subscription, falling back to the legacy plan field.
84    ///
85    /// Returns `None` if neither the subscription nor the legacy plan is set.
86    pub fn current_plan(&self) -> Option<&str> {
87        self.subscription
88            .as_ref()
89            .map(|s| s.plan.as_str())
90            .or(self.plan.as_deref())
91    }
92}
93
94/// Extracts the current tenant from task-local context.
95///
96/// Returns `Ok(TenantContext)` when called from a handler behind
97/// `TenantMiddleware`. Returns a 400 error if no tenant context exists.
98///
99/// # Example
100///
101/// ```rust,ignore
102/// #[handler]
103/// pub async fn dashboard(tenant: TenantContext) -> Response {
104///     Ok(json!({"tenant": tenant.name}))
105/// }
106/// ```
107#[async_trait]
108impl FromRequest for TenantContext {
109    async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
110        current_tenant().ok_or_else(|| {
111            FrameworkError::domain(
112                "No tenant context available. Ensure this route is behind TenantMiddleware.",
113                400,
114            )
115        })
116    }
117}
118
119/// Controls framework behavior when no tenant is resolved for a request.
120pub enum TenantFailureMode {
121    /// Return 404 Not Found when the tenant cannot be resolved.
122    NotFound,
123    /// Return 403 Forbidden when the tenant cannot be resolved.
124    Forbidden,
125    /// Pass through — allow the request even without a resolved tenant.
126    Allow,
127    /// Return a custom response when the tenant cannot be resolved.
128    Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
129}
130
131impl std::fmt::Debug for TenantFailureMode {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match self {
134            Self::NotFound => write!(f, "NotFound"),
135            Self::Forbidden => write!(f, "Forbidden"),
136            Self::Allow => write!(f, "Allow"),
137            Self::Custom(_) => write!(f, "Custom(...)"),
138        }
139    }
140}
141
142impl Clone for TenantFailureMode {
143    fn clone(&self) -> Self {
144        match self {
145            Self::NotFound => Self::NotFound,
146            Self::Forbidden => Self::Forbidden,
147            Self::Allow => Self::Allow,
148            Self::Custom(_) => panic!("TenantFailureMode::Custom cannot be cloned"),
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::tenant::context::{tenant_scope, with_tenant_scope};
157    use hyper_util::rt::TokioIo;
158    use tokio::sync::oneshot;
159
160    fn make_tenant(id: i64, slug: &str) -> TenantContext {
161        TenantContext {
162            id,
163            slug: slug.to_string(),
164            name: format!("Tenant {slug}"),
165            plan: None,
166            #[cfg(feature = "stripe")]
167            subscription: None,
168        }
169    }
170
171    /// Create a minimal Request via TCP loopback.
172    ///
173    /// hyper::body::Incoming has no default constructor, so we use a real
174    /// TCP connection (matching the pattern in middleware tests).
175    async fn make_request() -> Request {
176        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
177        let addr = listener.local_addr().unwrap();
178
179        let (tx, rx) = oneshot::channel::<Request>();
180        let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
181
182        tokio::spawn(async move {
183            if let Ok((stream, _)) = listener.accept().await {
184                let io = TokioIo::new(stream);
185                let tx_holder = tx_holder.clone();
186                hyper::server::conn::http1::Builder::new()
187                    .serve_connection(
188                        io,
189                        hyper::service::service_fn(move |req| {
190                            let tx_holder = tx_holder.clone();
191                            async move {
192                                if let Some(tx) = tx_holder.lock().unwrap().take() {
193                                    let _ = tx.send(Request::new(req));
194                                }
195                                Ok::<_, hyper::Error>(hyper::Response::new(
196                                    http_body_util::Empty::<bytes::Bytes>::new(),
197                                ))
198                            }
199                        }),
200                    )
201                    .await
202                    .ok();
203            }
204        });
205
206        let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
207        let io = TokioIo::new(stream);
208        let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
209        tokio::spawn(async move { conn.await.ok() });
210
211        let req = hyper::Request::builder()
212            .uri("/test")
213            .body(http_body_util::Empty::<bytes::Bytes>::new())
214            .unwrap();
215        let _ = sender.send_request(req).await;
216        rx.await.unwrap()
217    }
218
219    /// Test 4: TenantContext FromRequest returns Ok(ctx) when current_tenant() is Some.
220    #[tokio::test]
221    async fn from_request_returns_ok_when_tenant_context_is_set() {
222        let ctx = tenant_scope();
223        {
224            let mut guard = ctx.write().await;
225            *guard = Some(make_tenant(99, "acme"));
226        }
227
228        let result = with_tenant_scope(ctx, async {
229            let req = make_request().await;
230            TenantContext::from_request(req).await
231        })
232        .await;
233
234        assert!(
235            result.is_ok(),
236            "Expected Ok(TenantContext), got: {result:?}"
237        );
238        let tenant = result.unwrap();
239        assert_eq!(tenant.id, 99);
240        assert_eq!(tenant.slug, "acme");
241    }
242
243    /// Test 5: TenantContext FromRequest returns Err(FrameworkError) with status 400 when no tenant context.
244    #[tokio::test]
245    async fn from_request_returns_400_error_when_no_tenant_context() {
246        // Call from_request without any TenantMiddleware scope
247        let req = make_request().await;
248        let result = TenantContext::from_request(req).await;
249
250        assert!(result.is_err(), "Expected Err when no tenant context");
251        let err = result.unwrap_err();
252        assert_eq!(
253            err.status_code(),
254            400,
255            "Expected 400 status code, got: {}",
256            err.status_code()
257        );
258    }
259
260    #[cfg(feature = "stripe")]
261    mod stripe_tests {
262        use super::*;
263        use ferro_stripe::{SubscriptionInfo, SubscriptionStatus};
264
265        fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
266            SubscriptionInfo {
267                stripe_subscription_id: "sub_test".to_string(),
268                plan: plan.to_string(),
269                status,
270                trial_ends_at: None,
271                cancel_at_period_end: false,
272                current_period_end: chrono::Utc::now(),
273                stripe_connect_account_id: None,
274            }
275        }
276
277        fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
278            TenantContext {
279                id: 1,
280                slug: "acme".to_string(),
281                name: "ACME Corp".to_string(),
282                plan: Some(plan.to_string()),
283                subscription: Some(make_subscription(plan, status)),
284            }
285        }
286
287        fn make_tenant_no_subscription() -> TenantContext {
288            TenantContext {
289                id: 1,
290                slug: "acme".to_string(),
291                name: "ACME Corp".to_string(),
292                plan: None,
293                subscription: None,
294            }
295        }
296
297        #[test]
298        fn tenant_with_none_subscription_serializes_with_null_subscription() {
299            let tenant = make_tenant_no_subscription();
300            let json = serde_json::to_value(&tenant).unwrap();
301            assert!(json["subscription"].is_null());
302        }
303
304        #[test]
305        fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
306            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
307            let json = serde_json::to_value(&tenant).unwrap();
308            assert!(!json["subscription"].is_null());
309            assert_eq!(json["subscription"]["plan"], "pro");
310            assert_eq!(json["subscription"]["status"], "active");
311        }
312
313        #[test]
314        fn on_trial_returns_true_when_subscription_is_trialing() {
315            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
316            assert!(tenant.on_trial());
317        }
318
319        #[test]
320        fn on_trial_returns_false_when_subscription_is_active() {
321            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
322            assert!(!tenant.on_trial());
323        }
324
325        #[test]
326        fn on_trial_returns_false_when_no_subscription() {
327            let tenant = make_tenant_no_subscription();
328            assert!(!tenant.on_trial());
329        }
330
331        #[test]
332        fn subscribed_returns_true_when_active() {
333            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
334            assert!(tenant.subscribed());
335        }
336
337        #[test]
338        fn subscribed_returns_false_when_canceled() {
339            let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
340            assert!(!tenant.subscribed());
341        }
342
343        #[test]
344        fn subscribed_returns_false_when_no_subscription() {
345            let tenant = make_tenant_no_subscription();
346            assert!(!tenant.subscribed());
347        }
348
349        #[test]
350        fn on_grace_period_returns_false_when_no_subscription() {
351            let tenant = make_tenant_no_subscription();
352            assert!(!tenant.on_grace_period());
353        }
354
355        #[test]
356        fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
357            let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
358            if let Some(ref mut sub) = tenant.subscription {
359                sub.cancel_at_period_end = true;
360            }
361            assert!(tenant.on_grace_period());
362        }
363
364        #[test]
365        fn current_plan_returns_subscription_plan_when_present() {
366            let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
367            assert_eq!(tenant.current_plan(), Some("enterprise"));
368        }
369
370        #[test]
371        fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
372            let tenant = TenantContext {
373                id: 1,
374                slug: "acme".to_string(),
375                name: "ACME Corp".to_string(),
376                plan: Some("pro".to_string()),
377                subscription: None,
378            };
379            assert_eq!(tenant.current_plan(), Some("pro"));
380        }
381
382        #[test]
383        fn current_plan_returns_none_when_neither_is_set() {
384            let tenant = make_tenant_no_subscription();
385            assert_eq!(tenant.current_plan(), None);
386        }
387    }
388}