nyx-scanner 0.6.1

A multi-language static analysis tool for detecting security vulnerabilities
Documentation
use axum::extract::{Request, State};
use axum::http::header::{HOST, ORIGIN};
use axum::http::{Method, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use std::sync::Arc;
use uuid::Uuid;

const CSRF_HEADER: &str = "x-nyx-csrf";

#[derive(Debug)]
pub struct LocalServerSecurity {
    port: u16,
    csrf_token: String,
}

impl LocalServerSecurity {
    pub fn new(port: u16) -> Arc<Self> {
        Arc::new(Self {
            port,
            csrf_token: Uuid::new_v4().as_simple().to_string(),
        })
    }

    pub fn csrf_token(&self) -> &str {
        &self.csrf_token
    }

    fn host_allowed(&self, authority: &str) -> bool {
        let Some((host, port)) = parse_host_like(authority) else {
            return false;
        };

        matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1")
            && port.is_none_or(|value| value == self.port)
    }

    fn origin_allowed(&self, origin: &str) -> bool {
        let Some(rest) = origin.strip_prefix("http://") else {
            return false;
        };

        let authority = rest.split('/').next().unwrap_or(rest);
        self.host_allowed(authority)
    }
}

pub async fn guard_requests(
    State(security): State<Arc<LocalServerSecurity>>,
    request: Request,
    next: Next,
) -> Response {
    let Some(host) = request
        .headers()
        .get(HOST)
        .and_then(|value| value.to_str().ok())
    else {
        return (StatusCode::BAD_REQUEST, "missing Host header").into_response();
    };
    if !security.host_allowed(host) {
        return (StatusCode::BAD_REQUEST, "invalid Host header").into_response();
    }

    if is_mutating_method(request.method()) {
        if let Some(origin) = request
            .headers()
            .get(ORIGIN)
            .and_then(|value| value.to_str().ok())
            && !security.origin_allowed(origin)
        {
            return (StatusCode::FORBIDDEN, "cross-origin mutation blocked").into_response();
        }

        let Some(token) = request
            .headers()
            .get(CSRF_HEADER)
            .and_then(|value| value.to_str().ok())
        else {
            return (StatusCode::FORBIDDEN, "missing CSRF token").into_response();
        };

        if token != security.csrf_token() {
            return (StatusCode::FORBIDDEN, "invalid CSRF token").into_response();
        }
    }

    next.run(request).await
}

fn is_mutating_method(method: &Method) -> bool {
    matches!(
        *method,
        Method::POST | Method::PUT | Method::PATCH | Method::DELETE
    )
}

fn parse_host_like(value: &str) -> Option<(String, Option<u16>)> {
    if value.is_empty() {
        return None;
    }

    if let Some(rest) = value.strip_prefix('[') {
        let end = rest.find(']')?;
        let host = &rest[..end];
        let port = rest[end + 1..]
            .strip_prefix(':')
            .and_then(|p| p.parse().ok());
        return Some((host.to_ascii_lowercase(), port));
    }

    if value.matches(':').count() == 1 {
        let (host, port) = value.rsplit_once(':')?;
        return Some((host.to_ascii_lowercase(), port.parse().ok()));
    }

    Some((value.to_ascii_lowercase(), None))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn host_check_accepts_loopback_names() {
        let security = LocalServerSecurity::new(9700);
        assert!(security.host_allowed("localhost:9700"));
        assert!(security.host_allowed("127.0.0.1:9700"));
        assert!(security.host_allowed("[::1]:9700"));
    }

    #[test]
    fn host_check_rejects_non_loopback_names() {
        let security = LocalServerSecurity::new(9700);
        assert!(!security.host_allowed("evil.example:9700"));
        assert!(!security.host_allowed("192.168.1.10:9700"));
    }

    #[test]
    fn origin_check_requires_loopback_http_origin() {
        let security = LocalServerSecurity::new(9700);
        assert!(security.origin_allowed("http://localhost:9700"));
        assert!(!security.origin_allowed("https://localhost:9700"));
        assert!(!security.origin_allowed("http://evil.example:9700"));
    }
}