1use tokio::sync::mpsc::Sender;
2use hyper::upgrade::Upgraded;
3use hyper::{Request, Response, StatusCode};
4use hyper::header::{HeaderName, HeaderValue};
5use http_body_util::{Full, BodyExt};
6use tokio_tungstenite::WebSocketStream;
7use tokio_tungstenite::tungstenite::protocol::Message;
8use futures_util::{StreamExt, SinkExt};
9use data_encoding::BASE64;
10use relay_core_api::flow::{Flow, FlowUpdate, WebSocketMessage, Direction, BodyData, Layer, HttpResponse};
11use relay_core_api::policy::ProxyPolicy;
12use crate::intercept::types::{Interceptor, InterceptionResult, RequestAction, WebSocketMessageAction, HttpBody, BoxError};
13use crate::proxy::http_utils::{HttpsClient, parse_request_meta, create_initial_flow, mock_to_response, create_error_response};
14use crate::capture::loop_detection::LoopDetector;
15use std::sync::Arc;
16use std::convert::Infallible;
17use std::net::SocketAddr;
18use uuid::Uuid;
19use chrono::Utc;
20use url::Url;
21use hyper::body::Bytes;
22
23use tokio::sync::watch;
24use hyper_util::rt::TokioIo;
25use relay_core_api::flow::ResponseTiming;
26
27fn validate_ws_strict_handshake<B>(
28 req: &Request<B>,
29 policy: &ProxyPolicy,
30) -> Result<(), Box<Response<HttpBody>>> {
31 if !policy.strict_http_semantics {
32 return Ok(());
33 }
34
35 if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
36 return Err(Box::new(create_error_response(
37 StatusCode::BAD_REQUEST,
38 "Missing Sec-WebSocket-Key header in Strict Mode",
39 )));
40 }
41
42 if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
43 if v != "13" {
44 return Err(Box::new(create_error_response(
45 StatusCode::BAD_REQUEST,
46 "Unsupported WebSocket Version in Strict Mode (Expected 13)",
47 )));
48 }
49 } else {
50 return Err(Box::new(create_error_response(
51 StatusCode::BAD_REQUEST,
52 "Missing Sec-WebSocket-Version header in Strict Mode",
53 )));
54 }
55
56 Ok(())
57}
58
59#[allow(clippy::too_many_arguments)]
60pub async fn handle_websocket_handshake<B>(
61 req: Request<B>,
62 client_addr: SocketAddr,
63 on_flow: Sender<FlowUpdate>,
64 client: Arc<HttpsClient>,
65 interceptor: Arc<dyn Interceptor>,
66 is_mitm: bool,
67 policy_rx: watch::Receiver<ProxyPolicy>,
68 target_addr: Option<SocketAddr>,
69 loop_detector: Arc<LoopDetector>,
70) -> Result<Response<HttpBody>, Infallible>
71where
72 B: hyper::body::Body + Send + 'static,
73 B::Data: Send,
74 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
75{
76 let meta = parse_request_meta(&req, is_mitm);
78 let policy = policy_rx.borrow().clone();
79
80 if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
82 return Ok(*resp);
83 }
84
85 let mut flow = create_initial_flow(
87 meta.clone(),
88 None,
89 client_addr,
90 is_mitm,
91 true,
92 );
93
94 match interceptor.on_request_headers(&mut flow).await {
96 InterceptionResult::Drop => {
97 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
98 },
99 InterceptionResult::MockResponse(mock) => {
100 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
101 crate::metrics::inc_flows_dropped();
102 }
103 return Ok(mock_to_response(mock));
104 },
105 InterceptionResult::ModifiedRequest(req) => {
106 if let Layer::WebSocket(ws) = &mut flow.layer {
107 ws.handshake_request = req;
108 }
109 },
110 InterceptionResult::ModifiedResponse(res) => {
111 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
112 crate::metrics::inc_flows_dropped();
113 }
114 return Ok(mock_to_response(res));
115 },
116 _ => {}
117 }
118
119 let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
122
123 match interceptor.on_request(&mut flow, body).await {
124 Ok(RequestAction::Drop) => {
125 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
126 },
127 Ok(RequestAction::MockResponse(res)) => {
128 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
129 crate::metrics::inc_flows_dropped();
130 }
131 let (parts, body) = res.into_parts();
132 return Ok(Response::from_parts(parts, body));
133 },
134 Ok(RequestAction::Continue(_)) => {},
135 Err(e) => {
136 return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Interceptor Error: {}", e)));
137 }
138 }
139
140 if on_flow.try_send(FlowUpdate::Full(Box::new(flow.clone()))).is_err() {
141 crate::metrics::inc_flows_dropped();
142 }
143
144 let (parts, body) = req.into_parts();
146 let req_for_upgrade = Request::from_parts(parts, body);
147
148 let mut target_url_str = meta.url_str.clone();
150
151 if policy.transparent_enabled
152 && let Some(addr) = target_addr {
153 flow.tags.push("transparent".to_string());
154
155 flow.network.server_ip = addr.ip().to_string();
157 flow.network.server_port = addr.port();
158
159 if loop_detector.would_loop(addr) {
161 if let Layer::WebSocket(ws) = &mut flow.layer {
162 ws.handshake_response.status = 508;
163 ws.closed = true;
164 }
165 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
166 crate::metrics::inc_flows_dropped();
167 }
168 return Ok(create_error_response(StatusCode::LOOP_DETECTED, "Loop Detected"));
169 }
170
171 let mut u = if let Layer::WebSocket(ws) = &flow.layer {
173 ws.handshake_request.url.clone()
174 } else {
175 Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
176 };
177
178 if u.set_ip_host(addr.ip()).is_ok() {
179 u.set_port(Some(addr.port())).ok();
180 if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
182 u.set_scheme("wss").ok();
183 } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
184 u.set_scheme("ws").ok();
185 }
186 target_url_str = u.to_string();
187 }
188 }
189
190 let current_req = if let Layer::WebSocket(ws) = &flow.layer {
192 &ws.handshake_request
193 } else {
194 return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Invalid Flow Layer State"));
195 };
196
197 let mut forward_req_builder = Request::builder()
198 .method(current_req.method.as_str())
199 .uri(target_url_str.as_str())
200 .version(hyper::Version::HTTP_11);
201
202 for (k, v) in current_req.headers.iter() {
203 if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
204 forward_req_builder = forward_req_builder.header(name, val);
205 }
206 }
207
208 let forward_req = match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
209 Ok(req) => req,
210 Err(e) => return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to build forward request: {}", e))),
211 };
212
213 match tokio::time::timeout(std::time::Duration::from_secs(30), client.request(forward_req)).await {
214 Ok(Ok(resp)) => {
215 if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
216 let (parts, body) = resp.into_parts();
217 let resp_for_upgrade = Response::from_parts(parts.clone(), body);
218
219 let on_flow_clone = on_flow.clone();
221 let interceptor_clone = interceptor.clone();
222 let flow_clone = flow.clone(); tokio::task::spawn(async move {
225 let upgrade_timeout = std::time::Duration::from_secs(10);
227
228 let client_upgrade = tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
229 let server_upgrade = tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
230
231 match tokio::try_join!(client_upgrade, server_upgrade) {
232 Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
233 if let Err(e) = handle_websocket_tunnel(
235 upgraded_client,
236 upgraded_server,
237 flow_clone,
238 on_flow_clone,
239 interceptor_clone
240 ).await {
241 tracing::error!("WebSocket Tunnel Error: {}", e);
242 }
243 },
244 Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
245 Ok((_, Err(e))) => tracing::error!("Upstream WebSocket Upgrade Error: {}", e),
246 Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
247 }
248 });
249
250 let mut client_resp_builder = Response::builder()
251 .status(StatusCode::SWITCHING_PROTOCOLS)
252 .version(parts.version);
253
254 for (k, v) in parts.headers.iter() {
255 client_resp_builder = client_resp_builder.header(k, v);
256 }
257
258 let client_resp = match client_resp_builder
259 .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
260 {
261 Ok(r) => r,
262 Err(e) => {
263 tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
264 return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Response build failed"));
265 }
266 };
267
268 Ok(client_resp)
275 } else {
276 let (parts, body) = resp.into_parts();
278 let body_bytes = match body.collect().await {
279 Ok(c) => c.to_bytes(),
280 Err(_) => Bytes::new(),
281 };
282
283 let http_resp = HttpResponse {
285 status: parts.status.as_u16(),
286 status_text: parts.status.canonical_reason().unwrap_or("Unknown").to_string(),
287 version: format!("{:?}", parts.version),
288 headers: parts.headers.iter()
289 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
290 .collect(),
291 cookies: vec![], body: Some(BodyData {
293 encoding: "utf-8".to_string(),
294 content: String::from_utf8_lossy(&body_bytes).to_string(),
295 size: body_bytes.len() as u64,
296 }),
297 timing: ResponseTiming {
298 time_to_first_byte: None,
299 time_to_last_byte: None,
300 connect_time_ms: None,
301 ssl_time_ms: None,
302 },
303 };
304
305 if let Layer::WebSocket(ws) = &mut flow.layer {
306 ws.handshake_response = http_resp.clone();
307 ws.closed = true;
308 }
309
310 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
311 crate::metrics::inc_flows_dropped();
312 }
313
314 Ok(Response::from_parts(parts, Full::new(body_bytes).map_err(|e| e.into()).boxed()))
315 }
316 },
317 Ok(Err(e)) => Ok(create_error_response(StatusCode::BAD_GATEWAY, format!("Upstream Handshake Failed: {}", e))),
318 Err(_) => Ok(create_error_response(StatusCode::GATEWAY_TIMEOUT, "Upstream Handshake Timed Out")),
319 }
320}
321
322async fn handle_websocket_tunnel(
323 client_io: Upgraded,
324 server_io: Upgraded,
325 mut flow: Flow,
326 on_flow: Sender<FlowUpdate>,
327 interceptor: Arc<dyn Interceptor>,
328) -> Result<(), BoxError> {
329 let client_ws = WebSocketStream::from_raw_socket(TokioIo::new(client_io), tokio_tungstenite::tungstenite::protocol::Role::Server, None).await;
330 let server_ws = WebSocketStream::from_raw_socket(TokioIo::new(server_io), tokio_tungstenite::tungstenite::protocol::Role::Client, None).await;
331
332 let (mut client_tx, mut client_rx) = client_ws.split();
333 let (mut server_tx, mut server_rx) = server_ws.split();
334
335 let idle_timeout_duration = std::time::Duration::from_secs(300); loop {
339 let event = tokio::time::timeout(idle_timeout_duration, async {
340 tokio::select! {
341 msg = client_rx.next() => (Direction::ClientToServer, msg),
342 msg = server_rx.next() => (Direction::ServerToClient, msg),
343 }
344 }).await;
345
346 match event {
347 Ok((dir, msg_opt)) => {
348 match msg_opt {
349 Some(Ok(msg)) => {
350 let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer {
352 (&mut server_tx, &mut client_tx, Direction::ClientToServer)
353 } else {
354 (&mut client_tx, &mut server_tx, Direction::ServerToClient)
355 };
356
357 if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
358 match interceptor.on_websocket_message(&mut flow, ws_msg.clone()).await {
359 Ok(WebSocketMessageAction::Drop) => continue,
360 Ok(WebSocketMessageAction::Continue(mod_msg)) => {
361 let t_msg = flow_msg_to_tungstenite(&mod_msg);
362 sender.send(t_msg).await?;
363
364 if on_flow.try_send(FlowUpdate::WebSocketMessage {
365 flow_id: flow.id.to_string(),
366 message: mod_msg,
367 }).is_err() {
368 crate::metrics::inc_flows_dropped();
369 }
370 },
371 Err(e) => {
372 tracing::error!("WebSocket Interception Error: {}", e);
373 sender.send(msg).await?;
374
375 if on_flow.try_send(FlowUpdate::WebSocketMessage {
376 flow_id: flow.id.to_string(),
377 message: ws_msg,
378 }).is_err() {
379 crate::metrics::inc_flows_dropped();
380 }
381 }
382 }
383 } else {
384 sender.send(msg).await?;
386 }
387 },
388 Some(Err(e)) => return Err(e.into()),
389 None => break, }
391 },
392 Err(_) => {
393 tracing::warn!("WebSocket Tunnel Idle Timeout");
394 return Err("WebSocket Idle Timeout".into());
396 }
397 }
398 }
399
400 Ok(())
401}
402
403fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
404 let (opcode, content, encoding, size) = match msg {
405 Message::Text(t) => {
406 let len = t.len();
407 ("Text", t.to_string(), "utf-8", len)
408 },
409 Message::Binary(b) => {
410 let len = b.len();
411 ("Binary", BASE64.encode(&b), "base64", len)
412 },
413 Message::Ping(b) => {
414 let len = b.len();
415 ("Ping", BASE64.encode(&b), "base64", len)
416 },
417 Message::Pong(b) => {
418 let len = b.len();
419 ("Pong", BASE64.encode(&b), "base64", len)
420 },
421 Message::Close(_) => ("Close", String::new(), "none", 0),
422 Message::Frame(_) => return None,
423 };
424
425 Some(WebSocketMessage {
426 id: Uuid::new_v4(),
427 timestamp: Utc::now(),
428 direction: dir,
429 content: BodyData {
430 encoding: encoding.to_string(),
431 content,
432 size: size as u64,
433 },
434 opcode: opcode.to_string(),
435 })
436}
437
438fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
439 match msg.opcode.as_str() {
440 "Text" => Message::Text(msg.content.content.clone().into()),
441 "Binary" => {
442 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
443 Message::Binary(Bytes::from(b))
444 } else {
445 Message::Binary(Bytes::new())
446 }
447 },
448 "Ping" => {
449 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
450 Message::Ping(Bytes::from(b))
451 } else {
452 Message::Ping(Bytes::new())
453 }
454 },
455 "Pong" => {
456 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
457 Message::Pong(Bytes::from(b))
458 } else {
459 Message::Pong(Bytes::new())
460 }
461 },
462 "Close" => Message::Close(None),
463 _ => Message::Text(msg.content.content.clone().into()),
464 }
465}
466
467#[cfg(test)]
468mod websocket_tests {
469 use super::*;
470 use http_body_util::Empty;
471
472 #[test]
473 fn test_validate_ws_strict_handshake_rejects_missing_key() {
474 let policy = ProxyPolicy { strict_http_semantics: true, ..Default::default() };
475 let req = Request::builder()
476 .method("GET")
477 .uri("ws://example.com/socket")
478 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
479 .body(Empty::<Bytes>::new())
480 .expect("request build");
481 let result = validate_ws_strict_handshake(&req, &policy);
482 assert!(result.is_err());
483 }
484
485 #[test]
486 fn test_validate_ws_strict_handshake_rejects_invalid_version() {
487 let policy = ProxyPolicy { strict_http_semantics: true, ..Default::default() };
488 let req = Request::builder()
489 .method("GET")
490 .uri("ws://example.com/socket")
491 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
492 .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
493 .body(Empty::<Bytes>::new())
494 .expect("request build");
495 let result = validate_ws_strict_handshake(&req, &policy);
496 assert!(result.is_err());
497 }
498
499 #[test]
500 fn test_validate_ws_strict_handshake_accepts_valid_request() {
501 let policy = ProxyPolicy { strict_http_semantics: true, ..Default::default() };
502 let req = Request::builder()
503 .method("GET")
504 .uri("ws://example.com/socket")
505 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
506 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
507 .body(Empty::<Bytes>::new())
508 .expect("request build");
509 let result = validate_ws_strict_handshake(&req, &policy);
510 assert!(result.is_ok());
511 }
512
513 #[test]
514 fn test_validate_ws_strict_handshake_skips_when_disabled() {
515 let policy = ProxyPolicy { strict_http_semantics: false, ..Default::default() };
516 let req = Request::builder()
517 .method("GET")
518 .uri("ws://example.com/socket")
519 .body(Empty::<Bytes>::new())
520 .expect("request build");
521 let result = validate_ws_strict_handshake(&req, &policy);
522 assert!(result.is_ok());
523 }
524}