use hyper::StatusCode;
use reinhardt_core::signals::{
RequestFinishedEvent, RequestStartedEvent, request_finished, request_started,
};
use reinhardt_http::Handler;
use reinhardt_http::{Request, Response};
use reinhardt_urls::routers::DefaultRouter;
use std::sync::Arc;
use tracing::{debug, error, trace, warn};
use crate::DispatchError;
pub struct BaseHandler {
#[allow(dead_code)]
is_async: bool,
router: Option<Arc<DefaultRouter>>,
}
impl BaseHandler {
pub fn new() -> Self {
Self {
is_async: true,
router: None,
}
}
pub fn with_router(router: Arc<DefaultRouter>) -> Self {
Self {
is_async: true,
router: Some(router),
}
}
pub async fn handle_request(
&self,
request: Request,
) -> std::result::Result<Response, DispatchError> {
trace!("Handling request: {:?}", request.uri);
let event = RequestStartedEvent::new();
if let Err(e) = request_started().send(event).await {
warn!("Failed to send request_started signal: {}", e);
}
let response = Self::get_response_async(request, self.router.as_ref()).await;
let event = RequestFinishedEvent::new();
if let Err(e) = request_finished().send(event).await {
warn!("Failed to send request_finished signal: {}", e);
}
response
}
async fn get_response_async(
request: Request,
router: Option<&Arc<DefaultRouter>>,
) -> std::result::Result<Response, DispatchError> {
debug!("Getting response for: {}", request.uri.path());
if let Some(router) = router {
trace!("Attempting to route request through router");
match router.handle(request).await {
Ok(response) => {
trace!("Route handled successfully");
return Ok(response);
}
Err(reinhardt_core::exception::Error::NotFound(msg)) => {
debug!("No route matched: {}", msg);
return Ok(Response::new(StatusCode::NOT_FOUND));
}
Err(e) => {
error!("Handler error: {}", e);
return Err(DispatchError::View(e.to_string()));
}
}
}
debug!("No router configured, returning 404 Not Found");
Ok(Response::new(StatusCode::NOT_FOUND))
}
pub async fn handle_exception(&self, _request: &Request, error: DispatchError) -> Response {
error!("Handling exception: {}", error);
crate::build_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error")
}
pub fn is_async(&self) -> bool {
self.is_async
}
pub fn set_async(&mut self, is_async: bool) {
self.is_async = is_async;
}
}
impl Default for BaseHandler {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl Handler for BaseHandler {
async fn handle(&self, request: Request) -> reinhardt_core::exception::Result<Response> {
match self.handle_request(request).await {
Ok(response) => Ok(response),
Err(e) => {
error!("Handler error in BaseHandler::handle: {}", e);
Ok(crate::build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use reinhardt_urls::routers::{DefaultRouter, Router, path};
struct TestHandler {
response_body: String,
}
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, _req: Request) -> reinhardt_core::exception::Result<Response> {
Ok(Response::ok().with_body(self.response_body.clone()))
}
}
#[tokio::test]
async fn test_base_handler_new() {
let handler = BaseHandler::new();
assert!(handler.is_async());
}
#[tokio::test]
async fn test_base_handler_handle_request() {
let handler = BaseHandler::new();
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await;
let resp = response.unwrap();
assert_eq!(resp.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_base_handler_handle_exception() {
let handler = BaseHandler::new();
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let error = DispatchError::View("Test error".to_string());
let response = handler.handle_exception(&request, error).await;
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn test_handle_exception_does_not_expose_internal_details() {
let handler = BaseHandler::new();
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let sensitive_detail = "database connection refused at postgres://admin:secret@db:5432";
let error = DispatchError::Internal(sensitive_detail.to_string());
let response = handler.handle_exception(&request, error).await;
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
assert!(!body.contains("database"));
assert!(!body.contains("postgres"));
assert!(!body.contains("secret"));
assert_eq!(body, "Internal Server Error");
}
#[tokio::test]
async fn test_handler_impl_does_not_expose_error_in_body() {
struct FailingHandler;
#[async_trait]
impl Handler for FailingHandler {
async fn handle(&self, _req: Request) -> reinhardt_core::exception::Result<Response> {
Err(reinhardt_core::exception::Error::Internal(
"module::secret_handler panicked at /src/app/handlers.rs:42".to_string(),
))
}
}
let mut router = DefaultRouter::new();
let failing = Arc::new(FailingHandler);
let route = path("/fail", failing).with_name("fail");
router.add_route(route);
let handler = BaseHandler::with_router(Arc::new(router));
let request = Request::builder()
.method(Method::GET)
.uri("/fail")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
assert!(!body.contains("panicked"));
assert!(!body.contains("handlers.rs"));
assert!(!body.contains("secret_handler"));
assert_eq!(body, "Internal Server Error");
}
#[test]
fn test_base_handler_async_mode() {
let mut handler = BaseHandler::new();
assert!(handler.is_async());
handler.set_async(false);
assert!(!handler.is_async());
}
#[tokio::test]
async fn test_base_handler_different_methods() {
let handler = BaseHandler::new();
for method in [Method::GET, Method::POST, Method::PUT, Method::DELETE] {
let request = Request::builder()
.method(method)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await;
assert!(response.is_ok());
}
}
#[tokio::test]
async fn test_base_handler_different_uris() {
let handler = BaseHandler::new();
for path in ["/", "/test", "/api/v1/users", "/admin/login"] {
let request = Request::builder()
.method(Method::GET)
.uri(path)
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await;
assert!(response.is_ok());
}
}
#[tokio::test]
async fn test_handler_with_router() {
let mut router = DefaultRouter::new();
let test_handler = Arc::new(TestHandler {
response_body: "Test response".to_string(),
});
let route = path("/test", test_handler).with_name("test");
router.add_route(route);
let handler = BaseHandler::with_router(Arc::new(router));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await;
let resp = response.unwrap();
assert_eq!(resp.status, StatusCode::OK);
let body = String::from_utf8(resp.body.to_vec()).unwrap();
assert_eq!(body, "Test response");
}
#[tokio::test]
async fn test_handler_404_not_found() {
let router = DefaultRouter::new();
let handler = BaseHandler::with_router(Arc::new(router));
let request = Request::builder()
.method(Method::GET)
.uri("/nonexistent")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await;
let resp = response.unwrap();
assert_eq!(resp.status, StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_handler_multiple_routes() {
let mut router = DefaultRouter::new();
let hello_handler = Arc::new(TestHandler {
response_body: "Hello".to_string(),
});
let hello_route = path("/hello", hello_handler).with_name("hello");
router.add_route(hello_route);
let world_handler = Arc::new(TestHandler {
response_body: "World".to_string(),
});
let world_route = path("/world", world_handler).with_name("world");
router.add_route(world_route);
let handler = BaseHandler::with_router(Arc::new(router));
let request = Request::builder()
.method(Method::GET)
.uri("/hello")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(String::from_utf8(response.body.to_vec()).unwrap(), "Hello");
let request = Request::builder()
.method(Method::GET)
.uri("/world")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle_request(request).await.unwrap();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(String::from_utf8(response.body.to_vec()).unwrap(), "World");
}
}