1use crate::capture::loop_detection::LoopDetector;
2use crate::intercept::types::{
3 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, WebSocketMessageAction,
4};
5use crate::proxy::http_utils::{
6 create_error_response, create_initial_flow, mock_to_response, parse_request_meta,
7};
8use crate::proxy::outbound::OutboundConnector;
9use chrono::Utc;
10use data_encoding::BASE64;
11use futures_util::{SinkExt, StreamExt};
12use http_body_util::{BodyExt, Full};
13use hyper::body::Bytes;
14use hyper::header::{HeaderName, HeaderValue};
15use hyper::upgrade::Upgraded;
16use hyper::{Request, Response, StatusCode};
17use relay_core_api::flow::{
18 BodyData, Direction, Flow, FlowUpdate, HttpResponse, Layer, WebSocketMessage,
19};
20use relay_core_api::policy::ProxyPolicy;
21use std::convert::Infallible;
22use std::net::SocketAddr;
23use std::sync::Arc;
24use tokio::sync::mpsc::Sender;
25use tokio_tungstenite::WebSocketStream;
26use tokio_tungstenite::tungstenite::protocol::Message;
27use url::Url;
28use uuid::Uuid;
29
30use hyper_util::rt::TokioIo;
31use relay_core_api::flow::ResponseTiming;
32use tokio::sync::watch;
33
34fn validate_ws_strict_handshake<B>(
35 req: &Request<B>,
36 policy: &ProxyPolicy,
37) -> Result<(), Box<Response<HttpBody>>> {
38 if !policy.strict_http_semantics {
39 return Ok(());
40 }
41
42 if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
43 return Err(Box::new(create_error_response(
44 StatusCode::BAD_REQUEST,
45 "Missing Sec-WebSocket-Key header in Strict Mode",
46 )));
47 }
48
49 if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
50 if v != "13" {
51 return Err(Box::new(create_error_response(
52 StatusCode::BAD_REQUEST,
53 "Unsupported WebSocket Version in Strict Mode (Expected 13)",
54 )));
55 }
56 } else {
57 return Err(Box::new(create_error_response(
58 StatusCode::BAD_REQUEST,
59 "Missing Sec-WebSocket-Version header in Strict Mode",
60 )));
61 }
62
63 Ok(())
64}
65
66#[allow(clippy::too_many_arguments)]
67pub async fn handle_websocket_handshake<B>(
68 req: Request<B>,
69 client_addr: SocketAddr,
70 on_flow: Sender<FlowUpdate>,
71 connector: Arc<dyn OutboundConnector>,
72 interceptor: Arc<dyn Interceptor>,
73 is_mitm: bool,
74 policy_rx: watch::Receiver<ProxyPolicy>,
75 target_addr: Option<SocketAddr>,
76 loop_detector: Arc<LoopDetector>,
77) -> Result<Response<HttpBody>, Infallible>
78where
79 B: hyper::body::Body + Send + 'static,
80 B::Data: Send,
81 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
82{
83 let meta = parse_request_meta(&req, is_mitm);
85 let policy = policy_rx.borrow().clone();
86
87 if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
89 return Ok(*resp);
90 }
91
92 let mut flow = create_initial_flow(meta.clone(), None, client_addr, is_mitm, true);
94
95 match interceptor.on_request_headers(&mut flow).await {
97 InterceptionResult::Drop => {
98 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
99 }
100 InterceptionResult::MockResponse(mock) => {
101 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
102 crate::metrics::inc_flows_dropped();
103 }
104 return Ok(mock_to_response(mock));
105 }
106 InterceptionResult::ModifiedRequest(req) => {
107 if let Layer::WebSocket(ws) = &mut flow.layer {
108 ws.handshake_request = req;
109 }
110 }
111 InterceptionResult::ModifiedResponse(res) => {
112 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
113 crate::metrics::inc_flows_dropped();
114 }
115 return Ok(mock_to_response(res));
116 }
117 _ => {}
118 }
119
120 let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
123
124 match interceptor.on_request(&mut flow, body).await {
125 Ok(RequestAction::Drop) => {
126 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
127 }
128 Ok(RequestAction::MockResponse(res)) => {
129 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
130 crate::metrics::inc_flows_dropped();
131 }
132 let (parts, body) = res.into_parts();
133 return Ok(Response::from_parts(parts, body));
134 }
135 Ok(RequestAction::Continue(_)) => {}
136 Err(e) => {
137 return Ok(create_error_response(
138 StatusCode::INTERNAL_SERVER_ERROR,
139 format!("Interceptor Error: {}", e),
140 ));
141 }
142 }
143
144 if on_flow
145 .try_send(FlowUpdate::Full(Box::new(flow.clone())))
146 .is_err()
147 {
148 crate::metrics::inc_flows_dropped();
149 }
150
151 let (parts, body) = req.into_parts();
153 let req_for_upgrade = Request::from_parts(parts, body);
154
155 let mut target_url_str = meta.url_str.clone();
157
158 if policy.transparent_enabled
159 && let Some(addr) = target_addr
160 {
161 flow.tags.push("transparent".to_string());
162
163 flow.network.server_ip = addr.ip().to_string();
165 flow.network.server_port = addr.port();
166
167 if loop_detector.would_loop(addr) {
169 if let Layer::WebSocket(ws) = &mut flow.layer {
170 ws.handshake_response.status = 508;
171 ws.closed = true;
172 }
173 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
174 crate::metrics::inc_flows_dropped();
175 }
176 return Ok(create_error_response(
177 StatusCode::LOOP_DETECTED,
178 "Loop Detected",
179 ));
180 }
181
182 let mut u = if let Layer::WebSocket(ws) = &flow.layer {
184 ws.handshake_request.url.clone()
185 } else {
186 Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
187 };
188
189 if u.set_ip_host(addr.ip()).is_ok() {
190 u.set_port(Some(addr.port())).ok();
191 if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
193 u.set_scheme("wss").ok();
194 } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
195 u.set_scheme("ws").ok();
196 }
197 target_url_str = u.to_string();
198 }
199 }
200
201 let current_req = if let Layer::WebSocket(ws) = &flow.layer {
203 &ws.handshake_request
204 } else {
205 return Ok(create_error_response(
206 StatusCode::INTERNAL_SERVER_ERROR,
207 "Invalid Flow Layer State",
208 ));
209 };
210
211 let mut forward_req_builder = Request::builder()
212 .method(current_req.method.as_str())
213 .uri(target_url_str.as_str())
214 .version(hyper::Version::HTTP_11);
215
216 for (k, v) in current_req.headers.iter() {
217 if let (Ok(name), Ok(val)) = (
218 HeaderName::from_bytes(k.as_bytes()),
219 HeaderValue::from_str(v),
220 ) {
221 forward_req_builder = forward_req_builder.header(name, val);
222 }
223 }
224
225 let forward_req =
226 match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
227 Ok(req) => req,
228 Err(e) => {
229 return Ok(create_error_response(
230 StatusCode::INTERNAL_SERVER_ERROR,
231 format!("Failed to build forward request: {}", e),
232 ));
233 }
234 };
235
236 let target_host = forward_req.uri().host().unwrap_or("unknown").to_string();
237 let port = forward_req.uri().port_u16().unwrap_or(
238 if forward_req.uri().scheme_str() == Some("https") {
239 443
240 } else {
241 80
242 },
243 );
244
245 match tokio::time::timeout(
246 std::time::Duration::from_secs(30),
247 connector.send_request(forward_req, &target_host, port, &mut flow),
248 )
249 .await
250 {
251 Ok(Ok(resp)) => {
252 if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
253 let (parts, body) = resp.into_parts();
254 let resp_for_upgrade = Response::from_parts(parts.clone(), body);
255
256 let on_flow_clone = on_flow.clone();
258 let interceptor_clone = interceptor.clone();
259 let flow_clone = flow.clone(); tokio::task::spawn(async move {
262 let upgrade_timeout = std::time::Duration::from_secs(10);
264
265 let client_upgrade =
266 tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
267 let server_upgrade =
268 tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
269
270 match tokio::try_join!(client_upgrade, server_upgrade) {
271 Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
272 if let Err(e) = handle_websocket_tunnel(
274 upgraded_client,
275 upgraded_server,
276 flow_clone,
277 on_flow_clone,
278 interceptor_clone,
279 )
280 .await
281 {
282 tracing::error!("WebSocket Tunnel Error: {}", e);
283 }
284 }
285 Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
286 Ok((_, Err(e))) => {
287 tracing::error!("Upstream WebSocket Upgrade Error: {}", e)
288 }
289 Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
290 }
291 });
292
293 let mut client_resp_builder = Response::builder()
294 .status(StatusCode::SWITCHING_PROTOCOLS)
295 .version(parts.version);
296
297 for (k, v) in parts.headers.iter() {
298 client_resp_builder = client_resp_builder.header(k, v);
299 }
300
301 let client_resp = match client_resp_builder
302 .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
303 {
304 Ok(r) => r,
305 Err(e) => {
306 tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
307 return Ok(create_error_response(
308 StatusCode::INTERNAL_SERVER_ERROR,
309 "Response build failed",
310 ));
311 }
312 };
313
314 Ok(client_resp)
321 } else {
322 let (parts, body) = resp.into_parts();
324 let body_bytes = match body.collect().await {
325 Ok(c) => c.to_bytes(),
326 Err(_) => Bytes::new(),
327 };
328
329 let http_resp = HttpResponse {
331 status: parts.status.as_u16(),
332 status_text: parts
333 .status
334 .canonical_reason()
335 .unwrap_or("Unknown")
336 .to_string(),
337 version: format!("{:?}", parts.version),
338 headers: parts
339 .headers
340 .iter()
341 .map(|(k, v)| {
342 (
343 k.as_str().to_string(),
344 String::from_utf8_lossy(v.as_bytes()).to_string(),
345 )
346 })
347 .collect(),
348 cookies: vec![], body: Some(BodyData {
350 encoding: "utf-8".to_string(),
351 content: String::from_utf8_lossy(&body_bytes).to_string(),
352 size: body_bytes.len() as u64,
353 }),
354 timing: ResponseTiming {
355 time_to_first_byte: None,
356 time_to_last_byte: None,
357 connect_time_ms: None,
358 ssl_time_ms: None,
359 },
360 };
361
362 if let Layer::WebSocket(ws) = &mut flow.layer {
363 ws.handshake_response = http_resp.clone();
364 ws.closed = true;
365 }
366
367 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
368 crate::metrics::inc_flows_dropped();
369 }
370
371 Ok(Response::from_parts(
372 parts,
373 Full::new(body_bytes).map_err(|e| e.into()).boxed(),
374 ))
375 }
376 }
377 Ok(Err(e)) => Ok(create_error_response(
378 StatusCode::BAD_GATEWAY,
379 format!("Upstream Handshake Failed: {}", e),
380 )),
381 Err(_) => Ok(create_error_response(
382 StatusCode::GATEWAY_TIMEOUT,
383 "Upstream Handshake Timed Out",
384 )),
385 }
386}
387
388async fn handle_websocket_tunnel(
389 client_io: Upgraded,
390 server_io: Upgraded,
391 mut flow: Flow,
392 on_flow: Sender<FlowUpdate>,
393 interceptor: Arc<dyn Interceptor>,
394) -> Result<(), BoxError> {
395 let client_ws = WebSocketStream::from_raw_socket(
396 TokioIo::new(client_io),
397 tokio_tungstenite::tungstenite::protocol::Role::Server,
398 None,
399 )
400 .await;
401 let server_ws = WebSocketStream::from_raw_socket(
402 TokioIo::new(server_io),
403 tokio_tungstenite::tungstenite::protocol::Role::Client,
404 None,
405 )
406 .await;
407
408 let (mut client_tx, mut client_rx) = client_ws.split();
409 let (mut server_tx, mut server_rx) = server_ws.split();
410
411 let idle_timeout_duration = std::time::Duration::from_secs(300); interceptor.on_websocket_start(&mut flow).await;
415
416 loop {
417 let event = tokio::time::timeout(idle_timeout_duration, async {
418 tokio::select! {
419 msg = client_rx.next() => (Direction::ClientToServer, msg),
420 msg = server_rx.next() => (Direction::ServerToClient, msg),
421 }
422 })
423 .await;
424
425 match event {
426 Ok((dir, msg_opt)) => {
427 match msg_opt {
428 Some(Ok(msg)) => {
429 let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer
431 {
432 (&mut server_tx, &mut client_tx, Direction::ClientToServer)
433 } else {
434 (&mut client_tx, &mut server_tx, Direction::ServerToClient)
435 };
436
437 if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
438 match interceptor
439 .on_websocket_message(&mut flow, ws_msg.clone())
440 .await
441 {
442 Ok(WebSocketMessageAction::Drop) => continue,
443 Ok(WebSocketMessageAction::Continue(mod_msg)) => {
444 let t_msg = flow_msg_to_tungstenite(&mod_msg);
445 if let Err(e) = sender.send(t_msg).await {
446 interceptor
447 .on_websocket_error(&mut flow, &e.to_string())
448 .await;
449 return Err(e.into());
450 }
451
452 if on_flow
453 .try_send(FlowUpdate::WebSocketMessage {
454 flow_id: flow.id.to_string(),
455 message: mod_msg,
456 })
457 .is_err()
458 {
459 crate::metrics::inc_flows_dropped();
460 }
461 }
462 Err(e) => {
463 tracing::error!("WebSocket Interception Error: {}", e);
464 sender.send(msg).await?;
465
466 if on_flow
467 .try_send(FlowUpdate::WebSocketMessage {
468 flow_id: flow.id.to_string(),
469 message: ws_msg,
470 })
471 .is_err()
472 {
473 crate::metrics::inc_flows_dropped();
474 }
475 }
476 }
477 } else {
478 if let Err(e) = sender.send(msg).await {
480 interceptor
481 .on_websocket_error(&mut flow, &e.to_string())
482 .await;
483 return Err(e.into());
484 }
485 }
486 }
487 Some(Err(e)) => {
488 interceptor
489 .on_websocket_error(&mut flow, &e.to_string())
490 .await;
491 return Err(e.into());
492 }
493 None => {
494 interceptor
495 .on_websocket_end(&mut flow, 1000, "normal")
496 .await;
497 break;
498 }
499 }
500 }
501 Err(_) => {
502 tracing::warn!("WebSocket Tunnel Idle Timeout");
503 interceptor
504 .on_websocket_error(&mut flow, "WebSocket Idle Timeout")
505 .await;
506 return Err("WebSocket Idle Timeout".into());
507 }
508 }
509 }
510
511 Ok(())
512}
513
514fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
515 let (opcode, content, encoding, size) = match msg {
516 Message::Text(t) => {
517 let len = t.len();
518 ("Text", t.to_string(), "utf-8", len)
519 }
520 Message::Binary(b) => {
521 let len = b.len();
522 ("Binary", BASE64.encode(&b), "base64", len)
523 }
524 Message::Ping(b) => {
525 let len = b.len();
526 ("Ping", BASE64.encode(&b), "base64", len)
527 }
528 Message::Pong(b) => {
529 let len = b.len();
530 ("Pong", BASE64.encode(&b), "base64", len)
531 }
532 Message::Close(_) => ("Close", String::new(), "none", 0),
533 Message::Frame(_) => return None,
534 };
535
536 Some(WebSocketMessage {
537 id: Uuid::new_v4(),
538 timestamp: Utc::now(),
539 direction: dir,
540 content: BodyData {
541 encoding: encoding.to_string(),
542 content,
543 size: size as u64,
544 },
545 opcode: opcode.to_string(),
546 })
547}
548
549fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
550 match msg.opcode.as_str() {
551 "Text" => Message::Text(msg.content.content.clone().into()),
552 "Binary" => {
553 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
554 Message::Binary(Bytes::from(b))
555 } else {
556 Message::Binary(Bytes::new())
557 }
558 }
559 "Ping" => {
560 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
561 Message::Ping(Bytes::from(b))
562 } else {
563 Message::Ping(Bytes::new())
564 }
565 }
566 "Pong" => {
567 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
568 Message::Pong(Bytes::from(b))
569 } else {
570 Message::Pong(Bytes::new())
571 }
572 }
573 "Close" => Message::Close(None),
574 _ => Message::Text(msg.content.content.clone().into()),
575 }
576}
577
578#[cfg(test)]
579mod websocket_tests {
580 use super::*;
581 use http_body_util::Empty;
582
583 #[test]
584 fn test_validate_ws_strict_handshake_rejects_missing_key() {
585 let policy = ProxyPolicy {
586 strict_http_semantics: true,
587 ..Default::default()
588 };
589 let req = Request::builder()
590 .method("GET")
591 .uri("ws://example.com/socket")
592 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
593 .body(Empty::<Bytes>::new())
594 .expect("request build");
595 let result = validate_ws_strict_handshake(&req, &policy);
596 assert!(result.is_err());
597 }
598
599 #[test]
600 fn test_validate_ws_strict_handshake_rejects_invalid_version() {
601 let policy = ProxyPolicy {
602 strict_http_semantics: true,
603 ..Default::default()
604 };
605 let req = Request::builder()
606 .method("GET")
607 .uri("ws://example.com/socket")
608 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
609 .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
610 .body(Empty::<Bytes>::new())
611 .expect("request build");
612 let result = validate_ws_strict_handshake(&req, &policy);
613 assert!(result.is_err());
614 }
615
616 #[test]
617 fn test_validate_ws_strict_handshake_accepts_valid_request() {
618 let policy = ProxyPolicy {
619 strict_http_semantics: true,
620 ..Default::default()
621 };
622 let req = Request::builder()
623 .method("GET")
624 .uri("ws://example.com/socket")
625 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
626 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
627 .body(Empty::<Bytes>::new())
628 .expect("request build");
629 let result = validate_ws_strict_handshake(&req, &policy);
630 assert!(result.is_ok());
631 }
632
633 #[test]
634 fn test_validate_ws_strict_handshake_skips_when_disabled() {
635 let policy = ProxyPolicy {
636 strict_http_semantics: false,
637 ..Default::default()
638 };
639 let req = Request::builder()
640 .method("GET")
641 .uri("ws://example.com/socket")
642 .body(Empty::<Bytes>::new())
643 .expect("request build");
644 let result = validate_ws_strict_handshake(&req, &policy);
645 assert!(result.is_ok());
646 }
647}