hyperlite 0.1.0

Lightweight HTTP framework built on hyper, tokio, and tower
Documentation
use std::net::{SocketAddr, TcpListener};
use std::sync::Arc;
use std::time::Duration;

use bytes::Bytes;
use http::header::HeaderValue;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::client::conn::http1;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use serde::Deserialize;
use serde_json::json;
use tokio::net::TcpStream;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tower::{Service, ServiceBuilder};
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};

use hyperlite::{parse_json_body, serve, success, BoxBody, BoxError, Router};

mod test_helpers;
use test_helpers::*;

fn available_addr() -> SocketAddr {
    let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind test listener");
    let addr = listener.local_addr().expect("local addr");
    drop(listener);
    addr
}

async fn spawn_server<S>(addr: SocketAddr, service: S) -> JoinHandle<Result<(), BoxError>>
where
    S: Service<
            Request<BoxBody>,
            Response = Response<Full<Bytes>>,
            Error = std::convert::Infallible,
        > + Clone
        + Send
        + Sync
        + 'static,
    S::Future: Send + 'static,
{
    tokio::spawn(async move { serve(addr, service).await })
}

async fn wait_for_server(addr: SocketAddr) {
    for _ in 0..10 {
        if tokio::net::TcpStream::connect(addr).await.is_ok() {
            return;
        }
        tokio::time::sleep(Duration::from_millis(25)).await;
    }
    panic!("server did not start on {addr}");
}

async fn send_request(addr: SocketAddr, request: Request<Full<Bytes>>) -> Response<Incoming> {
    let stream = TcpStream::connect(addr)
        .await
        .expect("failed to connect to server");
    let io = TokioIo::new(stream);
    let (mut sender, connection) = http1::handshake(io).await.expect("client handshake failed");

    tokio::spawn(async move {
        if let Err(err) = connection.await {
            #[cfg(feature = "tracing")]
            tracing::error!(?err, "client connection error");
            #[cfg(not(feature = "tracing"))]
            let _ = err;
        }
    });

    sender
        .send_request(request)
        .await
        .expect("request should succeed")
}

async fn send_request_with_timeout(
    addr: SocketAddr,
    request: Request<Full<Bytes>>,
) -> Response<Incoming> {
    timeout(Duration::from_secs(2), send_request(addr, request))
        .await
        .expect("client request timed out")
}

async fn response_bytes(response: Response<Incoming>) -> Bytes {
    response
        .into_body()
        .collect()
        .await
        .expect("failed to read response body")
        .to_bytes()
}

async fn response_json(response: Response<Incoming>) -> serde_json::Value {
    let bytes = response_bytes(response).await;
    serde_json::from_slice(&bytes).expect("invalid JSON body")
}

#[derive(Deserialize)]
struct EchoPayload {
    name: String,
}

#[tokio::test]
async fn test_connection_handler_body_conversion() {
    let router = Router::new(()).route(
        "/json",
        Method::POST,
        Arc::new(|req, _| {
            Box::pin(async move {
                let payload = parse_json_body::<EchoPayload>(req).await?;
                Ok(success(StatusCode::OK, payload.name))
            })
        }),
    );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    let uri = format!("http://{}:{}/json", addr.ip(), addr.port());
    let request = Request::builder()
        .method(Method::POST)
        .uri(uri)
        .header("content-type", "application/json")
        .body(Full::new(Bytes::from_static(br#"{"name":"Alice"}"#)))
        .unwrap();

    let response = send_request_with_timeout(addr, request).await;
    assert_eq!(response.status(), StatusCode::OK);
    let json = response_json(response).await;
    assert_eq!(json["data"], "Alice");

    handle.abort();
}

#[tokio::test]
async fn test_connection_handler_service_call() {
    let state = TestState::new();
    let router = Router::new(state.clone()).route(
        "/count",
        Method::GET,
        Arc::new(|_, state| {
            Box::pin(async move {
                state.increment();
                Ok(success(StatusCode::OK, state.get()))
            })
        }),
    );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    let uri = format!("http://{}:{}/count", addr.ip(), addr.port());
    let request = Request::builder()
        .method(Method::GET)
        .uri(uri)
        .body(Full::new(Bytes::new()))
        .unwrap();

    let response = send_request_with_timeout(addr, request).await;
    assert_eq!(response.status(), StatusCode::OK);
    assert_eq!(state.get(), 1);

    handle.abort();
}

#[tokio::test]
async fn test_connection_handler_extensions_preserved() {
    let router = Router::new(()).route(
        "/headers",
        Method::GET,
        Arc::new(|req, _| {
            Box::pin(async move {
                let value = req
                    .headers()
                    .get("x-custom")
                    .and_then(|v| v.to_str().ok())
                    .unwrap_or("")
                    .to_string();
                Ok(success(StatusCode::OK, value))
            })
        }),
    );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    let uri = format!("http://{}:{}/headers", addr.ip(), addr.port());
    let request = Request::builder()
        .method(Method::GET)
        .uri(uri)
        .header("x-custom", "value")
        .body(Full::new(Bytes::new()))
        .unwrap();

    let response = send_request_with_timeout(addr, request).await;
    let json = response_json(response).await;
    assert_eq!(json["data"], "value");

    handle.abort();
}

#[tokio::test]
async fn test_connection_handler_clone() {
    let state = TestState::new();
    let router = Router::new(state.clone()).route(
        "/clone",
        Method::GET,
        Arc::new(|_, state| {
            Box::pin(async move {
                state.increment();
                Ok(success(StatusCode::OK, state.get()))
            })
        }),
    );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    for _ in 0..2 {
        let uri = format!("http://{}:{}/clone", addr.ip(), addr.port());
        let request = Request::builder()
            .method(Method::GET)
            .uri(&uri)
            .body(Full::new(Bytes::new()))
            .unwrap();
        let response = send_request_with_timeout(addr, request).await;
        assert_eq!(response.status(), StatusCode::OK);
    }

    assert_eq!(state.get(), 2);
    handle.abort();
}

#[tokio::test]
async fn test_serve_binds_to_port() {
    let router = Router::new(()).route(
        "/health",
        Method::GET,
        Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "ok")) })),
    );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    let uri = format!("http://{}:{}/health", addr.ip(), addr.port());
    let request = Request::builder()
        .method(Method::GET)
        .uri(uri)
        .body(Full::new(Bytes::new()))
        .unwrap();

    let response = send_request_with_timeout(addr, request).await;
    assert_eq!(response.status(), StatusCode::OK);

    handle.abort();
}

#[tokio::test]
async fn test_serve_accepts_connections() {
    let router = Router::new(()).route(
        "/ping",
        Method::GET,
        Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "pong")) })),
    );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    let uri = format!("http://{}:{}/ping", addr.ip(), addr.port());
    let request = Request::builder()
        .method(Method::GET)
        .uri(uri)
        .body(Full::new(Bytes::new()))
        .unwrap();

    let response = send_request_with_timeout(addr, request).await;
    let json = response_json(response).await;
    assert_eq!(json["data"], "pong");

    handle.abort();
}

// Run the ignored integration tests with `cargo test -- --ignored` to exercise full server stacks.
#[tokio::test]
#[ignore]
async fn test_serve_with_router() {
    let router = Router::new(())
        .route(
            "/one",
            Method::GET,
            Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "one")) })),
        )
        .route(
            "/two",
            Method::GET,
            Arc::new(|_, _| Box::pin(async { Ok(success(StatusCode::OK, "two")) })),
        );

    let addr = available_addr();
    let handle = spawn_server(addr, router).await;
    wait_for_server(addr).await;

    for (path, expected) in [("/one", "one"), ("/two", "two")] {
        let uri = format!("http://{}:{}{}", addr.ip(), addr.port(), path);
        let request = Request::builder()
            .method(Method::GET)
            .uri(uri)
            .body(Full::new(Bytes::new()))
            .unwrap();
        let response = send_request_with_timeout(addr, request).await;
        let json = response_json(response).await;
        assert_eq!(json["data"], expected);
    }

    handle.abort();
}

#[tokio::test]
#[ignore]
async fn test_serve_with_middleware() {
    let router = Router::new(()).route(
        "/middleware",
        Method::GET,
        Arc::new(|req, _| {
            Box::pin(async move {
                let request_id = req
                    .extensions()
                    .get::<tower_http::request_id::RequestId>()
                    .and_then(|id| id.header_value().to_str().ok())
                    .unwrap_or("")
                    .to_string();
                Ok(success(StatusCode::OK, json!({ "request_id": request_id })))
            })
        }),
    );

    let service = ServiceBuilder::new()
        .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
        .layer(PropagateRequestIdLayer::x_request_id())
        .service(router);

    let addr = available_addr();
    let handle = spawn_server(addr, service).await;
    wait_for_server(addr).await;

    let uri = format!("http://{}:{}/middleware", addr.ip(), addr.port());
    let request = Request::builder()
        .method(Method::GET)
        .uri(uri)
        .body(Full::new(Bytes::new()))
        .unwrap();

    let response = send_request_with_timeout(addr, request).await;
    let header = response
        .headers()
        .get("x-request-id")
        .cloned()
        .unwrap_or_else(|| HeaderValue::from_static(""));
    let bytes = response_bytes(response).await;
    let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
    assert_eq!(
        json["data"]["request_id"].as_str().unwrap(),
        header.to_str().unwrap()
    );

    handle.abort();
}