Skip to main content

api_scanner/scanner/
websocket.rs

1use async_trait::async_trait;
2use base64::Engine as _;
3use rand::{seq::SliceRandom, RngCore};
4use url::Url;
5
6use crate::{
7    config::Config,
8    error::CapturedError,
9    http_client::HttpClient,
10    reports::{Finding, Severity},
11};
12
13use super::Scanner;
14
15pub struct WebSocketScanner;
16
17impl WebSocketScanner {
18    pub fn new(_config: &Config) -> Self {
19        Self
20    }
21}
22
23static WS_PATHS: &[&str] = &[
24    "/ws",
25    "/websocket",
26    "/socket",
27    "/socket.io/?EIO=4&transport=websocket",
28    "/graphql",
29];
30
31fn random_cross_origin_probe() -> &'static str {
32    const ORIGINS: &[&str] = &[
33        "https://app.example.net",
34        "https://cdn.example.net",
35        "https://portal.example.org",
36    ];
37    let mut rng = rand::thread_rng();
38    ORIGINS
39        .choose(&mut rng)
40        .copied()
41        .unwrap_or("https://app.example.net")
42}
43
44#[async_trait]
45impl Scanner for WebSocketScanner {
46    fn name(&self) -> &'static str {
47        "websocket"
48    }
49
50    async fn scan(
51        &self,
52        url: &str,
53        client: &HttpClient,
54        config: &Config,
55    ) -> (Vec<Finding>, Vec<CapturedError>) {
56        if !config.active_checks {
57            return (Vec::new(), Vec::new());
58        }
59
60        let mut findings = Vec::new();
61        let mut errors = Vec::new();
62
63        let Some((same_origin, candidates)) = websocket_candidates(url) else {
64            return (findings, errors);
65        };
66
67        for candidate in candidates {
68            let cross_origin = random_cross_origin_probe();
69            let same_origin_resp = match websocket_probe(client, &candidate, &same_origin).await {
70                Ok(resp) => Some(resp),
71                Err(e) => {
72                    errors.push(e);
73                    None
74                }
75            };
76            let cross_origin_resp = match websocket_probe(client, &candidate, cross_origin).await {
77                Ok(resp) => Some(resp),
78                Err(e) => {
79                    errors.push(e);
80                    None
81                }
82            };
83
84            if let Some(resp) = same_origin_resp.as_ref() {
85                if is_upgrade_success(resp) {
86                    findings.push(
87                        Finding::new(
88                            &candidate,
89                            "websocket/upgrade-endpoint",
90                            "WebSocket endpoint accepts upgrade",
91                            Severity::Info,
92                            "Endpoint accepted a WebSocket upgrade handshake.",
93                            "websocket",
94                        )
95                        .with_evidence(format!(
96                            "GET {candidate}\nOrigin: {same_origin}\nStatus: {}",
97                            resp.status
98                        ))
99                        .with_remediation(
100                            "Ensure this endpoint enforces authentication and strict message-level authorization.",
101                        ),
102                    );
103                }
104            }
105
106            if let Some(resp) = cross_origin_resp.as_ref() {
107                if is_upgrade_success(resp) {
108                    findings.push(
109                        Finding::new(
110                            &candidate,
111                            "websocket/origin-not-validated",
112                            "WebSocket origin validation may be missing",
113                            Severity::Medium,
114                            "Endpoint accepted WebSocket upgrades for a cross-origin request.",
115                            "websocket",
116                        )
117                        .with_evidence(format!(
118                            "GET {candidate}\nOrigin: {cross_origin}\nStatus: {}\nSec-WebSocket-Accept: {}",
119                            resp.status,
120                            resp.header("sec-websocket-accept").unwrap_or("-")
121                        ))
122                        .with_remediation(
123                            "Validate the Origin header against an allowlist and reject untrusted origins.",
124                        ),
125                    );
126                }
127            }
128        }
129
130        (findings, errors)
131    }
132}
133
134fn websocket_candidates(seed: &str) -> Option<(String, Vec<String>)> {
135    let parsed = Url::parse(seed).ok()?;
136    if parsed.scheme() != "http" && parsed.scheme() != "https" {
137        return None;
138    }
139
140    let host = parsed.host_str()?;
141    let mut origin = format!("{}://{}", parsed.scheme(), host);
142    if let Some(port) = parsed.port() {
143        origin.push(':');
144        origin.push_str(&port.to_string());
145    }
146
147    let mut base = origin.clone();
148    if base.ends_with('/') {
149        base.pop();
150    }
151
152    let mut candidates = Vec::new();
153    for path in WS_PATHS {
154        candidates.push(format!("{base}{path}"));
155    }
156
157    let seed_lower = parsed.path().to_ascii_lowercase();
158    if seed_lower.contains("ws") || seed_lower.contains("socket") {
159        candidates.push(seed.to_string());
160    }
161
162    candidates.sort();
163    candidates.dedup();
164    if candidates.len() > 1 {
165        let mut rng = rand::thread_rng();
166        candidates.shuffle(&mut rng);
167    }
168
169    Some((origin, candidates))
170}
171
172async fn websocket_probe(
173    client: &HttpClient,
174    url: &str,
175    origin: &str,
176) -> Result<crate::http_client::HttpResponse, CapturedError> {
177    let mut key_bytes = [0u8; 16];
178    rand::thread_rng().fill_bytes(&mut key_bytes);
179    let ws_key = base64::engine::general_purpose::STANDARD.encode(key_bytes);
180
181    let headers = vec![
182        ("Connection".to_string(), "Upgrade".to_string()),
183        ("Upgrade".to_string(), "websocket".to_string()),
184        ("Sec-WebSocket-Version".to_string(), "13".to_string()),
185        ("Sec-WebSocket-Key".to_string(), ws_key),
186        ("Origin".to_string(), origin.to_string()),
187    ];
188
189    client.get_with_headers(url, &headers).await
190}
191
192fn is_upgrade_success(resp: &crate::http_client::HttpResponse) -> bool {
193    if resp.status == 101 {
194        return true;
195    }
196
197    resp.header("sec-websocket-accept").is_some()
198}