1use axum::{
2 extract::State,
3 http::Request,
4 middleware::Next,
5 response::{IntoResponse, Response},
6};
7use http_body::Body as HttpBody;
8use pin_project_lite::pin_project;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13tokio::task_local! {
15 pub static CURRENT_TENANT: Option<String>;
16}
17
18#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct Tenant(pub String);
21
22impl axum::extract::FromRequestParts<crate::AppState> for Tenant {
23 type Rejection = crate::AutumnError;
24
25 async fn from_request_parts(
26 parts: &mut axum::http::request::Parts,
27 state: &crate::AppState,
28 ) -> Result<Self, Self::Rejection> {
29 let config = state
30 .extension::<crate::config::AutumnConfig>()
31 .ok_or_else(|| {
32 crate::AutumnError::service_unavailable_msg("Config is not available")
33 })?;
34 let tenant_id = extract_tenant_from_parts(parts, &config).await?;
35 Ok(Self(tenant_id))
36 }
37}
38
39pub async fn with_tenant<F, R>(tenant_id: String, future: F) -> R
41where
42 F: Future<Output = R>,
43{
44 CURRENT_TENANT.scope(Some(tenant_id), future).await
45}
46
47#[allow(clippy::missing_errors_doc, clippy::too_many_lines)]
49pub async fn extract_tenant_from_parts(
50 parts: &mut axum::http::request::Parts,
51 config: &crate::config::AutumnConfig,
52) -> Result<String, crate::AutumnError> {
53 if !config.tenancy.enabled {
54 return Err(crate::AutumnError::bad_request_msg("Tenancy is disabled"));
55 }
56
57 match config.tenancy.source.as_str() {
58 "header" => {
59 let header_value = parts
60 .headers
61 .get(&config.tenancy.header_name)
62 .ok_or_else(|| {
63 crate::AutumnError::bad_request_msg(format!(
64 "Missing required tenant header: {}",
65 config.tenancy.header_name
66 ))
67 })?;
68 let val = header_value
69 .to_str()
70 .map_err(|_| {
71 crate::AutumnError::bad_request_msg(format!(
72 "Invalid UTF-8 in tenant header: {}",
73 config.tenancy.header_name
74 ))
75 })?
76 .to_string();
77 if val.trim().is_empty() {
78 return Err(crate::AutumnError::bad_request_msg(format!(
79 "Tenant header {} is empty",
80 config.tenancy.header_name
81 )));
82 }
83 Ok(val)
84 }
85 "subdomain" => {
86 let host_owned: String = parts
89 .extensions
90 .get::<crate::security::ResolvedClientIdentity>()
91 .and_then(|id| id.host.clone())
92 .map_or_else(
93 || {
94 parts
95 .headers
96 .get(axum::http::header::HOST)
97 .ok_or_else(|| {
98 crate::AutumnError::bad_request_msg(
99 "Missing Host header for subdomain tenancy",
100 )
101 })
102 .and_then(|h| {
103 h.to_str().map(ToOwned::to_owned).map_err(|_| {
104 crate::AutumnError::bad_request_msg(
105 "Invalid UTF-8 in Host header",
106 )
107 })
108 })
109 },
110 Ok,
111 )?;
112
113 let host = host_owned.as_str();
114 let host_only = host.split(':').next().unwrap_or(host).trim();
115
116 if host_only.parse::<std::net::IpAddr>().is_ok() {
117 return Err(crate::AutumnError::bad_request_msg(
118 "IP address host not allowed in subdomain mode",
119 ));
120 }
121
122 let host_lower = host_only.to_lowercase();
125
126 if let Some(ref base_domain) = config.tenancy.base_domain {
127 let base_domain_clean = base_domain.trim().to_lowercase();
128 if !host_lower.ends_with(base_domain_clean.as_str()) {
129 return Err(crate::AutumnError::bad_request_msg(format!(
130 "Host does not match base domain: {base_domain_clean}"
131 )));
132 }
133 if host_lower.len() <= base_domain_clean.len() {
134 return Err(crate::AutumnError::bad_request_msg(
135 "Apex domain not allowed in subdomain mode",
136 ));
137 }
138 let prefix_len = host_lower.len() - base_domain_clean.len();
139 if !host_lower[..prefix_len].ends_with('.') {
140 return Err(crate::AutumnError::bad_request_msg(
141 "Invalid subdomain format",
142 ));
143 }
144 let subdomain_part = &host_lower[..prefix_len - 1];
145 let tenant = subdomain_part.split('.').next().ok_or_else(|| {
146 crate::AutumnError::bad_request_msg("Unable to extract subdomain tenant")
147 })?;
148 if tenant.trim().is_empty() {
149 return Err(crate::AutumnError::bad_request_msg(
150 "Extracted subdomain tenant is empty",
151 ));
152 }
153 Ok(tenant.to_string())
154 } else {
155 let labels: Vec<&str> = host_lower.split('.').filter(|s| !s.is_empty()).collect();
156 if labels.is_empty() {
157 return Err(crate::AutumnError::bad_request_msg("Empty host header"));
158 }
159
160 if labels.len() < 2 {
161 return Err(crate::AutumnError::bad_request_msg(
162 "Apex or local host without subdomain not allowed",
163 ));
164 }
165
166 if labels.len() == 2 && labels[1] != "localhost" {
167 return Err(crate::AutumnError::bad_request_msg(
168 "Apex domain not allowed in subdomain mode",
169 ));
170 }
171
172 let tenant = labels[0].to_string();
173 if tenant.trim().is_empty() {
174 return Err(crate::AutumnError::bad_request_msg(
175 "Extracted subdomain tenant is empty",
176 ));
177 }
178 Ok(tenant)
179 }
180 }
181 "session" => {
182 let session = parts
183 .extensions
184 .get::<crate::session::Session>()
185 .ok_or_else(|| {
186 crate::AutumnError::internal_server_error_msg(
187 "SessionLayer not installed but session tenancy source is configured",
188 )
189 })?;
190 let tenant = session
191 .get(&config.tenancy.session_key)
192 .await
193 .ok_or_else(|| {
194 crate::AutumnError::unauthorized_msg(format!(
195 "Tenant ID missing from session key: {}",
196 config.tenancy.session_key
197 ))
198 })?;
199 if tenant.trim().is_empty() {
200 return Err(crate::AutumnError::unauthorized_msg(format!(
201 "Tenant ID in session key {} is empty",
202 config.tenancy.session_key
203 )));
204 }
205 Ok(tenant)
206 }
207 "jwt" => {
208 let auth_header = parts
209 .headers
210 .get(axum::http::header::AUTHORIZATION)
211 .ok_or_else(|| {
212 crate::AutumnError::unauthorized_msg(
213 "Missing Authorization header for JWT tenancy",
214 )
215 })?;
216 let auth_str = auth_header.to_str().map_err(|_| {
217 crate::AutumnError::unauthorized_msg("Invalid UTF-8 in Authorization header")
218 })?;
219
220 if auth_str.len() < 7
221 || !auth_str.is_char_boundary(7)
222 || !auth_str[..7].eq_ignore_ascii_case("bearer ")
223 {
224 return Err(crate::AutumnError::unauthorized_msg(
225 "Invalid Authorization header format. Expected Bearer <token>",
226 ));
227 }
228 let token = &auth_str[7..];
229
230 let secret = config.tenancy.jwt_secret.as_ref().ok_or_else(|| {
231 crate::AutumnError::unauthorized_msg("JWT secret is not configured")
232 })?;
233
234 let mut validation = ::jsonwebtoken::Validation::default();
235 if let Some(ref iss) = config.tenancy.jwt_issuer {
236 validation.set_issuer(::std::slice::from_ref(iss));
237 }
238 if let Some(ref aud) = config.tenancy.jwt_audience {
239 validation.set_audience(&[aud.as_str()]);
240 } else {
241 validation.validate_aud = false;
242 }
243
244 let token_data = ::jsonwebtoken::decode::<serde_json::Value>(
245 token,
246 &::jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()),
247 &validation,
248 )
249 .map_err(|e| {
250 crate::AutumnError::unauthorized_msg(format!("JWT validation failed: {e}"))
251 })?;
252
253 if let Some(ref expected_aud) = config.tenancy.jwt_audience {
259 let aud_ok = token_data.claims.get("aud").is_some_and(|v| match v {
260 serde_json::Value::String(s) => s == expected_aud,
261 serde_json::Value::Array(arr) => arr
262 .iter()
263 .any(|e| e.as_str() == Some(expected_aud.as_str())),
264 _ => false,
265 });
266 if !aud_ok {
267 return Err(crate::AutumnError::unauthorized_msg(
268 "JWT audience validation failed: missing or invalid aud claim",
269 ));
270 }
271 }
272
273 let tenant = token_data
274 .claims
275 .get(&config.tenancy.jwt_claim)
276 .and_then(|v| v.as_str())
277 .ok_or_else(|| {
278 crate::AutumnError::unauthorized_msg(format!(
279 "Tenant claim '{}' missing from JWT payload",
280 config.tenancy.jwt_claim
281 ))
282 })?
283 .to_string();
284
285 if tenant.trim().is_empty() {
286 return Err(crate::AutumnError::unauthorized_msg(format!(
287 "Tenant claim '{}' in JWT payload is empty",
288 config.tenancy.jwt_claim
289 )));
290 }
291 Ok(tenant)
292 }
293 other => Err(crate::AutumnError::internal_server_error_msg(format!(
294 "Unsupported tenancy source: {other}"
295 ))),
296 }
297}
298
299pub async fn tenancy_middleware(
301 State(state): State<crate::AppState>,
302 request: Request<axum::body::Body>,
303 next: Next,
304) -> Response {
305 let Some(config) = state.extension::<crate::config::AutumnConfig>() else {
306 return crate::AutumnError::internal_server_error_msg("AutumnConfig not found in AppState")
307 .into_response();
308 };
309
310 if !config.tenancy.enabled {
311 return next.run(request).await;
312 }
313
314 let (mut parts, body) = request.into_parts();
315 let tenant_id = match extract_tenant_from_parts(&mut parts, &config).await {
316 Ok(t) => t,
317 Err(e) => return e.into_response(),
318 };
319
320 crate::log::context::set_tenant_id(&tenant_id);
323
324 let request = Request::from_parts(parts, body);
325 let tenant_id_clone = tenant_id.clone();
326 let response = CURRENT_TENANT
327 .scope(Some(tenant_id), next.run(request))
328 .await;
329
330 let (parts, body) = response.into_parts();
331 let wrapped = TenantPropagatingBody {
332 inner: body,
333 tenant_id: tenant_id_clone,
334 };
335 Response::from_parts(parts, axum::body::Body::new(wrapped))
336}
337
338pin_project! {
339 pub struct TenantPropagatingBody<B> {
343 #[pin]
344 pub inner: B,
345 pub tenant_id: String,
346 }
347}
348
349impl<B> HttpBody for TenantPropagatingBody<B>
350where
351 B: HttpBody,
352{
353 type Data = B::Data;
354 type Error = B::Error;
355
356 fn poll_frame(
357 self: Pin<&mut Self>,
358 cx: &mut Context<'_>,
359 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
360 let this = self.project();
361 let tenant_id = this.tenant_id.clone();
362 CURRENT_TENANT.sync_scope(Some(tenant_id), || this.inner.poll_frame(cx))
363 }
364
365 fn is_end_stream(&self) -> bool {
366 self.inner.is_end_stream()
367 }
368
369 fn size_hint(&self) -> http_body::SizeHint {
370 self.inner.size_hint()
371 }
372}
373
374#[cfg(feature = "db")]
379pub trait TenantInsertable<'a, Table> {
380 type Values;
381 fn tenant_values(self, tenant_id: &'a str) -> Self::Values;
382}
383
384#[cfg(feature = "db")]
386pub trait ModelTenantIdMeta {
387 const HAS_MANUAL_TENANT_ID: bool;
389 fn try_set_tenant_id(&mut self, tenant_id: &str);
391}
392
393#[cfg(feature = "db")]
395pub trait HasTenantIdColumn {
396 type Column: ::diesel::Expression;
397 fn column() -> Self::Column;
398}
399
400#[cfg(feature = "db")]
402pub struct TenantInsertableValuesSelector<'a, T, Table, const HAS_MANUAL: bool> {
403 pub inner: T,
404 pub tenant_id: &'a str,
405 pub _marker: std::marker::PhantomData<Table>,
406}
407
408#[cfg(feature = "db")]
410pub trait GetInsertableValues {
411 type Values;
412 fn get_values(self) -> Self::Values;
413}
414
415#[cfg(feature = "db")]
416impl<T, Table> GetInsertableValues for TenantInsertableValuesSelector<'_, T, Table, true>
417where
418 T: ModelTenantIdMeta,
419{
420 type Values = T;
421 fn get_values(mut self) -> Self::Values {
422 self.inner.try_set_tenant_id(self.tenant_id);
423 self.inner
424 }
425}
426
427#[cfg(feature = "db")]
428impl<'a, T, Table> GetInsertableValues for TenantInsertableValuesSelector<'a, T, Table, false>
429where
430 Table: HasTenantIdColumn,
431 Table::Column: ::diesel::ExpressionMethods,
432 <Table::Column as ::diesel::Expression>::SqlType: ::diesel::sql_types::SqlType,
433 &'a str: ::diesel::expression::AsExpression<<Table::Column as ::diesel::Expression>::SqlType>,
434{
435 type Values = (T, ::diesel::dsl::Eq<Table::Column, &'a str>);
436 fn get_values(self) -> Self::Values {
437 use ::diesel::ExpressionMethods;
438 (self.inner, Table::column().eq(self.tenant_id))
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use crate::security::ResolvedClientIdentity;
446
447 fn subdomain_config() -> crate::config::AutumnConfig {
448 let mut c = crate::config::AutumnConfig::default();
449 c.tenancy.enabled = true;
450 c.tenancy.source = "subdomain".to_string();
451 c
452 }
453
454 fn subdomain_config_with_base(base: &str) -> crate::config::AutumnConfig {
455 let mut c = subdomain_config();
456 c.tenancy.base_domain = Some(base.to_string());
457 c
458 }
459
460 fn make_parts(host: &str) -> axum::http::request::Parts {
461 let (parts, ()) = axum::http::Request::builder()
462 .uri("http://ignored/")
463 .header(axum::http::header::HOST, host)
464 .body(())
465 .unwrap()
466 .into_parts();
467 parts
468 }
469
470 fn make_parts_with_identity(
471 host_header: &str,
472 resolved_host: &str,
473 ) -> axum::http::request::Parts {
474 let (mut parts, ()) = axum::http::Request::builder()
475 .uri("http://ignored/")
476 .header(axum::http::header::HOST, host_header)
477 .body(())
478 .unwrap()
479 .into_parts();
480 parts.extensions.insert(ResolvedClientIdentity {
481 addr: None,
482 host: Some(resolved_host.to_string()),
483 scheme: None,
484 });
485 parts
486 }
487
488 #[tokio::test]
491 async fn subdomain_falls_back_to_host_header_without_extension() {
492 let config = subdomain_config();
493 let mut parts = make_parts("tenant1.example.com");
494 let result = extract_tenant_from_parts(&mut parts, &config).await;
495 assert_eq!(result.unwrap(), "tenant1");
496 }
497
498 #[tokio::test]
501 async fn subdomain_uses_resolved_host_from_extension() {
502 let config = subdomain_config();
503 let mut parts = make_parts_with_identity("internal.cluster.local", "tenant1.example.com");
505 let result = extract_tenant_from_parts(&mut parts, &config).await;
506 assert_eq!(result.unwrap(), "tenant1");
507 }
508
509 #[tokio::test]
511 async fn subdomain_uses_resolved_host_with_base_domain() {
512 let config = subdomain_config_with_base("example.com");
513 let mut parts = make_parts_with_identity("internal.cluster.local", "acme.example.com");
514 let result = extract_tenant_from_parts(&mut parts, &config).await;
515 assert_eq!(result.unwrap(), "acme");
516 }
517
518 #[tokio::test]
520 async fn subdomain_strips_port_from_resolved_host() {
521 let config = subdomain_config_with_base("example.com");
522 let mut parts =
523 make_parts_with_identity("internal.cluster.local", "tenant2.example.com:8080");
524 let result = extract_tenant_from_parts(&mut parts, &config).await;
525 assert_eq!(result.unwrap(), "tenant2");
526 }
527
528 #[tokio::test]
531 async fn subdomain_falls_back_when_resolved_host_is_none() {
532 let config = subdomain_config();
533 let (mut parts, ()) = axum::http::Request::builder()
534 .uri("http://ignored/")
535 .header(axum::http::header::HOST, "tenant3.example.com")
536 .body(())
537 .unwrap()
538 .into_parts();
539 parts.extensions.insert(ResolvedClientIdentity {
540 addr: None,
541 host: None,
542 scheme: None,
543 });
544 let result = extract_tenant_from_parts(&mut parts, &config).await;
545 assert_eq!(result.unwrap(), "tenant3");
546 }
547}