use crate::http::{HttpResponse, Response};
use crate::middleware::{Middleware, Next};
use crate::tenant::context::{tenant_scope, with_tenant_scope};
use crate::tenant::{TenantContext, TenantFailureMode};
use crate::Request;
use async_trait::async_trait;
use serde_json::json;
use super::resolver::TenantResolver;
pub struct TenantMiddleware {
resolvers: Vec<Box<dyn TenantResolver>>,
on_failure: TenantFailureMode,
}
impl TenantMiddleware {
pub fn new() -> Self {
Self {
resolvers: Vec::new(),
on_failure: TenantFailureMode::NotFound,
}
}
pub fn resolver(mut self, resolver: impl TenantResolver + 'static) -> Self {
self.resolvers.push(Box::new(resolver));
self
}
pub fn on_failure(mut self, mode: TenantFailureMode) -> Self {
self.on_failure = mode;
self
}
}
impl Default for TenantMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for TenantMiddleware {
async fn handle(&self, request: Request, next: Next) -> Response {
let mut resolved: Option<TenantContext> = None;
for resolver in &self.resolvers {
if let Some(ctx) = resolver.resolve(&request).await {
resolved = Some(ctx);
break;
}
}
match resolved {
Some(ctx) => {
let scope = tenant_scope();
{
let mut guard = scope.write().await;
*guard = Some(ctx);
}
with_tenant_scope(scope, next(request)).await
}
None => match &self.on_failure {
TenantFailureMode::NotFound => {
Err(HttpResponse::json(json!({"error": "Tenant not found"})).status(404))
}
TenantFailureMode::Forbidden => {
Err(HttpResponse::json(json!({"error": "Access denied"})).status(403))
}
TenantFailureMode::Allow => {
let scope = tenant_scope();
with_tenant_scope(scope, next(request)).await
}
TenantFailureMode::Custom(handler) => handler(),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::HttpResponse;
use crate::tenant::context::current_tenant;
use crate::tenant::{TenantContext, TenantFailureMode};
use async_trait::async_trait;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::oneshot;
fn make_tenant(slug: &str) -> TenantContext {
TenantContext {
id: 1,
slug: slug.to_string(),
name: "Test Corp".to_string(),
plan: None,
#[cfg(feature = "stripe")]
subscription: None,
}
}
async fn make_request_with_header(header_name: &str, header_value: &str) -> Request {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
let tx_holder = Arc::new(Mutex::new(Some(tx)));
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let tx_holder = tx_holder.clone();
let service =
hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let tx_holder = tx_holder.clone();
async move {
if let Some(tx) = tx_holder.lock().unwrap().take() {
let _ = tx.send(Request::new(req));
}
Ok::<_, hyper::Error>(hyper::Response::new(http_body_util::Empty::<
bytes::Bytes,
>::new(
)))
}
});
hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
.ok();
});
let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
conn.await.ok();
});
let req = hyper::Request::builder()
.uri("/test")
.header(header_name, header_value)
.body(http_body_util::Empty::<bytes::Bytes>::new())
.unwrap();
let _ = sender.send_request(req).await;
rx.await.unwrap()
}
async fn make_request() -> Request {
make_request_with_header("x-test", "1").await
}
struct AlwaysResolver {
tenant: TenantContext,
}
#[async_trait]
impl TenantResolver for AlwaysResolver {
async fn resolve(&self, _req: &Request) -> Option<TenantContext> {
Some(self.tenant.clone())
}
}
struct NeverResolver;
#[async_trait]
impl TenantResolver for NeverResolver {
async fn resolve(&self, _req: &Request) -> Option<TenantContext> {
None
}
}
fn ok_next() -> Next {
Arc::new(|_req| {
Box::pin(async { Ok(HttpResponse::text("ok")) }) as crate::middleware::MiddlewareFuture
})
}
fn tenant_capture_next() -> Next {
Arc::new(|_req| {
Box::pin(async move {
let tenant = current_tenant();
match tenant {
Some(t) => Ok(HttpResponse::json(serde_json::json!({"slug": t.slug}))),
None => Ok(HttpResponse::json(serde_json::json!({"slug": null}))),
}
}) as crate::middleware::MiddlewareFuture
})
}
#[test]
fn new_creates_empty_instance_with_not_found_default() {
let mw = TenantMiddleware::new();
assert!(mw.resolvers.is_empty());
assert!(matches!(mw.on_failure, TenantFailureMode::NotFound));
}
#[test]
fn resolver_adds_to_chain() {
let mw = TenantMiddleware::new().resolver(NeverResolver);
assert_eq!(mw.resolvers.len(), 1);
}
#[test]
fn on_failure_sets_mode() {
let mw = TenantMiddleware::new().on_failure(TenantFailureMode::Allow);
assert!(matches!(mw.on_failure, TenantFailureMode::Allow));
}
#[tokio::test]
async fn resolves_tenant_and_stores_in_task_local() {
let mw = TenantMiddleware::new()
.resolver(AlwaysResolver {
tenant: make_tenant("acme"),
})
.on_failure(TenantFailureMode::NotFound);
let req = make_request().await;
let next = tenant_capture_next();
let resp = mw.handle(req, next).await.unwrap();
let json: serde_json::Value = serde_json::from_str(resp.body()).unwrap();
assert_eq!(json["slug"], "acme");
}
#[tokio::test]
async fn tries_resolvers_in_order_first_some_wins() {
let mw = TenantMiddleware::new()
.resolver(NeverResolver)
.resolver(AlwaysResolver {
tenant: make_tenant("beta"),
})
.resolver(AlwaysResolver {
tenant: make_tenant("gamma"),
});
let req = make_request().await;
let next = tenant_capture_next();
let response = mw.handle(req, next).await.unwrap();
let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
assert_eq!(json["slug"], "beta");
}
#[tokio::test]
async fn no_match_not_found_returns_404() {
let mw = TenantMiddleware::new()
.resolver(NeverResolver)
.on_failure(TenantFailureMode::NotFound);
let req = make_request().await;
let err = mw.handle(req, ok_next()).await.unwrap_err();
assert_eq!(err.status_code(), 404);
let json: serde_json::Value = serde_json::from_str(err.body()).unwrap();
assert_eq!(json["error"], "Tenant not found");
}
#[tokio::test]
async fn no_match_forbidden_returns_403() {
let mw = TenantMiddleware::new()
.resolver(NeverResolver)
.on_failure(TenantFailureMode::Forbidden);
let req = make_request().await;
let err = mw.handle(req, ok_next()).await.unwrap_err();
assert_eq!(err.status_code(), 403);
let json: serde_json::Value = serde_json::from_str(err.body()).unwrap();
assert_eq!(json["error"], "Access denied");
}
#[tokio::test]
async fn no_match_allow_continues_with_none() {
let mw = TenantMiddleware::new()
.resolver(NeverResolver)
.on_failure(TenantFailureMode::Allow);
let req = make_request().await;
let next = tenant_capture_next();
let response = mw.handle(req, next).await.unwrap();
let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
assert!(json["slug"].is_null());
}
#[tokio::test]
async fn current_tenant_available_in_downstream_handler() {
let mw = TenantMiddleware::new().resolver(AlwaysResolver {
tenant: make_tenant("downstream-test"),
});
let req = make_request().await;
let next = tenant_capture_next();
let response = mw.handle(req, next).await.unwrap();
let json: serde_json::Value = serde_json::from_str(response.body()).unwrap();
assert_eq!(json["slug"], "downstream-test");
}
}