use async_trait::async_trait;
use reinhardt_core::exception::Result;
use std::sync::Arc;
use crate::{Request, Response};
#[async_trait]
pub trait Handler: Send + Sync {
async fn handle(&self, request: Request) -> Result<Response>;
}
#[async_trait]
impl<T: Handler + ?Sized> Handler for Arc<T> {
async fn handle(&self, request: Request) -> Result<Response> {
(**self).handle(request).await
}
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;
fn should_continue(&self, _request: &Request) -> bool {
true
}
}
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn Middleware>>,
handler: Arc<dyn Handler>,
}
impl MiddlewareChain {
pub fn new(handler: Arc<dyn Handler>) -> Self {
Self {
middlewares: Vec::new(),
handler,
}
}
pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middlewares.push(middleware);
self
}
pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
self.middlewares.push(middleware);
}
}
#[async_trait]
impl Handler for MiddlewareChain {
async fn handle(&self, request: Request) -> Result<Response> {
if self.middlewares.is_empty() {
return self.handler.handle(request).await;
}
let mut current_handler = self.handler.clone();
let active_middlewares: Vec<_> = self
.middlewares
.iter()
.rev()
.filter(|mw| mw.should_continue(&request))
.collect();
for middleware in active_middlewares {
let mw = middleware.clone();
let handler = current_handler.clone();
current_handler = Arc::new(ConditionalComposedHandler {
middleware: mw,
next: handler,
});
}
current_handler.handle(request).await
}
}
struct ConditionalComposedHandler {
middleware: Arc<dyn Middleware>,
next: Arc<dyn Handler>,
}
#[async_trait]
impl Handler for ConditionalComposedHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let response = self.middleware.process(request, self.next.clone()).await?;
if response.should_stop_chain() {
return Ok(response);
}
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
struct MockHandler {
response_body: String,
}
#[async_trait]
impl Handler for MockHandler {
async fn handle(&self, _request: Request) -> Result<Response> {
Ok(Response::ok().with_body(self.response_body.clone()))
}
}
struct MockMiddleware {
prefix: String,
}
#[async_trait]
impl Middleware for MockMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
}
fn create_test_request() -> Request {
Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap()
}
#[tokio::test]
async fn test_handler_basic() {
let handler = MockHandler {
response_body: "Hello".to_string(),
};
let request = create_test_request();
let response = handler.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Hello");
}
#[tokio::test]
async fn test_middleware_basic() {
let handler = Arc::new(MockHandler {
response_body: "World".to_string(),
});
let middleware = MockMiddleware {
prefix: "Hello, ".to_string(),
};
let request = create_test_request();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Hello, World");
}
#[tokio::test]
async fn test_middleware_chain_empty() {
let handler = Arc::new(MockHandler {
response_body: "Test".to_string(),
});
let chain = MiddlewareChain::new(handler);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Test");
}
#[tokio::test]
async fn test_middleware_chain_single() {
let handler = Arc::new(MockHandler {
response_body: "Handler".to_string(),
});
let middleware1 = Arc::new(MockMiddleware {
prefix: "MW1:".to_string(),
});
let chain = MiddlewareChain::new(handler).with_middleware(middleware1);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "MW1:Handler");
}
#[tokio::test]
async fn test_middleware_chain_multiple() {
let handler = Arc::new(MockHandler {
response_body: "Data".to_string(),
});
let middleware1 = Arc::new(MockMiddleware {
prefix: "M1:".to_string(),
});
let middleware2 = Arc::new(MockMiddleware {
prefix: "M2:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(middleware1)
.with_middleware(middleware2);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "M1:M2:Data");
}
#[tokio::test]
async fn test_middleware_chain_add_middleware() {
let handler = Arc::new(MockHandler {
response_body: "Result".to_string(),
});
let middleware = Arc::new(MockMiddleware {
prefix: "Prefix:".to_string(),
});
let mut chain = MiddlewareChain::new(handler);
chain.add_middleware(middleware);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Prefix:Result");
}
struct ConditionalMiddleware {
prefix: String,
}
#[async_trait]
impl Middleware for ConditionalMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
fn should_continue(&self, request: &Request) -> bool {
request.uri.path().starts_with("/api/")
}
}
#[tokio::test]
async fn test_middleware_conditional_skip() {
let handler = Arc::new(MockHandler {
response_body: "Response".to_string(),
});
let conditional_mw = Arc::new(ConditionalMiddleware {
prefix: "API:".to_string(),
});
let chain = MiddlewareChain::new(handler).with_middleware(conditional_mw);
let api_request = Request::builder()
.method(Method::GET)
.uri("/api/users")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "API:Response");
let non_api_request = Request::builder()
.method(Method::GET)
.uri("/public")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(non_api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Response"); }
struct ShortCircuitMiddleware {
should_stop: bool,
}
#[async_trait]
impl Middleware for ShortCircuitMiddleware {
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
if self.should_stop {
return Ok(Response::unauthorized()
.with_body("Auth required")
.with_stop_chain(true));
}
next.handle(request).await
}
}
#[tokio::test]
async fn test_middleware_short_circuit() {
let handler = Arc::new(MockHandler {
response_body: "Handler Response".to_string(),
});
let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: true });
let normal_mw = Arc::new(MockMiddleware {
prefix: "Normal:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(short_circuit_mw)
.with_middleware(normal_mw);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::UNAUTHORIZED);
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Auth required");
}
#[tokio::test]
async fn test_middleware_no_short_circuit() {
let handler = Arc::new(MockHandler {
response_body: "Handler Response".to_string(),
});
let short_circuit_mw = Arc::new(ShortCircuitMiddleware { should_stop: false });
let normal_mw = Arc::new(MockMiddleware {
prefix: "Normal:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(short_circuit_mw)
.with_middleware(normal_mw);
let request = create_test_request();
let response = chain.handle(request).await.unwrap();
assert_eq!(response.status, hyper::StatusCode::OK);
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Normal:Handler Response");
}
#[tokio::test]
async fn test_middleware_multiple_conditions() {
let handler = Arc::new(MockHandler {
response_body: "Base".to_string(),
});
let api_mw = Arc::new(ConditionalMiddleware {
prefix: "API:".to_string(),
});
let always_mw = Arc::new(MockMiddleware {
prefix: "Always:".to_string(),
});
let chain = MiddlewareChain::new(handler)
.with_middleware(api_mw)
.with_middleware(always_mw);
let api_request = Request::builder()
.method(Method::GET)
.uri("/api/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "API:Always:Base");
let non_api_request = Request::builder()
.method(Method::GET)
.uri("/public")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = chain.handle(non_api_request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "Always:Base"); }
#[tokio::test]
async fn test_response_should_stop_chain() {
let response = Response::ok();
assert!(!response.should_stop_chain());
let stopping_response = Response::unauthorized().with_stop_chain(true);
assert!(stopping_response.should_stop_chain());
}
}