cheers 0.1.0-alpha.1

Fullstack hypermedia framework for Rust.
use axum::{
    body::Body,
    middleware::Next,
    response::{IntoResponse, Redirect},
};

pub async fn redirect_trailing_slash(
    req: axum::http::Request<Body>,
    next: Next,
) -> axum::response::Response {
    let uri = req.uri();
    let path = uri.path();

    if path == "/" || !path.ends_with('/') {
        return next.run(req).await;
    }

    let path = path.trim_end_matches('/');
    if path.is_empty() || path.starts_with("//") {
        return next.run(req).await;
    }

    let uri = if let Some(query) = uri.query() {
        format!("{}?{}", path, query)
    } else {
        path.to_owned()
    };
    Redirect::permanent(&uri).into_response()
}

#[cfg(test)]
mod tests {
    use axum::{Router, http::StatusCode, routing::get};
    use tower::ServiceExt;

    use super::*;

    const ROUTE: &str = "/api/v1/data";
    const RESPONSE: &str = "data";

    const ROOT_RESPONSE: &str = "root";

    fn app() -> Router {
        Router::new()
            .route("/", get(async || ROOT_RESPONSE))
            .route(ROUTE, get(async || RESPONSE))
            .layer(axum::middleware::from_fn(redirect_trailing_slash))
    }

    #[tokio::test]
    async fn no_trailing_slash_passes_through() {
        let app = app();

        let request = axum::http::Request::builder()
            .uri(ROUTE)
            .body(Body::empty())
            .expect("request should build");

        let response = app
            .clone()
            .oneshot(request)
            .await
            .expect("router should return a response");

        assert_eq!(response.status(), StatusCode::OK);

        let got = axum::body::to_bytes(response.into_body(), usize::MAX)
            .await
            .map(|b| String::from_utf8(b.into()))
            .expect("response body should be readable")
            .expect("response body should be valid UTF-8");
        assert_eq!(got, RESPONSE);
    }

    #[tokio::test]
    async fn trailing_slash_redirects_permanently() {
        let app = app();

        let request = axum::http::Request::builder()
            .uri(format!("{ROUTE}/"))
            .body(Body::empty())
            .expect("request should build");

        let response = app
            .clone()
            .oneshot(request)
            .await
            .expect("router should return a response");

        assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
        assert_eq!(
            response
                .headers()
                .get("location")
                .expect("redirect response should set location header"),
            ROUTE
        );
    }

    #[tokio::test]
    async fn trailing_slash_redirect_does_not_emit_scheme_relative_location() {
        let app = app();

        let request = axum::http::Request::builder()
            .uri("//evil.example/")
            .body(Body::empty())
            .expect("request should build");

        let response = app
            .clone()
            .oneshot(request)
            .await
            .expect("router should return a response");

        let location = response
            .headers()
            .get("location")
            .and_then(|value| value.to_str().ok())
            .unwrap_or_default();

        assert!(
            !location.starts_with("//"),
            "redirect Location must not be browser-interpreted as cross-origin: {location}"
        );
    }

    #[tokio::test]
    async fn can_access_root() {
        let app = app();

        let request = axum::http::Request::builder()
            .body(Body::empty())
            .expect("request should build");

        let response = app
            .clone()
            .oneshot(request)
            .await
            .expect("router should return a response");

        assert_eq!(response.status(), StatusCode::OK);

        let got = axum::body::to_bytes(response.into_body(), usize::MAX)
            .await
            .map(|b| String::from_utf8(b.into()))
            .expect("response body should be readable")
            .expect("response body should be valid UTF-8");
        assert_eq!(got, ROOT_RESPONSE);
    }
}