use axum::{body::Body, http::Request, middleware::Next, response::Response};
use tracing::debug;
use crate::middleware::oidc_auth::AuthUser;
pub async fn tenant_middleware(mut request: Request<Body>, next: Next) -> Response {
let mut org_id: Option<String> = None;
if let Some(auth_user) = request.extensions().get::<AuthUser>() {
if let Some(header_value) = request.headers().get("X-Org-ID") {
if let Ok(org_id_str) = header_value.to_str() {
org_id = Some(org_id_str.to_string());
debug!(
user_id = %auth_user.0.user_id,
org_id = %org_id_str,
source = "header",
"Extracted org_id from X-Org-ID header for authenticated user"
);
}
}
} else if request.headers().contains_key("X-Org-ID") {
tracing::warn!("Rejected X-Org-ID header from unauthenticated request");
}
request.extensions_mut().insert(TenantContext { org_id });
next.run(request).await
}
#[derive(Debug, Clone)]
pub struct TenantContext {
pub org_id: Option<String>,
}
impl TenantContext {
pub fn is_tenant_scoped(&self) -> bool {
self.org_id.is_some()
}
pub fn get_org_id(&self) -> Option<&str> {
self.org_id.as_deref()
}
pub fn require_org_id(&self) -> Result<&str, String> {
self.org_id
.as_deref()
.ok_or_else(|| "Request must be tenant-scoped (missing org_id)".to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tenant_context_scoped() {
let ctx = TenantContext {
org_id: Some("org-123".to_string()),
};
assert!(ctx.is_tenant_scoped());
assert_eq!(ctx.get_org_id(), Some("org-123"));
}
#[test]
fn test_tenant_context_unscoped() {
let ctx = TenantContext { org_id: None };
assert!(!ctx.is_tenant_scoped());
assert_eq!(ctx.get_org_id(), None);
}
#[test]
fn test_require_org_id_success() {
let ctx = TenantContext {
org_id: Some("org-123".to_string()),
};
assert!(ctx.require_org_id().is_ok());
assert_eq!(ctx.require_org_id().unwrap(), "org-123");
}
#[test]
fn test_require_org_id_failure() {
let ctx = TenantContext { org_id: None };
assert!(ctx.require_org_id().is_err());
assert_eq!(
ctx.require_org_id().unwrap_err(),
"Request must be tenant-scoped (missing org_id)"
);
}
}