use http::{Request, Response, StatusCode};
use hyper::upgrade::OnUpgrade;
use tokio::io::copy_bidirectional;
use tokio::net::TcpStream;
use tracing::{debug, error, warn};
use crate::{Body, ProxyError, empty_body, websocket};
pub fn is_websocket_upgrade<B>(req: &Request<B>) -> bool {
websocket::is_websocket_upgrade(req.headers())
}
pub async fn proxy_websocket(
mut req: Request<Body>,
upstream_addr: &str,
) -> Result<Response<Body>, ProxyError> {
let mut upstream_stream = TcpStream::connect(upstream_addr).await.map_err(|e| {
ProxyError::Internal(format!(
"failed to connect to upstream {upstream_addr}: {e}"
))
})?;
debug!(upstream = %upstream_addr, "connected to upstream for WebSocket upgrade");
let raw_request = build_raw_upgrade_request(&req, upstream_addr);
use tokio::io::AsyncWriteExt;
upstream_stream
.write_all(raw_request.as_bytes())
.await
.map_err(|e| {
ProxyError::Internal(format!("failed to write upgrade request to upstream: {e}"))
})?;
use tokio::io::AsyncReadExt;
let mut buf = Vec::with_capacity(4096);
let mut tmp = [0u8; 1024];
loop {
let n = upstream_stream.read(&mut tmp).await.map_err(|e| {
ProxyError::Internal(format!("failed to read upstream upgrade response: {e}"))
})?;
if n == 0 {
return Err(ProxyError::Internal(
"upstream closed connection before completing WebSocket handshake".into(),
));
}
buf.extend_from_slice(&tmp[..n]);
if buf.len() > 16_384 {
return Err(ProxyError::Internal(
"upstream upgrade response too large".into(),
));
}
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let response_str = String::from_utf8_lossy(&buf);
if !response_str.starts_with("HTTP/1.1 101") {
let first_line = response_str.lines().next().unwrap_or("<empty>");
warn!(
upstream = %upstream_addr,
response = %first_line,
"upstream did not accept WebSocket upgrade"
);
return Err(ProxyError::Internal(format!(
"upstream did not accept WebSocket upgrade: {first_line}"
)));
}
debug!(upstream = %upstream_addr, "upstream accepted WebSocket upgrade");
let client_upgrade: OnUpgrade = hyper::upgrade::on(&mut req);
let mut response = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::CONNECTION, "Upgrade")
.header(http::header::UPGRADE, "websocket");
for line in response_str.lines().skip(1) {
if line.is_empty() || line == "\r" {
break;
}
if let Some((name, value)) = line.split_once(':') {
let name = name.trim();
let value = value.trim();
let name_lower = name.to_ascii_lowercase();
if name_lower == "sec-websocket-accept"
|| name_lower == "sec-websocket-protocol"
|| name_lower == "sec-websocket-extensions"
{
response = response.header(name, value);
}
}
}
let response = response.body(empty_body())?;
tokio::spawn(async move {
match client_upgrade.await {
Ok(client_io) => {
let mut client_io = hyper_util::rt::TokioIo::new(client_io);
let mut upstream_stream = upstream_stream;
match copy_bidirectional(&mut client_io, &mut upstream_stream).await {
Ok((client_to_upstream, upstream_to_client)) => {
debug!(
client_to_upstream,
upstream_to_client, "WebSocket tunnel closed"
);
}
Err(e) => {
debug!("WebSocket tunnel error: {e}");
}
}
}
Err(e) => {
error!("WebSocket client upgrade failed: {e}");
}
}
});
Ok(response)
}
fn build_raw_upgrade_request<B>(req: &Request<B>, upstream_addr: &str) -> String {
let method = req.method();
let path = req
.uri()
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let mut raw = format!("{method} {path} HTTP/1.1\r\n");
raw.push_str(&format!("Host: {upstream_addr}\r\n"));
for (name, value) in req.headers() {
if name == http::header::HOST {
continue;
}
if let Ok(v) = value.to_str() {
raw.push_str(&format!("{}: {v}\r\n", name.as_str()));
}
}
raw.push_str("\r\n");
raw
}