use async_trait::async_trait;
use reinhardt_core::exception::Result;
use std::sync::Arc;
use crate::{Request, Response};
#[async_trait]
pub trait Handler: Send + Sync {
async fn handle(&self, request: Request) -> Result<Response>;
}
#[async_trait]
impl<T: Handler + ?Sized> Handler for Arc<T> {
async fn handle(&self, request: Request) -> Result<Response> {
(**self).handle(request).await
}
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
fn should_continue(&self, _request: &Request) -> bool {
true
}
}
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn Middleware>>,
handler: Arc<dyn Handler>,
}
impl MiddlewareChain {
pub fn new(handler: Arc<dyn Handler>) -> Self {
Self {
middlewares: Vec::new(),
handler,
}
}
pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middlewares.push(middleware);
self
}
pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
self.middlewares.push(middleware);
}
}
#[async_trait]
impl Handler for MiddlewareChain {
async fn handle(&self, request: Request) -> Result<Response> {
if self.middlewares.is_empty() {
return self.handler.handle(request).await;
}
let mut current_handler: Arc<dyn Handler> = Arc::new(ErrorToResponseHandler {
inner: self.handler.clone(),
});
let active_middlewares: Vec<_> = self
.middlewares
.iter()
.rev()
.filter(|mw| mw.should_continue(&request))
.collect();
for middleware in active_middlewares {
let mw = middleware.clone();
let handler = current_handler.clone();
current_handler = Arc::new(ConditionalComposedHandler {
middleware: mw,
next: handler,
});
}
current_handler.handle(request).await
}
}
pub struct ExcludeMiddleware {
inner: Arc<dyn Middleware>,
exclusions: Vec<String>,
}
impl ExcludeMiddleware {
pub fn new(inner: Arc<dyn Middleware>) -> Self {
Self {
inner,
exclusions: Vec::new(),
}
}
pub fn add_exclusion(mut self, pattern: &str) -> Self {
self.exclusions.push(pattern.to_string());
self
}
pub fn add_exclusion_mut(&mut self, pattern: &str) {
self.exclusions.push(pattern.to_string());
}
fn is_excluded(&self, path: &str) -> bool {
self.exclusions.iter().any(|pattern| {
if pattern.ends_with('/') {
path.starts_with(pattern.as_str())
} else {
path == pattern
}
})
}
}
#[async_trait]
impl Middleware for ExcludeMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
self.inner.process(request, next).await
}
fn should_continue(&self, request: &Request) -> bool {
if self.is_excluded(request.uri.path()) {
return false;
}
self.inner.should_continue(request)
}
}
struct ErrorToResponseHandler {
inner: Arc<dyn Handler>,
}
#[async_trait]
impl Handler for ErrorToResponseHandler {
async fn handle(&self, request: Request) -> Result<Response> {
match self.inner.handle(request).await {
Ok(response) => Ok(response),
Err(e) => Ok(Response::from(e)),
}
}
}
struct ConditionalComposedHandler {
middleware: Arc<dyn Middleware>,
next: Arc<dyn Handler>,
}
#[async_trait]
impl Handler for ConditionalComposedHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let response = match self.middleware.process(request, self.next.clone()).await {
Ok(response) => response,
Err(e) => Response::from(e),
};
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
struct MockHandler {
response_body: String,
}
#[async_trait]
impl Handler for MockHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::ok().with_body(self.response_body.clone()))
}
}
struct MockMiddleware {
prefix: String,
}
#[async_trait]
impl Middleware for MockMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
}
fn create_test_request() -> Request {
Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap()
}
#[tokio::test]
async fn test_handler_basic() {
let handler = MockHandler {
response_body: "Hello".to_string(),
};
let request = create_test_request();
let response = handler.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Hello");
}
#[tokio::test]
async fn test_middleware_basic() {
let handler = Arc::new(MockHandler {
response_body: "World".to_string(),
});
let middleware = MockMiddleware {
prefix: "Hello, ".to_string(),
};
let request = create_test_request();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Hello, World");
}
#[tokio::test]
async fn test_middleware_chain_empty() {
let handler = Arc::new(MockHandler {
response_body: "Test".to_string(),
});
let chain = MiddlewareChain::new(handler);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Test");
}
#[tokio::test]
async fn test_middleware_chain_single() {
let handler = Arc::new(MockHandler {
response_body: "Handler".to_string(),
});
let middleware1 = Arc::new(MockMiddleware {
prefix: "MW1:".to_string(),
});
let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "MW1:Handler");
}
#[tokio::test]
async fn test_middleware_chain_multiple() {
let handler = Arc::new(MockHandler {
response_body: "Data".to_string(),
});
let middleware1 = Arc::new(MockMiddleware {
prefix: "M1:".to_string(),
});
let middleware2 = Arc::new(MockMiddleware {
prefix: "M2:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(middleware1)
.with_middleware(middleware2);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "M1:M2:Data");
}
#[tokio::test]
async fn test_middleware_chain_add_middleware() {
let handler = Arc::new(MockHandler {
response_body: "Result".to_string(),
});
let middleware = Arc::new(MockMiddleware {
prefix: "Prefix:".to_string(),
});
let mut chain = MiddlewareChain::new(handler);
chain.add_middleware(middleware);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Prefix:Result");
}
struct ConditionalMiddleware {
prefix: String,
}
#[async_trait]
impl Middleware for ConditionalMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
fn should_continue(&self, request: &Request) -> bool {
request.uri.path().starts_with("/api/")
}
}
#[tokio::test]
async fn test_middleware_conditional_skip() {
let handler = Arc::new(MockHandler {
response_body: "Response".to_string(),
});
let conditional_mw = Arc::new(ConditionalMiddleware {
prefix: "API:".to_string(),
});
let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
let api_request = Request::builder()
.method(Method::GET)
.uri("/api/users")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "API:Response");
let non_api_request = Request::builder()
.method(Method::GET)
.uri("/public")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(non_api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Response"); }
struct ShortCircuitMiddleware {
should_stop: bool,
}
#[async_trait]
impl Middleware for ShortCircuitMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
if self.should_stop {
return Ok(Response::unauthorized()
.with_body("Auth required")
.with_stop_chain(true));
}
next.handle(request).await
}
}
#[tokio::test]
async fn test_middleware_short_circuit() {
let handler = Arc::new(MockHandler {
response_body: "Handler Response".to_string(),
});
let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
let normal_mw = Arc::new(MockMiddleware {
prefix: "Normal:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(short_circuit_mw)
.with_middleware(normal_mw);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Auth required");
}
#[tokio::test]
async fn test_middleware_no_short_circuit() {
let handler = Arc::new(MockHandler {
response_body: "Handler Response".to_string(),
});
let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
let normal_mw = Arc::new(MockMiddleware {
prefix: "Normal:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(short_circuit_mw)
.with_middleware(normal_mw);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::OK);
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Normal:Handler Response");
}
#[tokio::test]
async fn test_middleware_multiple_conditions() {
let handler = Arc::new(MockHandler {
response_body: "Base".to_string(),
});
let api_mw = Arc::new(ConditionalMiddleware {
prefix: "API:".to_string(),
});
let always_mw = Arc::new(MockMiddleware {
prefix: "Always:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(api_mw)
.with_middleware(always_mw);
let api_request = Request::builder()
.method(Method::GET)
.uri("/api/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "API:Always:Base");
let non_api_request = Request::builder()
.method(Method::GET)
.uri("/public")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(non_api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Always:Base"); }
#[tokio::test]
async fn test_response_should_stop_chain() {
let response = Response::ok();
assert!(!response.should_stop_chain());
let stopping_response = Response::unauthorized().with_stop_chain(true);
assert!(stopping_response.should_stop_chain());
}
fn create_request_with_path(path: &str) -> Request {
Request::builder()
.method(Method::GET)
.uri(path)
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap()
}
#[rstest::rstest]
#[case("/api/auth/login", true)]
#[case("/api/auth/register", true)]
#[case("/api/auth/", true)]
#[case("/api/users", false)]
#[case("/public", false)]
fn test_exclude_middleware_prefix_match(#[case] path: &str, #[case] should_exclude: bool) {
let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
prefix: "MW:".to_string(),
});
let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
let request = create_request_with_path(path);
let result = exclude_mw.should_continue(&request);
assert_eq!(result, !should_exclude);
}
#[rstest::rstest]
#[case("/health", true)]
#[case("/health/check", false)]
#[case("/healthz", false)]
#[case("/api/health", false)]
fn test_exclude_middleware_exact_match(#[case] path: &str, #[case] should_exclude: bool) {
let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
prefix: "MW:".to_string(),
});
let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/health");
let request = create_request_with_path(path);
let result = exclude_mw.should_continue(&request);
assert_eq!(result, !should_exclude);
}
#[rstest::rstest]
fn test_exclude_middleware_no_match_passes_through() {
let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
prefix: "MW:".to_string(),
});
let exclude_mw = ExcludeMiddleware::new(inner)
.add_exclusion("/api/auth/")
.add_exclusion("/health");
let request = create_request_with_path("/api/users");
let result = exclude_mw.should_continue(&request);
assert!(result);
}
#[rstest::rstest]
#[tokio::test]
async fn test_exclude_middleware_delegates_process() {
let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
prefix: "INNER:".to_string(),
});
let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/excluded/");
let handler = Arc::new(MockHandler {
response_body: "Response".to_string(),
});
let request = create_request_with_path("/api/test");
let response = exclude_mw.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "INNER:Response");
}
#[rstest::rstest]
fn test_exclude_middleware_multiple_exclusions() {
let inner: Arc<dyn Middleware> = Arc::new(MockMiddleware {
prefix: "MW:".to_string(),
});
let mut exclude_mw = ExcludeMiddleware::new(inner);
exclude_mw.add_exclusion_mut("/api/auth/");
exclude_mw.add_exclusion_mut("/admin/");
exclude_mw.add_exclusion_mut("/health");
assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
assert!(!exclude_mw.should_continue(&create_request_with_path("/admin/dashboard")));
assert!(!exclude_mw.should_continue(&create_request_with_path("/health")));
assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
}
#[rstest::rstest]
fn test_exclude_middleware_respects_inner_should_continue() {
let inner: Arc<dyn Middleware> = Arc::new(ConditionalMiddleware {
prefix: "API:".to_string(),
});
let exclude_mw = ExcludeMiddleware::new(inner).add_exclusion("/api/auth/");
assert!(!exclude_mw.should_continue(&create_request_with_path("/api/auth/login")));
assert!(!exclude_mw.should_continue(&create_request_with_path("/public")));
assert!(exclude_mw.should_continue(&create_request_with_path("/api/users")));
}
struct NotFoundHandler;
#[async_trait]
impl Handler for NotFoundHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Err(reinhardt_core::exception::Error::NotFound(
"not found".into(),
))
}
}
struct UnauthorizedHandler;
#[async_trait]
impl Handler for UnauthorizedHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Err(reinhardt_core::exception::Error::Authentication(
"unauthorized".into(),
))
}
}
struct HeaderAddingMiddleware {
header_name: &'static str,
header_value: &'static str,
}
#[async_trait]
impl Middleware for HeaderAddingMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let response = next.handle(request).await?;
Ok(response.with_header(self.header_name, self.header_value))
}
}
struct RejectingMiddleware;
#[async_trait]
impl Middleware for RejectingMiddleware {
async fn process(&self, _request: Request, _next: Arc<dyn Handler>) -> Result<Response> {
Err(reinhardt_core::exception::Error::Authorization(
"CSRF check failed".into(),
))
}
}
#[rstest::rstest]
#[tokio::test]
async fn test_chain_post_processing_runs_on_handler_error() {
let handler: Arc<dyn Handler> = Arc::new(NotFoundHandler);
let mut chain = MiddlewareChain::new(handler);
chain.add_middleware(Arc::new(HeaderAddingMiddleware {
header_name: "X-Custom-Security",
header_value: "applied",
}));
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::NOT_FOUND);
assert_eq!(
response
.headers
.get("X-Custom-Security")
.map(|v| v.to_str().unwrap()),
Some("applied")
);
}
#[rstest::rstest]
#[tokio::test]
async fn test_chain_post_processing_runs_on_middleware_error() {
let handler = Arc::new(MockHandler {
response_body: "OK".into(),
});
let mut chain = MiddlewareChain::new(handler);
chain.add_middleware(Arc::new(HeaderAddingMiddleware {
header_name: "X-Frame-Options",
header_value: "DENY",
}));
chain.add_middleware(Arc::new(RejectingMiddleware));
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::FORBIDDEN);
assert_eq!(
response
.headers
.get("X-Frame-Options")
.map(|v| v.to_str().unwrap()),
Some("DENY")
);
}
struct PassthroughMiddleware;
#[async_trait]
impl Middleware for PassthroughMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
next.handle(request).await
}
}
#[rstest::rstest]
#[tokio::test]
async fn test_chain_error_preserves_correct_status_code() {
let handler: Arc<dyn Handler> = Arc::new(UnauthorizedHandler);
let mut chain = MiddlewareChain::new(handler);
chain.add_middleware(Arc::new(PassthroughMiddleware));
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
}
}