use axum::{
body::Body,
http::{Request, Response, Uri},
};
use futures_util::{SinkExt, stream::StreamExt};
use http::{HeaderMap, HeaderValue, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time::{Duration, timeout};
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, connect_async,
tungstenite::{Error, Message, handshake::derive_accept_key},
};
use tracing::{error, trace};
use url::{Host, Url};
pub(crate) fn is_websocket_upgrade(headers: &HeaderMap<HeaderValue>) -> bool {
let has_upgrade = headers
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let has_connection = headers
.get("connection")
.and_then(|v| v.to_str().ok())
.map(|v| {
v.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("upgrade"))
})
.unwrap_or(false);
let has_websocket_key = headers.contains_key("sec-websocket-key");
let has_websocket_version = headers.contains_key("sec-websocket-version");
trace!(
"is_websocket_upgrade - upgrade: {has_upgrade}, connection: {has_connection}, websocket key: {has_websocket_key}, websocket version: {has_websocket_version}"
);
has_upgrade && has_connection && has_websocket_key && has_websocket_version
}
fn compute_host_header_from_url(url: &Url) -> (String, u16) {
let scheme = url.scheme();
let host = match url.host() {
Some(Host::Ipv6(addr)) => format!("[{addr}]"),
Some(Host::Ipv4(addr)) => addr.to_string(),
Some(Host::Domain(s)) => s.to_string(),
None => "localhost".to_string(),
};
let port = match url.port() {
Some(p) => p,
None => {
if scheme == "wss" {
443
} else {
80
}
}
};
let header = if (scheme == "wss" && port == 443) || (scheme == "ws" && port == 80) {
host.clone()
} else {
format!("{host}:{port}")
};
(header, port)
}
#[cfg(test)]
fn compute_host_header(url: &str) -> (String, u16) {
let url = Url::parse(url).unwrap();
compute_host_header_from_url(&url)
}
pub(crate) async fn handle_websocket_with_upstream_uri(
req: Request<Body>,
upstream_http_uri: Uri,
) -> Result<Response<Body>, Box<dyn std::error::Error + Send + Sync>> {
trace!("Handling WebSocket upgrade request");
let ws_key = req
.headers()
.get("sec-websocket-key")
.and_then(|key| key.to_str().ok())
.ok_or("Missing or invalid Sec-WebSocket-Key header")?;
let ws_accept = derive_accept_key(ws_key.as_bytes());
let scheme = upstream_http_uri.scheme_str().unwrap_or("http");
let ws_scheme = match scheme {
"wss" | "ws" => scheme,
"https" => "wss",
_ => "ws",
};
let authority = upstream_http_uri
.authority()
.ok_or("Upstream URI missing authority")?
.as_str();
let path_q = upstream_http_uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let upstream_url = format!("{ws_scheme}://{authority}{path_q}");
trace!("Connecting to upstream WebSocket at {}", upstream_url);
let url = Url::parse(&upstream_url)?;
let (host_header, _port) = compute_host_header_from_url(&url);
let mut request = tokio_tungstenite::tungstenite::handshake::client::Request::builder()
.uri(upstream_url)
.header("host", host_header);
for (key, value) in req.headers() {
if key != "host" && key != "sec-websocket-extensions" {
request = request.header(key.as_str(), value);
}
}
let request = request.body(())?;
let (upstream_ws, upstream_response) = timeout(Duration::from_secs(5), connect_async(request))
.await
.map_err(|_| "Upstream WebSocket connection timed out")??;
trace!("Upstream WebSocket connected successfully");
let mut response_builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Accept", ws_accept);
if let Some(protocol) = upstream_response.headers().get("sec-websocket-protocol") {
response_builder = response_builder.header("Sec-WebSocket-Protocol", protocol);
}
trace!("Returning upgrade response to client");
let response = response_builder.body(Body::empty())?;
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts, body);
tokio::spawn(async move {
match handle_websocket_bridge(req, upstream_ws).await {
Ok(_) => trace!("WebSocket connection closed gracefully"),
Err(e) => error!("WebSocket connection error: {}", e),
}
});
Ok(response)
}
async fn handle_websocket_bridge(
req: Request<Body>,
upstream_ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let upgraded = match timeout(Duration::from_secs(5), hyper::upgrade::on(req)).await {
Ok(Ok(upgraded)) => upgraded,
Ok(Err(e)) => return Err(Box::new(e)),
Err(e) => return Err(Box::new(e)),
};
let io = TokioIo::new(upgraded);
let client_ws = tokio_tungstenite::WebSocketStream::from_raw_socket(
io,
tokio_tungstenite::tungstenite::protocol::Role::Server,
None,
)
.await;
let (mut client_sender, mut client_receiver) = client_ws.split();
let (mut upstream_sender, mut upstream_receiver) = upstream_ws.split();
let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
let close_tx_upstream = close_tx.clone();
let client_to_upstream = tokio::spawn(async move {
let mut client_closed = false;
while let Some(msg) = client_receiver.next().await {
let msg = msg?;
match msg {
Message::Close(_) => {
if !client_closed {
upstream_sender.send(Message::Close(None)).await?;
close_tx.send(()).await.ok();
client_closed = true;
break;
}
}
msg @ Message::Binary(_)
| msg @ Message::Text(_)
| msg @ Message::Ping(_)
| msg @ Message::Pong(_) => {
if !client_closed {
upstream_sender.send(msg).await?;
}
}
Message::Frame(_) => {}
}
}
if !client_closed {
upstream_sender.send(Message::Close(None)).await?;
close_tx.send(()).await.ok();
}
Ok::<_, Error>(())
});
let upstream_to_client = tokio::spawn(async move {
let mut upstream_closed = false;
while let Some(msg) = upstream_receiver.next().await {
let msg = msg?;
match msg {
Message::Close(_) => {
if !upstream_closed {
client_sender.send(Message::Close(None)).await?;
close_tx_upstream.send(()).await.ok();
upstream_closed = true;
break;
}
}
msg @ Message::Binary(_)
| msg @ Message::Text(_)
| msg @ Message::Ping(_)
| msg @ Message::Pong(_) => {
if !upstream_closed {
client_sender.send(msg).await?;
}
}
Message::Frame(_) => {}
}
}
if !upstream_closed {
client_sender.send(Message::Close(None)).await?;
close_tx_upstream.send(()).await.ok();
}
Ok::<_, Error>(())
});
tokio::select! {
_ = close_rx.recv() => {
trace!("WebSocket connection closed gracefully");
}
res = client_to_upstream => {
if let Err(e) = res {
error!("Client to upstream task failed: {:?}", e);
}
}
res = upstream_to_client => {
if let Err(e) = res {
error!("Upstream to client task failed: {:?}", e);
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{compute_host_header, is_websocket_upgrade};
use http::{HeaderMap, HeaderValue};
#[test]
fn host_header_ws_default_port() {
let (host, port) = compute_host_header("ws://example.com/path");
assert_eq!(host, "example.com");
assert_eq!(port, 80);
}
#[test]
fn host_header_wss_default_port() {
let (host, port) = compute_host_header("wss://example.com/path");
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn host_header_wss_custom_port() {
let (host, port) = compute_host_header("wss://example.com:8443/path");
assert_eq!(host, "example.com:8443");
assert_eq!(port, 8443);
}
#[test]
fn websocket_upgrade_valid_headers() {
let mut headers = HeaderMap::new();
headers.insert("Upgrade", HeaderValue::from_static("websocket"));
headers.insert(
"Connection",
HeaderValue::from_static("keep-alive, Upgrade"),
);
headers.insert(
"Sec-WebSocket-Key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
headers.insert("Sec-WebSocket-Version", HeaderValue::from_static("13"));
assert!(is_websocket_upgrade(&headers));
}
#[test]
fn websocket_upgrade_missing_upgrade_header() {
let mut headers = HeaderMap::new();
headers.insert("Connection", HeaderValue::from_static("Upgrade"));
headers.insert(
"Sec-WebSocket-Key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
headers.insert("Sec-WebSocket-Version", HeaderValue::from_static("13"));
assert!(!is_websocket_upgrade(&headers));
}
#[test]
fn websocket_upgrade_invalid_connection_header() {
let mut headers = HeaderMap::new();
headers.insert("Upgrade", HeaderValue::from_static("websocket"));
headers.insert("Connection", HeaderValue::from_static("keep-alive"));
headers.insert(
"Sec-WebSocket-Key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
headers.insert("Sec-WebSocket-Version", HeaderValue::from_static("13"));
assert!(!is_websocket_upgrade(&headers));
}
#[test]
fn websocket_upgrade_missing_key_or_version() {
let mut headers = HeaderMap::new();
headers.insert("Upgrade", HeaderValue::from_static("websocket"));
headers.insert("Connection", HeaderValue::from_static("Upgrade"));
headers.insert("Sec-WebSocket-Version", HeaderValue::from_static("13"));
assert!(!is_websocket_upgrade(&headers));
headers.insert(
"Sec-WebSocket-Key",
HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
);
headers.remove("Sec-WebSocket-Version");
assert!(!is_websocket_upgrade(&headers));
}
#[test]
fn host_header_ipv6_with_port() {
let (host, port) = compute_host_header("ws://[::1]:9000/path");
assert_eq!(host, "[::1]:9000");
assert_eq!(port, 9000);
}
}