use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::HeaderName;
use http::StatusCode;
use tako_rs_core::body::TakoBody;
use tako_rs_core::middleware::IntoMiddleware;
use tako_rs_core::middleware::Next;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
#[derive(Debug, Clone)]
pub struct Tenant(pub String);
#[derive(Clone)]
pub enum TenantStrategy {
Header(HeaderName),
Subdomain,
PathPrefix(usize),
Custom(TenantCustomFn),
}
pub type TenantCustomFn = Arc<dyn Fn(&Request) -> Option<String> + Send + Sync + 'static>;
pub struct TenantMiddleware {
strategy: TenantStrategy,
required: bool,
}
impl TenantMiddleware {
pub fn from_header(name: HeaderName) -> Self {
Self {
strategy: TenantStrategy::Header(name),
required: false,
}
}
pub fn from_subdomain() -> Self {
Self {
strategy: TenantStrategy::Subdomain,
required: false,
}
}
pub fn from_path_segment(index: usize) -> Self {
Self {
strategy: TenantStrategy::PathPrefix(index),
required: false,
}
}
pub fn custom<F>(f: F) -> Self
where
F: Fn(&Request) -> Option<String> + Send + Sync + 'static,
{
Self {
strategy: TenantStrategy::Custom(Arc::new(f)),
required: false,
}
}
pub fn require(mut self, required: bool) -> Self {
self.required = required;
self
}
}
fn extract_subdomain(host: &str) -> Option<String> {
let host = host.split(':').next().unwrap_or(host);
let mut labels = host.split('.');
let first = labels.next()?;
labels.next()?;
if first.is_empty() {
return None;
}
Some(first.to_ascii_lowercase())
}
fn extract_path_segment(path: &str, index: usize) -> Option<String> {
path
.split('/')
.filter(|s| !s.is_empty())
.nth(index)
.map(str::to_string)
}
fn is_valid_tenant_id(id: &str) -> bool {
const MAX_LEN: usize = 64;
if id.is_empty() || id.len() > MAX_LEN {
return false;
}
if id == "." || id == ".." {
return false;
}
id.bytes()
.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.'))
}
impl IntoMiddleware for TenantMiddleware {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let strategy = self.strategy;
let required = self.required;
move |mut req: Request, next: Next| {
let strategy = strategy.clone();
Box::pin(async move {
let tenant = match &strategy {
TenantStrategy::Header(h) => req
.headers()
.get(h)
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|s| !s.is_empty())
.map(str::to_string),
TenantStrategy::Subdomain => req
.headers()
.get(http::header::HOST)
.and_then(|v| v.to_str().ok())
.and_then(extract_subdomain),
TenantStrategy::PathPrefix(idx) => extract_path_segment(req.uri().path(), *idx),
TenantStrategy::Custom(f) => f(&req),
};
let tenant = tenant.filter(|t| is_valid_tenant_id(t));
match tenant {
Some(t) => {
req.extensions_mut().insert(Tenant(t));
next.run(req).await
}
None if required => http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(TakoBody::from("missing tenant identifier"))
.expect("valid response"),
None => next.run(req).await,
}
})
}
}
}