apihunter 0.1.2

Async API security scanner with passive and active checks for CORS, CSP, GraphQL, JWT, OpenAPI, and API posture.
Documentation
use async_trait::async_trait;
use base64::Engine as _;
use rand::{seq::SliceRandom, RngCore};
use url::Url;

use crate::{
    config::Config,
    error::CapturedError,
    http_client::HttpClient,
    reports::{Finding, Severity},
};

use super::Scanner;

pub struct WebSocketScanner;

impl WebSocketScanner {
    pub fn new(_config: &Config) -> Self {
        Self
    }
}

static WS_PATHS: &[&str] = &[
    "/ws",
    "/websocket",
    "/socket",
    "/socket.io/?EIO=4&transport=websocket",
    "/graphql",
];

fn random_cross_origin_probe() -> &'static str {
    const ORIGINS: &[&str] = &[
        "https://app.example.net",
        "https://cdn.example.net",
        "https://portal.example.org",
    ];
    let mut rng = rand::thread_rng();
    ORIGINS
        .choose(&mut rng)
        .copied()
        .unwrap_or("https://app.example.net")
}

#[async_trait]
impl Scanner for WebSocketScanner {
    fn name(&self) -> &'static str {
        "websocket"
    }

    async fn scan(
        &self,
        url: &str,
        client: &HttpClient,
        config: &Config,
    ) -> (Vec<Finding>, Vec<CapturedError>) {
        if !config.active_checks {
            return (Vec::new(), Vec::new());
        }

        let mut findings = Vec::new();
        let mut errors = Vec::new();

        let Some((same_origin, candidates)) = websocket_candidates(url) else {
            return (findings, errors);
        };

        for candidate in candidates {
            let cross_origin = random_cross_origin_probe();
            let same_origin_resp = match websocket_probe(client, &candidate, &same_origin).await {
                Ok(resp) => Some(resp),
                Err(e) => {
                    errors.push(e);
                    None
                }
            };
            let cross_origin_resp = match websocket_probe(client, &candidate, cross_origin).await {
                Ok(resp) => Some(resp),
                Err(e) => {
                    errors.push(e);
                    None
                }
            };

            if let Some(resp) = same_origin_resp.as_ref() {
                if is_upgrade_success(resp) {
                    findings.push(
                        Finding::new(
                            &candidate,
                            "websocket/upgrade-endpoint",
                            "WebSocket endpoint accepts upgrade",
                            Severity::Info,
                            "Endpoint accepted a WebSocket upgrade handshake.",
                            "websocket",
                        )
                        .with_evidence(format!(
                            "GET {candidate}\nOrigin: {same_origin}\nStatus: {}",
                            resp.status
                        ))
                        .with_remediation(
                            "Ensure this endpoint enforces authentication and strict message-level authorization.",
                        ),
                    );
                }
            }

            if let Some(resp) = cross_origin_resp.as_ref() {
                if is_upgrade_success(resp) {
                    findings.push(
                        Finding::new(
                            &candidate,
                            "websocket/origin-not-validated",
                            "WebSocket origin validation may be missing",
                            Severity::Medium,
                            "Endpoint accepted WebSocket upgrades for a cross-origin request.",
                            "websocket",
                        )
                        .with_evidence(format!(
                            "GET {candidate}\nOrigin: {cross_origin}\nStatus: {}\nSec-WebSocket-Accept: {}",
                            resp.status,
                            resp.header("sec-websocket-accept").unwrap_or("-")
                        ))
                        .with_remediation(
                            "Validate the Origin header against an allowlist and reject untrusted origins.",
                        ),
                    );
                }
            }
        }

        (findings, errors)
    }
}

fn websocket_candidates(seed: &str) -> Option<(String, Vec<String>)> {
    let parsed = Url::parse(seed).ok()?;
    if parsed.scheme() != "http" && parsed.scheme() != "https" {
        return None;
    }

    let host = parsed.host_str()?;
    let mut origin = format!("{}://{}", parsed.scheme(), host);
    if let Some(port) = parsed.port() {
        origin.push(':');
        origin.push_str(&port.to_string());
    }

    let mut base = origin.clone();
    if base.ends_with('/') {
        base.pop();
    }

    let mut candidates = Vec::new();
    for path in WS_PATHS {
        candidates.push(format!("{base}{path}"));
    }

    let seed_lower = parsed.path().to_ascii_lowercase();
    if seed_lower.contains("ws") || seed_lower.contains("socket") {
        candidates.push(seed.to_string());
    }

    candidates.sort();
    candidates.dedup();
    if candidates.len() > 1 {
        let mut rng = rand::thread_rng();
        candidates.shuffle(&mut rng);
    }

    Some((origin, candidates))
}

async fn websocket_probe(
    client: &HttpClient,
    url: &str,
    origin: &str,
) -> Result<crate::http_client::HttpResponse, CapturedError> {
    let mut key_bytes = [0u8; 16];
    rand::thread_rng().fill_bytes(&mut key_bytes);
    let ws_key = base64::engine::general_purpose::STANDARD.encode(key_bytes);

    let headers = vec![
        ("Connection".to_string(), "Upgrade".to_string()),
        ("Upgrade".to_string(), "websocket".to_string()),
        ("Sec-WebSocket-Version".to_string(), "13".to_string()),
        ("Sec-WebSocket-Key".to_string(), ws_key),
        ("Origin".to_string(), origin.to_string()),
    ];

    client.get_with_headers(url, &headers).await
}

fn is_upgrade_success(resp: &crate::http_client::HttpResponse) -> bool {
    if resp.status == 101 {
        return true;
    }

    resp.header("sec-websocket-accept").is_some()
}