hyperlite 0.1.0

Lightweight HTTP framework built on hyper, tokio, and tower
Documentation
//! Server utilities for running Hyperlite applications.
//!
//! This module provides a reusable [`serve`] function that binds a TCP listener,
//! accepts incoming HTTP connections, and forwards each request to a Tower
//! `Service`. It mirrors the manual server loop used in the backend binary while
//! adding graceful shutdown handling out of the box.
//!
//! # Features
//! - Runs any Tower `Service` that accepts Hyperlite's boxed request bodies
//! - Spawns a task per connection and supports concurrent handling
//! - Automatically listens for shutdown signals (Ctrl+C and SIGTERM)
//! - Currently supports HTTP/1.1 via Hyper's connection builder
//!
//! # Example
//! ```rust,no_run
//! use bytes::Bytes;
//! use hyper::{Method, Request, Response, StatusCode};
//! use http_body_util::Full;
//! use hyperlite::{serve, success, BoxBody, Router};
//! use serde::Serialize;
//! use std::net::SocketAddr;
//! use std::sync::Arc;
//!
//! #[derive(Clone)]
//! struct AppState;
//!
//! #[derive(Serialize)]
//! struct Greeting {
//!     message: String,
//! }
//!
//! async fn hello(
//!     _req: Request<BoxBody>,
//!     _state: Arc<AppState>,
//! ) -> Result<Response<Full<Bytes>>, hyperlite::BoxError> {
//!     Ok(success(
//!         StatusCode::OK,
//!         Greeting {
//!             message: "Hello, World!".to_string(),
//!         },
//!     ))
//! }
//!
//! #[tokio::main]
//! async fn main() -> Result<(), hyperlite::BoxError> {
//!     let router = Router::new(AppState).route(
//!         "/hello",
//!         Method::GET,
//!         Arc::new(|req, state| Box::pin(hello(req, state))),
//!     );
//!
//!     let addr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
//!     serve(addr, router).await
//! }
//! ```

use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;

use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use hyper_util::service::TowerToHyperService;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tower::{Service, ServiceExt};

use crate::{BodyError, BoxBody, BoxError};

/// Runs the provided Tower service on the supplied address until a shutdown signal is received.
pub async fn serve<S>(addr: impl Into<SocketAddr>, service: S) -> Result<(), BoxError>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    let address = addr.into();
    let listener = TcpListener::bind(address)
        .await
        .map_err(|err| -> BoxError { Box::new(err) })?;

    #[cfg(feature = "tracing")]
    tracing::info!(addr = %listener.local_addr().unwrap_or(address), "listening for HTTP connections");

    let mut shutdown = Box::pin(shutdown_signal());
    let mut tasks = JoinSet::new();

    loop {
        tokio::select! {
            _ = &mut shutdown => {
                #[cfg(feature = "tracing")]
                tracing::info!("shutdown signal received");
                break;
            }
            accept_result = listener.accept() => {
                let (stream, peer_addr) = accept_result.map_err(|err| -> BoxError { Box::new(err) })?;
                #[cfg(not(feature = "tracing"))]
                let _ = &peer_addr;

                let service = service.clone();

                tasks.spawn(async move {
                    let io = TokioIo::new(stream);
                    let handler = ConnectionHandler::new(service);
                    let hyper_service = TowerToHyperService::new(handler);

                    if let Err(error) = http1::Builder::new().serve_connection(io, hyper_service).await {
                        #[cfg(feature = "tracing")]
                        tracing::error!(?peer_addr, ?error, "error serving connection");
                        #[cfg(not(feature = "tracing"))]
                        let _ = error;
                    }
                });
            }
            join_result = tasks.join_next(), if !tasks.is_empty() => {
                if let Some(Err(join_error)) = join_result {
                    #[cfg(feature = "tracing")]
                    tracing::error!(?join_error, "connection task failed");
                    #[cfg(not(feature = "tracing"))]
                    let _ = join_error;
                }
            }
        }
    }

    while let Some(result) = tasks.join_next().await {
        if let Err(join_error) = result {
            #[cfg(feature = "tracing")]
            tracing::error!(?join_error, "connection task failed during shutdown");
            #[cfg(not(feature = "tracing"))]
            let _ = join_error;
        }
    }

    #[cfg(feature = "tracing")]
    tracing::info!("server shutdown complete");

    Ok(())
}

struct ConnectionHandler<S>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    service: S,
}

impl<S> ConnectionHandler<S>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    fn new(service: S) -> Self {
        Self { service }
    }
}

impl<S> Clone for ConnectionHandler<S>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    fn clone(&self) -> Self {
        Self {
            service: self.service.clone(),
        }
    }
}

impl<S> Service<Request<Incoming>> for ConnectionHandler<S>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    type Response = Response<Full<Bytes>>;
    type Error = BoxError;
    type Future =
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        match self.service.poll_ready(cx) {
            std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(())),
            std::task::Poll::Ready(Err(err)) => match err {},
            std::task::Poll::Pending => std::task::Poll::Pending,
        }
    }

    fn call(&mut self, req: Request<Incoming>) -> Self::Future {
        let service = self.service.clone();

        let fut = async move {
            let (parts, body) = req.into_parts();
            let boxed_body: BoxBody = body.map_err(|err| -> BodyError { err }).boxed();
            let request = Request::from_parts(parts, boxed_body);

            let mut svc = service;
            if let Err(err) = svc.ready().await {
                match err {}
            }

            match svc.call(request).await {
                Ok(response) => Ok::<_, BoxError>(response),
                Err(err) => match err {},
            }
        };

        Box::pin(fut)
    }
}

fn shutdown_signal() -> Pin<Box<dyn Future<Output = ()> + Send>> {
    Box::pin(async {
        let ctrl_c = async {
            tokio::signal::ctrl_c()
                .await
                .expect("failed to install Ctrl+C handler");
        };

        #[cfg(unix)]
        let terminate = async {
            use tokio::signal::unix::{signal, SignalKind};

            signal(SignalKind::terminate())
                .expect("failed to install terminate signal handler")
                .recv()
                .await;
        };

        #[cfg(not(unix))]
        let terminate = std::future::pending::<()>();

        tokio::select! {
            _ = ctrl_c => {},
            _ = terminate => {},
        }
    })
}