Skip to main content

autumn_web/
tenancy.rs

1use axum::{
2    extract::State,
3    http::Request,
4    middleware::Next,
5    response::{IntoResponse, Response},
6};
7use http_body::Body as HttpBody;
8use pin_project_lite::pin_project;
9use std::future::Future;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13// 1. Task-local storage for CURRENT_TENANT
14tokio::task_local! {
15    pub static CURRENT_TENANT: Option<String>;
16}
17
18// 2. Extractor structure
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct Tenant(pub String);
21
22impl axum::extract::FromRequestParts<crate::AppState> for Tenant {
23    type Rejection = crate::AutumnError;
24
25    async fn from_request_parts(
26        parts: &mut axum::http::request::Parts,
27        state: &crate::AppState,
28    ) -> Result<Self, Self::Rejection> {
29        let config = state
30            .extension::<crate::config::AutumnConfig>()
31            .ok_or_else(|| {
32                crate::AutumnError::service_unavailable_msg("Config is not available")
33            })?;
34        let tenant_id = extract_tenant_from_parts(parts, &config).await?;
35        Ok(Self(tenant_id))
36    }
37}
38
39// Helper to run in-test tenancy contexts
40pub async fn with_tenant<F, R>(tenant_id: String, future: F) -> R
41where
42    F: Future<Output = R>,
43{
44    CURRENT_TENANT.scope(Some(tenant_id), future).await
45}
46
47// Tenant extraction logic based on configuration
48#[allow(clippy::missing_errors_doc, clippy::too_many_lines)]
49pub async fn extract_tenant_from_parts(
50    parts: &mut axum::http::request::Parts,
51    config: &crate::config::AutumnConfig,
52) -> Result<String, crate::AutumnError> {
53    if !config.tenancy.enabled {
54        return Err(crate::AutumnError::bad_request_msg("Tenancy is disabled"));
55    }
56
57    match config.tenancy.source.as_str() {
58        "header" => {
59            let header_value = parts
60                .headers
61                .get(&config.tenancy.header_name)
62                .ok_or_else(|| {
63                    crate::AutumnError::bad_request_msg(format!(
64                        "Missing required tenant header: {}",
65                        config.tenancy.header_name
66                    ))
67                })?;
68            let val = header_value
69                .to_str()
70                .map_err(|_| {
71                    crate::AutumnError::bad_request_msg(format!(
72                        "Invalid UTF-8 in tenant header: {}",
73                        config.tenancy.header_name
74                    ))
75                })?
76                .to_string();
77            if val.trim().is_empty() {
78                return Err(crate::AutumnError::bad_request_msg(format!(
79                    "Tenant header {} is empty",
80                    config.tenancy.header_name
81                )));
82            }
83            Ok(val)
84        }
85        "subdomain" => {
86            // Prefer the proxy-resolved host (honours X-Forwarded-Host from trusted
87            // upstreams); fall back to the raw Host header when the layer has not run.
88            let host_owned: String = parts
89                .extensions
90                .get::<crate::security::ResolvedClientIdentity>()
91                .and_then(|id| id.host.clone())
92                .map_or_else(
93                    || {
94                        parts
95                            .headers
96                            .get(axum::http::header::HOST)
97                            .ok_or_else(|| {
98                                crate::AutumnError::bad_request_msg(
99                                    "Missing Host header for subdomain tenancy",
100                                )
101                            })
102                            .and_then(|h| {
103                                h.to_str().map(ToOwned::to_owned).map_err(|_| {
104                                    crate::AutumnError::bad_request_msg(
105                                        "Invalid UTF-8 in Host header",
106                                    )
107                                })
108                            })
109                    },
110                    Ok,
111                )?;
112
113            let host = host_owned.as_str();
114            let host_only = host.split(':').next().unwrap_or(host).trim();
115
116            if host_only.parse::<std::net::IpAddr>().is_ok() {
117                return Err(crate::AutumnError::bad_request_msg(
118                    "IP address host not allowed in subdomain mode",
119                ));
120            }
121
122            // DNS hostnames are case-insensitive; normalise to lowercase
123            // before any matching so that e.g. `Tenant1.Example.COM` works.
124            let host_lower = host_only.to_lowercase();
125
126            if let Some(ref base_domain) = config.tenancy.base_domain {
127                let base_domain_clean = base_domain.trim().to_lowercase();
128                if !host_lower.ends_with(base_domain_clean.as_str()) {
129                    return Err(crate::AutumnError::bad_request_msg(format!(
130                        "Host does not match base domain: {base_domain_clean}"
131                    )));
132                }
133                if host_lower.len() <= base_domain_clean.len() {
134                    return Err(crate::AutumnError::bad_request_msg(
135                        "Apex domain not allowed in subdomain mode",
136                    ));
137                }
138                let prefix_len = host_lower.len() - base_domain_clean.len();
139                if !host_lower[..prefix_len].ends_with('.') {
140                    return Err(crate::AutumnError::bad_request_msg(
141                        "Invalid subdomain format",
142                    ));
143                }
144                let subdomain_part = &host_lower[..prefix_len - 1];
145                let tenant = subdomain_part.split('.').next().ok_or_else(|| {
146                    crate::AutumnError::bad_request_msg("Unable to extract subdomain tenant")
147                })?;
148                if tenant.trim().is_empty() {
149                    return Err(crate::AutumnError::bad_request_msg(
150                        "Extracted subdomain tenant is empty",
151                    ));
152                }
153                Ok(tenant.to_string())
154            } else {
155                let labels: Vec<&str> = host_lower.split('.').filter(|s| !s.is_empty()).collect();
156                if labels.is_empty() {
157                    return Err(crate::AutumnError::bad_request_msg("Empty host header"));
158                }
159
160                if labels.len() < 2 {
161                    return Err(crate::AutumnError::bad_request_msg(
162                        "Apex or local host without subdomain not allowed",
163                    ));
164                }
165
166                if labels.len() == 2 && labels[1] != "localhost" {
167                    return Err(crate::AutumnError::bad_request_msg(
168                        "Apex domain not allowed in subdomain mode",
169                    ));
170                }
171
172                let tenant = labels[0].to_string();
173                if tenant.trim().is_empty() {
174                    return Err(crate::AutumnError::bad_request_msg(
175                        "Extracted subdomain tenant is empty",
176                    ));
177                }
178                Ok(tenant)
179            }
180        }
181        "session" => {
182            let session = parts
183                .extensions
184                .get::<crate::session::Session>()
185                .ok_or_else(|| {
186                    crate::AutumnError::internal_server_error_msg(
187                        "SessionLayer not installed but session tenancy source is configured",
188                    )
189                })?;
190            let tenant = session
191                .get(&config.tenancy.session_key)
192                .await
193                .ok_or_else(|| {
194                    crate::AutumnError::unauthorized_msg(format!(
195                        "Tenant ID missing from session key: {}",
196                        config.tenancy.session_key
197                    ))
198                })?;
199            if tenant.trim().is_empty() {
200                return Err(crate::AutumnError::unauthorized_msg(format!(
201                    "Tenant ID in session key {} is empty",
202                    config.tenancy.session_key
203                )));
204            }
205            Ok(tenant)
206        }
207        "jwt" => {
208            let auth_header = parts
209                .headers
210                .get(axum::http::header::AUTHORIZATION)
211                .ok_or_else(|| {
212                    crate::AutumnError::unauthorized_msg(
213                        "Missing Authorization header for JWT tenancy",
214                    )
215                })?;
216            let auth_str = auth_header.to_str().map_err(|_| {
217                crate::AutumnError::unauthorized_msg("Invalid UTF-8 in Authorization header")
218            })?;
219
220            if auth_str.len() < 7
221                || !auth_str.is_char_boundary(7)
222                || !auth_str[..7].eq_ignore_ascii_case("bearer ")
223            {
224                return Err(crate::AutumnError::unauthorized_msg(
225                    "Invalid Authorization header format. Expected Bearer <token>",
226                ));
227            }
228            let token = &auth_str[7..];
229
230            let secret = config.tenancy.jwt_secret.as_ref().ok_or_else(|| {
231                crate::AutumnError::unauthorized_msg("JWT secret is not configured")
232            })?;
233
234            let mut validation = ::jsonwebtoken::Validation::default();
235            if let Some(ref iss) = config.tenancy.jwt_issuer {
236                validation.set_issuer(::std::slice::from_ref(iss));
237            }
238            if let Some(ref aud) = config.tenancy.jwt_audience {
239                validation.set_audience(&[aud.as_str()]);
240            } else {
241                validation.validate_aud = false;
242            }
243
244            let token_data = ::jsonwebtoken::decode::<serde_json::Value>(
245                token,
246                &::jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()),
247                &validation,
248            )
249            .map_err(|e| {
250                crate::AutumnError::unauthorized_msg(format!("JWT validation failed: {e}"))
251            })?;
252
253            // `jsonwebtoken`'s `set_audience` validates the `aud` value when
254            // the claim is *present*, but silently accepts tokens that omit the
255            // `aud` field entirely. Explicitly reject those when audience
256            // validation is enabled so legacy tokens without an `aud` claim
257            // cannot bypass the check.
258            if let Some(ref expected_aud) = config.tenancy.jwt_audience {
259                let aud_ok = token_data.claims.get("aud").is_some_and(|v| match v {
260                    serde_json::Value::String(s) => s == expected_aud,
261                    serde_json::Value::Array(arr) => arr
262                        .iter()
263                        .any(|e| e.as_str() == Some(expected_aud.as_str())),
264                    _ => false,
265                });
266                if !aud_ok {
267                    return Err(crate::AutumnError::unauthorized_msg(
268                        "JWT audience validation failed: missing or invalid aud claim",
269                    ));
270                }
271            }
272
273            let tenant = token_data
274                .claims
275                .get(&config.tenancy.jwt_claim)
276                .and_then(|v| v.as_str())
277                .ok_or_else(|| {
278                    crate::AutumnError::unauthorized_msg(format!(
279                        "Tenant claim '{}' missing from JWT payload",
280                        config.tenancy.jwt_claim
281                    ))
282                })?
283                .to_string();
284
285            if tenant.trim().is_empty() {
286                return Err(crate::AutumnError::unauthorized_msg(format!(
287                    "Tenant claim '{}' in JWT payload is empty",
288                    config.tenancy.jwt_claim
289                )));
290            }
291            Ok(tenant)
292        }
293        other => Err(crate::AutumnError::internal_server_error_msg(format!(
294            "Unsupported tenancy source: {other}"
295        ))),
296    }
297}
298
299// Tenancy middleware for Axum requests
300pub async fn tenancy_middleware(
301    State(state): State<crate::AppState>,
302    request: Request<axum::body::Body>,
303    next: Next,
304) -> Response {
305    let Some(config) = state.extension::<crate::config::AutumnConfig>() else {
306        return crate::AutumnError::internal_server_error_msg("AutumnConfig not found in AppState")
307            .into_response();
308    };
309
310    if !config.tenancy.enabled {
311        return next.run(request).await;
312    }
313
314    let (mut parts, body) = request.into_parts();
315    let tenant_id = match extract_tenant_from_parts(&mut parts, &config).await {
316        Ok(t) => t,
317        Err(e) => return e.into_response(),
318    };
319
320    // Tag the request-scoped log context (#1169) so every subsequent event
321    // automatically carries the resolved tenant id.
322    crate::log::context::set_tenant_id(&tenant_id);
323
324    let request = Request::from_parts(parts, body);
325    let tenant_id_clone = tenant_id.clone();
326    let response = CURRENT_TENANT
327        .scope(Some(tenant_id), next.run(request))
328        .await;
329
330    let (parts, body) = response.into_parts();
331    let wrapped = TenantPropagatingBody {
332        inner: body,
333        tenant_id: tenant_id_clone,
334    };
335    Response::from_parts(parts, axum::body::Body::new(wrapped))
336}
337
338pin_project! {
339    /// A response body wrapper that re-establishes the tenant context
340    /// for each poll of the inner body, so lazy/streaming bodies can
341    /// access tenant-scoped repositories during their polling phase.
342    pub struct TenantPropagatingBody<B> {
343        #[pin]
344        pub inner: B,
345        pub tenant_id: String,
346    }
347}
348
349impl<B> HttpBody for TenantPropagatingBody<B>
350where
351    B: HttpBody,
352{
353    type Data = B::Data;
354    type Error = B::Error;
355
356    fn poll_frame(
357        self: Pin<&mut Self>,
358        cx: &mut Context<'_>,
359    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
360        let this = self.project();
361        let tenant_id = this.tenant_id.clone();
362        CURRENT_TENANT.sync_scope(Some(tenant_id), || this.inner.poll_frame(cx))
363    }
364
365    fn is_end_stream(&self) -> bool {
366        self.inner.is_end_stream()
367    }
368
369    fn size_hint(&self) -> http_body::SizeHint {
370        self.inner.size_hint()
371    }
372}
373
374/// A trait implemented by model insertable helper types to dynamically set tenant ID.
375///
376/// This sets or appends the tenant ID before database insertion. This avoids SQL duplicate
377/// column errors when a model already has a manual (non-default) `tenant_id` field.
378#[cfg(feature = "db")]
379pub trait TenantInsertable<'a, Table> {
380    type Values;
381    fn tenant_values(self, tenant_id: &'a str) -> Self::Values;
382}
383
384/// Metadata about a model's `tenant_id` struct field.
385#[cfg(feature = "db")]
386pub trait ModelTenantIdMeta {
387    /// True if the struct has a manual `tenant_id` field.
388    const HAS_MANUAL_TENANT_ID: bool;
389    /// Sets the tenant ID field on the struct if it has one.
390    fn try_set_tenant_id(&mut self, tenant_id: &str);
391}
392
393/// A trait that bridges a Diesel table to its `tenant_id` column.
394#[cfg(feature = "db")]
395pub trait HasTenantIdColumn {
396    type Column: ::diesel::Expression;
397    fn column() -> Self::Column;
398}
399
400/// A selector helper to choose between different insertable values.
401#[cfg(feature = "db")]
402pub struct TenantInsertableValuesSelector<'a, T, Table, const HAS_MANUAL: bool> {
403    pub inner: T,
404    pub tenant_id: &'a str,
405    pub _marker: std::marker::PhantomData<Table>,
406}
407
408/// A trait implemented by selector variants to get the actual insertable values.
409#[cfg(feature = "db")]
410pub trait GetInsertableValues {
411    type Values;
412    fn get_values(self) -> Self::Values;
413}
414
415#[cfg(feature = "db")]
416impl<T, Table> GetInsertableValues for TenantInsertableValuesSelector<'_, T, Table, true>
417where
418    T: ModelTenantIdMeta,
419{
420    type Values = T;
421    fn get_values(mut self) -> Self::Values {
422        self.inner.try_set_tenant_id(self.tenant_id);
423        self.inner
424    }
425}
426
427#[cfg(feature = "db")]
428impl<'a, T, Table> GetInsertableValues for TenantInsertableValuesSelector<'a, T, Table, false>
429where
430    Table: HasTenantIdColumn,
431    Table::Column: ::diesel::ExpressionMethods,
432    <Table::Column as ::diesel::Expression>::SqlType: ::diesel::sql_types::SqlType,
433    &'a str: ::diesel::expression::AsExpression<<Table::Column as ::diesel::Expression>::SqlType>,
434{
435    type Values = (T, ::diesel::dsl::Eq<Table::Column, &'a str>);
436    fn get_values(self) -> Self::Values {
437        use ::diesel::ExpressionMethods;
438        (self.inner, Table::column().eq(self.tenant_id))
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::security::ResolvedClientIdentity;
446
447    fn subdomain_config() -> crate::config::AutumnConfig {
448        let mut c = crate::config::AutumnConfig::default();
449        c.tenancy.enabled = true;
450        c.tenancy.source = "subdomain".to_string();
451        c
452    }
453
454    fn subdomain_config_with_base(base: &str) -> crate::config::AutumnConfig {
455        let mut c = subdomain_config();
456        c.tenancy.base_domain = Some(base.to_string());
457        c
458    }
459
460    fn make_parts(host: &str) -> axum::http::request::Parts {
461        let (parts, ()) = axum::http::Request::builder()
462            .uri("http://ignored/")
463            .header(axum::http::header::HOST, host)
464            .body(())
465            .unwrap()
466            .into_parts();
467        parts
468    }
469
470    fn make_parts_with_identity(
471        host_header: &str,
472        resolved_host: &str,
473    ) -> axum::http::request::Parts {
474        let (mut parts, ()) = axum::http::Request::builder()
475            .uri("http://ignored/")
476            .header(axum::http::header::HOST, host_header)
477            .body(())
478            .unwrap()
479            .into_parts();
480        parts.extensions.insert(ResolvedClientIdentity {
481            addr: None,
482            host: Some(resolved_host.to_string()),
483            scheme: None,
484        });
485        parts
486    }
487
488    /// When no `ResolvedClientIdentity` extension is present, subdomain mode falls
489    /// back to the raw Host header as before.
490    #[tokio::test]
491    async fn subdomain_falls_back_to_host_header_without_extension() {
492        let config = subdomain_config();
493        let mut parts = make_parts("tenant1.example.com");
494        let result = extract_tenant_from_parts(&mut parts, &config).await;
495        assert_eq!(result.unwrap(), "tenant1");
496    }
497
498    /// When `ResolvedClientIdentity.host` is present, subdomain mode uses it instead
499    /// of the raw Host header so that X-Forwarded-Host from trusted proxies is honoured.
500    #[tokio::test]
501    async fn subdomain_uses_resolved_host_from_extension() {
502        let config = subdomain_config();
503        // Raw Host header is the internal address; resolved host is the public subdomain.
504        let mut parts = make_parts_with_identity("internal.cluster.local", "tenant1.example.com");
505        let result = extract_tenant_from_parts(&mut parts, &config).await;
506        assert_eq!(result.unwrap(), "tenant1");
507    }
508
509    /// With a configured `base_domain`, the resolved host is matched against it.
510    #[tokio::test]
511    async fn subdomain_uses_resolved_host_with_base_domain() {
512        let config = subdomain_config_with_base("example.com");
513        let mut parts = make_parts_with_identity("internal.cluster.local", "acme.example.com");
514        let result = extract_tenant_from_parts(&mut parts, &config).await;
515        assert_eq!(result.unwrap(), "acme");
516    }
517
518    /// Port suffixes in the resolved host are stripped before subdomain extraction.
519    #[tokio::test]
520    async fn subdomain_strips_port_from_resolved_host() {
521        let config = subdomain_config_with_base("example.com");
522        let mut parts =
523            make_parts_with_identity("internal.cluster.local", "tenant2.example.com:8080");
524        let result = extract_tenant_from_parts(&mut parts, &config).await;
525        assert_eq!(result.unwrap(), "tenant2");
526    }
527
528    /// When `ResolvedClientIdentity.host` is `None` (layer ran but found no host),
529    /// subdomain mode falls back to the raw Host header.
530    #[tokio::test]
531    async fn subdomain_falls_back_when_resolved_host_is_none() {
532        let config = subdomain_config();
533        let (mut parts, ()) = axum::http::Request::builder()
534            .uri("http://ignored/")
535            .header(axum::http::header::HOST, "tenant3.example.com")
536            .body(())
537            .unwrap()
538            .into_parts();
539        parts.extensions.insert(ResolvedClientIdentity {
540            addr: None,
541            host: None,
542            scheme: None,
543        });
544        let result = extract_tenant_from_parts(&mut parts, &config).await;
545        assert_eq!(result.unwrap(), "tenant3");
546    }
547}