api_scanner/scanner/
websocket.rs1use async_trait::async_trait;
2use base64::Engine as _;
3use rand::{seq::SliceRandom, RngCore};
4use url::Url;
5
6use crate::{
7 config::Config,
8 error::CapturedError,
9 http_client::HttpClient,
10 reports::{Finding, Severity},
11};
12
13use super::Scanner;
14
15pub struct WebSocketScanner;
16
17impl WebSocketScanner {
18 pub fn new(_config: &Config) -> Self {
19 Self
20 }
21}
22
23static WS_PATHS: &[&str] = &[
24 "/ws",
25 "/websocket",
26 "/socket",
27 "/socket.io/?EIO=4&transport=websocket",
28 "/graphql",
29];
30
31fn random_cross_origin_probe() -> &'static str {
32 const ORIGINS: &[&str] = &[
33 "https://app.example.net",
34 "https://cdn.example.net",
35 "https://portal.example.org",
36 ];
37 let mut rng = rand::thread_rng();
38 ORIGINS
39 .choose(&mut rng)
40 .copied()
41 .unwrap_or("https://app.example.net")
42}
43
44#[async_trait]
45impl Scanner for WebSocketScanner {
46 fn name(&self) -> &'static str {
47 "websocket"
48 }
49
50 async fn scan(
51 &self,
52 url: &str,
53 client: &HttpClient,
54 config: &Config,
55 ) -> (Vec<Finding>, Vec<CapturedError>) {
56 if !config.active_checks {
57 return (Vec::new(), Vec::new());
58 }
59
60 let mut findings = Vec::new();
61 let mut errors = Vec::new();
62
63 let Some((same_origin, candidates)) = websocket_candidates(url) else {
64 return (findings, errors);
65 };
66
67 for candidate in candidates {
68 let cross_origin = random_cross_origin_probe();
69 let same_origin_resp = match websocket_probe(client, &candidate, &same_origin).await {
70 Ok(resp) => Some(resp),
71 Err(e) => {
72 errors.push(e);
73 None
74 }
75 };
76 let cross_origin_resp = match websocket_probe(client, &candidate, cross_origin).await {
77 Ok(resp) => Some(resp),
78 Err(e) => {
79 errors.push(e);
80 None
81 }
82 };
83
84 if let Some(resp) = same_origin_resp.as_ref() {
85 if is_upgrade_success(resp) {
86 findings.push(
87 Finding::new(
88 &candidate,
89 "websocket/upgrade-endpoint",
90 "WebSocket endpoint accepts upgrade",
91 Severity::Info,
92 "Endpoint accepted a WebSocket upgrade handshake.",
93 "websocket",
94 )
95 .with_evidence(format!(
96 "GET {candidate}\nOrigin: {same_origin}\nStatus: {}",
97 resp.status
98 ))
99 .with_remediation(
100 "Ensure this endpoint enforces authentication and strict message-level authorization.",
101 ),
102 );
103 }
104 }
105
106 if let Some(resp) = cross_origin_resp.as_ref() {
107 if is_upgrade_success(resp) {
108 findings.push(
109 Finding::new(
110 &candidate,
111 "websocket/origin-not-validated",
112 "WebSocket origin validation may be missing",
113 Severity::Medium,
114 "Endpoint accepted WebSocket upgrades for a cross-origin request.",
115 "websocket",
116 )
117 .with_evidence(format!(
118 "GET {candidate}\nOrigin: {cross_origin}\nStatus: {}\nSec-WebSocket-Accept: {}",
119 resp.status,
120 resp.header("sec-websocket-accept").unwrap_or("-")
121 ))
122 .with_remediation(
123 "Validate the Origin header against an allowlist and reject untrusted origins.",
124 ),
125 );
126 }
127 }
128 }
129
130 (findings, errors)
131 }
132}
133
134fn websocket_candidates(seed: &str) -> Option<(String, Vec<String>)> {
135 let parsed = Url::parse(seed).ok()?;
136 if parsed.scheme() != "http" && parsed.scheme() != "https" {
137 return None;
138 }
139
140 let host = parsed.host_str()?;
141 let mut origin = format!("{}://{}", parsed.scheme(), host);
142 if let Some(port) = parsed.port() {
143 origin.push(':');
144 origin.push_str(&port.to_string());
145 }
146
147 let mut base = origin.clone();
148 if base.ends_with('/') {
149 base.pop();
150 }
151
152 let mut candidates = Vec::new();
153 for path in WS_PATHS {
154 candidates.push(format!("{base}{path}"));
155 }
156
157 let seed_lower = parsed.path().to_ascii_lowercase();
158 if seed_lower.contains("ws") || seed_lower.contains("socket") {
159 candidates.push(seed.to_string());
160 }
161
162 candidates.sort();
163 candidates.dedup();
164 if candidates.len() > 1 {
165 let mut rng = rand::thread_rng();
166 candidates.shuffle(&mut rng);
167 }
168
169 Some((origin, candidates))
170}
171
172async fn websocket_probe(
173 client: &HttpClient,
174 url: &str,
175 origin: &str,
176) -> Result<crate::http_client::HttpResponse, CapturedError> {
177 let mut key_bytes = [0u8; 16];
178 rand::thread_rng().fill_bytes(&mut key_bytes);
179 let ws_key = base64::engine::general_purpose::STANDARD.encode(key_bytes);
180
181 let headers = vec![
182 ("Connection".to_string(), "Upgrade".to_string()),
183 ("Upgrade".to_string(), "websocket".to_string()),
184 ("Sec-WebSocket-Version".to_string(), "13".to_string()),
185 ("Sec-WebSocket-Key".to_string(), ws_key),
186 ("Origin".to_string(), origin.to_string()),
187 ];
188
189 client.get_with_headers(url, &headers).await
190}
191
192fn is_upgrade_success(resp: &crate::http_client::HttpResponse) -> bool {
193 if resp.status == 101 {
194 return true;
195 }
196
197 resp.header("sec-websocket-accept").is_some()
198}