1pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21pub mod scoped;
22#[cfg(feature = "stripe")]
23pub mod subscription;
24pub mod worker;
25
26pub use context::current_tenant;
27pub use lookup::{DbTenantLookup, TenantLookup};
28pub use middleware::TenantMiddleware;
29#[cfg(feature = "stripe")]
30pub use requires_plan::RequiresPlan;
31pub use resolver::{
32 HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
33};
34pub use scope::TenantScope;
35pub use scoped::TenantScoped;
36pub use worker::FrameworkTenantScopeProvider;
37
38use crate::error::FrameworkError;
39use crate::http::{FromRequest, Request};
40use async_trait::async_trait;
41
42#[derive(Debug, Clone, serde::Serialize)]
48pub struct TenantContext {
49 pub id: i64,
51 pub slug: String,
53 pub name: String,
55 pub plan: Option<String>,
57 #[cfg(feature = "stripe")]
59 pub subscription: Option<subscription::SubscriptionInfo>,
60}
61
62impl TenantContext {
63 pub fn new(id: i64, slug: String, name: String, plan: Option<String>) -> Self {
70 Self {
71 id,
72 slug,
73 name,
74 plan,
75 #[cfg(feature = "stripe")]
76 subscription: None,
77 }
78 }
79}
80
81#[cfg(feature = "stripe")]
82impl TenantContext {
83 pub fn on_trial(&self) -> bool {
87 self.subscription.as_ref().is_some_and(|s| s.on_trial())
88 }
89
90 pub fn subscribed(&self) -> bool {
94 self.subscription.as_ref().is_some_and(|s| s.subscribed())
95 }
96
97 pub fn on_grace_period(&self) -> bool {
101 self.subscription
102 .as_ref()
103 .is_some_and(|s| s.on_grace_period())
104 }
105
106 pub fn current_plan(&self) -> Option<&str> {
110 self.subscription
111 .as_ref()
112 .map(|s| s.plan.as_str())
113 .or(self.plan.as_deref())
114 }
115}
116
117#[async_trait]
131impl FromRequest for TenantContext {
132 async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
133 current_tenant().ok_or_else(|| {
134 FrameworkError::domain(
135 "No tenant context available. Ensure this route is behind TenantMiddleware.",
136 400,
137 )
138 })
139 }
140}
141
142pub enum TenantFailureMode {
144 NotFound,
146 Forbidden,
148 Allow,
150 Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
152}
153
154impl std::fmt::Debug for TenantFailureMode {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match self {
157 Self::NotFound => write!(f, "NotFound"),
158 Self::Forbidden => write!(f, "Forbidden"),
159 Self::Allow => write!(f, "Allow"),
160 Self::Custom(_) => write!(f, "Custom(...)"),
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::tenant::context::{tenant_scope, with_tenant_scope};
169 use hyper_util::rt::TokioIo;
170 use tokio::sync::oneshot;
171
172 fn make_tenant(id: i64, slug: &str) -> TenantContext {
173 TenantContext {
174 id,
175 slug: slug.to_string(),
176 name: format!("Tenant {slug}"),
177 plan: None,
178 #[cfg(feature = "stripe")]
179 subscription: None,
180 }
181 }
182
183 async fn make_request() -> Request {
188 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
189 let addr = listener.local_addr().unwrap();
190
191 let (tx, rx) = oneshot::channel::<Request>();
192 let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
193
194 tokio::spawn(async move {
195 if let Ok((stream, _)) = listener.accept().await {
196 let io = TokioIo::new(stream);
197 let tx_holder = tx_holder.clone();
198 hyper::server::conn::http1::Builder::new()
199 .serve_connection(
200 io,
201 hyper::service::service_fn(move |req| {
202 let tx_holder = tx_holder.clone();
203 async move {
204 if let Some(tx) = tx_holder.lock().unwrap().take() {
205 let _ = tx.send(Request::new(req));
206 }
207 Ok::<_, hyper::Error>(hyper::Response::new(
208 http_body_util::Empty::<bytes::Bytes>::new(),
209 ))
210 }
211 }),
212 )
213 .await
214 .ok();
215 }
216 });
217
218 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
219 let io = TokioIo::new(stream);
220 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
221 tokio::spawn(async move { conn.await.ok() });
222
223 let req = hyper::Request::builder()
224 .uri("/test")
225 .body(http_body_util::Empty::<bytes::Bytes>::new())
226 .unwrap();
227 let _ = sender.send_request(req).await;
228 rx.await.unwrap()
229 }
230
231 #[tokio::test]
233 async fn from_request_returns_ok_when_tenant_context_is_set() {
234 let ctx = tenant_scope();
235 {
236 let mut guard = ctx.write().await;
237 *guard = Some(make_tenant(99, "acme"));
238 }
239
240 let result = with_tenant_scope(ctx, async {
241 let req = make_request().await;
242 TenantContext::from_request(req).await
243 })
244 .await;
245
246 assert!(
247 result.is_ok(),
248 "Expected Ok(TenantContext), got: {result:?}"
249 );
250 let tenant = result.unwrap();
251 assert_eq!(tenant.id, 99);
252 assert_eq!(tenant.slug, "acme");
253 }
254
255 #[tokio::test]
257 async fn from_request_returns_400_error_when_no_tenant_context() {
258 let req = make_request().await;
260 let result = TenantContext::from_request(req).await;
261
262 assert!(result.is_err(), "Expected Err when no tenant context");
263 let err = result.unwrap_err();
264 assert_eq!(
265 err.status_code(),
266 400,
267 "Expected 400 status code, got: {}",
268 err.status_code()
269 );
270 }
271
272 #[cfg(feature = "stripe")]
273 mod stripe_tests {
274 use super::*;
275 use crate::tenant::subscription::{SubscriptionInfo, SubscriptionStatus};
276
277 fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
278 SubscriptionInfo {
279 stripe_subscription_id: "sub_test".to_string(),
280 plan: plan.to_string(),
281 status,
282 trial_ends_at: None,
283 cancel_at_period_end: false,
284 current_period_end: chrono::Utc::now(),
285 stripe_connect_account_id: None,
286 }
287 }
288
289 fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
290 TenantContext {
291 id: 1,
292 slug: "acme".to_string(),
293 name: "ACME Corp".to_string(),
294 plan: Some(plan.to_string()),
295 subscription: Some(make_subscription(plan, status)),
296 }
297 }
298
299 fn make_tenant_no_subscription() -> TenantContext {
300 TenantContext {
301 id: 1,
302 slug: "acme".to_string(),
303 name: "ACME Corp".to_string(),
304 plan: None,
305 subscription: None,
306 }
307 }
308
309 #[test]
310 fn tenant_with_none_subscription_serializes_with_null_subscription() {
311 let tenant = make_tenant_no_subscription();
312 let json = serde_json::to_value(&tenant).unwrap();
313 assert!(json["subscription"].is_null());
314 }
315
316 #[test]
317 fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
318 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
319 let json = serde_json::to_value(&tenant).unwrap();
320 assert!(!json["subscription"].is_null());
321 assert_eq!(json["subscription"]["plan"], "pro");
322 assert_eq!(json["subscription"]["status"], "active");
323 }
324
325 #[test]
326 fn on_trial_returns_true_when_subscription_is_trialing() {
327 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
328 assert!(tenant.on_trial());
329 }
330
331 #[test]
332 fn on_trial_returns_false_when_subscription_is_active() {
333 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
334 assert!(!tenant.on_trial());
335 }
336
337 #[test]
338 fn on_trial_returns_false_when_no_subscription() {
339 let tenant = make_tenant_no_subscription();
340 assert!(!tenant.on_trial());
341 }
342
343 #[test]
344 fn subscribed_returns_true_when_active() {
345 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
346 assert!(tenant.subscribed());
347 }
348
349 #[test]
350 fn subscribed_returns_false_when_canceled() {
351 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
352 assert!(!tenant.subscribed());
353 }
354
355 #[test]
356 fn subscribed_returns_false_when_no_subscription() {
357 let tenant = make_tenant_no_subscription();
358 assert!(!tenant.subscribed());
359 }
360
361 #[test]
362 fn on_grace_period_returns_false_when_no_subscription() {
363 let tenant = make_tenant_no_subscription();
364 assert!(!tenant.on_grace_period());
365 }
366
367 #[test]
368 fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
369 let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
370 if let Some(ref mut sub) = tenant.subscription {
371 sub.cancel_at_period_end = true;
372 }
373 assert!(tenant.on_grace_period());
374 }
375
376 #[test]
377 fn current_plan_returns_subscription_plan_when_present() {
378 let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
379 assert_eq!(tenant.current_plan(), Some("enterprise"));
380 }
381
382 #[test]
383 fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
384 let tenant = TenantContext {
385 id: 1,
386 slug: "acme".to_string(),
387 name: "ACME Corp".to_string(),
388 plan: Some("pro".to_string()),
389 subscription: None,
390 };
391 assert_eq!(tenant.current_plan(), Some("pro"));
392 }
393
394 #[test]
395 fn current_plan_returns_none_when_neither_is_set() {
396 let tenant = make_tenant_no_subscription();
397 assert_eq!(tenant.current_plan(), None);
398 }
399 }
400}