Skip to main content

relay_core_lib/proxy/
websocket.rs

1use crate::capture::loop_detection::LoopDetector;
2use crate::intercept::types::{
3    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, WebSocketMessageAction,
4};
5use crate::proxy::http_utils::{
6    HttpsClient, create_error_response, create_initial_flow, mock_to_response, parse_request_meta,
7};
8use chrono::Utc;
9use data_encoding::BASE64;
10use futures_util::{SinkExt, StreamExt};
11use http_body_util::{BodyExt, Full};
12use hyper::body::Bytes;
13use hyper::header::{HeaderName, HeaderValue};
14use hyper::upgrade::Upgraded;
15use hyper::{Request, Response, StatusCode};
16use relay_core_api::flow::{
17    BodyData, Direction, Flow, FlowUpdate, HttpResponse, Layer, WebSocketMessage,
18};
19use relay_core_api::policy::ProxyPolicy;
20use std::convert::Infallible;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use tokio::sync::mpsc::Sender;
24use tokio_tungstenite::WebSocketStream;
25use tokio_tungstenite::tungstenite::protocol::Message;
26use url::Url;
27use uuid::Uuid;
28
29use hyper_util::rt::TokioIo;
30use relay_core_api::flow::ResponseTiming;
31use tokio::sync::watch;
32
33fn validate_ws_strict_handshake<B>(
34    req: &Request<B>,
35    policy: &ProxyPolicy,
36) -> Result<(), Box<Response<HttpBody>>> {
37    if !policy.strict_http_semantics {
38        return Ok(());
39    }
40
41    if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
42        return Err(Box::new(create_error_response(
43            StatusCode::BAD_REQUEST,
44            "Missing Sec-WebSocket-Key header in Strict Mode",
45        )));
46    }
47
48    if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
49        if v != "13" {
50            return Err(Box::new(create_error_response(
51                StatusCode::BAD_REQUEST,
52                "Unsupported WebSocket Version in Strict Mode (Expected 13)",
53            )));
54        }
55    } else {
56        return Err(Box::new(create_error_response(
57            StatusCode::BAD_REQUEST,
58            "Missing Sec-WebSocket-Version header in Strict Mode",
59        )));
60    }
61
62    Ok(())
63}
64
65#[allow(clippy::too_many_arguments)]
66pub async fn handle_websocket_handshake<B>(
67    req: Request<B>,
68    client_addr: SocketAddr,
69    on_flow: Sender<FlowUpdate>,
70    client: Arc<HttpsClient>,
71    interceptor: Arc<dyn Interceptor>,
72    is_mitm: bool,
73    policy_rx: watch::Receiver<ProxyPolicy>,
74    target_addr: Option<SocketAddr>,
75    loop_detector: Arc<LoopDetector>,
76) -> Result<Response<HttpBody>, Infallible>
77where
78    B: hyper::body::Body + Send + 'static,
79    B::Data: Send,
80    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
81{
82    // 2. Parse L7 (HTTP) - Re-extract for WS
83    let meta = parse_request_meta(&req, is_mitm);
84    let policy = policy_rx.borrow().clone();
85
86    // STRICT MODE CHECKS
87    if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
88        return Ok(*resp);
89    }
90
91    // Create Initial Flow for WebSocket Handshake
92    let mut flow = create_initial_flow(meta.clone(), None, client_addr, is_mitm, true);
93
94    // INTERCEPT HEADERS (Handshake)
95    match interceptor.on_request_headers(&mut flow).await {
96        InterceptionResult::Drop => {
97            return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
98        }
99        InterceptionResult::MockResponse(mock) => {
100            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
101                crate::metrics::inc_flows_dropped();
102            }
103            return Ok(mock_to_response(mock));
104        }
105        InterceptionResult::ModifiedRequest(req) => {
106            if let Layer::WebSocket(ws) = &mut flow.layer {
107                ws.handshake_request = req;
108            }
109        }
110        InterceptionResult::ModifiedResponse(res) => {
111            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
112                crate::metrics::inc_flows_dropped();
113            }
114            return Ok(mock_to_response(res));
115        }
116        _ => {}
117    }
118
119    // INTERCEPT REQUEST (Handshake Full)
120    // WS Handshake has empty body
121    let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
122
123    match interceptor.on_request(&mut flow, body).await {
124        Ok(RequestAction::Drop) => {
125            return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
126        }
127        Ok(RequestAction::MockResponse(res)) => {
128            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
129                crate::metrics::inc_flows_dropped();
130            }
131            let (parts, body) = res.into_parts();
132            return Ok(Response::from_parts(parts, body));
133        }
134        Ok(RequestAction::Continue(_)) => {}
135        Err(e) => {
136            return Ok(create_error_response(
137                StatusCode::INTERNAL_SERVER_ERROR,
138                format!("Interceptor Error: {}", e),
139            ));
140        }
141    }
142
143    if on_flow
144        .try_send(FlowUpdate::Full(Box::new(flow.clone())))
145        .is_err()
146    {
147        crate::metrics::inc_flows_dropped();
148    }
149
150    // Prepare Upgrade
151    let (parts, body) = req.into_parts();
152    let req_for_upgrade = Request::from_parts(parts, body);
153
154    // Determine Target URL
155    let mut target_url_str = meta.url_str.clone();
156
157    if policy.transparent_enabled
158        && let Some(addr) = target_addr
159    {
160        flow.tags.push("transparent".to_string());
161
162        // Update Flow Network Info
163        flow.network.server_ip = addr.ip().to_string();
164        flow.network.server_port = addr.port();
165
166        // Loop Detection
167        if loop_detector.would_loop(addr) {
168            if let Layer::WebSocket(ws) = &mut flow.layer {
169                ws.handshake_response.status = 508;
170                ws.closed = true;
171            }
172            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
173                crate::metrics::inc_flows_dropped();
174            }
175            return Ok(create_error_response(
176                StatusCode::LOOP_DETECTED,
177                "Loop Detected",
178            ));
179        }
180
181        // Rewrite URI
182        let mut u = if let Layer::WebSocket(ws) = &flow.layer {
183            ws.handshake_request.url.clone()
184        } else {
185            Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
186        };
187
188        if u.set_ip_host(addr.ip()).is_ok() {
189            u.set_port(Some(addr.port())).ok();
190            // Ensure scheme is correct (ws/wss)
191            if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
192                u.set_scheme("wss").ok();
193            } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
194                u.set_scheme("ws").ok();
195            }
196            target_url_str = u.to_string();
197        }
198    }
199
200    // Prepare Forward Request
201    let current_req = if let Layer::WebSocket(ws) = &flow.layer {
202        &ws.handshake_request
203    } else {
204        return Ok(create_error_response(
205            StatusCode::INTERNAL_SERVER_ERROR,
206            "Invalid Flow Layer State",
207        ));
208    };
209
210    let mut forward_req_builder = Request::builder()
211        .method(current_req.method.as_str())
212        .uri(target_url_str.as_str())
213        .version(hyper::Version::HTTP_11);
214
215    for (k, v) in current_req.headers.iter() {
216        if let (Ok(name), Ok(val)) = (
217            HeaderName::from_bytes(k.as_bytes()),
218            HeaderValue::from_str(v),
219        ) {
220            forward_req_builder = forward_req_builder.header(name, val);
221        }
222    }
223
224    let forward_req =
225        match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
226            Ok(req) => req,
227            Err(e) => {
228                return Ok(create_error_response(
229                    StatusCode::INTERNAL_SERVER_ERROR,
230                    format!("Failed to build forward request: {}", e),
231                ));
232            }
233        };
234
235    match tokio::time::timeout(
236        std::time::Duration::from_secs(30),
237        client.request(forward_req),
238    )
239    .await
240    {
241        Ok(Ok(resp)) => {
242            if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
243                let (parts, body) = resp.into_parts();
244                let resp_for_upgrade = Response::from_parts(parts.clone(), body);
245
246                // Spawn upgrade task
247                let on_flow_clone = on_flow.clone();
248                let interceptor_clone = interceptor.clone();
249                let flow_clone = flow.clone(); // Clone flow for the tunnel task
250
251                tokio::task::spawn(async move {
252                    // Add timeout for upgrades
253                    let upgrade_timeout = std::time::Duration::from_secs(10);
254
255                    let client_upgrade =
256                        tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
257                    let server_upgrade =
258                        tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
259
260                    match tokio::try_join!(client_upgrade, server_upgrade) {
261                        Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
262                            // Start WebSocket Tunnel
263                            if let Err(e) = handle_websocket_tunnel(
264                                upgraded_client,
265                                upgraded_server,
266                                flow_clone,
267                                on_flow_clone,
268                                interceptor_clone,
269                            )
270                            .await
271                            {
272                                tracing::error!("WebSocket Tunnel Error: {}", e);
273                            }
274                        }
275                        Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
276                        Ok((_, Err(e))) => {
277                            tracing::error!("Upstream WebSocket Upgrade Error: {}", e)
278                        }
279                        Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
280                    }
281                });
282
283                let mut client_resp_builder = Response::builder()
284                    .status(StatusCode::SWITCHING_PROTOCOLS)
285                    .version(parts.version);
286
287                for (k, v) in parts.headers.iter() {
288                    client_resp_builder = client_resp_builder.header(k, v);
289                }
290
291                let client_resp = match client_resp_builder
292                    .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
293                {
294                    Ok(r) => r,
295                    Err(e) => {
296                        tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
297                        return Ok(create_error_response(
298                            StatusCode::INTERNAL_SERVER_ERROR,
299                            "Response build failed",
300                        ));
301                    }
302                };
303
304                // Update handshake response in flow (but we don't send update here as flow is moved/cloned)
305                // Ideally we should send an update here for the handshake response
306                // But `flow` variable is local.
307                // We can construct a partial update? Or just ignore.
308                // The tunnel will send messages.
309
310                Ok(client_resp)
311            } else {
312                // Normal HTTP Response (Handshake failed)
313                let (parts, body) = resp.into_parts();
314                let body_bytes = match body.collect().await {
315                    Ok(c) => c.to_bytes(),
316                    Err(_) => Bytes::new(),
317                };
318
319                // Create HTTP Response for Flow
320                let http_resp = HttpResponse {
321                    status: parts.status.as_u16(),
322                    status_text: parts
323                        .status
324                        .canonical_reason()
325                        .unwrap_or("Unknown")
326                        .to_string(),
327                    version: format!("{:?}", parts.version),
328                    headers: parts
329                        .headers
330                        .iter()
331                        .map(|(k, v)| {
332                            (
333                                k.as_str().to_string(),
334                                String::from_utf8_lossy(v.as_bytes()).to_string(),
335                            )
336                        })
337                        .collect(),
338                    cookies: vec![], // Todo: parse cookies
339                    body: Some(BodyData {
340                        encoding: "utf-8".to_string(),
341                        content: String::from_utf8_lossy(&body_bytes).to_string(),
342                        size: body_bytes.len() as u64,
343                    }),
344                    timing: ResponseTiming {
345                        time_to_first_byte: None,
346                        time_to_last_byte: None,
347                        connect_time_ms: None,
348                        ssl_time_ms: None,
349                    },
350                };
351
352                if let Layer::WebSocket(ws) = &mut flow.layer {
353                    ws.handshake_response = http_resp.clone();
354                    ws.closed = true;
355                }
356
357                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
358                    crate::metrics::inc_flows_dropped();
359                }
360
361                Ok(Response::from_parts(
362                    parts,
363                    Full::new(body_bytes).map_err(|e| e.into()).boxed(),
364                ))
365            }
366        }
367        Ok(Err(e)) => Ok(create_error_response(
368            StatusCode::BAD_GATEWAY,
369            format!("Upstream Handshake Failed: {}", e),
370        )),
371        Err(_) => Ok(create_error_response(
372            StatusCode::GATEWAY_TIMEOUT,
373            "Upstream Handshake Timed Out",
374        )),
375    }
376}
377
378async fn handle_websocket_tunnel(
379    client_io: Upgraded,
380    server_io: Upgraded,
381    mut flow: Flow,
382    on_flow: Sender<FlowUpdate>,
383    interceptor: Arc<dyn Interceptor>,
384) -> Result<(), BoxError> {
385    let client_ws = WebSocketStream::from_raw_socket(
386        TokioIo::new(client_io),
387        tokio_tungstenite::tungstenite::protocol::Role::Server,
388        None,
389    )
390    .await;
391    let server_ws = WebSocketStream::from_raw_socket(
392        TokioIo::new(server_io),
393        tokio_tungstenite::tungstenite::protocol::Role::Client,
394        None,
395    )
396    .await;
397
398    let (mut client_tx, mut client_rx) = client_ws.split();
399    let (mut server_tx, mut server_rx) = server_ws.split();
400
401    // Idle timeout for WebSocket
402    let idle_timeout_duration = std::time::Duration::from_secs(300); // 5 minutes
403
404    interceptor.on_websocket_start(&mut flow).await;
405
406    loop {
407        let event = tokio::time::timeout(idle_timeout_duration, async {
408            tokio::select! {
409                msg = client_rx.next() => (Direction::ClientToServer, msg),
410                msg = server_rx.next() => (Direction::ServerToClient, msg),
411            }
412        })
413        .await;
414
415        match event {
416            Ok((dir, msg_opt)) => {
417                match msg_opt {
418                    Some(Ok(msg)) => {
419                        // Handle Message
420                        let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer
421                        {
422                            (&mut server_tx, &mut client_tx, Direction::ClientToServer)
423                        } else {
424                            (&mut client_tx, &mut server_tx, Direction::ServerToClient)
425                        };
426
427                        if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
428                            match interceptor
429                                .on_websocket_message(&mut flow, ws_msg.clone())
430                                .await
431                            {
432                                Ok(WebSocketMessageAction::Drop) => continue,
433                                Ok(WebSocketMessageAction::Continue(mod_msg)) => {
434                                    let t_msg = flow_msg_to_tungstenite(&mod_msg);
435                                    if let Err(e) = sender.send(t_msg).await {
436                                        interceptor
437                                            .on_websocket_error(&mut flow, &e.to_string())
438                                            .await;
439                                        return Err(e.into());
440                                    }
441
442                                    if on_flow
443                                        .try_send(FlowUpdate::WebSocketMessage {
444                                            flow_id: flow.id.to_string(),
445                                            message: mod_msg,
446                                        })
447                                        .is_err()
448                                    {
449                                        crate::metrics::inc_flows_dropped();
450                                    }
451                                }
452                                Err(e) => {
453                                    tracing::error!("WebSocket Interception Error: {}", e);
454                                    sender.send(msg).await?;
455
456                                    if on_flow
457                                        .try_send(FlowUpdate::WebSocketMessage {
458                                            flow_id: flow.id.to_string(),
459                                            message: ws_msg,
460                                        })
461                                        .is_err()
462                                    {
463                                        crate::metrics::inc_flows_dropped();
464                                    }
465                                }
466                            }
467                        } else {
468                            // Non-data message
469                            if let Err(e) = sender.send(msg).await {
470                                interceptor
471                                    .on_websocket_error(&mut flow, &e.to_string())
472                                    .await;
473                                return Err(e.into());
474                            }
475                        }
476                    }
477                    Some(Err(e)) => {
478                        interceptor
479                            .on_websocket_error(&mut flow, &e.to_string())
480                            .await;
481                        return Err(e.into());
482                    }
483                    None => {
484                        interceptor
485                            .on_websocket_end(&mut flow, 1000, "normal")
486                            .await;
487                        break;
488                    }
489                }
490            }
491            Err(_) => {
492                tracing::warn!("WebSocket Tunnel Idle Timeout");
493                interceptor
494                    .on_websocket_error(&mut flow, "WebSocket Idle Timeout")
495                    .await;
496                return Err("WebSocket Idle Timeout".into());
497            }
498        }
499    }
500
501    Ok(())
502}
503
504fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
505    let (opcode, content, encoding, size) = match msg {
506        Message::Text(t) => {
507            let len = t.len();
508            ("Text", t.to_string(), "utf-8", len)
509        }
510        Message::Binary(b) => {
511            let len = b.len();
512            ("Binary", BASE64.encode(&b), "base64", len)
513        }
514        Message::Ping(b) => {
515            let len = b.len();
516            ("Ping", BASE64.encode(&b), "base64", len)
517        }
518        Message::Pong(b) => {
519            let len = b.len();
520            ("Pong", BASE64.encode(&b), "base64", len)
521        }
522        Message::Close(_) => ("Close", String::new(), "none", 0),
523        Message::Frame(_) => return None,
524    };
525
526    Some(WebSocketMessage {
527        id: Uuid::new_v4(),
528        timestamp: Utc::now(),
529        direction: dir,
530        content: BodyData {
531            encoding: encoding.to_string(),
532            content,
533            size: size as u64,
534        },
535        opcode: opcode.to_string(),
536    })
537}
538
539fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
540    match msg.opcode.as_str() {
541        "Text" => Message::Text(msg.content.content.clone().into()),
542        "Binary" => {
543            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
544                Message::Binary(Bytes::from(b))
545            } else {
546                Message::Binary(Bytes::new())
547            }
548        }
549        "Ping" => {
550            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
551                Message::Ping(Bytes::from(b))
552            } else {
553                Message::Ping(Bytes::new())
554            }
555        }
556        "Pong" => {
557            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
558                Message::Pong(Bytes::from(b))
559            } else {
560                Message::Pong(Bytes::new())
561            }
562        }
563        "Close" => Message::Close(None),
564        _ => Message::Text(msg.content.content.clone().into()),
565    }
566}
567
568#[cfg(test)]
569mod websocket_tests {
570    use super::*;
571    use http_body_util::Empty;
572
573    #[test]
574    fn test_validate_ws_strict_handshake_rejects_missing_key() {
575        let policy = ProxyPolicy {
576            strict_http_semantics: true,
577            ..Default::default()
578        };
579        let req = Request::builder()
580            .method("GET")
581            .uri("ws://example.com/socket")
582            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
583            .body(Empty::<Bytes>::new())
584            .expect("request build");
585        let result = validate_ws_strict_handshake(&req, &policy);
586        assert!(result.is_err());
587    }
588
589    #[test]
590    fn test_validate_ws_strict_handshake_rejects_invalid_version() {
591        let policy = ProxyPolicy {
592            strict_http_semantics: true,
593            ..Default::default()
594        };
595        let req = Request::builder()
596            .method("GET")
597            .uri("ws://example.com/socket")
598            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
599            .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
600            .body(Empty::<Bytes>::new())
601            .expect("request build");
602        let result = validate_ws_strict_handshake(&req, &policy);
603        assert!(result.is_err());
604    }
605
606    #[test]
607    fn test_validate_ws_strict_handshake_accepts_valid_request() {
608        let policy = ProxyPolicy {
609            strict_http_semantics: true,
610            ..Default::default()
611        };
612        let req = Request::builder()
613            .method("GET")
614            .uri("ws://example.com/socket")
615            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
616            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
617            .body(Empty::<Bytes>::new())
618            .expect("request build");
619        let result = validate_ws_strict_handshake(&req, &policy);
620        assert!(result.is_ok());
621    }
622
623    #[test]
624    fn test_validate_ws_strict_handshake_skips_when_disabled() {
625        let policy = ProxyPolicy {
626            strict_http_semantics: false,
627            ..Default::default()
628        };
629        let req = Request::builder()
630            .method("GET")
631            .uri("ws://example.com/socket")
632            .body(Empty::<Bytes>::new())
633            .expect("request build");
634        let result = validate_ws_strict_handshake(&req, &policy);
635        assert!(result.is_ok());
636    }
637}