use async_trait::async_trait;
use hyper::header::HeaderName;
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, PartialEq)]
pub struct Site {
pub id: u64,
pub domain: String,
pub name: String,
}
impl Site {
pub fn new(id: u64, domain: String, name: String) -> Self {
Self { id, domain, name }
}
}
#[derive(Debug, Default)]
pub struct SiteRegistry {
sites: RwLock<HashMap<String, Site>>,
default_site: RwLock<Option<Site>>,
}
impl SiteRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, site: Site) {
let domain = site.domain.clone();
self.sites
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(domain, site);
}
pub fn set_default(&self, site: Site) {
*self.default_site.write().unwrap_or_else(|e| e.into_inner()) = Some(site);
}
pub fn get_by_domain(&self, domain: &str) -> Option<Site> {
if let Some(site) = self
.sites
.read()
.unwrap_or_else(|e| e.into_inner())
.get(domain)
{
return Some(site.clone());
}
if let Some(without_www) = domain.strip_prefix("www.")
&& let Some(site) = self
.sites
.read()
.unwrap_or_else(|e| e.into_inner())
.get(without_www)
{
return Some(site.clone());
}
None
}
pub fn default_site(&self) -> Option<Site> {
self.default_site
.read()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
pub fn all(&self) -> Vec<Site> {
self.sites
.read()
.unwrap_or_else(|e| e.into_inner())
.values()
.cloned()
.collect()
}
pub fn clear(&self) {
self.sites
.write()
.unwrap_or_else(|e| e.into_inner())
.clear();
*self.default_site.write().unwrap_or_else(|e| e.into_inner()) = None;
}
}
pub const SITE_ID_HEADER: &str = "X-Site-ID";
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SiteConfig {
pub enabled: bool,
pub fallback_enabled: bool,
}
impl SiteConfig {
pub fn new() -> Self {
Self {
enabled: true,
fallback_enabled: true,
}
}
pub fn without_fallback(mut self) -> Self {
self.fallback_enabled = false;
self
}
pub fn disabled(mut self) -> Self {
self.enabled = false;
self
}
}
impl Default for SiteConfig {
fn default() -> Self {
Self::new()
}
}
pub struct SiteMiddleware {
config: SiteConfig,
pub registry: Arc<SiteRegistry>,
}
impl SiteMiddleware {
pub fn new(config: SiteConfig) -> Self {
Self {
config,
registry: Arc::new(SiteRegistry::new()),
}
}
pub fn with_defaults() -> Self {
Self::new(SiteConfig::default())
}
fn get_host(&self, request: &Request) -> Option<String> {
request
.headers
.get(hyper::header::HOST)
.and_then(|h| h.to_str().ok())
.map(|s| {
s.split(':').next().unwrap_or(s).to_string()
})
}
}
impl Default for SiteMiddleware {
fn default() -> Self {
Self::with_defaults()
}
}
#[async_trait]
impl Middleware for SiteMiddleware {
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
if !self.config.enabled {
return handler.handle(request).await;
}
let host = match self.get_host(&request) {
Some(h) => h,
None => {
if self.config.fallback_enabled {
let default_site = self
.registry
.default_site
.read()
.unwrap_or_else(|e| e.into_inner())
.clone();
if let Some(site) = default_site {
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if let (Ok(header_name), Ok(header_value)) = (
SITE_ID_HEADER.parse::<HeaderName>(),
site.id.to_string().parse(),
) {
response.headers.insert(header_name, header_value);
}
return Ok(response);
}
}
return handler.handle(request).await;
}
};
let mut site = self.registry.get_by_domain(&host);
if site.is_none() && self.config.fallback_enabled {
site = self.registry.default_site();
}
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
if let Some(site) = site
&& let (Ok(header_name), Ok(header_value)) = (
SITE_ID_HEADER.parse::<HeaderName>(),
site.id.to_string().parse(),
) {
response.headers.insert(header_name, header_value);
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::new(StatusCode::OK).with_body(Bytes::from("OK")))
}
}
#[tokio::test]
async fn test_site_detection() {
let config = SiteConfig::new();
let middleware = SiteMiddleware::new(config);
let site = Site::new(1, "example.com".to_string(), "Example Site".to_string());
middleware.registry.register(site);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(response.headers.contains_key(SITE_ID_HEADER));
assert_eq!(response.headers.get(SITE_ID_HEADER).unwrap(), "1");
}
#[tokio::test]
async fn test_www_subdomain_handling() {
let config = SiteConfig::new();
let middleware = SiteMiddleware::new(config);
let site = Site::new(1, "example.com".to_string(), "Example".to_string());
middleware.registry.register(site);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "www.example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(SITE_ID_HEADER).unwrap(), "1");
}
#[tokio::test]
async fn test_default_site_fallback() {
let config = SiteConfig::new();
let middleware = SiteMiddleware::new(config);
let default_site = Site::new(99, "default.com".to_string(), "Default".to_string());
middleware.registry.set_default(default_site);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "unknown.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(SITE_ID_HEADER).unwrap(), "99");
}
#[tokio::test]
async fn test_no_fallback() {
let config = SiteConfig::new().without_fallback();
let middleware = SiteMiddleware::new(config);
let default_site = Site::new(99, "default.com".to_string(), "Default".to_string());
middleware.registry.set_default(default_site);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "unknown.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(SITE_ID_HEADER));
}
#[tokio::test]
async fn test_multiple_sites() {
let config = SiteConfig::new();
let middleware = Arc::new(SiteMiddleware::new(config));
middleware
.registry
.register(Site::new(1, "site1.com".to_string(), "Site 1".to_string()));
middleware
.registry
.register(Site::new(2, "site2.com".to_string(), "Site 2".to_string()));
let handler = Arc::new(TestHandler);
let mut headers1 = HeaderMap::new();
headers1.insert(hyper::header::HOST, "site1.com".parse().unwrap());
let request1 = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers1)
.body(Bytes::new())
.build()
.unwrap();
let response1 = middleware.process(request1, handler.clone()).await.unwrap();
assert_eq!(response1.headers.get(SITE_ID_HEADER).unwrap(), "1");
let mut headers2 = HeaderMap::new();
headers2.insert(hyper::header::HOST, "site2.com".parse().unwrap());
let request2 = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers2)
.body(Bytes::new())
.build()
.unwrap();
let response2 = middleware.process(request2, handler).await.unwrap();
assert_eq!(response2.headers.get(SITE_ID_HEADER).unwrap(), "2");
}
#[tokio::test]
async fn test_disabled_middleware() {
let config = SiteConfig::new().disabled();
let middleware = SiteMiddleware::new(config);
let site = Site::new(1, "example.com".to_string(), "Example".to_string());
middleware.registry.register(site);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert!(!response.headers.contains_key(SITE_ID_HEADER));
}
#[tokio::test]
async fn test_port_handling() {
let config = SiteConfig::new();
let middleware = SiteMiddleware::new(config);
let site = Site::new(1, "example.com".to_string(), "Example".to_string());
middleware.registry.register(site);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(hyper::header::HOST, "example.com:8080".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.headers.get(SITE_ID_HEADER).unwrap(), "1");
}
#[tokio::test]
async fn test_no_host_header() {
let config = SiteConfig::new();
let middleware = SiteMiddleware::new(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[tokio::test]
async fn test_all_sites() {
let registry = SiteRegistry::new();
registry.register(Site::new(1, "site1.com".to_string(), "Site 1".to_string()));
registry.register(Site::new(2, "site2.com".to_string(), "Site 2".to_string()));
let sites = registry.all();
assert_eq!(sites.len(), 2);
}
#[tokio::test]
async fn test_clear_registry() {
let registry = SiteRegistry::new();
registry.register(Site::new(1, "site1.com".to_string(), "Site 1".to_string()));
registry.set_default(Site::new(
99,
"default.com".to_string(),
"Default".to_string(),
));
registry.clear();
assert_eq!(registry.all().len(), 0);
assert!(registry.default_site.read().unwrap().is_none());
}
#[tokio::test]
async fn test_default_middleware() {
let middleware = SiteMiddleware::default();
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
}
#[rstest::rstest]
fn test_rwlock_poison_recovery_site_registry() {
let registry = Arc::new(SiteRegistry::new());
registry.register(Site::new(
1,
"example.com".to_string(),
"Example".to_string(),
));
let registry_clone = Arc::clone(®istry);
let _ = std::thread::spawn(move || {
let _guard = registry_clone.sites.write().unwrap();
panic!("intentional panic to poison lock");
})
.join();
registry.register(Site::new(2, "test.com".to_string(), "Test".to_string()));
assert!(registry.get_by_domain("example.com").is_some());
assert!(registry.get_by_domain("test.com").is_some());
assert_eq!(registry.all().len(), 2);
let default = Site::new(99, "default.com".to_string(), "Default".to_string());
registry.set_default(default);
assert!(registry.default_site().is_some());
registry.clear();
assert_eq!(registry.all().len(), 0);
}
}