1pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21#[cfg(feature = "stripe")]
22pub mod subscription;
23pub mod worker;
24
25pub use context::current_tenant;
26pub use lookup::{DbTenantLookup, TenantLookup};
27pub use middleware::TenantMiddleware;
28#[cfg(feature = "stripe")]
29pub use requires_plan::RequiresPlan;
30pub use resolver::{
31 HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
32};
33pub use scope::TenantScope;
34pub use worker::FrameworkTenantScopeProvider;
35
36use crate::error::FrameworkError;
37use crate::http::{FromRequest, Request};
38use async_trait::async_trait;
39
40#[derive(Debug, Clone, serde::Serialize)]
46pub struct TenantContext {
47 pub id: i64,
49 pub slug: String,
51 pub name: String,
53 pub plan: Option<String>,
55 #[cfg(feature = "stripe")]
57 pub subscription: Option<subscription::SubscriptionInfo>,
58}
59
60#[cfg(feature = "stripe")]
61impl TenantContext {
62 pub fn on_trial(&self) -> bool {
66 self.subscription.as_ref().is_some_and(|s| s.on_trial())
67 }
68
69 pub fn subscribed(&self) -> bool {
73 self.subscription.as_ref().is_some_and(|s| s.subscribed())
74 }
75
76 pub fn on_grace_period(&self) -> bool {
80 self.subscription
81 .as_ref()
82 .is_some_and(|s| s.on_grace_period())
83 }
84
85 pub fn current_plan(&self) -> Option<&str> {
89 self.subscription
90 .as_ref()
91 .map(|s| s.plan.as_str())
92 .or(self.plan.as_deref())
93 }
94}
95
96#[async_trait]
110impl FromRequest for TenantContext {
111 async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
112 current_tenant().ok_or_else(|| {
113 FrameworkError::domain(
114 "No tenant context available. Ensure this route is behind TenantMiddleware.",
115 400,
116 )
117 })
118 }
119}
120
121pub enum TenantFailureMode {
123 NotFound,
125 Forbidden,
127 Allow,
129 Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
131}
132
133impl std::fmt::Debug for TenantFailureMode {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Self::NotFound => write!(f, "NotFound"),
137 Self::Forbidden => write!(f, "Forbidden"),
138 Self::Allow => write!(f, "Allow"),
139 Self::Custom(_) => write!(f, "Custom(...)"),
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::tenant::context::{tenant_scope, with_tenant_scope};
148 use hyper_util::rt::TokioIo;
149 use tokio::sync::oneshot;
150
151 fn make_tenant(id: i64, slug: &str) -> TenantContext {
152 TenantContext {
153 id,
154 slug: slug.to_string(),
155 name: format!("Tenant {slug}"),
156 plan: None,
157 #[cfg(feature = "stripe")]
158 subscription: None,
159 }
160 }
161
162 async fn make_request() -> Request {
167 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
168 let addr = listener.local_addr().unwrap();
169
170 let (tx, rx) = oneshot::channel::<Request>();
171 let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
172
173 tokio::spawn(async move {
174 if let Ok((stream, _)) = listener.accept().await {
175 let io = TokioIo::new(stream);
176 let tx_holder = tx_holder.clone();
177 hyper::server::conn::http1::Builder::new()
178 .serve_connection(
179 io,
180 hyper::service::service_fn(move |req| {
181 let tx_holder = tx_holder.clone();
182 async move {
183 if let Some(tx) = tx_holder.lock().unwrap().take() {
184 let _ = tx.send(Request::new(req));
185 }
186 Ok::<_, hyper::Error>(hyper::Response::new(
187 http_body_util::Empty::<bytes::Bytes>::new(),
188 ))
189 }
190 }),
191 )
192 .await
193 .ok();
194 }
195 });
196
197 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
198 let io = TokioIo::new(stream);
199 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
200 tokio::spawn(async move { conn.await.ok() });
201
202 let req = hyper::Request::builder()
203 .uri("/test")
204 .body(http_body_util::Empty::<bytes::Bytes>::new())
205 .unwrap();
206 let _ = sender.send_request(req).await;
207 rx.await.unwrap()
208 }
209
210 #[tokio::test]
212 async fn from_request_returns_ok_when_tenant_context_is_set() {
213 let ctx = tenant_scope();
214 {
215 let mut guard = ctx.write().await;
216 *guard = Some(make_tenant(99, "acme"));
217 }
218
219 let result = with_tenant_scope(ctx, async {
220 let req = make_request().await;
221 TenantContext::from_request(req).await
222 })
223 .await;
224
225 assert!(
226 result.is_ok(),
227 "Expected Ok(TenantContext), got: {result:?}"
228 );
229 let tenant = result.unwrap();
230 assert_eq!(tenant.id, 99);
231 assert_eq!(tenant.slug, "acme");
232 }
233
234 #[tokio::test]
236 async fn from_request_returns_400_error_when_no_tenant_context() {
237 let req = make_request().await;
239 let result = TenantContext::from_request(req).await;
240
241 assert!(result.is_err(), "Expected Err when no tenant context");
242 let err = result.unwrap_err();
243 assert_eq!(
244 err.status_code(),
245 400,
246 "Expected 400 status code, got: {}",
247 err.status_code()
248 );
249 }
250
251 #[cfg(feature = "stripe")]
252 mod stripe_tests {
253 use super::*;
254 use crate::tenant::subscription::{SubscriptionInfo, SubscriptionStatus};
255
256 fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
257 SubscriptionInfo {
258 stripe_subscription_id: "sub_test".to_string(),
259 plan: plan.to_string(),
260 status,
261 trial_ends_at: None,
262 cancel_at_period_end: false,
263 current_period_end: chrono::Utc::now(),
264 stripe_connect_account_id: None,
265 }
266 }
267
268 fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
269 TenantContext {
270 id: 1,
271 slug: "acme".to_string(),
272 name: "ACME Corp".to_string(),
273 plan: Some(plan.to_string()),
274 subscription: Some(make_subscription(plan, status)),
275 }
276 }
277
278 fn make_tenant_no_subscription() -> TenantContext {
279 TenantContext {
280 id: 1,
281 slug: "acme".to_string(),
282 name: "ACME Corp".to_string(),
283 plan: None,
284 subscription: None,
285 }
286 }
287
288 #[test]
289 fn tenant_with_none_subscription_serializes_with_null_subscription() {
290 let tenant = make_tenant_no_subscription();
291 let json = serde_json::to_value(&tenant).unwrap();
292 assert!(json["subscription"].is_null());
293 }
294
295 #[test]
296 fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
297 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
298 let json = serde_json::to_value(&tenant).unwrap();
299 assert!(!json["subscription"].is_null());
300 assert_eq!(json["subscription"]["plan"], "pro");
301 assert_eq!(json["subscription"]["status"], "active");
302 }
303
304 #[test]
305 fn on_trial_returns_true_when_subscription_is_trialing() {
306 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
307 assert!(tenant.on_trial());
308 }
309
310 #[test]
311 fn on_trial_returns_false_when_subscription_is_active() {
312 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
313 assert!(!tenant.on_trial());
314 }
315
316 #[test]
317 fn on_trial_returns_false_when_no_subscription() {
318 let tenant = make_tenant_no_subscription();
319 assert!(!tenant.on_trial());
320 }
321
322 #[test]
323 fn subscribed_returns_true_when_active() {
324 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
325 assert!(tenant.subscribed());
326 }
327
328 #[test]
329 fn subscribed_returns_false_when_canceled() {
330 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
331 assert!(!tenant.subscribed());
332 }
333
334 #[test]
335 fn subscribed_returns_false_when_no_subscription() {
336 let tenant = make_tenant_no_subscription();
337 assert!(!tenant.subscribed());
338 }
339
340 #[test]
341 fn on_grace_period_returns_false_when_no_subscription() {
342 let tenant = make_tenant_no_subscription();
343 assert!(!tenant.on_grace_period());
344 }
345
346 #[test]
347 fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
348 let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
349 if let Some(ref mut sub) = tenant.subscription {
350 sub.cancel_at_period_end = true;
351 }
352 assert!(tenant.on_grace_period());
353 }
354
355 #[test]
356 fn current_plan_returns_subscription_plan_when_present() {
357 let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
358 assert_eq!(tenant.current_plan(), Some("enterprise"));
359 }
360
361 #[test]
362 fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
363 let tenant = TenantContext {
364 id: 1,
365 slug: "acme".to_string(),
366 name: "ACME Corp".to_string(),
367 plan: Some("pro".to_string()),
368 subscription: None,
369 };
370 assert_eq!(tenant.current_plan(), Some("pro"));
371 }
372
373 #[test]
374 fn current_plan_returns_none_when_neither_is_set() {
375 let tenant = make_tenant_no_subscription();
376 assert_eq!(tenant.current_plan(), None);
377 }
378 }
379}