Skip to main content

relay_core_lib/proxy/
websocket.rs

1use crate::capture::loop_detection::LoopDetector;
2use crate::intercept::types::{
3    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, WebSocketMessageAction,
4};
5use crate::proxy::http_utils::{
6    create_error_response, create_initial_flow, mock_to_response, parse_request_meta,
7};
8use crate::proxy::outbound::OutboundConnector;
9use chrono::Utc;
10use data_encoding::BASE64;
11use futures_util::{SinkExt, StreamExt};
12use http_body_util::{BodyExt, Full};
13use hyper::body::Bytes;
14use hyper::header::{HeaderName, HeaderValue};
15use hyper::upgrade::Upgraded;
16use hyper::{Request, Response, StatusCode};
17use relay_core_api::flow::{
18    BodyData, Direction, Flow, FlowUpdate, HttpResponse, Layer, WebSocketMessage,
19};
20use relay_core_api::policy::ProxyPolicy;
21use std::convert::Infallible;
22use std::net::SocketAddr;
23use std::sync::Arc;
24use tokio::sync::mpsc::Sender;
25use tokio_tungstenite::WebSocketStream;
26use tokio_tungstenite::tungstenite::protocol::Message;
27use url::Url;
28use uuid::Uuid;
29
30use hyper_util::rt::TokioIo;
31use relay_core_api::flow::ResponseTiming;
32use tokio::sync::watch;
33
34fn validate_ws_strict_handshake<B>(
35    req: &Request<B>,
36    policy: &ProxyPolicy,
37) -> Result<(), Box<Response<HttpBody>>> {
38    if !policy.strict_http_semantics {
39        return Ok(());
40    }
41
42    if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
43        return Err(Box::new(create_error_response(
44            StatusCode::BAD_REQUEST,
45            "Missing Sec-WebSocket-Key header in Strict Mode",
46        )));
47    }
48
49    if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
50        if v != "13" {
51            return Err(Box::new(create_error_response(
52                StatusCode::BAD_REQUEST,
53                "Unsupported WebSocket Version in Strict Mode (Expected 13)",
54            )));
55        }
56    } else {
57        return Err(Box::new(create_error_response(
58            StatusCode::BAD_REQUEST,
59            "Missing Sec-WebSocket-Version header in Strict Mode",
60        )));
61    }
62
63    Ok(())
64}
65
66#[allow(clippy::too_many_arguments)]
67pub async fn handle_websocket_handshake<B>(
68    req: Request<B>,
69    client_addr: SocketAddr,
70    on_flow: Sender<FlowUpdate>,
71    connector: Arc<dyn OutboundConnector>,
72    interceptor: Arc<dyn Interceptor>,
73    is_mitm: bool,
74    policy_rx: watch::Receiver<ProxyPolicy>,
75    target_addr: Option<SocketAddr>,
76    loop_detector: Arc<LoopDetector>,
77) -> Result<Response<HttpBody>, Infallible>
78where
79    B: hyper::body::Body + Send + 'static,
80    B::Data: Send,
81    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
82{
83    // 2. Parse L7 (HTTP) - Re-extract for WS
84    let meta = parse_request_meta(&req, is_mitm);
85    let policy = policy_rx.borrow().clone();
86
87    // STRICT MODE CHECKS
88    if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
89        return Ok(*resp);
90    }
91
92    // Create Initial Flow for WebSocket Handshake
93    let mut flow = create_initial_flow(meta.clone(), None, client_addr, is_mitm, true);
94
95    // INTERCEPT HEADERS (Handshake)
96    match interceptor.on_request_headers(&mut flow).await {
97        InterceptionResult::Drop => {
98            return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
99        }
100        InterceptionResult::MockResponse(mock) => {
101            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
102                crate::metrics::inc_flows_dropped();
103            }
104            return Ok(mock_to_response(mock));
105        }
106        InterceptionResult::ModifiedRequest(req) => {
107            if let Layer::WebSocket(ws) = &mut flow.layer {
108                ws.handshake_request = req;
109            }
110        }
111        InterceptionResult::ModifiedResponse(res) => {
112            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
113                crate::metrics::inc_flows_dropped();
114            }
115            return Ok(mock_to_response(res));
116        }
117        _ => {}
118    }
119
120    // INTERCEPT REQUEST (Handshake Full)
121    // WS Handshake has empty body
122    let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
123
124    match interceptor.on_request(&mut flow, body).await {
125        Ok(RequestAction::Drop) => {
126            return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
127        }
128        Ok(RequestAction::MockResponse(res)) => {
129            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
130                crate::metrics::inc_flows_dropped();
131            }
132            let (parts, body) = res.into_parts();
133            return Ok(Response::from_parts(parts, body));
134        }
135        Ok(RequestAction::Continue(_)) => {}
136        Err(e) => {
137            return Ok(create_error_response(
138                StatusCode::INTERNAL_SERVER_ERROR,
139                format!("Interceptor Error: {}", e),
140            ));
141        }
142    }
143
144    if on_flow
145        .try_send(FlowUpdate::Full(Box::new(flow.clone())))
146        .is_err()
147    {
148        crate::metrics::inc_flows_dropped();
149    }
150
151    // Prepare Upgrade
152    let (parts, body) = req.into_parts();
153    let req_for_upgrade = Request::from_parts(parts, body);
154
155    // Determine Target URL
156    let mut target_url_str = meta.url_str.clone();
157
158    if policy.transparent_enabled
159        && let Some(addr) = target_addr
160    {
161        flow.tags.push("transparent".to_string());
162
163        // Update Flow Network Info
164        flow.network.server_ip = addr.ip().to_string();
165        flow.network.server_port = addr.port();
166
167        // Loop Detection
168        if loop_detector.would_loop(addr) {
169            if let Layer::WebSocket(ws) = &mut flow.layer {
170                ws.handshake_response.status = 508;
171                ws.closed = true;
172            }
173            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
174                crate::metrics::inc_flows_dropped();
175            }
176            return Ok(create_error_response(
177                StatusCode::LOOP_DETECTED,
178                "Loop Detected",
179            ));
180        }
181
182        // Rewrite URI
183        let mut u = if let Layer::WebSocket(ws) = &flow.layer {
184            ws.handshake_request.url.clone()
185        } else {
186            Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
187        };
188
189        if u.set_ip_host(addr.ip()).is_ok() {
190            u.set_port(Some(addr.port())).ok();
191            // Ensure scheme is correct (ws/wss)
192            if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
193                u.set_scheme("wss").ok();
194            } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
195                u.set_scheme("ws").ok();
196            }
197            target_url_str = u.to_string();
198        }
199    }
200
201    // Prepare Forward Request
202    let current_req = if let Layer::WebSocket(ws) = &flow.layer {
203        &ws.handshake_request
204    } else {
205        return Ok(create_error_response(
206            StatusCode::INTERNAL_SERVER_ERROR,
207            "Invalid Flow Layer State",
208        ));
209    };
210
211    let mut forward_req_builder = Request::builder()
212        .method(current_req.method.as_str())
213        .uri(target_url_str.as_str())
214        .version(hyper::Version::HTTP_11);
215
216    for (k, v) in current_req.headers.iter() {
217        if let (Ok(name), Ok(val)) = (
218            HeaderName::from_bytes(k.as_bytes()),
219            HeaderValue::from_str(v),
220        ) {
221            forward_req_builder = forward_req_builder.header(name, val);
222        }
223    }
224
225    let forward_req =
226        match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
227            Ok(req) => req,
228            Err(e) => {
229                return Ok(create_error_response(
230                    StatusCode::INTERNAL_SERVER_ERROR,
231                    format!("Failed to build forward request: {}", e),
232                ));
233            }
234        };
235
236    let target_host = forward_req.uri().host().unwrap_or("unknown").to_string();
237    let port = forward_req.uri().port_u16().unwrap_or(
238        if forward_req.uri().scheme_str() == Some("https") {
239            443
240        } else {
241            80
242        },
243    );
244
245    match tokio::time::timeout(
246        std::time::Duration::from_secs(30),
247        connector.send_request(forward_req, &target_host, port, &mut flow),
248    )
249    .await
250    {
251        Ok(Ok(resp)) => {
252            if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
253                let (parts, body) = resp.into_parts();
254                let resp_for_upgrade = Response::from_parts(parts.clone(), body);
255
256                // Spawn upgrade task
257                let on_flow_clone = on_flow.clone();
258                let interceptor_clone = interceptor.clone();
259                let flow_clone = flow.clone(); // Clone flow for the tunnel task
260
261                tokio::task::spawn(async move {
262                    // Add timeout for upgrades
263                    let upgrade_timeout = std::time::Duration::from_secs(10);
264
265                    let client_upgrade =
266                        tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
267                    let server_upgrade =
268                        tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
269
270                    match tokio::try_join!(client_upgrade, server_upgrade) {
271                        Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
272                            // Start WebSocket Tunnel
273                            if let Err(e) = handle_websocket_tunnel(
274                                upgraded_client,
275                                upgraded_server,
276                                flow_clone,
277                                on_flow_clone,
278                                interceptor_clone,
279                            )
280                            .await
281                            {
282                                tracing::error!("WebSocket Tunnel Error: {}", e);
283                            }
284                        }
285                        Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
286                        Ok((_, Err(e))) => {
287                            tracing::error!("Upstream WebSocket Upgrade Error: {}", e)
288                        }
289                        Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
290                    }
291                });
292
293                let mut client_resp_builder = Response::builder()
294                    .status(StatusCode::SWITCHING_PROTOCOLS)
295                    .version(parts.version);
296
297                for (k, v) in parts.headers.iter() {
298                    client_resp_builder = client_resp_builder.header(k, v);
299                }
300
301                let client_resp = match client_resp_builder
302                    .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
303                {
304                    Ok(r) => r,
305                    Err(e) => {
306                        tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
307                        return Ok(create_error_response(
308                            StatusCode::INTERNAL_SERVER_ERROR,
309                            "Response build failed",
310                        ));
311                    }
312                };
313
314                // Update handshake response in flow (but we don't send update here as flow is moved/cloned)
315                // Ideally we should send an update here for the handshake response
316                // But `flow` variable is local.
317                // We can construct a partial update? Or just ignore.
318                // The tunnel will send messages.
319
320                Ok(client_resp)
321            } else {
322                // Normal HTTP Response (Handshake failed)
323                let (parts, body) = resp.into_parts();
324                let body_bytes = match body.collect().await {
325                    Ok(c) => c.to_bytes(),
326                    Err(_) => Bytes::new(),
327                };
328
329                // Create HTTP Response for Flow
330                let http_resp = HttpResponse {
331                    status: parts.status.as_u16(),
332                    status_text: parts
333                        .status
334                        .canonical_reason()
335                        .unwrap_or("Unknown")
336                        .to_string(),
337                    version: format!("{:?}", parts.version),
338                    headers: parts
339                        .headers
340                        .iter()
341                        .map(|(k, v)| {
342                            (
343                                k.as_str().to_string(),
344                                String::from_utf8_lossy(v.as_bytes()).to_string(),
345                            )
346                        })
347                        .collect(),
348                    cookies: vec![], // Todo: parse cookies
349                    body: Some(BodyData {
350                        encoding: "utf-8".to_string(),
351                        content: String::from_utf8_lossy(&body_bytes).to_string(),
352                        size: body_bytes.len() as u64,
353                    }),
354                    timing: ResponseTiming {
355                        time_to_first_byte: None,
356                        time_to_last_byte: None,
357                        connect_time_ms: None,
358                        ssl_time_ms: None,
359                    },
360                };
361
362                if let Layer::WebSocket(ws) = &mut flow.layer {
363                    ws.handshake_response = http_resp.clone();
364                    ws.closed = true;
365                }
366
367                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
368                    crate::metrics::inc_flows_dropped();
369                }
370
371                Ok(Response::from_parts(
372                    parts,
373                    Full::new(body_bytes).map_err(|e| e.into()).boxed(),
374                ))
375            }
376        }
377        Ok(Err(e)) => Ok(create_error_response(
378            StatusCode::BAD_GATEWAY,
379            format!("Upstream Handshake Failed: {}", e),
380        )),
381        Err(_) => Ok(create_error_response(
382            StatusCode::GATEWAY_TIMEOUT,
383            "Upstream Handshake Timed Out",
384        )),
385    }
386}
387
388async fn handle_websocket_tunnel(
389    client_io: Upgraded,
390    server_io: Upgraded,
391    mut flow: Flow,
392    on_flow: Sender<FlowUpdate>,
393    interceptor: Arc<dyn Interceptor>,
394) -> Result<(), BoxError> {
395    let client_ws = WebSocketStream::from_raw_socket(
396        TokioIo::new(client_io),
397        tokio_tungstenite::tungstenite::protocol::Role::Server,
398        None,
399    )
400    .await;
401    let server_ws = WebSocketStream::from_raw_socket(
402        TokioIo::new(server_io),
403        tokio_tungstenite::tungstenite::protocol::Role::Client,
404        None,
405    )
406    .await;
407
408    let (mut client_tx, mut client_rx) = client_ws.split();
409    let (mut server_tx, mut server_rx) = server_ws.split();
410
411    // Idle timeout for WebSocket
412    let idle_timeout_duration = std::time::Duration::from_secs(300); // 5 minutes
413
414    interceptor.on_websocket_start(&mut flow).await;
415
416    loop {
417        let event = tokio::time::timeout(idle_timeout_duration, async {
418            tokio::select! {
419                msg = client_rx.next() => (Direction::ClientToServer, msg),
420                msg = server_rx.next() => (Direction::ServerToClient, msg),
421            }
422        })
423        .await;
424
425        match event {
426            Ok((dir, msg_opt)) => {
427                match msg_opt {
428                    Some(Ok(msg)) => {
429                        // Handle Message
430                        let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer
431                        {
432                            (&mut server_tx, &mut client_tx, Direction::ClientToServer)
433                        } else {
434                            (&mut client_tx, &mut server_tx, Direction::ServerToClient)
435                        };
436
437                        if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
438                            match interceptor
439                                .on_websocket_message(&mut flow, ws_msg.clone())
440                                .await
441                            {
442                                Ok(WebSocketMessageAction::Drop) => continue,
443                                Ok(WebSocketMessageAction::Continue(mod_msg)) => {
444                                    let t_msg = flow_msg_to_tungstenite(&mod_msg);
445                                    if let Err(e) = sender.send(t_msg).await {
446                                        interceptor
447                                            .on_websocket_error(&mut flow, &e.to_string())
448                                            .await;
449                                        return Err(e.into());
450                                    }
451
452                                    if on_flow
453                                        .try_send(FlowUpdate::WebSocketMessage {
454                                            flow_id: flow.id.to_string(),
455                                            message: mod_msg,
456                                        })
457                                        .is_err()
458                                    {
459                                        crate::metrics::inc_flows_dropped();
460                                    }
461                                }
462                                Err(e) => {
463                                    tracing::error!("WebSocket Interception Error: {}", e);
464                                    sender.send(msg).await?;
465
466                                    if on_flow
467                                        .try_send(FlowUpdate::WebSocketMessage {
468                                            flow_id: flow.id.to_string(),
469                                            message: ws_msg,
470                                        })
471                                        .is_err()
472                                    {
473                                        crate::metrics::inc_flows_dropped();
474                                    }
475                                }
476                            }
477                        } else {
478                            // Non-data message
479                            if let Err(e) = sender.send(msg).await {
480                                interceptor
481                                    .on_websocket_error(&mut flow, &e.to_string())
482                                    .await;
483                                return Err(e.into());
484                            }
485                        }
486                    }
487                    Some(Err(e)) => {
488                        interceptor
489                            .on_websocket_error(&mut flow, &e.to_string())
490                            .await;
491                        return Err(e.into());
492                    }
493                    None => {
494                        interceptor
495                            .on_websocket_end(&mut flow, 1000, "normal")
496                            .await;
497                        break;
498                    }
499                }
500            }
501            Err(_) => {
502                tracing::warn!("WebSocket Tunnel Idle Timeout");
503                interceptor
504                    .on_websocket_error(&mut flow, "WebSocket Idle Timeout")
505                    .await;
506                return Err("WebSocket Idle Timeout".into());
507            }
508        }
509    }
510
511    Ok(())
512}
513
514fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
515    let (opcode, content, encoding, size) = match msg {
516        Message::Text(t) => {
517            let len = t.len();
518            ("Text", t.to_string(), "utf-8", len)
519        }
520        Message::Binary(b) => {
521            let len = b.len();
522            ("Binary", BASE64.encode(&b), "base64", len)
523        }
524        Message::Ping(b) => {
525            let len = b.len();
526            ("Ping", BASE64.encode(&b), "base64", len)
527        }
528        Message::Pong(b) => {
529            let len = b.len();
530            ("Pong", BASE64.encode(&b), "base64", len)
531        }
532        Message::Close(_) => ("Close", String::new(), "none", 0),
533        Message::Frame(_) => return None,
534    };
535
536    Some(WebSocketMessage {
537        id: Uuid::new_v4(),
538        timestamp: Utc::now(),
539        direction: dir,
540        content: BodyData {
541            encoding: encoding.to_string(),
542            content,
543            size: size as u64,
544        },
545        opcode: opcode.to_string(),
546    })
547}
548
549fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
550    match msg.opcode.as_str() {
551        "Text" => Message::Text(msg.content.content.clone().into()),
552        "Binary" => {
553            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
554                Message::Binary(Bytes::from(b))
555            } else {
556                Message::Binary(Bytes::new())
557            }
558        }
559        "Ping" => {
560            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
561                Message::Ping(Bytes::from(b))
562            } else {
563                Message::Ping(Bytes::new())
564            }
565        }
566        "Pong" => {
567            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
568                Message::Pong(Bytes::from(b))
569            } else {
570                Message::Pong(Bytes::new())
571            }
572        }
573        "Close" => Message::Close(None),
574        _ => Message::Text(msg.content.content.clone().into()),
575    }
576}
577
578#[cfg(test)]
579mod websocket_tests {
580    use super::*;
581    use http_body_util::Empty;
582
583    #[test]
584    fn test_validate_ws_strict_handshake_rejects_missing_key() {
585        let policy = ProxyPolicy {
586            strict_http_semantics: true,
587            ..Default::default()
588        };
589        let req = Request::builder()
590            .method("GET")
591            .uri("ws://example.com/socket")
592            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
593            .body(Empty::<Bytes>::new())
594            .expect("request build");
595        let result = validate_ws_strict_handshake(&req, &policy);
596        assert!(result.is_err());
597    }
598
599    #[test]
600    fn test_validate_ws_strict_handshake_rejects_invalid_version() {
601        let policy = ProxyPolicy {
602            strict_http_semantics: true,
603            ..Default::default()
604        };
605        let req = Request::builder()
606            .method("GET")
607            .uri("ws://example.com/socket")
608            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
609            .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
610            .body(Empty::<Bytes>::new())
611            .expect("request build");
612        let result = validate_ws_strict_handshake(&req, &policy);
613        assert!(result.is_err());
614    }
615
616    #[test]
617    fn test_validate_ws_strict_handshake_accepts_valid_request() {
618        let policy = ProxyPolicy {
619            strict_http_semantics: true,
620            ..Default::default()
621        };
622        let req = Request::builder()
623            .method("GET")
624            .uri("ws://example.com/socket")
625            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
626            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
627            .body(Empty::<Bytes>::new())
628            .expect("request build");
629        let result = validate_ws_strict_handshake(&req, &policy);
630        assert!(result.is_ok());
631    }
632
633    #[test]
634    fn test_validate_ws_strict_handshake_skips_when_disabled() {
635        let policy = ProxyPolicy {
636            strict_http_semantics: false,
637            ..Default::default()
638        };
639        let req = Request::builder()
640            .method("GET")
641            .uri("ws://example.com/socket")
642            .body(Empty::<Bytes>::new())
643            .expect("request build");
644        let result = validate_ws_strict_handshake(&req, &policy);
645        assert!(result.is_ok());
646    }
647}