camber 0.1.6

Opinionated async Rust for IO-bound services on top of Tokio
Documentation
mod common;

use camber::http::{Request, Response, Router, cors};
use camber::runtime;
use std::io::{Read, Write};

fn send_raw(addr: std::net::SocketAddr, request: &str) -> String {
    let mut stream = std::net::TcpStream::connect(addr).unwrap();
    stream.write_all(request.as_bytes()).unwrap();
    let mut response = String::new();
    stream.read_to_string(&mut response).unwrap();
    response
}

fn find_header<'a>(raw: &'a str, name: &str) -> Option<&'a str> {
    let header_section = raw.split("\r\n\r\n").next()?;
    for line in header_section.split("\r\n") {
        if let Some((key, value)) = line.split_once(": ") {
            if key.eq_ignore_ascii_case(name) {
                return Some(value);
            }
        }
    }
    None
}

#[test]
fn cors_adds_origin_header_for_allowed_origin() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.get("/hello", |_req: &Request| async { Response::text(200, "ok") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "GET /hello HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 200"));
        assert_eq!(
            find_header(&raw, "access-control-allow-origin"),
            Some("https://example.com"),
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_rejects_disallowed_origin() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.get("/hello", |_req: &Request| async { Response::text(200, "ok") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "GET /hello HTTP/1.1\r\nHost: localhost\r\nOrigin: https://evil.com\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 200"));
        assert!(
            find_header(&raw, "access-control-allow-origin").is_none(),
            "should not have ACAO header for disallowed origin",
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_handles_preflight_options() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.get("/api", |_req: &Request| async { Response::text(200, "data") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "OPTIONS /api HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
        );

        assert!(
            raw.starts_with("HTTP/1.1 204"),
            "preflight should return 204, got: {raw}",
        );
        assert_eq!(
            find_header(&raw, "access-control-allow-origin"),
            Some("https://example.com"),
        );
        assert!(
            find_header(&raw, "access-control-allow-methods").is_some(),
            "preflight should include allow-methods header",
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_builder_customizes_methods_and_max_age() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(
            cors::builder()
                .origins(&["https://example.com"])
                .methods(&["GET", "POST"])
                .max_age(7200)
                .build(),
        );
        router.get("/api", |_req: &Request| async { Response::text(200, "data") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "OPTIONS /api HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 204"));
        assert_eq!(
            find_header(&raw, "access-control-allow-methods"),
            Some("GET, POST"),
        );
        assert_eq!(
            find_header(&raw, "access-control-max-age"),
            Some("7200"),
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_wildcard_takes_precedence_over_exact_origin_when_credentials_disabled() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(
            cors::builder()
                .origins(&["https://example.com", "*"])
                .build(),
        );
        router.get("/hello", |_req: &Request| async { Response::text(200, "ok") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "GET /hello HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 200"));
        assert_eq!(
            find_header(&raw, "access-control-allow-origin"),
            Some("*"),
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_applies_to_proxy_response() {
    common::test_runtime().run(|| {
        let mut backend = Router::new();
        backend.get("/data", |_req: &Request| async { Response::text(200, "proxied-data") });
        let backend_addr = common::spawn_server(backend);

        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.proxy("/api", &format!("http://{backend_addr}"));

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            &format!(
                "GET /api/data HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nConnection: close\r\n\r\n"
            ),
        );

        assert!(
            raw.starts_with("HTTP/1.1 200"),
            "expected 200 proxied response, got: {raw}"
        );
        assert_eq!(
            find_header(&raw, "access-control-allow-origin"),
            Some("https://example.com"),
            "CORS header should be present on proxied response"
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_response_includes_vary_origin() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.get("/hello", |_req: &Request| async { Response::text(200, "ok") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "GET /hello HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 200"));
        let vary = find_header(&raw, "vary").expect("Vary header must be present on CORS response");
        assert!(
            vary.contains("Origin"),
            "Vary must contain Origin, got: {vary}",
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_preflight_includes_vary_headers() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.get("/api", |_req: &Request| async { Response::text(200, "data") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "OPTIONS /api HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 204"));
        let vary = find_header(&raw, "vary").expect("Vary header must be present on preflight");
        assert!(
            vary.contains("Origin"),
            "Vary must contain Origin, got: {vary}",
        );
        assert!(
            vary.contains("Access-Control-Request-Method"),
            "Vary must contain Access-Control-Request-Method, got: {vary}",
        );
        assert!(
            vary.contains("Access-Control-Request-Headers"),
            "Vary must contain Access-Control-Request-Headers, got: {vary}",
        );

        runtime::request_shutdown();
    }).unwrap();
}

#[test]
fn cors_composes_with_other_middleware() {
    common::test_runtime().run(|| {
        let mut router = Router::new();
        router.use_middleware(cors::allow_origins(&["https://example.com"]));
        router.use_middleware(|req, next| {
            let fut = next.call(req);
            Box::pin(async move { fut.await.with_header("X-Custom", "present") })
        });
        router.get("/hello", |_req: &Request| async { Response::text(200, "ok") });

        let addr = common::spawn_server(router);
        let raw = send_raw(
            addr,
            "GET /hello HTTP/1.1\r\nHost: localhost\r\nOrigin: https://example.com\r\nConnection: close\r\n\r\n",
        );

        assert!(raw.starts_with("HTTP/1.1 200"));
        assert_eq!(
            find_header(&raw, "access-control-allow-origin"),
            Some("https://example.com"),
        );
        assert_eq!(
            find_header(&raw, "x-custom"),
            Some("present"),
        );

        runtime::request_shutdown();
    }).unwrap();
}