use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::StatusCode;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::Service;
use hyper_util::rt::TokioIo;
use reinhardt_di::InjectionContext;
use reinhardt_http::{Handler, Middleware, MiddlewareChain};
use reinhardt_http::{Request, Response};
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use crate::shutdown::ShutdownCoordinator;
pub struct HttpServer {
handler: Arc<dyn Handler>,
pub(crate) middlewares: Vec<Arc<dyn Middleware>>,
di_context: Option<Arc<InjectionContext>>,
}
impl HttpServer {
pub fn new<H: Handler + 'static>(handler: H) -> Self {
Self {
handler: Arc::new(handler),
middlewares: Vec::new(),
di_context: None,
}
}
pub fn with_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.middlewares.push(Arc::new(middleware));
self
}
pub fn with_di_context(mut self, context: Arc<InjectionContext>) -> Self {
self.di_context = Some(context);
self
}
pub fn handler(&self) -> Arc<dyn Handler> {
self.handler.clone()
}
fn build_handler(&self) -> Arc<dyn Handler> {
if self.middlewares.is_empty() {
return self.handler.clone();
}
let mut chain = MiddlewareChain::new(self.handler.clone());
for middleware in &self.middlewares {
chain.add_middleware(middleware.clone());
}
Arc::new(chain)
}
pub async fn listen(self, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
let handler = self.build_handler();
let di_context = self.di_context.clone();
loop {
let (stream, socket_addr) = listener.accept().await?;
let handler = handler.clone();
let di_context = di_context.clone();
tokio::task::spawn(async move {
if let Err(err) =
Self::handle_connection(stream, socket_addr, handler, di_context).await
{
eprintln!("Error handling connection: {:?}", err);
}
});
}
}
pub async fn listen_with_shutdown(
self,
addr: SocketAddr,
coordinator: ShutdownCoordinator,
) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
let handler = self.build_handler();
let di_context = self.di_context.clone();
let mut shutdown_rx = coordinator.subscribe();
loop {
tokio::select! {
result = listener.accept() => {
let (stream, socket_addr) = result?;
let handler = handler.clone();
let di_context = di_context.clone();
let mut conn_shutdown = coordinator.subscribe();
tokio::task::spawn(async move {
tokio::select! {
result = Self::handle_connection(stream, socket_addr, handler, di_context) => {
if let Err(err) = result {
eprintln!("Error handling connection: {:?}", err);
}
}
_ = conn_shutdown.recv() => {
}
}
});
}
_ = shutdown_rx.recv() => {
println!("Shutdown signal received, stopping server...");
break;
}
}
}
coordinator.notify_shutdown_complete();
Ok(())
}
pub async fn handle_connection(
stream: TcpStream,
socket_addr: SocketAddr,
handler: Arc<dyn Handler>,
di_context: Option<Arc<InjectionContext>>,
) -> Result<(), Box<dyn std::error::Error>> {
let io = TokioIo::new(stream);
let service = RequestService {
handler,
remote_addr: socket_addr,
di_context,
max_body_size: DEFAULT_MAX_BODY_SIZE,
};
http1::Builder::new().serve_connection(io, service).await?;
Ok(())
}
}
const DEFAULT_MAX_BODY_SIZE: u64 = 10 * 1024 * 1024;
struct RequestService {
handler: Arc<dyn Handler>,
remote_addr: SocketAddr,
di_context: Option<Arc<InjectionContext>>,
max_body_size: u64,
}
impl Service<hyper::Request<Incoming>> for RequestService {
type Response = hyper::Response<Full<Bytes>>;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn call(&self, req: hyper::Request<Incoming>) -> Self::Future {
let handler = self.handler.clone();
let remote_addr = self.remote_addr;
let di_context = self.di_context.clone();
let max_body_size = self.max_body_size;
Box::pin(async move {
if let Some(content_length) = req.headers().get(hyper::header::CONTENT_LENGTH)
&& let Ok(len_str) = content_length.to_str()
&& let Ok(len) = len_str.parse::<u64>()
&& len > max_body_size
{
return Ok(hyper::Response::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.body(Full::new(Bytes::from("Request body too large")))
.expect("Failed to build 413 response"));
}
let (parts, body) = req.into_parts();
let body_bytes = http_body_util::Limited::new(body, max_body_size as usize)
.collect()
.await
.map_err(|_| {
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Request body exceeds size limit",
)) as Box<dyn std::error::Error + Send + Sync>
})?
.to_bytes();
let mut request = Request::builder()
.method(parts.method)
.uri(parts.uri)
.version(parts.version)
.headers(parts.headers)
.body(body_bytes)
.remote_addr(remote_addr)
.build()
.expect("Failed to build request");
if let Some(ctx) = di_context {
request.set_di_context(ctx);
}
let request_path = request.uri.path().to_string();
let response = handler.handle(request).await.unwrap_or_else(|e| {
if request_path.contains('.') && !request_path.ends_with(".json") {
eprintln!(
"[reinhardt WARN] Non-API request hit error-to-JSON conversion: path={}, error={}",
request_path, e
);
}
Response::from(e)
});
let mut hyper_response = hyper::Response::builder().status(response.status);
for (key, value) in response.headers.iter() {
hyper_response = hyper_response.header(key, value);
}
Ok(hyper_response.body(Full::new(response.body))?)
})
}
}
pub async fn serve<H: Handler + 'static>(
addr: SocketAddr,
handler: H,
) -> Result<(), Box<dyn std::error::Error>> {
let server = HttpServer::new(handler);
server.listen(addr).await
}
pub async fn serve_with_shutdown<H: Handler + 'static>(
addr: SocketAddr,
handler: H,
coordinator: ShutdownCoordinator,
) -> Result<(), Box<dyn std::error::Error>> {
let server = HttpServer::new(handler);
server.listen_with_shutdown(addr, coordinator).await
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
Ok(Response::ok().with_body("Hello, World!"))
}
}
#[tokio::test]
async fn test_http_server_creation() {
let _server = HttpServer::new(TestHandler);
}
#[tokio::test]
async fn test_http_server_with_middleware() {
use reinhardt_http::Middleware;
struct TestMiddleware {
prefix: String,
}
#[async_trait::async_trait]
impl Middleware for TestMiddleware {
async fn process(
&self,
request: Request,
next: Arc<dyn Handler>,
) -> reinhardt_core::exception::Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
}
let server = HttpServer::new(TestHandler).with_middleware(TestMiddleware {
prefix: "Middleware: ".to_string(),
});
assert_eq!(server.middlewares.len(), 1);
}
#[tokio::test]
async fn test_http_server_multiple_middlewares() {
use reinhardt_http::Middleware;
struct PrefixMiddleware {
prefix: String,
}
#[async_trait::async_trait]
impl Middleware for PrefixMiddleware {
async fn process(
&self,
request: Request,
next: Arc<dyn Handler>,
) -> reinhardt_core::exception::Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
}
let server = HttpServer::new(TestHandler)
.with_middleware(PrefixMiddleware {
prefix: "MW1:".to_string(),
})
.with_middleware(PrefixMiddleware {
prefix: "MW2:".to_string(),
});
assert_eq!(server.middlewares.len(), 2);
}
#[tokio::test]
async fn test_middleware_chain_execution() {
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use reinhardt_http::Middleware;
struct PrefixMiddleware {
prefix: String,
}
#[async_trait::async_trait]
impl Middleware for PrefixMiddleware {
async fn process(
&self,
request: Request,
next: Arc<dyn Handler>,
) -> reinhardt_core::exception::Result<Response> {
let response = next.handle(request).await?;
let current_body = String::from_utf8(response.body.to_vec()).unwrap_or_default();
let new_body = format!("{}{}", self.prefix, current_body);
Ok(Response::ok().with_body(new_body))
}
}
let server = HttpServer::new(TestHandler)
.with_middleware(PrefixMiddleware {
prefix: "First:".to_string(),
})
.with_middleware(PrefixMiddleware {
prefix: "Second:".to_string(),
});
let handler = server.build_handler();
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle(request).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "First:Second:Hello, World!");
}
struct ErrorHandler {
error_message: String,
}
#[async_trait::async_trait]
impl Handler for ErrorHandler {
async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
Err(reinhardt_core::exception::Error::Database(
self.error_message.clone(),
))
}
}
#[rstest]
#[case::database_connection_string(
"postgres://admin:s3cret@10.0.0.5/prod_db: connection refused",
"postgres"
)]
#[case::internal_file_path("/opt/app/config/secrets.yml: file not found", "/opt/app")]
#[case::sql_query_details(
"SELECT * FROM users WHERE password = 'hash123': syntax error",
"SELECT"
)]
#[tokio::test]
async fn test_error_handler_does_not_leak_internal_details(
#[case] sensitive_message: &str,
#[case] leaked_fragment: &str,
) {
let server = HttpServer::new(ErrorHandler {
error_message: sensitive_message.to_string(),
});
let handler = server.build_handler();
let request = Request::builder()
.method(hyper::Method::GET)
.uri("/")
.version(hyper::Version::HTTP_11)
.headers(hyper::HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = handler.handle(request).await.unwrap_or_else(Response::from);
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(response.status, StatusCode::INTERNAL_SERVER_ERROR);
assert!(
!body.contains(leaked_fragment),
"Response body must not contain internal details '{leaked_fragment}', but got: {body}"
);
}
}