1use crate::capture::loop_detection::LoopDetector;
2use crate::intercept::types::{
3 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, WebSocketMessageAction,
4};
5use crate::proxy::http_utils::{
6 HttpsClient, create_error_response, create_initial_flow, mock_to_response, parse_request_meta,
7};
8use chrono::Utc;
9use data_encoding::BASE64;
10use futures_util::{SinkExt, StreamExt};
11use http_body_util::{BodyExt, Full};
12use hyper::body::Bytes;
13use hyper::header::{HeaderName, HeaderValue};
14use hyper::upgrade::Upgraded;
15use hyper::{Request, Response, StatusCode};
16use relay_core_api::flow::{
17 BodyData, Direction, Flow, FlowUpdate, HttpResponse, Layer, WebSocketMessage,
18};
19use relay_core_api::policy::ProxyPolicy;
20use std::convert::Infallible;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use tokio::sync::mpsc::Sender;
24use tokio_tungstenite::WebSocketStream;
25use tokio_tungstenite::tungstenite::protocol::Message;
26use url::Url;
27use uuid::Uuid;
28
29use hyper_util::rt::TokioIo;
30use relay_core_api::flow::ResponseTiming;
31use tokio::sync::watch;
32
33fn validate_ws_strict_handshake<B>(
34 req: &Request<B>,
35 policy: &ProxyPolicy,
36) -> Result<(), Box<Response<HttpBody>>> {
37 if !policy.strict_http_semantics {
38 return Ok(());
39 }
40
41 if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
42 return Err(Box::new(create_error_response(
43 StatusCode::BAD_REQUEST,
44 "Missing Sec-WebSocket-Key header in Strict Mode",
45 )));
46 }
47
48 if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
49 if v != "13" {
50 return Err(Box::new(create_error_response(
51 StatusCode::BAD_REQUEST,
52 "Unsupported WebSocket Version in Strict Mode (Expected 13)",
53 )));
54 }
55 } else {
56 return Err(Box::new(create_error_response(
57 StatusCode::BAD_REQUEST,
58 "Missing Sec-WebSocket-Version header in Strict Mode",
59 )));
60 }
61
62 Ok(())
63}
64
65#[allow(clippy::too_many_arguments)]
66pub async fn handle_websocket_handshake<B>(
67 req: Request<B>,
68 client_addr: SocketAddr,
69 on_flow: Sender<FlowUpdate>,
70 client: Arc<HttpsClient>,
71 interceptor: Arc<dyn Interceptor>,
72 is_mitm: bool,
73 policy_rx: watch::Receiver<ProxyPolicy>,
74 target_addr: Option<SocketAddr>,
75 loop_detector: Arc<LoopDetector>,
76) -> Result<Response<HttpBody>, Infallible>
77where
78 B: hyper::body::Body + Send + 'static,
79 B::Data: Send,
80 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
81{
82 let meta = parse_request_meta(&req, is_mitm);
84 let policy = policy_rx.borrow().clone();
85
86 if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
88 return Ok(*resp);
89 }
90
91 let mut flow = create_initial_flow(meta.clone(), None, client_addr, is_mitm, true);
93
94 match interceptor.on_request_headers(&mut flow).await {
96 InterceptionResult::Drop => {
97 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
98 }
99 InterceptionResult::MockResponse(mock) => {
100 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
101 crate::metrics::inc_flows_dropped();
102 }
103 return Ok(mock_to_response(mock));
104 }
105 InterceptionResult::ModifiedRequest(req) => {
106 if let Layer::WebSocket(ws) = &mut flow.layer {
107 ws.handshake_request = req;
108 }
109 }
110 InterceptionResult::ModifiedResponse(res) => {
111 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
112 crate::metrics::inc_flows_dropped();
113 }
114 return Ok(mock_to_response(res));
115 }
116 _ => {}
117 }
118
119 let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
122
123 match interceptor.on_request(&mut flow, body).await {
124 Ok(RequestAction::Drop) => {
125 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
126 }
127 Ok(RequestAction::MockResponse(res)) => {
128 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
129 crate::metrics::inc_flows_dropped();
130 }
131 let (parts, body) = res.into_parts();
132 return Ok(Response::from_parts(parts, body));
133 }
134 Ok(RequestAction::Continue(_)) => {}
135 Err(e) => {
136 return Ok(create_error_response(
137 StatusCode::INTERNAL_SERVER_ERROR,
138 format!("Interceptor Error: {}", e),
139 ));
140 }
141 }
142
143 if on_flow
144 .try_send(FlowUpdate::Full(Box::new(flow.clone())))
145 .is_err()
146 {
147 crate::metrics::inc_flows_dropped();
148 }
149
150 let (parts, body) = req.into_parts();
152 let req_for_upgrade = Request::from_parts(parts, body);
153
154 let mut target_url_str = meta.url_str.clone();
156
157 if policy.transparent_enabled
158 && let Some(addr) = target_addr
159 {
160 flow.tags.push("transparent".to_string());
161
162 flow.network.server_ip = addr.ip().to_string();
164 flow.network.server_port = addr.port();
165
166 if loop_detector.would_loop(addr) {
168 if let Layer::WebSocket(ws) = &mut flow.layer {
169 ws.handshake_response.status = 508;
170 ws.closed = true;
171 }
172 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
173 crate::metrics::inc_flows_dropped();
174 }
175 return Ok(create_error_response(
176 StatusCode::LOOP_DETECTED,
177 "Loop Detected",
178 ));
179 }
180
181 let mut u = if let Layer::WebSocket(ws) = &flow.layer {
183 ws.handshake_request.url.clone()
184 } else {
185 Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
186 };
187
188 if u.set_ip_host(addr.ip()).is_ok() {
189 u.set_port(Some(addr.port())).ok();
190 if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
192 u.set_scheme("wss").ok();
193 } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
194 u.set_scheme("ws").ok();
195 }
196 target_url_str = u.to_string();
197 }
198 }
199
200 let current_req = if let Layer::WebSocket(ws) = &flow.layer {
202 &ws.handshake_request
203 } else {
204 return Ok(create_error_response(
205 StatusCode::INTERNAL_SERVER_ERROR,
206 "Invalid Flow Layer State",
207 ));
208 };
209
210 let mut forward_req_builder = Request::builder()
211 .method(current_req.method.as_str())
212 .uri(target_url_str.as_str())
213 .version(hyper::Version::HTTP_11);
214
215 for (k, v) in current_req.headers.iter() {
216 if let (Ok(name), Ok(val)) = (
217 HeaderName::from_bytes(k.as_bytes()),
218 HeaderValue::from_str(v),
219 ) {
220 forward_req_builder = forward_req_builder.header(name, val);
221 }
222 }
223
224 let forward_req =
225 match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
226 Ok(req) => req,
227 Err(e) => {
228 return Ok(create_error_response(
229 StatusCode::INTERNAL_SERVER_ERROR,
230 format!("Failed to build forward request: {}", e),
231 ));
232 }
233 };
234
235 match tokio::time::timeout(
236 std::time::Duration::from_secs(30),
237 client.request(forward_req),
238 )
239 .await
240 {
241 Ok(Ok(resp)) => {
242 if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
243 let (parts, body) = resp.into_parts();
244 let resp_for_upgrade = Response::from_parts(parts.clone(), body);
245
246 let on_flow_clone = on_flow.clone();
248 let interceptor_clone = interceptor.clone();
249 let flow_clone = flow.clone(); tokio::task::spawn(async move {
252 let upgrade_timeout = std::time::Duration::from_secs(10);
254
255 let client_upgrade =
256 tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
257 let server_upgrade =
258 tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
259
260 match tokio::try_join!(client_upgrade, server_upgrade) {
261 Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
262 if let Err(e) = handle_websocket_tunnel(
264 upgraded_client,
265 upgraded_server,
266 flow_clone,
267 on_flow_clone,
268 interceptor_clone,
269 )
270 .await
271 {
272 tracing::error!("WebSocket Tunnel Error: {}", e);
273 }
274 }
275 Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
276 Ok((_, Err(e))) => {
277 tracing::error!("Upstream WebSocket Upgrade Error: {}", e)
278 }
279 Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
280 }
281 });
282
283 let mut client_resp_builder = Response::builder()
284 .status(StatusCode::SWITCHING_PROTOCOLS)
285 .version(parts.version);
286
287 for (k, v) in parts.headers.iter() {
288 client_resp_builder = client_resp_builder.header(k, v);
289 }
290
291 let client_resp = match client_resp_builder
292 .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
293 {
294 Ok(r) => r,
295 Err(e) => {
296 tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
297 return Ok(create_error_response(
298 StatusCode::INTERNAL_SERVER_ERROR,
299 "Response build failed",
300 ));
301 }
302 };
303
304 Ok(client_resp)
311 } else {
312 let (parts, body) = resp.into_parts();
314 let body_bytes = match body.collect().await {
315 Ok(c) => c.to_bytes(),
316 Err(_) => Bytes::new(),
317 };
318
319 let http_resp = HttpResponse {
321 status: parts.status.as_u16(),
322 status_text: parts
323 .status
324 .canonical_reason()
325 .unwrap_or("Unknown")
326 .to_string(),
327 version: format!("{:?}", parts.version),
328 headers: parts
329 .headers
330 .iter()
331 .map(|(k, v)| {
332 (
333 k.as_str().to_string(),
334 String::from_utf8_lossy(v.as_bytes()).to_string(),
335 )
336 })
337 .collect(),
338 cookies: vec![], body: Some(BodyData {
340 encoding: "utf-8".to_string(),
341 content: String::from_utf8_lossy(&body_bytes).to_string(),
342 size: body_bytes.len() as u64,
343 }),
344 timing: ResponseTiming {
345 time_to_first_byte: None,
346 time_to_last_byte: None,
347 connect_time_ms: None,
348 ssl_time_ms: None,
349 },
350 };
351
352 if let Layer::WebSocket(ws) = &mut flow.layer {
353 ws.handshake_response = http_resp.clone();
354 ws.closed = true;
355 }
356
357 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
358 crate::metrics::inc_flows_dropped();
359 }
360
361 Ok(Response::from_parts(
362 parts,
363 Full::new(body_bytes).map_err(|e| e.into()).boxed(),
364 ))
365 }
366 }
367 Ok(Err(e)) => Ok(create_error_response(
368 StatusCode::BAD_GATEWAY,
369 format!("Upstream Handshake Failed: {}", e),
370 )),
371 Err(_) => Ok(create_error_response(
372 StatusCode::GATEWAY_TIMEOUT,
373 "Upstream Handshake Timed Out",
374 )),
375 }
376}
377
378async fn handle_websocket_tunnel(
379 client_io: Upgraded,
380 server_io: Upgraded,
381 mut flow: Flow,
382 on_flow: Sender<FlowUpdate>,
383 interceptor: Arc<dyn Interceptor>,
384) -> Result<(), BoxError> {
385 let client_ws = WebSocketStream::from_raw_socket(
386 TokioIo::new(client_io),
387 tokio_tungstenite::tungstenite::protocol::Role::Server,
388 None,
389 )
390 .await;
391 let server_ws = WebSocketStream::from_raw_socket(
392 TokioIo::new(server_io),
393 tokio_tungstenite::tungstenite::protocol::Role::Client,
394 None,
395 )
396 .await;
397
398 let (mut client_tx, mut client_rx) = client_ws.split();
399 let (mut server_tx, mut server_rx) = server_ws.split();
400
401 let idle_timeout_duration = std::time::Duration::from_secs(300); interceptor.on_websocket_start(&mut flow).await;
405
406 loop {
407 let event = tokio::time::timeout(idle_timeout_duration, async {
408 tokio::select! {
409 msg = client_rx.next() => (Direction::ClientToServer, msg),
410 msg = server_rx.next() => (Direction::ServerToClient, msg),
411 }
412 })
413 .await;
414
415 match event {
416 Ok((dir, msg_opt)) => {
417 match msg_opt {
418 Some(Ok(msg)) => {
419 let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer
421 {
422 (&mut server_tx, &mut client_tx, Direction::ClientToServer)
423 } else {
424 (&mut client_tx, &mut server_tx, Direction::ServerToClient)
425 };
426
427 if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
428 match interceptor
429 .on_websocket_message(&mut flow, ws_msg.clone())
430 .await
431 {
432 Ok(WebSocketMessageAction::Drop) => continue,
433 Ok(WebSocketMessageAction::Continue(mod_msg)) => {
434 let t_msg = flow_msg_to_tungstenite(&mod_msg);
435 if let Err(e) = sender.send(t_msg).await {
436 interceptor
437 .on_websocket_error(&mut flow, &e.to_string())
438 .await;
439 return Err(e.into());
440 }
441
442 if on_flow
443 .try_send(FlowUpdate::WebSocketMessage {
444 flow_id: flow.id.to_string(),
445 message: mod_msg,
446 })
447 .is_err()
448 {
449 crate::metrics::inc_flows_dropped();
450 }
451 }
452 Err(e) => {
453 tracing::error!("WebSocket Interception Error: {}", e);
454 sender.send(msg).await?;
455
456 if on_flow
457 .try_send(FlowUpdate::WebSocketMessage {
458 flow_id: flow.id.to_string(),
459 message: ws_msg,
460 })
461 .is_err()
462 {
463 crate::metrics::inc_flows_dropped();
464 }
465 }
466 }
467 } else {
468 if let Err(e) = sender.send(msg).await {
470 interceptor
471 .on_websocket_error(&mut flow, &e.to_string())
472 .await;
473 return Err(e.into());
474 }
475 }
476 }
477 Some(Err(e)) => {
478 interceptor
479 .on_websocket_error(&mut flow, &e.to_string())
480 .await;
481 return Err(e.into());
482 }
483 None => {
484 interceptor
485 .on_websocket_end(&mut flow, 1000, "normal")
486 .await;
487 break;
488 }
489 }
490 }
491 Err(_) => {
492 tracing::warn!("WebSocket Tunnel Idle Timeout");
493 interceptor
494 .on_websocket_error(&mut flow, "WebSocket Idle Timeout")
495 .await;
496 return Err("WebSocket Idle Timeout".into());
497 }
498 }
499 }
500
501 Ok(())
502}
503
504fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
505 let (opcode, content, encoding, size) = match msg {
506 Message::Text(t) => {
507 let len = t.len();
508 ("Text", t.to_string(), "utf-8", len)
509 }
510 Message::Binary(b) => {
511 let len = b.len();
512 ("Binary", BASE64.encode(&b), "base64", len)
513 }
514 Message::Ping(b) => {
515 let len = b.len();
516 ("Ping", BASE64.encode(&b), "base64", len)
517 }
518 Message::Pong(b) => {
519 let len = b.len();
520 ("Pong", BASE64.encode(&b), "base64", len)
521 }
522 Message::Close(_) => ("Close", String::new(), "none", 0),
523 Message::Frame(_) => return None,
524 };
525
526 Some(WebSocketMessage {
527 id: Uuid::new_v4(),
528 timestamp: Utc::now(),
529 direction: dir,
530 content: BodyData {
531 encoding: encoding.to_string(),
532 content,
533 size: size as u64,
534 },
535 opcode: opcode.to_string(),
536 })
537}
538
539fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
540 match msg.opcode.as_str() {
541 "Text" => Message::Text(msg.content.content.clone().into()),
542 "Binary" => {
543 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
544 Message::Binary(Bytes::from(b))
545 } else {
546 Message::Binary(Bytes::new())
547 }
548 }
549 "Ping" => {
550 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
551 Message::Ping(Bytes::from(b))
552 } else {
553 Message::Ping(Bytes::new())
554 }
555 }
556 "Pong" => {
557 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
558 Message::Pong(Bytes::from(b))
559 } else {
560 Message::Pong(Bytes::new())
561 }
562 }
563 "Close" => Message::Close(None),
564 _ => Message::Text(msg.content.content.clone().into()),
565 }
566}
567
568#[cfg(test)]
569mod websocket_tests {
570 use super::*;
571 use http_body_util::Empty;
572
573 #[test]
574 fn test_validate_ws_strict_handshake_rejects_missing_key() {
575 let policy = ProxyPolicy {
576 strict_http_semantics: true,
577 ..Default::default()
578 };
579 let req = Request::builder()
580 .method("GET")
581 .uri("ws://example.com/socket")
582 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
583 .body(Empty::<Bytes>::new())
584 .expect("request build");
585 let result = validate_ws_strict_handshake(&req, &policy);
586 assert!(result.is_err());
587 }
588
589 #[test]
590 fn test_validate_ws_strict_handshake_rejects_invalid_version() {
591 let policy = ProxyPolicy {
592 strict_http_semantics: true,
593 ..Default::default()
594 };
595 let req = Request::builder()
596 .method("GET")
597 .uri("ws://example.com/socket")
598 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
599 .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
600 .body(Empty::<Bytes>::new())
601 .expect("request build");
602 let result = validate_ws_strict_handshake(&req, &policy);
603 assert!(result.is_err());
604 }
605
606 #[test]
607 fn test_validate_ws_strict_handshake_accepts_valid_request() {
608 let policy = ProxyPolicy {
609 strict_http_semantics: true,
610 ..Default::default()
611 };
612 let req = Request::builder()
613 .method("GET")
614 .uri("ws://example.com/socket")
615 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
616 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
617 .body(Empty::<Bytes>::new())
618 .expect("request build");
619 let result = validate_ws_strict_handshake(&req, &policy);
620 assert!(result.is_ok());
621 }
622
623 #[test]
624 fn test_validate_ws_strict_handshake_skips_when_disabled() {
625 let policy = ProxyPolicy {
626 strict_http_semantics: false,
627 ..Default::default()
628 };
629 let req = Request::builder()
630 .method("GET")
631 .uri("ws://example.com/socket")
632 .body(Empty::<Bytes>::new())
633 .expect("request build");
634 let result = validate_ws_strict_handshake(&req, &policy);
635 assert!(result.is_ok());
636 }
637}