rjango 0.1.1

A full-stack Rust backend framework inspired by Django
Documentation
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use axum::{
    body::Body,
    http::{
        Request, StatusCode,
        header::{HOST, HeaderName, HeaderValue, LOCATION, STRICT_TRANSPORT_SECURITY},
    },
    response::Response,
};
use tower::{Layer, Service};

type BoxResponseFuture<E> = Pin<Box<dyn Future<Output = Result<Response, E>> + Send>>;

fn request_is_secure(request: &Request<Body>) -> bool {
    request
        .headers()
        .get(HeaderName::from_static("x-forwarded-proto"))
        .and_then(|value| value.to_str().ok())
        .is_some_and(|value| value.eq_ignore_ascii_case("https"))
}

fn redirect_target(request: &Request<Body>) -> Option<String> {
    let host = request.headers().get(HOST)?.to_str().ok()?;
    let path_and_query = request
        .uri()
        .path_and_query()
        .map_or("/", |value| value.as_str());
    Some(format!("https://{host}{path_and_query}"))
}

#[derive(Clone)]
pub struct SecurityMiddlewareLayer {
    pub hsts_seconds: u64,
    pub content_type_nosniff: bool,
    pub ssl_redirect: bool,
}

impl<S> Layer<S> for SecurityMiddlewareLayer {
    type Service = SecurityMiddleware<S>;

    fn layer(&self, inner: S) -> Self::Service {
        SecurityMiddleware {
            inner,
            hsts_seconds: self.hsts_seconds,
            content_type_nosniff: self.content_type_nosniff,
            ssl_redirect: self.ssl_redirect,
        }
    }
}

#[derive(Clone)]
pub struct SecurityMiddleware<S> {
    inner: S,
    hsts_seconds: u64,
    content_type_nosniff: bool,
    ssl_redirect: bool,
}

impl<S> Service<Request<Body>> for SecurityMiddleware<S>
where
    S: Service<Request<Body>, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = Response;
    type Error = S::Error;
    type Future = BoxResponseFuture<Self::Error>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<Body>) -> Self::Future {
        let hsts_seconds = self.hsts_seconds;
        let content_type_nosniff = self.content_type_nosniff;
        let ssl_redirect = self.ssl_redirect;

        if ssl_redirect && !request_is_secure(&request) {
            let redirect = redirect_target(&request).unwrap_or_else(|| "https:///".to_string());
            return Box::pin(async move {
                Ok(Response::builder()
                    .status(StatusCode::PERMANENT_REDIRECT)
                    .header(LOCATION, redirect)
                    .body(Body::empty())
                    .expect("redirect response should build"))
            });
        }

        let future = self.inner.call(request);

        Box::pin(async move {
            let mut response = future.await?;

            if hsts_seconds > 0 {
                let value = HeaderValue::from_str(&format!("max-age={hsts_seconds}"))
                    .expect("HSTS header value should be valid");
                response
                    .headers_mut()
                    .insert(STRICT_TRANSPORT_SECURITY, value);
            }

            if content_type_nosniff {
                response.headers_mut().insert(
                    HeaderName::from_static("x-content-type-options"),
                    HeaderValue::from_static("nosniff"),
                );
            }

            Ok(response)
        })
    }
}

#[cfg(test)]
mod tests {
    use std::convert::Infallible;

    use super::*;
    use axum::http::{Request, StatusCode, header};
    use tower::{ServiceExt, service_fn};

    #[tokio::test]
    async fn security_layer_adds_configured_headers() {
        let layer = SecurityMiddlewareLayer {
            hsts_seconds: 31_536_000,
            content_type_nosniff: true,
            ssl_redirect: false,
        };
        let service = layer.layer(service_fn(|_request: Request<Body>| async move {
            Ok::<_, Infallible>(Response::new(Body::from("ok")))
        }));

        let response = service
            .oneshot(
                Request::builder()
                    .uri("/")
                    .body(Body::empty())
                    .expect("request should build"),
            )
            .await
            .expect("service should respond");

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(
            response
                .headers()
                .get(STRICT_TRANSPORT_SECURITY)
                .expect("HSTS header should be present"),
            "max-age=31536000"
        );
        assert_eq!(
            response
                .headers()
                .get("x-content-type-options")
                .expect("nosniff header should be present"),
            "nosniff"
        );
    }

    #[tokio::test]
    async fn security_layer_redirects_insecure_requests_when_enabled() {
        let layer = SecurityMiddlewareLayer {
            hsts_seconds: 0,
            content_type_nosniff: false,
            ssl_redirect: true,
        };
        let service = layer.layer(service_fn(|_request: Request<Body>| async move {
            Ok::<_, Infallible>(Response::new(Body::from("ok")))
        }));

        let response = service
            .oneshot(
                Request::builder()
                    .uri("/dashboard?tab=security")
                    .header(header::HOST, "example.com")
                    .body(Body::empty())
                    .expect("request should build"),
            )
            .await
            .expect("service should respond");

        assert_eq!(response.status(), StatusCode::PERMANENT_REDIRECT);
        assert_eq!(
            response
                .headers()
                .get(header::LOCATION)
                .expect("redirect location should be present"),
            "https://example.com/dashboard?tab=security"
        );
    }

    #[test]
    fn security_layer_is_cloneable() {
        let layer = SecurityMiddlewareLayer {
            hsts_seconds: 0,
            content_type_nosniff: false,
            ssl_redirect: true,
        };

        let _ = layer.clone();
    }
}