1pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21pub mod worker;
22
23pub use context::current_tenant;
24pub use lookup::{DbTenantLookup, TenantLookup};
25pub use middleware::TenantMiddleware;
26#[cfg(feature = "stripe")]
27pub use requires_plan::RequiresPlan;
28pub use resolver::{
29 HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
30};
31pub use scope::TenantScope;
32pub use worker::FrameworkTenantScopeProvider;
33
34use crate::error::FrameworkError;
35use crate::http::{FromRequest, Request};
36use async_trait::async_trait;
37
38#[derive(Debug, Clone, serde::Serialize)]
44pub struct TenantContext {
45 pub id: i64,
47 pub slug: String,
49 pub name: String,
51 pub plan: Option<String>,
53 #[cfg(feature = "stripe")]
55 pub subscription: Option<ferro_stripe::SubscriptionInfo>,
56}
57
58#[cfg(feature = "stripe")]
59impl TenantContext {
60 pub fn on_trial(&self) -> bool {
64 self.subscription.as_ref().is_some_and(|s| s.on_trial())
65 }
66
67 pub fn subscribed(&self) -> bool {
71 self.subscription.as_ref().is_some_and(|s| s.subscribed())
72 }
73
74 pub fn on_grace_period(&self) -> bool {
78 self.subscription
79 .as_ref()
80 .is_some_and(|s| s.on_grace_period())
81 }
82
83 pub fn current_plan(&self) -> Option<&str> {
87 self.subscription
88 .as_ref()
89 .map(|s| s.plan.as_str())
90 .or(self.plan.as_deref())
91 }
92}
93
94#[async_trait]
108impl FromRequest for TenantContext {
109 async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
110 current_tenant().ok_or_else(|| {
111 FrameworkError::domain(
112 "No tenant context available. Ensure this route is behind TenantMiddleware.",
113 400,
114 )
115 })
116 }
117}
118
119pub enum TenantFailureMode {
121 NotFound,
123 Forbidden,
125 Allow,
127 Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
129}
130
131impl std::fmt::Debug for TenantFailureMode {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 Self::NotFound => write!(f, "NotFound"),
135 Self::Forbidden => write!(f, "Forbidden"),
136 Self::Allow => write!(f, "Allow"),
137 Self::Custom(_) => write!(f, "Custom(...)"),
138 }
139 }
140}
141
142impl Clone for TenantFailureMode {
143 fn clone(&self) -> Self {
144 match self {
145 Self::NotFound => Self::NotFound,
146 Self::Forbidden => Self::Forbidden,
147 Self::Allow => Self::Allow,
148 Self::Custom(_) => panic!("TenantFailureMode::Custom cannot be cloned"),
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::tenant::context::{tenant_scope, with_tenant_scope};
157 use hyper_util::rt::TokioIo;
158 use tokio::sync::oneshot;
159
160 fn make_tenant(id: i64, slug: &str) -> TenantContext {
161 TenantContext {
162 id,
163 slug: slug.to_string(),
164 name: format!("Tenant {slug}"),
165 plan: None,
166 #[cfg(feature = "stripe")]
167 subscription: None,
168 }
169 }
170
171 async fn make_request() -> Request {
176 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
177 let addr = listener.local_addr().unwrap();
178
179 let (tx, rx) = oneshot::channel::<Request>();
180 let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
181
182 tokio::spawn(async move {
183 if let Ok((stream, _)) = listener.accept().await {
184 let io = TokioIo::new(stream);
185 let tx_holder = tx_holder.clone();
186 hyper::server::conn::http1::Builder::new()
187 .serve_connection(
188 io,
189 hyper::service::service_fn(move |req| {
190 let tx_holder = tx_holder.clone();
191 async move {
192 if let Some(tx) = tx_holder.lock().unwrap().take() {
193 let _ = tx.send(Request::new(req));
194 }
195 Ok::<_, hyper::Error>(hyper::Response::new(
196 http_body_util::Empty::<bytes::Bytes>::new(),
197 ))
198 }
199 }),
200 )
201 .await
202 .ok();
203 }
204 });
205
206 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
207 let io = TokioIo::new(stream);
208 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
209 tokio::spawn(async move { conn.await.ok() });
210
211 let req = hyper::Request::builder()
212 .uri("/test")
213 .body(http_body_util::Empty::<bytes::Bytes>::new())
214 .unwrap();
215 let _ = sender.send_request(req).await;
216 rx.await.unwrap()
217 }
218
219 #[tokio::test]
221 async fn from_request_returns_ok_when_tenant_context_is_set() {
222 let ctx = tenant_scope();
223 {
224 let mut guard = ctx.write().await;
225 *guard = Some(make_tenant(99, "acme"));
226 }
227
228 let result = with_tenant_scope(ctx, async {
229 let req = make_request().await;
230 TenantContext::from_request(req).await
231 })
232 .await;
233
234 assert!(
235 result.is_ok(),
236 "Expected Ok(TenantContext), got: {result:?}"
237 );
238 let tenant = result.unwrap();
239 assert_eq!(tenant.id, 99);
240 assert_eq!(tenant.slug, "acme");
241 }
242
243 #[tokio::test]
245 async fn from_request_returns_400_error_when_no_tenant_context() {
246 let req = make_request().await;
248 let result = TenantContext::from_request(req).await;
249
250 assert!(result.is_err(), "Expected Err when no tenant context");
251 let err = result.unwrap_err();
252 assert_eq!(
253 err.status_code(),
254 400,
255 "Expected 400 status code, got: {}",
256 err.status_code()
257 );
258 }
259
260 #[cfg(feature = "stripe")]
261 mod stripe_tests {
262 use super::*;
263 use ferro_stripe::{SubscriptionInfo, SubscriptionStatus};
264
265 fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
266 SubscriptionInfo {
267 stripe_subscription_id: "sub_test".to_string(),
268 plan: plan.to_string(),
269 status,
270 trial_ends_at: None,
271 cancel_at_period_end: false,
272 current_period_end: chrono::Utc::now(),
273 stripe_connect_account_id: None,
274 }
275 }
276
277 fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
278 TenantContext {
279 id: 1,
280 slug: "acme".to_string(),
281 name: "ACME Corp".to_string(),
282 plan: Some(plan.to_string()),
283 subscription: Some(make_subscription(plan, status)),
284 }
285 }
286
287 fn make_tenant_no_subscription() -> TenantContext {
288 TenantContext {
289 id: 1,
290 slug: "acme".to_string(),
291 name: "ACME Corp".to_string(),
292 plan: None,
293 subscription: None,
294 }
295 }
296
297 #[test]
298 fn tenant_with_none_subscription_serializes_with_null_subscription() {
299 let tenant = make_tenant_no_subscription();
300 let json = serde_json::to_value(&tenant).unwrap();
301 assert!(json["subscription"].is_null());
302 }
303
304 #[test]
305 fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
306 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
307 let json = serde_json::to_value(&tenant).unwrap();
308 assert!(!json["subscription"].is_null());
309 assert_eq!(json["subscription"]["plan"], "pro");
310 assert_eq!(json["subscription"]["status"], "active");
311 }
312
313 #[test]
314 fn on_trial_returns_true_when_subscription_is_trialing() {
315 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
316 assert!(tenant.on_trial());
317 }
318
319 #[test]
320 fn on_trial_returns_false_when_subscription_is_active() {
321 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
322 assert!(!tenant.on_trial());
323 }
324
325 #[test]
326 fn on_trial_returns_false_when_no_subscription() {
327 let tenant = make_tenant_no_subscription();
328 assert!(!tenant.on_trial());
329 }
330
331 #[test]
332 fn subscribed_returns_true_when_active() {
333 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
334 assert!(tenant.subscribed());
335 }
336
337 #[test]
338 fn subscribed_returns_false_when_canceled() {
339 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
340 assert!(!tenant.subscribed());
341 }
342
343 #[test]
344 fn subscribed_returns_false_when_no_subscription() {
345 let tenant = make_tenant_no_subscription();
346 assert!(!tenant.subscribed());
347 }
348
349 #[test]
350 fn on_grace_period_returns_false_when_no_subscription() {
351 let tenant = make_tenant_no_subscription();
352 assert!(!tenant.on_grace_period());
353 }
354
355 #[test]
356 fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
357 let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
358 if let Some(ref mut sub) = tenant.subscription {
359 sub.cancel_at_period_end = true;
360 }
361 assert!(tenant.on_grace_period());
362 }
363
364 #[test]
365 fn current_plan_returns_subscription_plan_when_present() {
366 let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
367 assert_eq!(tenant.current_plan(), Some("enterprise"));
368 }
369
370 #[test]
371 fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
372 let tenant = TenantContext {
373 id: 1,
374 slug: "acme".to_string(),
375 name: "ACME Corp".to_string(),
376 plan: Some("pro".to_string()),
377 subscription: None,
378 };
379 assert_eq!(tenant.current_plan(), Some("pro"));
380 }
381
382 #[test]
383 fn current_plan_returns_none_when_neither_is_set() {
384 let tenant = make_tenant_no_subscription();
385 assert_eq!(tenant.current_plan(), None);
386 }
387 }
388}