use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::StatusCode;
use hyper::body::Incoming;
use hyper::server::conn::http2;
use hyper::service::Service;
use hyper_util::rt::TokioIo;
use reinhardt_http::Handler;
use reinhardt_http::{Request, Response};
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use crate::shutdown::ShutdownCoordinator;
pub struct Http2Server {
handler: Arc<dyn Handler>,
}
impl Http2Server {
pub fn new<H: Handler + 'static>(handler: H) -> Self {
Self {
handler: Arc::new(handler),
}
}
pub async fn listen(self, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
println!("HTTP/2 server listening on http://{}", addr);
loop {
let (stream, _) = listener.accept().await?;
let handler = self.handler.clone();
tokio::task::spawn(async move {
if let Err(err) = Self::handle_connection(stream, handler).await {
eprintln!("Error handling HTTP/2 connection: {:?}", err);
}
});
}
}
pub async fn listen_with_shutdown(
self,
addr: SocketAddr,
coordinator: ShutdownCoordinator,
) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(addr).await?;
println!("HTTP/2 server listening on http://{}", addr);
let mut shutdown_rx = coordinator.subscribe();
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let handler = self.handler.clone();
let mut conn_shutdown = coordinator.subscribe();
tokio::task::spawn(async move {
tokio::select! {
result = Self::handle_connection(stream, handler) => {
if let Err(err) = result {
eprintln!("Error handling HTTP/2 connection: {:?}", err);
}
}
_ = conn_shutdown.recv() => {
}
}
});
}
_ = shutdown_rx.recv() => {
println!("Shutdown signal received, stopping HTTP/2 server...");
break;
}
}
}
coordinator.notify_shutdown_complete();
Ok(())
}
pub async fn handle_connection(
stream: TcpStream,
handler: Arc<dyn Handler>,
) -> Result<(), Box<dyn std::error::Error>> {
let io = TokioIo::new(stream);
let service = RequestService {
handler,
max_body_size: DEFAULT_MAX_BODY_SIZE,
};
http2::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await?;
Ok(())
}
}
const DEFAULT_MAX_BODY_SIZE: u64 = 10 * 1024 * 1024;
struct RequestService {
handler: Arc<dyn Handler>,
max_body_size: u64,
}
impl Service<hyper::Request<Incoming>> for RequestService {
type Response = hyper::Response<Full<Bytes>>;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: hyper::Request<Incoming>) -> Self::Future {
let handler = self.handler.clone();
let max_body_size = self.max_body_size;
Box::pin(async move {
if let Some(content_length) = req.headers().get(hyper::header::CONTENT_LENGTH)
&& let Ok(len_str) = content_length.to_str()
&& let Ok(len) = len_str.parse::<u64>()
&& len > max_body_size
{
return Ok(hyper::Response::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.body(Full::new(Bytes::from("Request body too large")))
.expect("Failed to build 413 response"));
}
let (parts, body) = req.into_parts();
let body_bytes = http_body_util::Limited::new(body, max_body_size as usize)
.collect()
.await
.map_err(|_| {
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Request body exceeds size limit",
)) as Box<dyn std::error::Error + Send + Sync>
})?
.to_bytes();
let request = Request::builder()
.method(parts.method)
.uri(parts.uri)
.version(parts.version)
.headers(parts.headers)
.body(body_bytes)
.build()
.expect("Failed to build request");
let response = handler
.handle(request)
.await
.unwrap_or_else(|_| Response::internal_server_error());
let mut hyper_response = hyper::Response::builder().status(response.status);
for (key, value) in response.headers.iter() {
hyper_response = hyper_response.header(key, value);
}
Ok(hyper_response.body(Full::new(response.body))?)
})
}
}
pub async fn serve_http2<H: Handler + 'static>(
addr: SocketAddr,
handler: H,
) -> Result<(), Box<dyn std::error::Error>> {
let server = Http2Server::new(handler);
server.listen(addr).await
}
pub async fn serve_http2_with_shutdown<H: Handler + 'static>(
addr: SocketAddr,
handler: H,
coordinator: ShutdownCoordinator,
) -> Result<(), Box<dyn std::error::Error>> {
let server = Http2Server::new(handler);
server.listen_with_shutdown(addr, coordinator).await
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHandler;
#[async_trait::async_trait]
impl Handler for TestHandler {
async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
Ok(Response::ok().with_body("Hello from HTTP/2!"))
}
}
#[tokio::test]
async fn test_http2_server_creation() {
let _server = Http2Server::new(TestHandler);
}
}