rjango 0.1.1

A full-stack Rust backend framework inspired by Django
Documentation
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use axum::{
    body::Body,
    http::{
        Method, Request, StatusCode,
        header::{COOKIE, HeaderMap, HeaderName},
    },
    response::Response,
};
use tower::{Layer, Service};

type BoxResponseFuture<E> = Pin<Box<dyn Future<Output = Result<Response, E>> + Send>>;

#[derive(Clone)]
pub struct CsrfMiddlewareLayer {
    pub cookie_name: String,
    pub header_name: String,
}

impl<S> Layer<S> for CsrfMiddlewareLayer {
    type Service = CsrfMiddleware<S>;

    fn layer(&self, inner: S) -> Self::Service {
        CsrfMiddleware {
            inner,
            cookie_name: self.cookie_name.clone(),
            header_name: self.header_name.clone(),
        }
    }
}

#[derive(Clone)]
pub struct CsrfMiddleware<S> {
    inner: S,
    cookie_name: String,
    header_name: String,
}

impl<S> Service<Request<Body>> for CsrfMiddleware<S>
where
    S: Service<Request<Body>, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = Response;
    type Error = S::Error;
    type Future = BoxResponseFuture<Self::Error>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<Body>) -> Self::Future {
        if is_safe_method(request.method())
            || has_matching_token(request.headers(), &self.cookie_name, &self.header_name)
        {
            return Box::pin(self.inner.call(request));
        }

        Box::pin(async {
            Ok(Response::builder()
                .status(StatusCode::FORBIDDEN)
                .body(Body::from("CSRF verification failed"))
                .expect("forbidden response should build"))
        })
    }
}

fn is_safe_method(method: &Method) -> bool {
    matches!(method.as_str(), "GET" | "HEAD" | "OPTIONS" | "TRACE")
}

fn has_matching_token(headers: &HeaderMap, cookie_name: &str, header_name: &str) -> bool {
    let cookie_value = headers
        .get(COOKIE)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| cookie_value(value, cookie_name));
    let header_value = HeaderName::from_bytes(header_name.as_bytes())
        .ok()
        .and_then(|name| headers.get(name))
        .and_then(|value| value.to_str().ok());

    matches!((cookie_value, header_value), (Some(cookie), Some(header)) if !cookie.is_empty() && cookie == header)
}

fn cookie_value<'a>(cookie_header: &'a str, cookie_name: &str) -> Option<&'a str> {
    cookie_header.split(';').find_map(|part| {
        let (name, value) = part.trim().split_once('=')?;
        (name == cookie_name).then_some(value)
    })
}

#[cfg(test)]
mod tests {
    use std::convert::Infallible;

    use super::*;
    use tower::{ServiceExt, service_fn};

    #[tokio::test]
    async fn csrf_layer_allows_safe_methods_without_tokens() {
        let layer = CsrfMiddlewareLayer {
            cookie_name: "csrftoken".to_string(),
            header_name: "x-csrftoken".to_string(),
        };
        let service = layer.layer(service_fn(|_request: Request<Body>| async move {
            Ok::<_, Infallible>(Response::new(Body::from("ok")))
        }));

        let response = service
            .oneshot(
                Request::builder()
                    .method(Method::GET)
                    .uri("/")
                    .body(Body::empty())
                    .expect("request should build"),
            )
            .await
            .expect("service should respond");

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[tokio::test]
    async fn csrf_layer_rejects_unsafe_methods_with_mismatched_tokens() {
        let layer = CsrfMiddlewareLayer {
            cookie_name: "csrftoken".to_string(),
            header_name: "x-csrftoken".to_string(),
        };
        let service = layer.layer(service_fn(|_request: Request<Body>| async move {
            Ok::<_, Infallible>(Response::new(Body::from("ok")))
        }));

        let response = service
            .oneshot(
                Request::builder()
                    .method(Method::POST)
                    .uri("/")
                    .header(COOKIE, "csrftoken=expected")
                    .header("x-csrftoken", "actual")
                    .body(Body::empty())
                    .expect("request should build"),
            )
            .await
            .expect("service should respond");

        assert_eq!(response.status(), StatusCode::FORBIDDEN);
    }

    #[tokio::test]
    async fn csrf_layer_allows_unsafe_methods_with_matching_tokens() {
        let layer = CsrfMiddlewareLayer {
            cookie_name: "csrftoken".to_string(),
            header_name: "x-csrftoken".to_string(),
        };
        let service = layer.layer(service_fn(|_request: Request<Body>| async move {
            Ok::<_, Infallible>(Response::new(Body::from("ok")))
        }));

        let response = service
            .oneshot(
                Request::builder()
                    .method(Method::POST)
                    .uri("/")
                    .header(COOKIE, "csrftoken=match")
                    .header("x-csrftoken", "match")
                    .body(Body::empty())
                    .expect("request should build"),
            )
            .await
            .expect("service should respond");

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[test]
    fn cookie_value_extracts_named_cookie_from_multi_cookie_header() {
        assert_eq!(
            cookie_value("sessionid=abc; csrftoken=expected; theme=dark", "csrftoken"),
            Some("expected")
        );
    }

    #[test]
    fn has_matching_token_requires_non_empty_cookie_and_header_values() {
        let headers = HeaderMap::from_iter([
            (COOKIE, http::HeaderValue::from_static("csrftoken=expected")),
            (
                HeaderName::from_static("x-csrftoken"),
                http::HeaderValue::from_static(""),
            ),
        ]);

        assert!(!has_matching_token(&headers, "csrftoken", "x-csrftoken"));
    }

    #[test]
    fn csrf_layer_is_cloneable() {
        let layer = CsrfMiddlewareLayer {
            cookie_name: "csrftoken".to_string(),
            header_name: "x-csrftoken".to_string(),
        };

        let _ = layer.clone();
    }
}