1pub mod context;
15pub mod lookup;
16pub mod middleware;
17#[cfg(feature = "stripe")]
18pub mod requires_plan;
19pub mod resolver;
20pub mod scope;
21#[cfg(feature = "stripe")]
22pub mod subscription;
23pub mod worker;
24
25pub use context::current_tenant;
26pub use lookup::{DbTenantLookup, TenantLookup};
27pub use middleware::TenantMiddleware;
28#[cfg(feature = "stripe")]
29pub use requires_plan::RequiresPlan;
30pub use resolver::{
31 HeaderResolver, JwtClaimResolver, PathResolver, SubdomainResolver, TenantResolver,
32};
33pub use scope::TenantScope;
34pub use worker::FrameworkTenantScopeProvider;
35
36use crate::error::FrameworkError;
37use crate::http::{FromRequest, Request};
38use async_trait::async_trait;
39
40#[derive(Debug, Clone, serde::Serialize)]
46pub struct TenantContext {
47 pub id: i64,
49 pub slug: String,
51 pub name: String,
53 pub plan: Option<String>,
55 #[cfg(feature = "stripe")]
57 pub subscription: Option<subscription::SubscriptionInfo>,
58}
59
60impl TenantContext {
61 pub fn new(id: i64, slug: String, name: String, plan: Option<String>) -> Self {
68 Self {
69 id,
70 slug,
71 name,
72 plan,
73 #[cfg(feature = "stripe")]
74 subscription: None,
75 }
76 }
77}
78
79#[cfg(feature = "stripe")]
80impl TenantContext {
81 pub fn on_trial(&self) -> bool {
85 self.subscription.as_ref().is_some_and(|s| s.on_trial())
86 }
87
88 pub fn subscribed(&self) -> bool {
92 self.subscription.as_ref().is_some_and(|s| s.subscribed())
93 }
94
95 pub fn on_grace_period(&self) -> bool {
99 self.subscription
100 .as_ref()
101 .is_some_and(|s| s.on_grace_period())
102 }
103
104 pub fn current_plan(&self) -> Option<&str> {
108 self.subscription
109 .as_ref()
110 .map(|s| s.plan.as_str())
111 .or(self.plan.as_deref())
112 }
113}
114
115#[async_trait]
129impl FromRequest for TenantContext {
130 async fn from_request(_req: Request) -> Result<Self, FrameworkError> {
131 current_tenant().ok_or_else(|| {
132 FrameworkError::domain(
133 "No tenant context available. Ensure this route is behind TenantMiddleware.",
134 400,
135 )
136 })
137 }
138}
139
140pub enum TenantFailureMode {
142 NotFound,
144 Forbidden,
146 Allow,
148 Custom(Box<dyn Fn() -> crate::http::Response + Send + Sync>),
150}
151
152impl std::fmt::Debug for TenantFailureMode {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 match self {
155 Self::NotFound => write!(f, "NotFound"),
156 Self::Forbidden => write!(f, "Forbidden"),
157 Self::Allow => write!(f, "Allow"),
158 Self::Custom(_) => write!(f, "Custom(...)"),
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::tenant::context::{tenant_scope, with_tenant_scope};
167 use hyper_util::rt::TokioIo;
168 use tokio::sync::oneshot;
169
170 fn make_tenant(id: i64, slug: &str) -> TenantContext {
171 TenantContext {
172 id,
173 slug: slug.to_string(),
174 name: format!("Tenant {slug}"),
175 plan: None,
176 #[cfg(feature = "stripe")]
177 subscription: None,
178 }
179 }
180
181 async fn make_request() -> Request {
186 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
187 let addr = listener.local_addr().unwrap();
188
189 let (tx, rx) = oneshot::channel::<Request>();
190 let tx_holder = std::sync::Arc::new(std::sync::Mutex::new(Some(tx)));
191
192 tokio::spawn(async move {
193 if let Ok((stream, _)) = listener.accept().await {
194 let io = TokioIo::new(stream);
195 let tx_holder = tx_holder.clone();
196 hyper::server::conn::http1::Builder::new()
197 .serve_connection(
198 io,
199 hyper::service::service_fn(move |req| {
200 let tx_holder = tx_holder.clone();
201 async move {
202 if let Some(tx) = tx_holder.lock().unwrap().take() {
203 let _ = tx.send(Request::new(req));
204 }
205 Ok::<_, hyper::Error>(hyper::Response::new(
206 http_body_util::Empty::<bytes::Bytes>::new(),
207 ))
208 }
209 }),
210 )
211 .await
212 .ok();
213 }
214 });
215
216 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
217 let io = TokioIo::new(stream);
218 let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
219 tokio::spawn(async move { conn.await.ok() });
220
221 let req = hyper::Request::builder()
222 .uri("/test")
223 .body(http_body_util::Empty::<bytes::Bytes>::new())
224 .unwrap();
225 let _ = sender.send_request(req).await;
226 rx.await.unwrap()
227 }
228
229 #[tokio::test]
231 async fn from_request_returns_ok_when_tenant_context_is_set() {
232 let ctx = tenant_scope();
233 {
234 let mut guard = ctx.write().await;
235 *guard = Some(make_tenant(99, "acme"));
236 }
237
238 let result = with_tenant_scope(ctx, async {
239 let req = make_request().await;
240 TenantContext::from_request(req).await
241 })
242 .await;
243
244 assert!(
245 result.is_ok(),
246 "Expected Ok(TenantContext), got: {result:?}"
247 );
248 let tenant = result.unwrap();
249 assert_eq!(tenant.id, 99);
250 assert_eq!(tenant.slug, "acme");
251 }
252
253 #[tokio::test]
255 async fn from_request_returns_400_error_when_no_tenant_context() {
256 let req = make_request().await;
258 let result = TenantContext::from_request(req).await;
259
260 assert!(result.is_err(), "Expected Err when no tenant context");
261 let err = result.unwrap_err();
262 assert_eq!(
263 err.status_code(),
264 400,
265 "Expected 400 status code, got: {}",
266 err.status_code()
267 );
268 }
269
270 #[cfg(feature = "stripe")]
271 mod stripe_tests {
272 use super::*;
273 use crate::tenant::subscription::{SubscriptionInfo, SubscriptionStatus};
274
275 fn make_subscription(plan: &str, status: SubscriptionStatus) -> SubscriptionInfo {
276 SubscriptionInfo {
277 stripe_subscription_id: "sub_test".to_string(),
278 plan: plan.to_string(),
279 status,
280 trial_ends_at: None,
281 cancel_at_period_end: false,
282 current_period_end: chrono::Utc::now(),
283 stripe_connect_account_id: None,
284 }
285 }
286
287 fn make_tenant_with_subscription(plan: &str, status: SubscriptionStatus) -> TenantContext {
288 TenantContext {
289 id: 1,
290 slug: "acme".to_string(),
291 name: "ACME Corp".to_string(),
292 plan: Some(plan.to_string()),
293 subscription: Some(make_subscription(plan, status)),
294 }
295 }
296
297 fn make_tenant_no_subscription() -> TenantContext {
298 TenantContext {
299 id: 1,
300 slug: "acme".to_string(),
301 name: "ACME Corp".to_string(),
302 plan: None,
303 subscription: None,
304 }
305 }
306
307 #[test]
308 fn tenant_with_none_subscription_serializes_with_null_subscription() {
309 let tenant = make_tenant_no_subscription();
310 let json = serde_json::to_value(&tenant).unwrap();
311 assert!(json["subscription"].is_null());
312 }
313
314 #[test]
315 fn tenant_with_some_subscription_serializes_with_full_subscription_object() {
316 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
317 let json = serde_json::to_value(&tenant).unwrap();
318 assert!(!json["subscription"].is_null());
319 assert_eq!(json["subscription"]["plan"], "pro");
320 assert_eq!(json["subscription"]["status"], "active");
321 }
322
323 #[test]
324 fn on_trial_returns_true_when_subscription_is_trialing() {
325 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Trialing);
326 assert!(tenant.on_trial());
327 }
328
329 #[test]
330 fn on_trial_returns_false_when_subscription_is_active() {
331 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
332 assert!(!tenant.on_trial());
333 }
334
335 #[test]
336 fn on_trial_returns_false_when_no_subscription() {
337 let tenant = make_tenant_no_subscription();
338 assert!(!tenant.on_trial());
339 }
340
341 #[test]
342 fn subscribed_returns_true_when_active() {
343 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
344 assert!(tenant.subscribed());
345 }
346
347 #[test]
348 fn subscribed_returns_false_when_canceled() {
349 let tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Canceled);
350 assert!(!tenant.subscribed());
351 }
352
353 #[test]
354 fn subscribed_returns_false_when_no_subscription() {
355 let tenant = make_tenant_no_subscription();
356 assert!(!tenant.subscribed());
357 }
358
359 #[test]
360 fn on_grace_period_returns_false_when_no_subscription() {
361 let tenant = make_tenant_no_subscription();
362 assert!(!tenant.on_grace_period());
363 }
364
365 #[test]
366 fn on_grace_period_returns_true_when_cancel_at_period_end_and_active() {
367 let mut tenant = make_tenant_with_subscription("pro", SubscriptionStatus::Active);
368 if let Some(ref mut sub) = tenant.subscription {
369 sub.cancel_at_period_end = true;
370 }
371 assert!(tenant.on_grace_period());
372 }
373
374 #[test]
375 fn current_plan_returns_subscription_plan_when_present() {
376 let tenant = make_tenant_with_subscription("enterprise", SubscriptionStatus::Active);
377 assert_eq!(tenant.current_plan(), Some("enterprise"));
378 }
379
380 #[test]
381 fn current_plan_falls_back_to_legacy_plan_when_no_subscription() {
382 let tenant = TenantContext {
383 id: 1,
384 slug: "acme".to_string(),
385 name: "ACME Corp".to_string(),
386 plan: Some("pro".to_string()),
387 subscription: None,
388 };
389 assert_eq!(tenant.current_plan(), Some("pro"));
390 }
391
392 #[test]
393 fn current_plan_returns_none_when_neither_is_set() {
394 let tenant = make_tenant_no_subscription();
395 assert_eq!(tenant.current_plan(), None);
396 }
397 }
398}