Skip to main content

relay_core_lib/proxy/
websocket.rs

1use tokio::sync::mpsc::Sender;
2use hyper::upgrade::Upgraded;
3use hyper::{Request, Response, StatusCode};
4use hyper::header::{HeaderName, HeaderValue};
5use http_body_util::{Full, BodyExt};
6use tokio_tungstenite::WebSocketStream;
7use tokio_tungstenite::tungstenite::protocol::Message;
8use futures_util::{StreamExt, SinkExt};
9use data_encoding::BASE64;
10use relay_core_api::flow::{Flow, FlowUpdate, WebSocketMessage, Direction, BodyData, Layer, HttpResponse};
11use relay_core_api::policy::ProxyPolicy;
12use crate::intercept::types::{Interceptor, InterceptionResult, RequestAction, WebSocketMessageAction, HttpBody, BoxError};
13use crate::proxy::http_utils::{HttpsClient, parse_request_meta, create_initial_flow, mock_to_response, create_error_response};
14use crate::capture::loop_detection::LoopDetector;
15use std::sync::Arc;
16use std::convert::Infallible;
17use std::net::SocketAddr;
18use uuid::Uuid;
19use chrono::Utc;
20use url::Url;
21use hyper::body::Bytes;
22
23use tokio::sync::watch;
24use hyper_util::rt::TokioIo;
25use relay_core_api::flow::ResponseTiming;
26
27fn validate_ws_strict_handshake<B>(
28    req: &Request<B>,
29    policy: &ProxyPolicy,
30) -> Result<(), Box<Response<HttpBody>>> {
31    if !policy.strict_http_semantics {
32        return Ok(());
33    }
34
35    if !req.headers().contains_key(hyper::header::SEC_WEBSOCKET_KEY) {
36        return Err(Box::new(create_error_response(
37            StatusCode::BAD_REQUEST,
38            "Missing Sec-WebSocket-Key header in Strict Mode",
39        )));
40    }
41
42    if let Some(v) = req.headers().get(hyper::header::SEC_WEBSOCKET_VERSION) {
43        if v != "13" {
44            return Err(Box::new(create_error_response(
45                StatusCode::BAD_REQUEST,
46                "Unsupported WebSocket Version in Strict Mode (Expected 13)",
47            )));
48        }
49    } else {
50        return Err(Box::new(create_error_response(
51            StatusCode::BAD_REQUEST,
52            "Missing Sec-WebSocket-Version header in Strict Mode",
53        )));
54    }
55
56    Ok(())
57}
58
59#[allow(clippy::too_many_arguments)]
60pub async fn handle_websocket_handshake<B>(
61    req: Request<B>,
62    client_addr: SocketAddr,
63    on_flow: Sender<FlowUpdate>,
64    client: Arc<HttpsClient>,
65    interceptor: Arc<dyn Interceptor>,
66    is_mitm: bool,
67    policy_rx: watch::Receiver<ProxyPolicy>,
68    target_addr: Option<SocketAddr>,
69    loop_detector: Arc<LoopDetector>,
70) -> Result<Response<HttpBody>, Infallible>
71where
72    B: hyper::body::Body + Send + 'static,
73    B::Data: Send,
74    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
75{
76    // 2. Parse L7 (HTTP) - Re-extract for WS
77    let meta = parse_request_meta(&req, is_mitm);
78    let policy = policy_rx.borrow().clone();
79
80    // STRICT MODE CHECKS
81    if let Err(resp) = validate_ws_strict_handshake(&req, &policy) {
82        return Ok(*resp);
83    }
84
85    // Create Initial Flow for WebSocket Handshake
86    let mut flow = create_initial_flow(
87        meta.clone(),
88        None,
89        client_addr,
90        is_mitm,
91        true,
92    );
93    
94    // INTERCEPT HEADERS (Handshake)
95    match interceptor.on_request_headers(&mut flow).await {
96        InterceptionResult::Drop => {
97             return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
98        },
99        InterceptionResult::MockResponse(mock) => {
100             if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
101                 crate::metrics::inc_flows_dropped();
102             }
103             return Ok(mock_to_response(mock));
104        },
105        InterceptionResult::ModifiedRequest(req) => {
106             if let Layer::WebSocket(ws) = &mut flow.layer {
107                 ws.handshake_request = req;
108             }
109        },
110        InterceptionResult::ModifiedResponse(res) => {
111             if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
112                 crate::metrics::inc_flows_dropped();
113             }
114             return Ok(mock_to_response(res));
115        },
116        _ => {}
117    }
118
119    // INTERCEPT REQUEST (Handshake Full)
120    // WS Handshake has empty body
121    let body = http_body_util::Empty::new().map_err(|e| e.into()).boxed();
122    
123    match interceptor.on_request(&mut flow, body).await {
124        Ok(RequestAction::Drop) => {
125             return Ok(create_error_response(StatusCode::FORBIDDEN, ""));
126        },
127        Ok(RequestAction::MockResponse(res)) => {
128             if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
129                 crate::metrics::inc_flows_dropped();
130             }
131             let (parts, body) = res.into_parts();
132             return Ok(Response::from_parts(parts, body));
133        },
134        Ok(RequestAction::Continue(_)) => {},
135        Err(e) => {
136             return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Interceptor Error: {}", e)));
137        }
138    }
139    
140    if on_flow.try_send(FlowUpdate::Full(Box::new(flow.clone()))).is_err() {
141        crate::metrics::inc_flows_dropped();
142    }
143
144    // Prepare Upgrade
145    let (parts, body) = req.into_parts();
146    let req_for_upgrade = Request::from_parts(parts, body);
147    
148    // Determine Target URL
149    let mut target_url_str = meta.url_str.clone();
150    
151    if policy.transparent_enabled
152        && let Some(addr) = target_addr {
153            flow.tags.push("transparent".to_string());
154            
155            // Update Flow Network Info
156            flow.network.server_ip = addr.ip().to_string();
157            flow.network.server_port = addr.port();
158
159            // Loop Detection
160            if loop_detector.would_loop(addr) {
161                if let Layer::WebSocket(ws) = &mut flow.layer {
162                     ws.handshake_response.status = 508;
163                     ws.closed = true;
164                }
165                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
166                    crate::metrics::inc_flows_dropped();
167                }
168                return Ok(create_error_response(StatusCode::LOOP_DETECTED, "Loop Detected"));
169            }
170
171            // Rewrite URI
172            let mut u = if let Layer::WebSocket(ws) = &flow.layer {
173                ws.handshake_request.url.clone()
174            } else {
175                Url::parse(&meta.url_str).unwrap_or_else(|_| Url::parse("http://unknown/").unwrap())
176            };
177
178            if u.set_ip_host(addr.ip()).is_ok() {
179                u.set_port(Some(addr.port())).ok();
180                // Ensure scheme is correct (ws/wss)
181                if is_mitm && (u.scheme() == "http" || u.scheme() == "ws") {
182                    u.set_scheme("wss").ok();
183                } else if !is_mitm && (u.scheme() == "https" || u.scheme() == "wss") {
184                    u.set_scheme("ws").ok();
185                }
186                target_url_str = u.to_string();
187            }
188        }
189
190    // Prepare Forward Request
191    let current_req = if let Layer::WebSocket(ws) = &flow.layer {
192        &ws.handshake_request
193    } else {
194        return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Invalid Flow Layer State"));
195    };
196
197    let mut forward_req_builder = Request::builder()
198        .method(current_req.method.as_str())
199        .uri(target_url_str.as_str())
200        .version(hyper::Version::HTTP_11);
201    
202    for (k, v) in current_req.headers.iter() {
203        if let (Ok(name), Ok(val)) = (HeaderName::from_bytes(k.as_bytes()), HeaderValue::from_str(v)) {
204            forward_req_builder = forward_req_builder.header(name, val);
205        }
206    }
207    
208    let forward_req = match forward_req_builder.body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed()) {
209        Ok(req) => req,
210        Err(e) => return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to build forward request: {}", e))),
211    };
212
213    match tokio::time::timeout(std::time::Duration::from_secs(30), client.request(forward_req)).await {
214        Ok(Ok(resp)) => {
215            if resp.status() == StatusCode::SWITCHING_PROTOCOLS {
216                let (parts, body) = resp.into_parts();
217                let resp_for_upgrade = Response::from_parts(parts.clone(), body);
218                
219                // Spawn upgrade task
220                let on_flow_clone = on_flow.clone();
221                let interceptor_clone = interceptor.clone();
222                let flow_clone = flow.clone(); // Clone flow for the tunnel task
223                
224                tokio::task::spawn(async move {
225                    // Add timeout for upgrades
226                    let upgrade_timeout = std::time::Duration::from_secs(10);
227                    
228                    let client_upgrade = tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(req_for_upgrade));
229                    let server_upgrade = tokio::time::timeout(upgrade_timeout, hyper::upgrade::on(resp_for_upgrade));
230
231                    match tokio::try_join!(client_upgrade, server_upgrade) {
232                        Ok((Ok(upgraded_client), Ok(upgraded_server))) => {
233                            // Start WebSocket Tunnel
234                            if let Err(e) = handle_websocket_tunnel(
235                                upgraded_client,
236                                upgraded_server,
237                                flow_clone,
238                                on_flow_clone,
239                                interceptor_clone
240                            ).await {
241                                tracing::error!("WebSocket Tunnel Error: {}", e);
242                            }
243                        },
244                        Ok((Err(e), _)) => tracing::error!("Client WebSocket Upgrade Error: {}", e),
245                        Ok((_, Err(e))) => tracing::error!("Upstream WebSocket Upgrade Error: {}", e),
246                        Err(_) => tracing::error!("WebSocket Upgrade Timed Out"),
247                    }
248                });
249
250                let mut client_resp_builder = Response::builder()
251                    .status(StatusCode::SWITCHING_PROTOCOLS)
252                    .version(parts.version);
253                    
254                for (k, v) in parts.headers.iter() {
255                    client_resp_builder = client_resp_builder.header(k, v);
256                }
257                
258                let client_resp = match client_resp_builder
259                    .body(Full::new(Bytes::new()).map_err(|e| e.into()).boxed())
260                {
261                    Ok(r) => r,
262                    Err(e) => {
263                        tracing::error!("Failed to build 101 Switching Protocols response: {}", e);
264                        return Ok(create_error_response(StatusCode::INTERNAL_SERVER_ERROR, "Response build failed"));
265                    }
266                };
267                
268                // Update handshake response in flow (but we don't send update here as flow is moved/cloned)
269                // Ideally we should send an update here for the handshake response
270                // But `flow` variable is local.
271                // We can construct a partial update? Or just ignore.
272                // The tunnel will send messages.
273                
274                Ok(client_resp)
275            } else {
276                // Normal HTTP Response (Handshake failed)
277                let (parts, body) = resp.into_parts();
278                let body_bytes = match body.collect().await {
279                    Ok(c) => c.to_bytes(),
280                    Err(_) => Bytes::new(),
281                };
282                
283                // Create HTTP Response for Flow
284                let http_resp = HttpResponse {
285                    status: parts.status.as_u16(),
286                    status_text: parts.status.canonical_reason().unwrap_or("Unknown").to_string(),
287                    version: format!("{:?}", parts.version),
288                    headers: parts.headers.iter()
289                        .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
290                        .collect(),
291                    cookies: vec![], // Todo: parse cookies
292                    body: Some(BodyData {
293                        encoding: "utf-8".to_string(),
294                        content: String::from_utf8_lossy(&body_bytes).to_string(),
295                        size: body_bytes.len() as u64,
296                    }),
297                    timing: ResponseTiming {
298                        time_to_first_byte: None,
299                        time_to_last_byte: None,
300                    connect_time_ms: None,
301                    ssl_time_ms: None,
302                    },
303                };
304
305                if let Layer::WebSocket(ws) = &mut flow.layer {
306                    ws.handshake_response = http_resp.clone();
307                    ws.closed = true;
308                }
309                
310                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
311                    crate::metrics::inc_flows_dropped();
312                }
313
314                Ok(Response::from_parts(parts, Full::new(body_bytes).map_err(|e| e.into()).boxed()))
315            }
316        },
317        Ok(Err(e)) => Ok(create_error_response(StatusCode::BAD_GATEWAY, format!("Upstream Handshake Failed: {}", e))),
318        Err(_) => Ok(create_error_response(StatusCode::GATEWAY_TIMEOUT, "Upstream Handshake Timed Out")),
319    }
320}
321
322async fn handle_websocket_tunnel(
323    client_io: Upgraded,
324    server_io: Upgraded,
325    mut flow: Flow,
326    on_flow: Sender<FlowUpdate>,
327    interceptor: Arc<dyn Interceptor>,
328) -> Result<(), BoxError> {
329    let client_ws = WebSocketStream::from_raw_socket(TokioIo::new(client_io), tokio_tungstenite::tungstenite::protocol::Role::Server, None).await;
330    let server_ws = WebSocketStream::from_raw_socket(TokioIo::new(server_io), tokio_tungstenite::tungstenite::protocol::Role::Client, None).await;
331
332    let (mut client_tx, mut client_rx) = client_ws.split();
333    let (mut server_tx, mut server_rx) = server_ws.split();
334    
335    // Idle timeout for WebSocket
336    let idle_timeout_duration = std::time::Duration::from_secs(300); // 5 minutes
337
338    loop {
339        let event = tokio::time::timeout(idle_timeout_duration, async {
340            tokio::select! {
341                msg = client_rx.next() => (Direction::ClientToServer, msg),
342                msg = server_rx.next() => (Direction::ServerToClient, msg),
343            }
344        }).await;
345
346        match event {
347            Ok((dir, msg_opt)) => {
348                match msg_opt {
349                    Some(Ok(msg)) => {
350                        // Handle Message
351                        let (sender, _receiver, intercept_dir) = if dir == Direction::ClientToServer {
352                            (&mut server_tx, &mut client_tx, Direction::ClientToServer)
353                        } else {
354                            (&mut client_tx, &mut server_tx, Direction::ServerToClient)
355                        };
356
357                        if let Some(ws_msg) = tungstenite_to_flow_msg(msg.clone(), intercept_dir) {
358                             match interceptor.on_websocket_message(&mut flow, ws_msg.clone()).await {
359                                Ok(WebSocketMessageAction::Drop) => continue,
360                                Ok(WebSocketMessageAction::Continue(mod_msg)) => {
361                                    let t_msg = flow_msg_to_tungstenite(&mod_msg);
362                                    sender.send(t_msg).await?;
363                                    
364                                    if on_flow.try_send(FlowUpdate::WebSocketMessage {
365                                        flow_id: flow.id.to_string(),
366                                        message: mod_msg,
367                                    }).is_err() {
368                                        crate::metrics::inc_flows_dropped();
369                                    }
370                                },
371                                Err(e) => {
372                                    tracing::error!("WebSocket Interception Error: {}", e);
373                                    sender.send(msg).await?;
374                                    
375                                    if on_flow.try_send(FlowUpdate::WebSocketMessage {
376                                        flow_id: flow.id.to_string(),
377                                        message: ws_msg,
378                                    }).is_err() {
379                                        crate::metrics::inc_flows_dropped();
380                                    }
381                                }
382                            }
383                        } else {
384                             // Non-data message
385                             sender.send(msg).await?;
386                        }
387                    },
388                    Some(Err(e)) => return Err(e.into()),
389                    None => break, // Connection closed
390                }
391            },
392            Err(_) => {
393                tracing::warn!("WebSocket Tunnel Idle Timeout");
394                // Optional: Send close frame?
395                return Err("WebSocket Idle Timeout".into());
396            }
397        }
398    }
399    
400    Ok(())
401}
402
403fn tungstenite_to_flow_msg(msg: Message, dir: Direction) -> Option<WebSocketMessage> {
404    let (opcode, content, encoding, size) = match msg {
405        Message::Text(t) => {
406            let len = t.len();
407            ("Text", t.to_string(), "utf-8", len)
408        },
409        Message::Binary(b) => {
410            let len = b.len();
411            ("Binary", BASE64.encode(&b), "base64", len)
412        },
413        Message::Ping(b) => {
414            let len = b.len();
415            ("Ping", BASE64.encode(&b), "base64", len)
416        },
417        Message::Pong(b) => {
418            let len = b.len();
419            ("Pong", BASE64.encode(&b), "base64", len)
420        },
421        Message::Close(_) => ("Close", String::new(), "none", 0),
422        Message::Frame(_) => return None,
423    };
424    
425    Some(WebSocketMessage {
426        id: Uuid::new_v4(),
427        timestamp: Utc::now(),
428        direction: dir,
429        content: BodyData {
430            encoding: encoding.to_string(),
431            content,
432            size: size as u64,
433        },
434        opcode: opcode.to_string(),
435    })
436}
437
438fn flow_msg_to_tungstenite(msg: &WebSocketMessage) -> Message {
439    match msg.opcode.as_str() {
440        "Text" => Message::Text(msg.content.content.clone().into()),
441        "Binary" => {
442            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
443                Message::Binary(Bytes::from(b))
444            } else {
445                Message::Binary(Bytes::new())
446            }
447        },
448        "Ping" => {
449            if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
450                Message::Ping(Bytes::from(b))
451            } else {
452                Message::Ping(Bytes::new())
453            }
454        },
455        "Pong" => {
456             if let Ok(b) = BASE64.decode(msg.content.content.as_bytes()) {
457                Message::Pong(Bytes::from(b))
458            } else {
459                Message::Pong(Bytes::new())
460            }
461        },
462        "Close" => Message::Close(None),
463        _ => Message::Text(msg.content.content.clone().into()),
464    }
465}
466
467#[cfg(test)]
468mod websocket_tests {
469    use super::*;
470    use http_body_util::Empty;
471
472    #[test]
473    fn test_validate_ws_strict_handshake_rejects_missing_key() {
474        let policy = ProxyPolicy { strict_http_semantics: true, ..Default::default() };
475        let req = Request::builder()
476            .method("GET")
477            .uri("ws://example.com/socket")
478            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
479            .body(Empty::<Bytes>::new())
480            .expect("request build");
481        let result = validate_ws_strict_handshake(&req, &policy);
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_validate_ws_strict_handshake_rejects_invalid_version() {
487        let policy = ProxyPolicy { strict_http_semantics: true, ..Default::default() };
488        let req = Request::builder()
489            .method("GET")
490            .uri("ws://example.com/socket")
491            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
492            .header(hyper::header::SEC_WEBSOCKET_VERSION, "12")
493            .body(Empty::<Bytes>::new())
494            .expect("request build");
495        let result = validate_ws_strict_handshake(&req, &policy);
496        assert!(result.is_err());
497    }
498
499    #[test]
500    fn test_validate_ws_strict_handshake_accepts_valid_request() {
501        let policy = ProxyPolicy { strict_http_semantics: true, ..Default::default() };
502        let req = Request::builder()
503            .method("GET")
504            .uri("ws://example.com/socket")
505            .header(hyper::header::SEC_WEBSOCKET_KEY, "test-key")
506            .header(hyper::header::SEC_WEBSOCKET_VERSION, "13")
507            .body(Empty::<Bytes>::new())
508            .expect("request build");
509        let result = validate_ws_strict_handshake(&req, &policy);
510        assert!(result.is_ok());
511    }
512
513    #[test]
514    fn test_validate_ws_strict_handshake_skips_when_disabled() {
515        let policy = ProxyPolicy { strict_http_semantics: false, ..Default::default() };
516        let req = Request::builder()
517            .method("GET")
518            .uri("ws://example.com/socket")
519            .body(Empty::<Bytes>::new())
520            .expect("request build");
521        let result = validate_ws_strict_handshake(&req, &policy);
522        assert!(result.is_ok());
523    }
524}