hyperlite 0.1.0

Lightweight HTTP framework built on hyper, tokio, and tower
Documentation
//! # Usage with Server
//!
//! The router composes with the [`crate::serve`] helper to form a complete HTTP
//! server stack:
//!
//! ```rust,no_run
//! use bytes::Bytes;
//! use hyper::{Method, Request, Response, StatusCode};
//! use hyperlite::{serve, success, BoxBody, Router};
//! use http_body_util::Full;
//! use std::net::SocketAddr;
//! use std::sync::Arc;
//!
//! #[derive(Clone)]
//! struct AppState;
//!
//! async fn list_users(
//!     _req: Request<BoxBody>,
//!     _state: Arc<AppState>,
//! ) -> Result<Response<Full<Bytes>>, hyperlite::BoxError> {
//!     Ok(success(StatusCode::OK, "ok"))
//! }
//!
//! async fn create_user(
//!     _req: Request<BoxBody>,
//!     _state: Arc<AppState>,
//! ) -> Result<Response<Full<Bytes>>, hyperlite::BoxError> {
//!     Ok(success(StatusCode::CREATED, "created"))
//! }
//!
//! #[tokio::main]
//! async fn main() -> Result<(), hyperlite::BoxError> {
//!     let router = Router::new(AppState)
//!         .route("/api/users", Method::GET, Arc::new(|req, state| Box::pin(list_users(req, state))))
//!         .route("/api/users", Method::POST, Arc::new(|req, state| Box::pin(create_user(req, state))));
//!
//!     let addr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
//!     serve(addr, router).await
//! }
//! ```
//!
//! Because the router implements Tower's [`Service`] trait, it
//! works seamlessly with middleware stacks built using `tower::ServiceBuilder`.

use std::collections::HashMap;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use bytes::Bytes;
use http_body_util::{combinators::BoxBody as HttpBoxBody, Full};
use hyper::{Method, Request, Response, StatusCode};
use matchit::Router as MatchitRouter;
use tower::Service;

use crate::BoxError;

/// Error type used by boxed bodies produced by the router.
pub type BodyError = hyper::Error;

/// Boxed body type leveraged by Hyperlite handlers and responses.
pub type BoxBody = HttpBoxBody<Bytes, BodyError>;

/// Shared handler type used by the router.
pub type Handler<S> = Arc<
    dyn Fn(
            Request<BoxBody>,
            Arc<S>,
        ) -> Pin<
            Box<dyn Future<Output = Result<Response<Full<Bytes>>, BoxError>> + Send + 'static>,
        > + Send
        + Sync
        + 'static,
>;

/// Wrapper for path parameters captured during routing.
#[derive(Clone, Debug, Default)]
pub struct PathParams(pub HashMap<String, String>);

#[derive(Clone)]
struct Route<S>
where
    S: Clone + Send + Sync + 'static,
{
    method: Method,
    handler: Handler<S>,
}

/// Hyperlite router built on top of `matchit` and the Tower `Service` trait.
pub struct Router<S>
where
    S: Clone + Send + Sync + 'static,
{
    routes: MatchitRouter<Vec<Route<S>>>,
    state: Arc<S>,
}

impl<S> Router<S>
where
    S: Clone + Send + Sync + 'static,
{
    /// Creates a new router instance with the provided application state.
    pub fn new(state: S) -> Self {
        Self {
            routes: MatchitRouter::new(),
            state: Arc::new(state),
        }
    }

    /// Registers a route for the given path and HTTP method.
    pub fn route(mut self, path: &str, method: Method, handler: Handler<S>) -> Self {
        // DEBUG: print registered route paths during tests
        println!(
            "[router debug] registering route '{}', method={}",
            path, method
        );

        // Try to get existing routes for this path
        match self.routes.at_mut(path) {
            Ok(matched) => {
                // Path already exists, add or update method handler
                if let Some(existing) = matched
                    .value
                    .iter_mut()
                    .find(|route| route.method == method)
                {
                    #[cfg(feature = "tracing")]
                    tracing::warn!(%path, %method, "overwriting existing route handler");
                    existing.handler = handler;
                } else {
                    matched.value.push(Route { method, handler });
                }
            }
            Err(_) => {
                // Path doesn't exist, insert new route with method
                let routes = vec![Route { method, handler }];
                if let Err(e) = self.routes.insert(path, routes) {
                    #[cfg(feature = "tracing")]
                    tracing::error!(%path, error = ?e, "failed to insert route");
                    #[cfg(not(feature = "tracing"))]
                    let _ = e;
                }
            }
        }

        self
    }
}

impl<S> Clone for Router<S>
where
    S: Clone + Send + Sync + 'static,
{
    fn clone(&self) -> Self {
        Self {
            routes: self.routes.clone(),
            state: self.state.clone(),
        }
    }
}

impl<S> Service<Request<BoxBody>> for Router<S>
where
    S: Clone + Send + Sync + 'static,
{
    type Response = Response<Full<Bytes>>;
    type Error = Infallible;
    type Future =
        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
        let method = req.method().clone();
        let path = req.uri().path().to_string();
        let state = self.state.clone();

        let mut allowed_methods = None;
        let mut params_map: Option<PathParams> = None;
        let handler = match self.routes.at(path.as_str()) {
            Ok(matched) => {
                params_map = Some(PathParams(
                    matched
                        .params
                        .iter()
                        .map(|(key, value)| (key.to_string(), value.to_string()))
                        .collect(),
                ));
                allowed_methods = Some(
                    matched
                        .value
                        .iter()
                        .map(|route| route.method.clone())
                        .collect::<Vec<_>>(),
                );

                matched
                    .value
                    .iter()
                    .find(|route| route.method == method)
                    .cloned()
            }
            Err(_) => None,
        };

        // DEBUG: print matched params when present to help diagnose path param issues in tests
        if let Some(ref p) = params_map {
            // print to stdout during tests
            println!("[router debug] path='{}' params={:?}", path, p);
        }

        Box::pin(async move {
            // If matchit provided a handler, use it. Insert path params if available.
            let response = match handler {
                Some(route) => {
                    let mut req = req;

                    // First try matchit-provided params
                    if let Some(path_params) = params_map {
                        req.extensions_mut().insert(path_params);
                    }

                    req.extensions_mut().insert(state.clone());
                    match (route.handler)(req, state).await {
                        Ok(response) => response,
                        Err(err) => error_response(err),
                    }
                }
                None => match allowed_methods {
                    Some(methods) => method_not_allowed_response(methods),
                    None => not_found_response(),
                },
            };

            Ok(response)
        })
    }
}

fn not_found_response() -> Response<Full<Bytes>> {
    match serde_json::to_vec(&serde_json::json!({
        "success": false,
        "errors": [{
            "code": "NOT_FOUND",
            "message": "Route not found",
        }]
    })) {
        Ok(body) => Response::builder()
            .status(StatusCode::NOT_FOUND)
            .header(hyper::header::CONTENT_TYPE, "application/json")
            .body(Full::from(Bytes::from(body)))
            .unwrap_or_else(|_| fallback_text_response(StatusCode::NOT_FOUND)),
        Err(_) => fallback_text_response(StatusCode::NOT_FOUND),
    }
}

fn method_not_allowed_response(methods: Vec<Method>) -> Response<Full<Bytes>> {
    let allow_header = methods
        .iter()
        .map(Method::as_str)
        .collect::<Vec<_>>()
        .join(", ");

    let body = serde_json::json!({
        "success": false,
        "errors": [{
            "code": "METHOD_NOT_ALLOWED",
            "message": "The requested method is not allowed for this path",
        }]
    });

    match serde_json::to_vec(&body) {
        Ok(bytes) => Response::builder()
            .status(StatusCode::METHOD_NOT_ALLOWED)
            .header(hyper::header::CONTENT_TYPE, "application/json")
            .header(hyper::header::ALLOW, allow_header)
            .body(Full::from(Bytes::from(bytes)))
            .unwrap_or_else(|_| fallback_text_response(StatusCode::METHOD_NOT_ALLOWED)),
        Err(_) => fallback_text_response(StatusCode::METHOD_NOT_ALLOWED),
    }
}

fn error_response(err: BoxError) -> Response<Full<Bytes>> {
    #[cfg(feature = "tracing")]
    {
        tracing::error!(error = ?err, "handler error");
    }

    #[cfg(not(feature = "tracing"))]
    let _ = &err;

    match serde_json::to_vec(&serde_json::json!({
        "success": false,
        "errors": [{
            "code": "INTERNAL_ERROR",
            "message": "An internal error occurred",
        }]
    })) {
        Ok(body) => Response::builder()
            .status(StatusCode::INTERNAL_SERVER_ERROR)
            .header(hyper::header::CONTENT_TYPE, "application/json")
            .body(Full::from(Bytes::from(body)))
            .unwrap_or_else(|_| fallback_text_response(StatusCode::INTERNAL_SERVER_ERROR)),
        Err(_) => fallback_text_response(StatusCode::INTERNAL_SERVER_ERROR),
    }
}

fn fallback_text_response(status: StatusCode) -> Response<Full<Bytes>> {
    Response::builder()
        .status(status)
        .header(hyper::header::CONTENT_TYPE, "text/plain; charset=utf-8")
        .body(Full::from(Bytes::from_static(b"internal server error")))
        .expect("static response")
}