Skip to main content

relay_core_lib/proxy/
websocket.rs

1use crate::capture::loop_detection::LoopDetector;
2use crate::intercept::types::{
3    BoxError, HttpBody, InterceptionResult, Interceptor, RequestAction, WebSocketMessageAction,
4};
5use crate::proxy::http_utils::{
6    HttpsClient, create_error_response, create_initial_flow, mock_to_response, parse_request_meta,
7};
8use chrono::Utc;
9use data_encoding::BASE64;
10use futures_util::{SinkExt, StreamExt};
11use http_body_util::{BodyExt, Full};
12use hyper::body::Bytes;
13use hyper::header::{HeaderName, HeaderValue};
14use hyper::upgrade::Upgraded;
15use hyper::{Request, Response, StatusCode};
16use relay_core_api::flow::{
17    BodyData, Direction, Flow, FlowUpdate, HttpResponse, Layer, WebSocketMessage,
18};
19use relay_core_api::policy::ProxyPolicy;
20use std::convert::Infallible;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use tokio::sync::mpsc::Sender;
24use tokio_tungstenite::WebSocketStream;
25use tokio_tungstenite::tungstenite::protocol::Message;
26use url::Url;
27use uuid::Uuid;
28
29use hyper_util::rt::TokioIo;
30use relay_core_api::flow::ResponseTiming;
31use tokio::sync::watch;
32
33fn validate_ws_strict_handshake<B>(
34    req: &Request<B>,
35    policy: &ProxyPolicy,
36) -> Result<(), Box<Response<HttpBody>>> {
37    if !policy.strict_http_semantics {
38        return Ok(());
39    }
40
41    if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
42        return Err(Box::new(create_error_response(
43            StatusCode::BAD_REQUEST,
44            "Missing Sec-WebSocket-Key header in Strict Mode",
45        )));
46    }
47
48    if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
49        if v != "13" {
50            return Err(Box::new(create_error_response(
51                StatusCode::BAD_REQUEST,
52                "Unsupported WebSocket Version in Strict Mode (Expected 13)",
53            )));
54        }
55    } else {
56        return Err(Box::new(create_error_response(
57            StatusCode::BAD_REQUEST,
58            "Missing Sec-WebSocket-Version header in Strict Mode",
59        )));
60    }
61
62    Ok(())
63}
64
65#[allow(clippy::too_many_arguments)]
66pub async fn handle_websocket_handshake<B>(
67    req: Request<B>,
68    client_addr: SocketAddr,
69    on_flow: Sender<FlowUpdate>,
70    client: Arc<HttpsClient>,
71    interceptor: Arc<dyn Interceptor>,
72    is_mitm: bool,
73    policy_rx: watch::Receiver<ProxyPolicy>,
74    target_addr: Option<SocketAddr>,
75    loop_detector: Arc<LoopDetector>,
76) -> Result<Response<HttpBody>, Infallible>
77where
78    B: hyper::body::Body + Send + 'static,
79    B::Data: Send,
80    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
81{
82    // 2. Parse L7 (HTTP) - Re-extract for WS
83    let meta = parse_request_meta(&req, is_mitm);
84    let policy = policy_rx.borrow().clone();
85
86    // STRICT MODE CHECKS
87    if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
88        return Ok(*resp);
89    }
90
91    // Create Initial Flow for WebSocket Handshake
92    let mut flow = create_initial_flow(meta.clone(), None, client_addr, is_mitm, true);
93
94    // INTERCEPT HEADERS (Handshake)
95    match interceptor.on_request_headers(&mut flow).await {
96        InterceptionResult::Drop => {
97            return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
98        }
99        InterceptionResult::MockResponse(mock) => {
100            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
101                crate::metrics::inc_flows_dropped();
102            }
103            return Ok(mock_to_response(mock));
104        }
105        InterceptionResult::ModifiedRequest(req) => {
106            if let Layer::WebSocket(ws) = &mut flow.layer {
107                ws.handshake_request = req;
108            }
109        }
110        InterceptionResult::ModifiedResponse(res) => {
111            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
112                crate::metrics::inc_flows_dropped();
113            }
114            return Ok(mock_to_response(res));
115        }
116        _ => {}
117    }
118
119    // INTERCEPT REQUEST (Handshake Full)
120    // WS Handshake has empty body
121    let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
122
123    match interceptor.on_request(&mut flow, body).await {
124        Ok(RequestAction::Drop) => {
125            return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
126        }
127        Ok(RequestAction::MockResponse(res)) => {
128            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
129                crate::metrics::inc_flows_dropped();
130            }
131            let (parts, body) = res.into_parts();
132            return Ok(Response::from_parts(parts, body));
133        }
134        Ok(RequestAction::Continue(_)) => {}
135        Err(e) => {
136            return Ok(create_error_response(
137                StatusCode::INTERNAL_SERVER_ERROR,
138                format!("Interceptor Error: {}", e),
139            ));
140        }
141    }
142
143    if on_flow
144        .try_send(FlowUpdate::Full(Box::new(flow.clone())))
145        .is_err()
146    {
147        crate::metrics::inc_flows_dropped();
148    }
149
150    // Prepare Upgrade
151    let (parts, body) = req.into_parts();
152    let req_for_upgrade = Request::from_parts(parts, body);
153
154    // Determine Target URL
155    let mut target_url_str = meta.url_str.clone();
156
157    if policy.transparent_enabled
158        && let Some(addr) = target_addr
159    {
160        flow.tags.push("transparent".to_string());
161
162        // Update Flow Network Info
163        flow.network.server_ip = addr.ip().to_string();
164        flow.network.server_port = addr.port();
165
166        // Loop Detection
167        if loop_detector.would_loop(addr) {
168            if let Layer::WebSocket(ws) = &mut flow.layer {
169                ws.handshake_response.status = 508;
170                ws.closed = true;
171            }
172            if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
173                crate::metrics::inc_flows_dropped();
174            }
175            return Ok(create_error_response(
176                StatusCode::LOOP_DETECTED,
177                "Loop Detected",
178            ));
179        }
180
181        // Rewrite URI
182        let mut u = if let Layer::WebSocket(ws) = &flow.layer {
183            ws.handshake_request.url.clone()
184        } else {
185            Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
186        };
187
188        if u.set_ip_host(addr.ip()).is_ok() {
189            u.set_port(Some(addr.port())).ok();
190            // Ensure scheme is correct (ws/wss)
191            if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
192                u.set_scheme("wss").ok();
193            } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
194                u.set_scheme("ws").ok();
195            }
196            target_url_str = u.to_string();
197        }
198    }
199
200    // Prepare Forward Request
201    let current_req = if let Layer::WebSocket(ws) = &flow.layer {
202        &ws.handshake_request
203    } else {
204        return Ok(create_error_response(
205            StatusCode::INTERNAL_SERVER_ERROR,
206            "Invalid Flow Layer State",
207        ));
208    };
209
210    let mut forward_req_builder = Request::builder()
211        .method(current_req.method.as_str())
212        .uri(target_url_str.as_str())
213        .version(hyper::Version::HTTP_11);
214
215    for (k, v) in current_req.headers.iter() {
216        if let (Ok(name), Ok(val)) = (
217            HeaderName::from_bytes(k.as_bytes()),
218            HeaderValue::from_str(v),
219        ) {
220            forward_req_builder = forward_req_builder.header(name, val);
221        }
222    }
223
224    let forward_req =
225        match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
226            Ok(req) => req,
227            Err(e) => {
228                return Ok(create_error_response(
229                    StatusCode::INTERNAL_SERVER_ERROR,
230                    format!("Failed to build forward request: {}", e),
231                ));
232            }
233        };
234
235    match tokio::time::timeout(
236        std::time::Duration::from_secs(30),
237        client.request(forward_req),
238    )
239    .await
240    {
241        Ok(Ok(resp)) => {
242            if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
243                let (parts, body) = resp.into_parts();
244                let resp_for_upgrade = Response::from_parts(parts.clone(), body);
245
246                // Spawn upgrade task
247                let on_flow_clone = on_flow.clone();
248                let interceptor_clone = interceptor.clone();
249                let flow_clone = flow.clone(); // Clone flow for the tunnel task
250
251                tokio::task::spawn(async move {
252                    // Add timeout for upgrades
253                    let upgrade_timeout = std::time::Duration::from_secs(10);
254
255                    let client_upgrade =
256                        tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
257                    let server_upgrade =
258                        tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
259
260                    match tokio::try_join!(client_upgrade, server_upgrade) {
261                        Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
262                            // Start WebSocket Tunnel
263                            if let Err(e) = handle_websocket_tunnel(
264                                upgraded_client,
265                                upgraded_server,
266                                flow_clone,
267                                on_flow_clone,
268                                interceptor_clone,
269                            )
270                            .await
271                            {
272                                tracing::error!("WebSocket Tunnel Error: {}", e);
273                            }
274                        }
275                        Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
276                        Ok((_, Err(e))) => {
277                            tracing::error!("Upstream WebSocket Upgrade Error: {}", e)
278                        }
279                        Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
280                    }
281                });
282
283                let mut client_resp_builder = Response::builder()
284                    .status(StatusCode::SWITCHING_PROTOCOLS)
285                    .version(parts.version);
286
287                for (k, v) in parts.headers.iter() {
288                    client_resp_builder = client_resp_builder.header(k, v);
289                }
290
291                let client_resp = match client_resp_builder
292                    .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
293                {
294                    Ok(r) => r,
295                    Err(e) => {
296                        tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
297                        return Ok(create_error_response(
298                            StatusCode::INTERNAL_SERVER_ERROR,
299                            "Response build failed",
300                        ));
301                    }
302                };
303
304                // Update handshake response in flow (but we don't send update here as flow is moved/cloned)
305                // Ideally we should send an update here for the handshake response
306                // But `flow` variable is local.
307                // We can construct a partial update? Or just ignore.
308                // The tunnel will send messages.
309
310                Ok(client_resp)
311            } else {
312                // Normal HTTP Response (Handshake failed)
313                let (parts, body) = resp.into_parts();
314                let body_bytes = match body.collect().await {
315                    Ok(c) => c.to_bytes(),
316                    Err(_) => Bytes::new(),
317                };
318
319                // Create HTTP Response for Flow
320                let http_resp = HttpResponse {
321                    status: parts.status.as_u16(),
322                    status_text: parts
323                        .status
324                        .canonical_reason()
325                        .unwrap_or("Unknown")
326                        .to_string(),
327                    version: format!("{:?}", parts.version),
328                    headers: parts
329                        .headers
330                        .iter()
331                        .map(|(k, v)| {
332                            (k.as_str().to_string(), v.to_str().unwrap_or("").to_string())
333                        })
334                        .collect(),
335                    cookies: vec![], // Todo: parse cookies
336                    body: Some(BodyData {
337                        encoding: "utf-8".to_string(),
338                        content: String::from_utf8_lossy(&body_bytes).to_string(),
339                        size: body_bytes.len() as u64,
340                    }),
341                    timing: ResponseTiming {
342                        time_to_first_byte: None,
343                        time_to_last_byte: None,
344                        connect_time_ms: None,
345                        ssl_time_ms: None,
346                    },
347                };
348
349                if let Layer::WebSocket(ws) = &mut flow.layer {
350                    ws.handshake_response = http_resp.clone();
351                    ws.closed = true;
352                }
353
354                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
355                    crate::metrics::inc_flows_dropped();
356                }
357
358                Ok(Response::from_parts(
359                    parts,
360                    Full::new(body_bytes).map_err(|e| e.into()).boxed(),
361                ))
362            }
363        }
364        Ok(Err(e)) => Ok(create_error_response(
365            StatusCode::BAD_GATEWAY,
366            format!("Upstream Handshake Failed: {}", e),
367        )),
368        Err(_) => Ok(create_error_response(
369            StatusCode::GATEWAY_TIMEOUT,
370            "Upstream Handshake Timed Out",
371        )),
372    }
373}
374
375async fn handle_websocket_tunnel(
376    client_io: Upgraded,
377    server_io: Upgraded,
378    mut flow: Flow,
379    on_flow: Sender<FlowUpdate>,
380    interceptor: Arc<dyn Interceptor>,
381) -> Result<(), BoxError> {
382    let client_ws = WebSocketStream::from_raw_socket(
383        TokioIo::new(client_io),
384        tokio_tungstenite::tungstenite::protocol::Role::Server,
385        None,
386    )
387    .await;
388    let server_ws = WebSocketStream::from_raw_socket(
389        TokioIo::new(server_io),
390        tokio_tungstenite::tungstenite::protocol::Role::Client,
391        None,
392    )
393    .await;
394
395    let (mut client_tx, mut client_rx) = client_ws.split();
396    let (mut server_tx, mut server_rx) = server_ws.split();
397
398    // Idle timeout for WebSocket
399    let idle_timeout_duration = std::time::Duration::from_secs(300); // 5 minutes
400
401    loop {
402        let event = tokio::time::timeout(idle_timeout_duration, async {
403            tokio::select! {
404                msg = client_rx.next() => (Direction::ClientToServer, msg),
405                msg = server_rx.next() => (Direction::ServerToClient, msg),
406            }
407        })
408        .await;
409
410        match event {
411            Ok((dir, msg_opt)) => {
412                match msg_opt {
413                    Some(Ok(msg)) => {
414                        // Handle Message
415                        let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer
416                        {
417                            (&mut server_tx, &mut client_tx, Direction::ClientToServer)
418                        } else {
419                            (&mut client_tx, &mut server_tx, Direction::ServerToClient)
420                        };
421
422                        if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
423                            match interceptor
424                                .on_websocket_message(&mut flow, ws_msg.clone())
425                                .await
426                            {
427                                Ok(WebSocketMessageAction::Drop) => continue,
428                                Ok(WebSocketMessageAction::Continue(mod_msg)) => {
429                                    let t_msg = flow_msg_to_tungstenite(&mod_msg);
430                                    sender.send(t_msg).await?;
431
432                                    if on_flow
433                                        .try_send(FlowUpdate::WebSocketMessage {
434                                            flow_id: flow.id.to_string(),
435                                            message: mod_msg,
436                                        })
437                                        .is_err()
438                                    {
439                                        crate::metrics::inc_flows_dropped();
440                                    }
441                                }
442                                Err(e) => {
443                                    tracing::error!("WebSocket Interception Error: {}", e);
444                                    sender.send(msg).await?;
445
446                                    if on_flow
447                                        .try_send(FlowUpdate::WebSocketMessage {
448                                            flow_id: flow.id.to_string(),
449                                            message: ws_msg,
450                                        })
451                                        .is_err()
452                                    {
453                                        crate::metrics::inc_flows_dropped();
454                                    }
455                                }
456                            }
457                        } else {
458                            // Non-data message
459                            sender.send(msg).await?;
460                        }
461                    }
462                    Some(Err(e)) => return Err(e.into()),
463                    None => break, // Connection closed
464                }
465            }
466            Err(_) => {
467                tracing::warn!("WebSocket Tunnel Idle Timeout");
468                // Optional: Send close frame?
469                return Err("WebSocket Idle Timeout".into());
470            }
471        }
472    }
473
474    Ok(())
475}
476
477fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
478    let (opcode, content, encoding, size) = match msg {
479        Message::Text(t) => {
480            let len = t.len();
481            ("Text", t.to_string(), "utf-8", len)
482        }
483        Message::Binary(b) => {
484            let len = b.len();
485            ("Binary", BASE64.encode(&b), "base64", len)
486        }
487        Message::Ping(b) => {
488            let len = b.len();
489            ("Ping", BASE64.encode(&b), "base64", len)
490        }
491        Message::Pong(b) => {
492            let len = b.len();
493            ("Pong", BASE64.encode(&b), "base64", len)
494        }
495        Message::Close(_) => ("Close", String::new(), "none", 0),
496        Message::Frame(_) => return None,
497    };
498
499    Some(WebSocketMessage {
500        id: Uuid::new_v4(),
501        timestamp: Utc::now(),
502        direction: dir,
503        content: BodyData {
504            encoding: encoding.to_string(),
505            content,
506            size: size as u64,
507        },
508        opcode: opcode.to_string(),
509    })
510}
511
512fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
513    match msg.opcode.as_str() {
514        "Text" => Message::Text(msg.content.content.clone().into()),
515        "Binary" => {
516            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
517                Message::Binary(Bytes::from(b))
518            } else {
519                Message::Binary(Bytes::new())
520            }
521        }
522        "Ping" => {
523            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
524                Message::Ping(Bytes::from(b))
525            } else {
526                Message::Ping(Bytes::new())
527            }
528        }
529        "Pong" => {
530            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
531                Message::Pong(Bytes::from(b))
532            } else {
533                Message::Pong(Bytes::new())
534            }
535        }
536        "Close" => Message::Close(None),
537        _ => Message::Text(msg.content.content.clone().into()),
538    }
539}
540
541#[cfg(test)]
542mod websocket_tests {
543    use super::*;
544    use http_body_util::Empty;
545
546    #[test]
547    fn test_validate_ws_strict_handshake_rejects_missing_key() {
548        let policy = ProxyPolicy {
549            strict_http_semantics: true,
550            ..Default::default()
551        };
552        let req = Request::builder()
553            .method("GET")
554            .uri("ws://example.com/socket")
555            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
556            .body(Empty::<Bytes>::new())
557            .expect("request build");
558        let result = validate_ws_strict_handshake(&req, &policy);
559        assert!(result.is_err());
560    }
561
562    #[test]
563    fn test_validate_ws_strict_handshake_rejects_invalid_version() {
564        let policy = ProxyPolicy {
565            strict_http_semantics: true,
566            ..Default::default()
567        };
568        let req = Request::builder()
569            .method("GET")
570            .uri("ws://example.com/socket")
571            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
572            .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
573            .body(Empty::<Bytes>::new())
574            .expect("request build");
575        let result = validate_ws_strict_handshake(&req, &policy);
576        assert!(result.is_err());
577    }
578
579    #[test]
580    fn test_validate_ws_strict_handshake_accepts_valid_request() {
581        let policy = ProxyPolicy {
582            strict_http_semantics: true,
583            ..Default::default()
584        };
585        let req = Request::builder()
586            .method("GET")
587            .uri("ws://example.com/socket")
588            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
589            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
590            .body(Empty::<Bytes>::new())
591            .expect("request build");
592        let result = validate_ws_strict_handshake(&req, &policy);
593        assert!(result.is_ok());
594    }
595
596    #[test]
597    fn test_validate_ws_strict_handshake_skips_when_disabled() {
598        let policy = ProxyPolicy {
599            strict_http_semantics: false,
600            ..Default::default()
601        };
602        let req = Request::builder()
603            .method("GET")
604            .uri("ws://example.com/socket")
605            .body(Empty::<Bytes>::new())
606            .expect("request build");
607        let result = validate_ws_strict_handshake(&req, &policy);
608        assert!(result.is_ok());
609    }
610}