use crate::error::{GatewayError, Result};
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
pub fn is_websocket_upgrade(headers: &http::HeaderMap) -> bool {
headers
.get("Upgrade")
.or_else(|| headers.get("upgrade"))
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
}
pub fn build_ws_url(backend_url: &str, uri: &http::Uri) -> String {
let backend = backend_url.trim_end_matches('/');
let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
let ws_url = if backend.starts_with("https://") {
backend.replacen("https://", "wss://", 1)
} else if backend.starts_with("http://") {
backend.replacen("http://", "ws://", 1)
} else if backend.starts_with("ws://") || backend.starts_with("wss://") {
backend.to_string()
} else {
format!("ws://{}", backend)
};
format!("{}{}", ws_url, path)
}
pub fn compute_accept_key(key: &str) -> String {
use base64::prelude::BASE64_STANDARD;
use base64::Engine as _;
use ring::digest::{Context, SHA1_FOR_LEGACY_USE_ONLY};
let mut ctx = Context::new(&SHA1_FOR_LEGACY_USE_ONLY);
ctx.update(key.as_bytes());
ctx.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let digest = ctx.finish();
BASE64_STANDARD.encode(digest.as_ref())
}
pub async fn relay_websocket<C, U>(mut client: WebSocketStream<C>, mut upstream: WebSocketStream<U>)
where
C: AsyncRead + AsyncWrite + Unpin,
U: AsyncRead + AsyncWrite + Unpin,
{
loop {
tokio::select! {
msg = client.next() => {
match msg {
Some(Ok(msg)) => {
if msg.is_close() {
let _ = upstream.close(None).await;
break;
}
if upstream.send(msg).await.is_err() {
break;
}
}
_ => break,
}
}
msg = upstream.next() => {
match msg {
Some(Ok(msg)) => {
if msg.is_close() {
let _ = client.close(None).await;
break;
}
if client.send(msg).await.is_err() {
break;
}
}
_ => break,
}
}
}
}
let _ = client.close(None).await;
let _ = upstream.close(None).await;
}
pub async fn connect_upstream(url: &str) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let (ws_stream, _response) = connect_async(url).await.map_err(|e| {
GatewayError::ServiceUnavailable(format!("WebSocket upstream connection failed: {}", e))
})?;
Ok(ws_stream)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
use tokio_tungstenite::accept_async;
#[test]
fn test_is_websocket_upgrade() {
let mut headers = http::HeaderMap::new();
assert!(!is_websocket_upgrade(&headers));
headers.insert("Upgrade", "websocket".parse().unwrap());
assert!(is_websocket_upgrade(&headers));
}
#[test]
fn test_is_websocket_upgrade_case_insensitive() {
let mut headers = http::HeaderMap::new();
headers.insert("upgrade", "WebSocket".parse().unwrap());
assert!(is_websocket_upgrade(&headers));
}
#[test]
fn test_is_websocket_upgrade_not_websocket() {
let mut headers = http::HeaderMap::new();
headers.insert("Upgrade", "h2c".parse().unwrap());
assert!(!is_websocket_upgrade(&headers));
}
#[test]
fn test_build_ws_url_from_http() {
let uri: http::Uri = "/ws/chat".parse().unwrap();
assert_eq!(
build_ws_url("http://127.0.0.1:8001", &uri),
"ws://127.0.0.1:8001/ws/chat"
);
}
#[test]
fn test_build_ws_url_from_https() {
let uri: http::Uri = "/ws".parse().unwrap();
assert_eq!(
build_ws_url("https://backend.example.com", &uri),
"wss://backend.example.com/ws"
);
}
#[test]
fn test_build_ws_url_already_ws() {
let uri: http::Uri = "/chat".parse().unwrap();
assert_eq!(
build_ws_url("ws://127.0.0.1:9000", &uri),
"ws://127.0.0.1:9000/chat"
);
}
#[test]
fn test_build_ws_url_with_query() {
let uri: http::Uri = "/ws?token=abc".parse().unwrap();
assert_eq!(
build_ws_url("http://127.0.0.1:8001", &uri),
"ws://127.0.0.1:8001/ws?token=abc"
);
}
#[test]
fn test_build_ws_url_trailing_slash() {
let uri: http::Uri = "/ws".parse().unwrap();
assert_eq!(
build_ws_url("http://127.0.0.1:8001/", &uri),
"ws://127.0.0.1:8001/ws"
);
}
#[test]
fn test_build_ws_url_bare_host() {
let uri: http::Uri = "/ws".parse().unwrap();
assert_eq!(
build_ws_url("127.0.0.1:8001", &uri),
"ws://127.0.0.1:8001/ws"
);
}
#[tokio::test]
async fn test_compute_accept_key() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
assert_eq!(compute_accept_key(key), expected);
}
#[tokio::test]
async fn test_connect_upstream_invalid_url() {
let result = connect_upstream("invalid-url").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_connect_upstream_connection_refused() {
let result = connect_upstream("ws://127.0.0.1:1/ws").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_connect_upstream_success() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let ws_url = format!("ws://{}/ws", addr);
let handle = tokio::spawn(async move {
if let Ok((stream, _)) = listener.accept().await {
let _ = accept_async(stream).await;
}
});
let result = connect_upstream(&ws_url).await;
assert!(result.is_ok());
handle.abort();
}
#[tokio::test]
async fn test_relay_websocket_upstream_error() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let ws_url = format!("ws://{}/ws", addr);
let server_handle = tokio::spawn(async move {
if let Ok((stream, _)) = listener.accept().await {
drop(stream);
}
});
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
let client_result = connect_async(&ws_url).await;
server_handle.abort();
let _ = client_result;
}
#[tokio::test]
async fn test_build_ws_url_with_path_only() {
let uri: http::Uri = "/".parse().unwrap();
assert_eq!(
build_ws_url("http://127.0.0.1:8001", &uri),
"ws://127.0.0.1:8001/"
);
}
}