hyperlite 0.1.0

Lightweight HTTP framework built on hyper, tokio, and tower
Documentation
//! Production-ready Hyperlite server demonstrating Tower middleware
//! composition, distributed tracing, and request correlation.
//!
//! Run with:
//! ```bash
//! RUST_LOG=info cargo run --example with_middleware
//! ```
//!
//! Handy curl commands:
//! ```bash
//! curl http://127.0.0.1:3000/health
//!
//! curl -v http://127.0.0.1:3000/protected
//!
//! curl -X POST http://127.0.0.1:3000/echo \
//!   -H "Content-Type: application/json" \
//!   -d '{"test":"data"}'
//!
//! curl -X OPTIONS http://127.0.0.1:3000/echo \
//!   -H "Origin: http://localhost:3001" \
//!   -H "Access-Control-Request-Method: POST"
//! ```

use bytes::Bytes;
use http::header::{HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use http_body_util::Full;
use hyper::{Method, Request, Response, StatusCode};
use hyperlite::{parse_json_body, serve, success, BoxBody, BoxError, Router};
use serde::Serialize;
use serde_json::Value;
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
use tower::util::BoxCloneSyncService;
use tower::{Layer, Service, ServiceBuilder};
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::request_id::{
    MakeRequestUuid, PropagateRequestIdLayer, RequestId, SetRequestIdLayer,
};
use tracing::{info, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};

/// Application state shared across the layered service stack.
#[derive(Clone)]
struct AppState {
    app_name: String,
    request_count: Arc<Mutex<u64>>,
}

impl AppState {
    fn new(name: impl Into<String>) -> Self {
        Self {
            app_name: name.into(),
            request_count: Arc::new(Mutex::new(0)),
        }
    }
}

/// Custom middleware that increments a shared counter for every request.
#[derive(Clone)]
struct RequestCounterLayer {
    state: Arc<AppState>,
}

impl RequestCounterLayer {
    fn new(state: Arc<AppState>) -> Self {
        Self { state }
    }
}

#[derive(Clone)]
struct RequestCounterService<S> {
    inner: S,
    state: Arc<AppState>,
}

impl<S> Layer<S> for RequestCounterLayer
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + 'static,
    S::Future: Send + 'static,
{
    type Service = RequestCounterService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        RequestCounterService {
            inner,
            state: self.state.clone(),
        }
    }
}

impl<S> Service<Request<BoxBody>> for RequestCounterService<S>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Send
        + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
        {
            let mut total = self
                .state
                .request_count
                .lock()
                .expect("request counter poisoned");
            *total += 1;
        }

        let future = self.inner.call(req);
        Box::pin(future)
    }
}

/// Health check response payload exposing request statistics.
#[derive(Serialize)]
struct HealthResponse {
    status: &'static str,
    total_requests: u64,
    service: String,
}

/// Protected response showing the request identifier issued by middleware.
#[derive(Serialize)]
struct ProtectedResponse {
    message: &'static str,
    request_id: Option<String>,
}

/// Echo response that returns the received JSON payload.
#[derive(Serialize)]
struct EchoResponse {
    received: Value,
}

async fn health_handler(
    _req: Request<BoxBody>,
    state: Arc<AppState>,
) -> Result<Response<Full<Bytes>>, BoxError> {
    let total = *state
        .request_count
        .lock()
        .expect("request counter poisoned");

    let response = HealthResponse {
        status: "ok",
        total_requests: total,
        service: state.app_name.clone(),
    };

    Ok(success(StatusCode::OK, response))
}

async fn protected_handler(
    req: Request<BoxBody>,
    _state: Arc<AppState>,
) -> Result<Response<Full<Bytes>>, BoxError> {
    let request_id = req
        .extensions()
        .get::<RequestId>()
        .and_then(|id| id.header_value().to_str().ok())
        .map(|value| value.to_string());

    if request_id.is_none() {
        warn!("request id missing from protected handler");
    }

    let payload = ProtectedResponse {
        message: "Authenticated request received",
        request_id,
    };

    Ok(success(StatusCode::OK, payload))
}

async fn echo_handler(
    req: Request<BoxBody>,
    _state: Arc<AppState>,
) -> Result<Response<Full<Bytes>>, BoxError> {
    let body = parse_json_body::<Value>(req).await?;

    let payload = EchoResponse { received: body };
    Ok(success(StatusCode::OK, payload))
}

#[derive(Clone)]
struct MiddlewareStack {
    state: Arc<AppState>,
}

impl MiddlewareStack {
    fn new(state: Arc<AppState>) -> Self {
        Self { state }
    }
}

#[derive(Clone, Default)]
struct RequestTracingLayer;

#[derive(Clone)]
struct RequestTracingService<S> {
    inner: S,
}

impl<S> Layer<S> for MiddlewareStack
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    type Service = BoxCloneSyncService<Request<BoxBody>, Response<Full<Bytes>>, Infallible>;

    fn layer(&self, inner: S) -> Self::Service {
        let cors_layer = CorsLayer::new()
            .allow_origin(AllowOrigin::list(vec![
                HeaderValue::from_static("http://localhost:3000"),
                HeaderValue::from_static("http://localhost:5173"),
            ]))
            .allow_methods([
                Method::GET,
                Method::POST,
                Method::PUT,
                Method::DELETE,
                Method::OPTIONS,
            ])
            .allow_headers([ACCEPT, AUTHORIZATION, CONTENT_TYPE])
            .allow_credentials(true);

        let request_id_layer = SetRequestIdLayer::x_request_id(MakeRequestUuid);
        let propagate_layer = PropagateRequestIdLayer::x_request_id();
        let trace_layer = RequestTracingLayer;

        let service = ServiceBuilder::new()
            // Request IDs must be generated before any other middleware executes.
            .layer(request_id_layer)
            // After the ID exists, ensure it flows back to the client.
            .layer(propagate_layer)
            // Tracing captures spans enriched with the request ID.
            .layer(trace_layer)
            // CORS negotiates browser requests before hitting business logic.
            .layer(cors_layer)
            // Finally, update shared request metrics.
            .layer(RequestCounterLayer::new(self.state.clone()))
            .service(inner);

        BoxCloneSyncService::new(service)
    }
}

impl<S> Layer<S> for RequestTracingLayer
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Clone
        + Send
        + 'static,
    S::Future: Send + 'static,
{
    type Service = RequestTracingService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        RequestTracingService { inner }
    }
}

impl<S> Service<Request<BoxBody>> for RequestTracingService<S>
where
    S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
        + Send
        + 'static,
    S::Future: Send + 'static,
{
    type Response = Response<Full<Bytes>>;
    type Error = Infallible;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
        let method = req.method().clone();
        let path = req.uri().path().to_owned();
        let request_id = req
            .headers()
            .get("x-request-id")
            .and_then(|value| value.to_str().ok())
            .map(|value| value.to_string());
        let start = Instant::now();

        let future = self.inner.call(req);
        Box::pin(async move {
            let response = future.await?;
            let status = response.status();
            let elapsed_ms = start.elapsed().as_millis();

            if let Some(id) = request_id {
                info!(%method, %path, %status, elapsed_ms, %id, "request completed");
            } else {
                info!(%method, %path, %status, elapsed_ms, "request completed");
            }

            Ok(response)
        })
    }
}

fn build_middleware_stack(state: Arc<AppState>) -> MiddlewareStack {
    MiddlewareStack::new(state)
}

#[tokio::main]
async fn main() -> Result<(), BoxError> {
    tracing_subscriber::registry()
        .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
        .with(tracing_subscriber::fmt::layer())
        .init();

    let base_state = AppState::new("Hyperlite Middleware Demo");
    let state = Arc::new(base_state.clone());

    let router = Router::new(base_state)
        .route(
            "/health",
            Method::GET,
            Arc::new(|req, state| Box::pin(health_handler(req, state))),
        )
        .route(
            "/protected",
            Method::GET,
            Arc::new(|req, state| Box::pin(protected_handler(req, state))),
        )
        .route(
            "/echo",
            Method::POST,
            Arc::new(|req, state| Box::pin(echo_handler(req, state))),
        );

    let service = build_middleware_stack(state.clone()).layer(router);

    let addr: SocketAddr = "127.0.0.1:3000"
        .parse()
        .expect("valid socket address for middleware example");
    info!("Middleware example running on http://{addr}");

    serve(addr, service).await
}