gatel_core/proxy/
websocket.rs1use http::{Request, Response, StatusCode};
10use hyper::upgrade::OnUpgrade;
11use tokio::io::copy_bidirectional;
12use tokio::net::TcpStream;
13use tracing::{debug, error, warn};
14
15use crate::{Body, ProxyError, empty_body, websocket};
16
17pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
25 websocket::is_websocket_upgrade(req.headers())
26}
27
28pub async fn proxy_websocket(
38 mut req: Request<Body>,
39 upstream_addr: &str,
40) -> Result<Response<Body>, ProxyError> {
41 let mut upstream_stream = TcpStream::connect(upstream_addr).await.map_err(|e| {
43 ProxyError::Internal(format!(
44 "failed to connect to upstream {upstream_addr}: {e}"
45 ))
46 })?;
47
48 debug!(upstream = %upstream_addr, "connected to upstream for WebSocket upgrade");
49
50 let raw_request = build_raw_upgrade_request(&req, upstream_addr);
52
53 use tokio::io::AsyncWriteExt;
55 upstream_stream
56 .write_all(raw_request.as_bytes())
57 .await
58 .map_err(|e| {
59 ProxyError::Internal(format!("failed to write upgrade request to upstream: {e}"))
60 })?;
61
62 use tokio::io::AsyncReadExt;
66 let mut buf = Vec::with_capacity(4096);
67 let mut tmp = [0u8; 1024];
68 loop {
69 let n = upstream_stream.read(&mut tmp).await.map_err(|e| {
70 ProxyError::Internal(format!("failed to read upstream upgrade response: {e}"))
71 })?;
72 if n == 0 {
73 return Err(ProxyError::Internal(
74 "upstream closed connection before completing WebSocket handshake".into(),
75 ));
76 }
77 buf.extend_from_slice(&tmp[..n]);
78 if buf.len() > 16_384 {
79 return Err(ProxyError::Internal(
80 "upstream upgrade response too large".into(),
81 ));
82 }
83 if buf.windows(4).any(|w| w == b"\r\n\r\n") {
84 break;
85 }
86 }
87
88 let response_str = String::from_utf8_lossy(&buf);
89
90 if !response_str.starts_with("HTTP/1.1 101") {
92 let first_line = response_str.lines().next().unwrap_or("<empty>");
93 warn!(
94 upstream = %upstream_addr,
95 response = %first_line,
96 "upstream did not accept WebSocket upgrade"
97 );
98 return Err(ProxyError::Internal(format!(
99 "upstream did not accept WebSocket upgrade: {first_line}"
100 )));
101 }
102
103 debug!(upstream = %upstream_addr, "upstream accepted WebSocket upgrade");
104
105 let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
112
113 let mut response = Response::builder()
114 .status(StatusCode::SWITCHING_PROTOCOLS)
115 .header(http::header::CONNECTION, "Upgrade")
116 .header(http::header::UPGRADE, "websocket");
117
118 for line in response_str.lines().skip(1) {
121 if line.is_empty() || line == "\r" {
122 break;
123 }
124 if let Some((name, value)) = line.split_once(':') {
125 let name = name.trim();
126 let value = value.trim();
127 let name_lower = name.to_ascii_lowercase();
128 if name_lower == "sec-websocket-accept"
129 || name_lower == "sec-websocket-protocol"
130 || name_lower == "sec-websocket-extensions"
131 {
132 response = response.header(name, value);
133 }
134 }
135 }
136
137 let response = response.body(empty_body())?;
138
139 tokio::spawn(async move {
142 match client_upgrade.await {
143 Ok(client_io) => {
144 let mut client_io = hyper_util::rt::TokioIo::new(client_io);
145 let mut upstream_stream = upstream_stream;
146
147 match copy_bidirectional(&mut client_io, &mut upstream_stream).await {
148 Ok((client_to_upstream, upstream_to_client)) => {
149 debug!(
150 client_to_upstream,
151 upstream_to_client, "WebSocket tunnel closed"
152 );
153 }
154 Err(e) => {
155 debug!("WebSocket tunnel error: {e}");
156 }
157 }
158 }
159 Err(e) => {
160 error!("WebSocket client upgrade failed: {e}");
161 }
162 }
163 });
164
165 Ok(response)
166}
167
168fn build_raw_upgrade_request<B>(req: &Request<B>, upstream_addr: &str) -> String {
171 let method = req.method();
172 let path = req
173 .uri()
174 .path_and_query()
175 .map(|pq| pq.as_str())
176 .unwrap_or("/");
177
178 let mut raw = format!("{method} {path} HTTP/1.1\r\n");
179 raw.push_str(&format!("Host: {upstream_addr}\r\n"));
180
181 for (name, value) in req.headers() {
182 if name == http::header::HOST {
184 continue;
185 }
186 if let Ok(v) = value.to_str() {
187 raw.push_str(&format!("{}: {v}\r\n", name.as_str()));
188 }
189 }
190
191 raw.push_str("\r\n");
192 raw
193}