use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tracing::{debug, error};
pub(crate) fn is_websocket_upgrade(req: &Request<Incoming>) -> bool {
req.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.eq_ignore_ascii_case("websocket"))
}
pub(crate) async fn handle_websocket_proxy(
req: Request<Incoming>,
backend_addr: &str,
) -> Response<http_body_util::Full<hyper::body::Bytes>> {
let mut backend = match TcpStream::connect(backend_addr).await {
Ok(stream) => stream,
Err(e) => {
error!("WebSocket backend connect failed ({backend_addr}): {e}");
return super::handler::error_response(
StatusCode::BAD_GATEWAY,
&format!("websocket backend error: {e}"),
);
}
};
let (parts, _body) = req.into_parts();
let path = parts
.uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let mut raw_req = format!("{} {} HTTP/1.1\r\n", parts.method, path);
for (name, value) in &parts.headers {
if let Ok(val) = value.to_str() {
raw_req.push_str(&format!("{}: {val}\r\n", name.as_str()));
}
}
raw_req.push_str("\r\n");
if let Err(e) = backend.write_all(raw_req.as_bytes()).await {
error!("Failed to send WebSocket upgrade to backend: {e}");
return super::handler::error_response(
StatusCode::BAD_GATEWAY,
&format!("websocket write error: {e}"),
);
}
debug!("WebSocket upgrade forwarded to {backend_addr}");
let mut resp = Response::new(http_body_util::Full::new(hyper::body::Bytes::new()));
*resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
resp.headers_mut()
.insert("upgrade", "websocket".parse().expect("valid header value"));
resp.headers_mut()
.insert("connection", "Upgrade".parse().expect("valid header value"));
resp
}
#[cfg(test)]
mod tests {
#[test]
fn test_websocket_upgrade_detected() {
let check = |val: &str| -> bool { val.eq_ignore_ascii_case("websocket") };
assert!(check("websocket"));
assert!(check("WebSocket"));
assert!(check("WEBSOCKET"));
assert!(!check("http"));
assert!(!check(""));
}
#[test]
fn test_non_websocket_not_detected() {
let check = |val: &str| -> bool { val.eq_ignore_ascii_case("websocket") };
assert!(!check("h2c"));
assert!(!check("TLS/1.0"));
}
}