Skip to main content

gatel_core/proxy/
websocket.rs

1//! WebSocket upgrade detection and bidirectional proxying.
2//!
3//! When a client sends a WebSocket upgrade request (`Connection: Upgrade` +
4//! `Upgrade: websocket`), this module forwards the upgrade to the upstream
5//! backend and then establishes a bidirectional byte-copy between the client
6//! and upstream, effectively tunnelling the WebSocket frames without
7//! interpreting them.
8
9use http::{Request, Response, StatusCode};
10use hyper::upgrade::OnUpgrade;
11use tokio::io::copy_bidirectional;
12use tokio::net::TcpStream;
13use tracing::{debug, error, warn};
14
15use crate::{Body, ProxyError, empty_body, websocket};
16
17/// Check whether an incoming request is a WebSocket upgrade.
18///
19/// Returns `true` when the request contains both `Connection: Upgrade` and
20/// `Upgrade: websocket` headers (case-insensitive value comparison).
21///
22/// Thin wrapper around [`crate::websocket::is_websocket_upgrade`] which takes
23/// a `&HeaderMap`; this adapter accepts the full `&Request<B>`.
24pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
25    websocket::is_websocket_upgrade(req.headers())
26}
27
28/// Proxy a WebSocket upgrade request to the given upstream address.
29///
30/// The flow:
31/// 1. Open a raw TCP connection to the upstream.
32/// 2. Write the HTTP upgrade request to the upstream.
33/// 3. Read the upstream's 101 response.
34/// 4. Return a 101 response to the client (with the upgrade extension on the hyper side).
35/// 5. Once both sides have upgraded, spawn a task that copies bytes bidirectionally until either
36///    side closes.
37pub async fn proxy_websocket(
38    mut req: Request<Body>,
39    upstream_addr: &str,
40) -> Result<Response<Body>, ProxyError> {
41    // Connect to the upstream over TCP.
42    let mut upstream_stream = TcpStream::connect(upstream_addr).await.map_err(|e| {
43        ProxyError::Internal(format!(
44            "failed to connect to upstream {upstream_addr}: {e}"
45        ))
46    })?;
47
48    debug!(upstream = %upstream_addr, "connected to upstream for WebSocket upgrade");
49
50    // Build the raw HTTP/1.1 upgrade request to send over the TCP connection.
51    let raw_request = build_raw_upgrade_request(&req, upstream_addr);
52
53    // Write the upgrade request to the upstream.
54    use tokio::io::AsyncWriteExt;
55    upstream_stream
56        .write_all(raw_request.as_bytes())
57        .await
58        .map_err(|e| {
59            ProxyError::Internal(format!("failed to write upgrade request to upstream: {e}"))
60        })?;
61
62    // Read the upstream's response (enough to see the 101 status line and
63    // headers). We read into a buffer until we see the end-of-headers marker
64    // (\r\n\r\n).
65    use tokio::io::AsyncReadExt;
66    let mut buf = Vec::with_capacity(4096);
67    let mut tmp = [0u8; 1024];
68    loop {
69        let n = upstream_stream.read(&mut tmp).await.map_err(|e| {
70            ProxyError::Internal(format!("failed to read upstream upgrade response: {e}"))
71        })?;
72        if n == 0 {
73            return Err(ProxyError::Internal(
74                "upstream closed connection before completing WebSocket handshake".into(),
75            ));
76        }
77        buf.extend_from_slice(&tmp[..n]);
78        if buf.len() > 16_384 {
79            return Err(ProxyError::Internal(
80                "upstream upgrade response too large".into(),
81            ));
82        }
83        if buf.windows(4).any(|w| w == b"\r\n\r\n") {
84            break;
85        }
86    }
87
88    let response_str = String::from_utf8_lossy(&buf);
89
90    // Verify the upstream responded with 101 Switching Protocols.
91    if !response_str.starts_with("HTTP/1.1 101") {
92        let first_line = response_str.lines().next().unwrap_or("<empty>");
93        warn!(
94            upstream = %upstream_addr,
95            response = %first_line,
96            "upstream did not accept WebSocket upgrade"
97        );
98        return Err(ProxyError::Internal(format!(
99            "upstream did not accept WebSocket upgrade: {first_line}"
100        )));
101    }
102
103    debug!(upstream = %upstream_addr, "upstream accepted WebSocket upgrade");
104
105    // Build the 101 response to send back to the client. We need to set up
106    // the hyper upgrade machinery so we can get the raw IO after sending the
107    // response.
108
109    // Capture the client's OnUpgrade before we consume the request.
110    // hyper stores the upgrade future in the request extensions.
111    let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
112
113    let mut response = Response::builder()
114        .status(StatusCode::SWITCHING_PROTOCOLS)
115        .header(http::header::CONNECTION, "Upgrade")
116        .header(http::header::UPGRADE, "websocket");
117
118    // Forward Sec-WebSocket-Accept and Sec-WebSocket-Protocol from upstream
119    // response to the client response.
120    for line in response_str.lines().skip(1) {
121        if line.is_empty() || line == "\r" {
122            break;
123        }
124        if let Some((name, value)) = line.split_once(':') {
125            let name = name.trim();
126            let value = value.trim();
127            let name_lower = name.to_ascii_lowercase();
128            if name_lower == "sec-websocket-accept"
129                || name_lower == "sec-websocket-protocol"
130                || name_lower == "sec-websocket-extensions"
131            {
132                response = response.header(name, value);
133            }
134        }
135    }
136
137    let response = response.body(empty_body())?;
138
139    // Spawn the bidirectional copy task. It will run once the client side
140    // completes its upgrade (i.e., after the 101 response is sent).
141    tokio::spawn(async move {
142        match client_upgrade.await {
143            Ok(client_io) => {
144                let mut client_io = hyper_util::rt::TokioIo::new(client_io);
145                let mut upstream_stream = upstream_stream;
146
147                match copy_bidirectional(&mut client_io, &mut upstream_stream).await {
148                    Ok((client_to_upstream, upstream_to_client)) => {
149                        debug!(
150                            client_to_upstream,
151                            upstream_to_client, "WebSocket tunnel closed"
152                        );
153                    }
154                    Err(e) => {
155                        debug!("WebSocket tunnel error: {e}");
156                    }
157                }
158            }
159            Err(e) => {
160                error!("WebSocket client upgrade failed: {e}");
161            }
162        }
163    });
164
165    Ok(response)
166}
167
168/// Build a raw HTTP/1.1 request string for the WebSocket upgrade, suitable
169/// for writing directly to a TCP stream.
170fn build_raw_upgrade_request<B>(req: &Request<B>, upstream_addr: &str) -> String {
171    let method = req.method();
172    let path = req
173        .uri()
174        .path_and_query()
175        .map(|pq| pq.as_str())
176        .unwrap_or("/");
177
178    let mut raw = format!("{method} {path} HTTP/1.1\r\n");
179    raw.push_str(&format!("Host: {upstream_addr}\r\n"));
180
181    for (name, value) in req.headers() {
182        // Skip the Host header since we already set it to the upstream.
183        if name == http::header::HOST {
184            continue;
185        }
186        if let Ok(v) = value.to_str() {
187            raw.push_str(&format!("{}: {v}\r\n", name.as_str()));
188        }
189    }
190
191    raw.push_str("\r\n");
192    raw
193}