1use crate::capture::loop_detection::LoopDetector;
2use crate::intercept::types::{
3 BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, WebSocketMessageAction,
4};
5use crate::proxy::http_utils::{
6 HttpsClient, create_error_response, create_initial_flow, mock_to_response, parse_request_meta,
7};
8use chrono::Utc;
9use data_encoding::BASE64;
10use futures_util::{SinkExt, StreamExt};
11use http_body_util::{BodyExt, Full};
12use hyper::body::Bytes;
13use hyper::header::{HeaderName, HeaderValue};
14use hyper::upgrade::Upgraded;
15use hyper::{Request, Response, StatusCode};
16use relay_core_api::flow::{
17 BodyData, Direction, Flow, FlowUpdate, HttpResponse, Layer, WebSocketMessage,
18};
19use relay_core_api::policy::ProxyPolicy;
20use std::convert::Infallible;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use tokio::sync::mpsc::Sender;
24use tokio_tungstenite::WebSocketStream;
25use tokio_tungstenite::tungstenite::protocol::Message;
26use url::Url;
27use uuid::Uuid;
28
29use hyper_util::rt::TokioIo;
30use relay_core_api::flow::ResponseTiming;
31use tokio::sync::watch;
32
33fn validate_ws_strict_handshake<B>(
34 req: &Request<B>,
35 policy: &ProxyPolicy,
36) -> Result<(), Box<Response<HttpBody>>> {
37 if !policy.strict_http_semantics {
38 return Ok(());
39 }
40
41 if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
42 return Err(Box::new(create_error_response(
43 StatusCode::BAD_REQUEST,
44 "Missing Sec-WebSocket-Key header in Strict Mode",
45 )));
46 }
47
48 if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
49 if v != "13" {
50 return Err(Box::new(create_error_response(
51 StatusCode::BAD_REQUEST,
52 "Unsupported WebSocket Version in Strict Mode (Expected 13)",
53 )));
54 }
55 } else {
56 return Err(Box::new(create_error_response(
57 StatusCode::BAD_REQUEST,
58 "Missing Sec-WebSocket-Version header in Strict Mode",
59 )));
60 }
61
62 Ok(())
63}
64
65#[allow(clippy::too_many_arguments)]
66pub async fn handle_websocket_handshake<B>(
67 req: Request<B>,
68 client_addr: SocketAddr,
69 on_flow: Sender<FlowUpdate>,
70 client: Arc<HttpsClient>,
71 interceptor: Arc<dyn Interceptor>,
72 is_mitm: bool,
73 policy_rx: watch::Receiver<ProxyPolicy>,
74 target_addr: Option<SocketAddr>,
75 loop_detector: Arc<LoopDetector>,
76) -> Result<Response<HttpBody>, Infallible>
77where
78 B: hyper::body::Body + Send + 'static,
79 B::Data: Send,
80 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
81{
82 let meta = parse_request_meta(&req, is_mitm);
84 let policy = policy_rx.borrow().clone();
85
86 if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
88 return Ok(*resp);
89 }
90
91 let mut flow = create_initial_flow(meta.clone(), None, client_addr, is_mitm, true);
93
94 match interceptor.on_request_headers(&mut flow).await {
96 InterceptionResult::Drop => {
97 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
98 }
99 InterceptionResult::MockResponse(mock) => {
100 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
101 crate::metrics::inc_flows_dropped();
102 }
103 return Ok(mock_to_response(mock));
104 }
105 InterceptionResult::ModifiedRequest(req) => {
106 if let Layer::WebSocket(ws) = &mut flow.layer {
107 ws.handshake_request = req;
108 }
109 }
110 InterceptionResult::ModifiedResponse(res) => {
111 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
112 crate::metrics::inc_flows_dropped();
113 }
114 return Ok(mock_to_response(res));
115 }
116 _ => {}
117 }
118
119 let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
122
123 match interceptor.on_request(&mut flow, body).await {
124 Ok(RequestAction::Drop) => {
125 return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
126 }
127 Ok(RequestAction::MockResponse(res)) => {
128 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
129 crate::metrics::inc_flows_dropped();
130 }
131 let (parts, body) = res.into_parts();
132 return Ok(Response::from_parts(parts, body));
133 }
134 Ok(RequestAction::Continue(_)) => {}
135 Err(e) => {
136 return Ok(create_error_response(
137 StatusCode::INTERNAL_SERVER_ERROR,
138 format!("Interceptor Error: {}", e),
139 ));
140 }
141 }
142
143 if on_flow
144 .try_send(FlowUpdate::Full(Box::new(flow.clone())))
145 .is_err()
146 {
147 crate::metrics::inc_flows_dropped();
148 }
149
150 let (parts, body) = req.into_parts();
152 let req_for_upgrade = Request::from_parts(parts, body);
153
154 let mut target_url_str = meta.url_str.clone();
156
157 if policy.transparent_enabled
158 && let Some(addr) = target_addr
159 {
160 flow.tags.push("transparent".to_string());
161
162 flow.network.server_ip = addr.ip().to_string();
164 flow.network.server_port = addr.port();
165
166 if loop_detector.would_loop(addr) {
168 if let Layer::WebSocket(ws) = &mut flow.layer {
169 ws.handshake_response.status = 508;
170 ws.closed = true;
171 }
172 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
173 crate::metrics::inc_flows_dropped();
174 }
175 return Ok(create_error_response(
176 StatusCode::LOOP_DETECTED,
177 "Loop Detected",
178 ));
179 }
180
181 let mut u = if let Layer::WebSocket(ws) = &flow.layer {
183 ws.handshake_request.url.clone()
184 } else {
185 Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
186 };
187
188 if u.set_ip_host(addr.ip()).is_ok() {
189 u.set_port(Some(addr.port())).ok();
190 if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
192 u.set_scheme("wss").ok();
193 } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
194 u.set_scheme("ws").ok();
195 }
196 target_url_str = u.to_string();
197 }
198 }
199
200 let current_req = if let Layer::WebSocket(ws) = &flow.layer {
202 &ws.handshake_request
203 } else {
204 return Ok(create_error_response(
205 StatusCode::INTERNAL_SERVER_ERROR,
206 "Invalid Flow Layer State",
207 ));
208 };
209
210 let mut forward_req_builder = Request::builder()
211 .method(current_req.method.as_str())
212 .uri(target_url_str.as_str())
213 .version(hyper::Version::HTTP_11);
214
215 for (k, v) in current_req.headers.iter() {
216 if let (Ok(name), Ok(val)) = (
217 HeaderName::from_bytes(k.as_bytes()),
218 HeaderValue::from_str(v),
219 ) {
220 forward_req_builder = forward_req_builder.header(name, val);
221 }
222 }
223
224 let forward_req =
225 match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
226 Ok(req) => req,
227 Err(e) => {
228 return Ok(create_error_response(
229 StatusCode::INTERNAL_SERVER_ERROR,
230 format!("Failed to build forward request: {}", e),
231 ));
232 }
233 };
234
235 match tokio::time::timeout(
236 std::time::Duration::from_secs(30),
237 client.request(forward_req),
238 )
239 .await
240 {
241 Ok(Ok(resp)) => {
242 if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
243 let (parts, body) = resp.into_parts();
244 let resp_for_upgrade = Response::from_parts(parts.clone(), body);
245
246 let on_flow_clone = on_flow.clone();
248 let interceptor_clone = interceptor.clone();
249 let flow_clone = flow.clone(); tokio::task::spawn(async move {
252 let upgrade_timeout = std::time::Duration::from_secs(10);
254
255 let client_upgrade =
256 tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
257 let server_upgrade =
258 tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
259
260 match tokio::try_join!(client_upgrade, server_upgrade) {
261 Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
262 if let Err(e) = handle_websocket_tunnel(
264 upgraded_client,
265 upgraded_server,
266 flow_clone,
267 on_flow_clone,
268 interceptor_clone,
269 )
270 .await
271 {
272 tracing::error!("WebSocket Tunnel Error: {}", e);
273 }
274 }
275 Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
276 Ok((_, Err(e))) => {
277 tracing::error!("Upstream WebSocket Upgrade Error: {}", e)
278 }
279 Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
280 }
281 });
282
283 let mut client_resp_builder = Response::builder()
284 .status(StatusCode::SWITCHING_PROTOCOLS)
285 .version(parts.version);
286
287 for (k, v) in parts.headers.iter() {
288 client_resp_builder = client_resp_builder.header(k, v);
289 }
290
291 let client_resp = match client_resp_builder
292 .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
293 {
294 Ok(r) => r,
295 Err(e) => {
296 tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
297 return Ok(create_error_response(
298 StatusCode::INTERNAL_SERVER_ERROR,
299 "Response build failed",
300 ));
301 }
302 };
303
304 Ok(client_resp)
311 } else {
312 let (parts, body) = resp.into_parts();
314 let body_bytes = match body.collect().await {
315 Ok(c) => c.to_bytes(),
316 Err(_) => Bytes::new(),
317 };
318
319 let http_resp = HttpResponse {
321 status: parts.status.as_u16(),
322 status_text: parts
323 .status
324 .canonical_reason()
325 .unwrap_or("Unknown")
326 .to_string(),
327 version: format!("{:?}", parts.version),
328 headers: parts
329 .headers
330 .iter()
331 .map(|(k, v)| {
332 (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())
333 })
334 .collect(),
335 cookies: vec![], body: Some(BodyData {
337 encoding: "utf-8".to_string(),
338 content: String::from_utf8_lossy(&body_bytes).to_string(),
339 size: body_bytes.len() as u64,
340 }),
341 timing: ResponseTiming {
342 time_to_first_byte: None,
343 time_to_last_byte: None,
344 connect_time_ms: None,
345 ssl_time_ms: None,
346 },
347 };
348
349 if let Layer::WebSocket(ws) = &mut flow.layer {
350 ws.handshake_response = http_resp.clone();
351 ws.closed = true;
352 }
353
354 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
355 crate::metrics::inc_flows_dropped();
356 }
357
358 Ok(Response::from_parts(
359 parts,
360 Full::new(body_bytes).map_err(|e| e.into()).boxed(),
361 ))
362 }
363 }
364 Ok(Err(e)) => Ok(create_error_response(
365 StatusCode::BAD_GATEWAY,
366 format!("Upstream Handshake Failed: {}", e),
367 )),
368 Err(_) => Ok(create_error_response(
369 StatusCode::GATEWAY_TIMEOUT,
370 "Upstream Handshake Timed Out",
371 )),
372 }
373}
374
375async fn handle_websocket_tunnel(
376 client_io: Upgraded,
377 server_io: Upgraded,
378 mut flow: Flow,
379 on_flow: Sender<FlowUpdate>,
380 interceptor: Arc<dyn Interceptor>,
381) -> Result<(), BoxError> {
382 let client_ws = WebSocketStream::from_raw_socket(
383 TokioIo::new(client_io),
384 tokio_tungstenite::tungstenite::protocol::Role::Server,
385 None,
386 )
387 .await;
388 let server_ws = WebSocketStream::from_raw_socket(
389 TokioIo::new(server_io),
390 tokio_tungstenite::tungstenite::protocol::Role::Client,
391 None,
392 )
393 .await;
394
395 let (mut client_tx, mut client_rx) = client_ws.split();
396 let (mut server_tx, mut server_rx) = server_ws.split();
397
398 let idle_timeout_duration = std::time::Duration::from_secs(300); loop {
402 let event = tokio::time::timeout(idle_timeout_duration, async {
403 tokio::select! {
404 msg = client_rx.next() => (Direction::ClientToServer, msg),
405 msg = server_rx.next() => (Direction::ServerToClient, msg),
406 }
407 })
408 .await;
409
410 match event {
411 Ok((dir, msg_opt)) => {
412 match msg_opt {
413 Some(Ok(msg)) => {
414 let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer
416 {
417 (&mut server_tx, &mut client_tx, Direction::ClientToServer)
418 } else {
419 (&mut client_tx, &mut server_tx, Direction::ServerToClient)
420 };
421
422 if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
423 match interceptor
424 .on_websocket_message(&mut flow, ws_msg.clone())
425 .await
426 {
427 Ok(WebSocketMessageAction::Drop) => continue,
428 Ok(WebSocketMessageAction::Continue(mod_msg)) => {
429 let t_msg = flow_msg_to_tungstenite(&mod_msg);
430 sender.send(t_msg).await?;
431
432 if on_flow
433 .try_send(FlowUpdate::WebSocketMessage {
434 flow_id: flow.id.to_string(),
435 message: mod_msg,
436 })
437 .is_err()
438 {
439 crate::metrics::inc_flows_dropped();
440 }
441 }
442 Err(e) => {
443 tracing::error!("WebSocket Interception Error: {}", e);
444 sender.send(msg).await?;
445
446 if on_flow
447 .try_send(FlowUpdate::WebSocketMessage {
448 flow_id: flow.id.to_string(),
449 message: ws_msg,
450 })
451 .is_err()
452 {
453 crate::metrics::inc_flows_dropped();
454 }
455 }
456 }
457 } else {
458 sender.send(msg).await?;
460 }
461 }
462 Some(Err(e)) => return Err(e.into()),
463 None => break, }
465 }
466 Err(_) => {
467 tracing::warn!("WebSocket Tunnel Idle Timeout");
468 return Err("WebSocket Idle Timeout".into());
470 }
471 }
472 }
473
474 Ok(())
475}
476
477fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
478 let (opcode, content, encoding, size) = match msg {
479 Message::Text(t) => {
480 let len = t.len();
481 ("Text", t.to_string(), "utf-8", len)
482 }
483 Message::Binary(b) => {
484 let len = b.len();
485 ("Binary", BASE64.encode(&b), "base64", len)
486 }
487 Message::Ping(b) => {
488 let len = b.len();
489 ("Ping", BASE64.encode(&b), "base64", len)
490 }
491 Message::Pong(b) => {
492 let len = b.len();
493 ("Pong", BASE64.encode(&b), "base64", len)
494 }
495 Message::Close(_) => ("Close", String::new(), "none", 0),
496 Message::Frame(_) => return None,
497 };
498
499 Some(WebSocketMessage {
500 id: Uuid::new_v4(),
501 timestamp: Utc::now(),
502 direction: dir,
503 content: BodyData {
504 encoding: encoding.to_string(),
505 content,
506 size: size as u64,
507 },
508 opcode: opcode.to_string(),
509 })
510}
511
512fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
513 match msg.opcode.as_str() {
514 "Text" => Message::Text(msg.content.content.clone().into()),
515 "Binary" => {
516 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
517 Message::Binary(Bytes::from(b))
518 } else {
519 Message::Binary(Bytes::new())
520 }
521 }
522 "Ping" => {
523 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
524 Message::Ping(Bytes::from(b))
525 } else {
526 Message::Ping(Bytes::new())
527 }
528 }
529 "Pong" => {
530 if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
531 Message::Pong(Bytes::from(b))
532 } else {
533 Message::Pong(Bytes::new())
534 }
535 }
536 "Close" => Message::Close(None),
537 _ => Message::Text(msg.content.content.clone().into()),
538 }
539}
540
541#[cfg(test)]
542mod websocket_tests {
543 use super::*;
544 use http_body_util::Empty;
545
546 #[test]
547 fn test_validate_ws_strict_handshake_rejects_missing_key() {
548 let policy = ProxyPolicy {
549 strict_http_semantics: true,
550 ..Default::default()
551 };
552 let req = Request::builder()
553 .method("GET")
554 .uri("ws://example.com/socket")
555 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
556 .body(Empty::<Bytes>::new())
557 .expect("request build");
558 let result = validate_ws_strict_handshake(&req, &policy);
559 assert!(result.is_err());
560 }
561
562 #[test]
563 fn test_validate_ws_strict_handshake_rejects_invalid_version() {
564 let policy = ProxyPolicy {
565 strict_http_semantics: true,
566 ..Default::default()
567 };
568 let req = Request::builder()
569 .method("GET")
570 .uri("ws://example.com/socket")
571 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
572 .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
573 .body(Empty::<Bytes>::new())
574 .expect("request build");
575 let result = validate_ws_strict_handshake(&req, &policy);
576 assert!(result.is_err());
577 }
578
579 #[test]
580 fn test_validate_ws_strict_handshake_accepts_valid_request() {
581 let policy = ProxyPolicy {
582 strict_http_semantics: true,
583 ..Default::default()
584 };
585 let req = Request::builder()
586 .method("GET")
587 .uri("ws://example.com/socket")
588 .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
589 .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
590 .body(Empty::<Bytes>::new())
591 .expect("request build");
592 let result = validate_ws_strict_handshake(&req, &policy);
593 assert!(result.is_ok());
594 }
595
596 #[test]
597 fn test_validate_ws_strict_handshake_skips_when_disabled() {
598 let policy = ProxyPolicy {
599 strict_http_semantics: false,
600 ..Default::default()
601 };
602 let req = Request::builder()
603 .method("GET")
604 .uri("ws://example.com/socket")
605 .body(Empty::<Bytes>::new())
606 .expect("request build");
607 let result = validate_ws_strict_handshake(&req, &policy);
608 assert!(result.is_ok());
609 }
610}