use std::sync::{Arc, Mutex};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::oneshot;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::error::Error as WsClientError;
use tokio_tungstenite::tungstenite::Message;
use tork::{api_model, websocket, App, BearerToken, Router, WebSocket, WebSocketConfig, WsMessage};
#[websocket("/ws")]
async fn echo(socket: WebSocket) -> tork::Result<()> {
let mut socket = socket.accept().await?;
while let Some(message) = socket.recv().await? {
match message {
WsMessage::Text(text) => socket.send_text(text).await?,
WsMessage::Binary(bytes) => socket.send_binary(bytes).await?,
WsMessage::Close(_) => break,
_ => {}
}
}
Ok(())
}
#[websocket("/guarded")]
async fn guarded(socket: WebSocket, _token: BearerToken) -> tork::Result<()> {
let _ = socket.accept().await?;
Ok(())
}
#[test]
fn websocket_route_builds() {
let app = App::new()
.include_router(Router::new().route(__tork_route_echo()))
.build();
assert!(app.is_ok(), "the websocket route should register");
}
#[api_model]
struct WsIn {
text: String,
}
#[api_model]
struct WsOut {
text: String,
}
#[websocket("/typed", incoming = WsIn, outgoing = WsOut)]
async fn typed(socket: WebSocket) -> tork::Result<()> {
let _ = socket.accept().await?;
Ok(())
}
#[test]
fn websocket_records_asyncapi_metadata() {
let route = __tork_route_typed();
assert!(
route.meta().websocket,
"should be marked as a websocket route"
);
assert!(
route.meta().ws_incoming.is_some(),
"incoming schema recorded"
);
assert!(
route.meta().ws_outgoing.is_some(),
"outgoing schema recorded"
);
}
async fn start_with_app(app: App) -> (std::net::SocketAddr, oneshot::Sender<()>) {
let (addr_tx, addr_rx) = oneshot::channel();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let sender = Arc::new(Mutex::new(Some(addr_tx)));
let app = app.on_ready(move |ctx| {
let sender = sender.clone();
async move {
if let Some(tx) = sender.lock().unwrap().take() {
let _ = tx.send(ctx.addr());
}
Ok(())
}
});
tokio::spawn(app.serve_with_shutdown("127.0.0.1:0", async move {
let _ = shutdown_rx.await;
}));
(addr_rx.await.unwrap(), shutdown_tx)
}
async fn start() -> (std::net::SocketAddr, oneshot::Sender<()>) {
start_with_app(
App::new().include_router(
Router::new()
.route(__tork_route_echo())
.route(__tork_route_guarded()),
),
)
.await
}
#[tokio::test]
async fn echoes_text_and_binary_over_a_real_connection() {
let (addr, shutdown) = start().await;
let (mut socket, _response) = connect_async(format!("ws://{addr}/ws")).await.unwrap();
socket.send(Message::Text("hello".into())).await.unwrap();
let reply = socket.next().await.unwrap().unwrap();
assert_eq!(reply, Message::Text("hello".into()));
socket.send(Message::Binary(vec![1, 2, 3])).await.unwrap();
let reply = socket.next().await.unwrap().unwrap();
assert_eq!(reply, Message::Binary(vec![1, 2, 3]));
socket.close(None).await.unwrap();
let _ = shutdown.send(());
}
#[tokio::test]
async fn upgrade_is_rejected_when_a_dependency_fails() {
let (addr, shutdown) = start().await;
let result = connect_async(format!("ws://{addr}/guarded")).await;
match result {
Err(WsClientError::Http(response)) => {
assert_eq!(response.status(), 401, "expected an unauthorized rejection");
}
other => panic!("expected an HTTP rejection, got {other:?}"),
}
let _ = shutdown.send(());
}
#[tokio::test]
async fn websocket_rejects_cross_origin_browser_handshakes_by_default() {
let (addr, shutdown) = start().await;
let request = http::Request::builder()
.method("GET")
.uri(format!("ws://{addr}/ws"))
.header("Host", addr.to_string())
.header("Origin", "https://evil.example.com")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.body(())
.unwrap();
let result = connect_async(request).await;
match result {
Err(WsClientError::Http(response)) => {
assert_eq!(response.status(), 403);
}
other => panic!("expected an HTTP rejection, got {other:?}"),
}
let _ = shutdown.send(());
}
#[tokio::test]
async fn websocket_accepts_same_origin_browser_handshakes() {
let (addr, shutdown) = start().await;
let request = http::Request::builder()
.method("GET")
.uri(format!("ws://{addr}/ws"))
.header("Host", addr.to_string())
.header("Origin", format!("http://{addr}"))
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.body(())
.unwrap();
let (mut socket, _response) = connect_async(request).await.unwrap();
socket.close(None).await.unwrap();
let _ = shutdown.send(());
}
#[tokio::test]
async fn websocket_origin_allowlist_can_opt_in_to_cross_origin_clients() {
let (addr, shutdown) = start_with_app(
App::new()
.websocket_config(WebSocketConfig::new().allow_origin("https://evil.example.com"))
.include_router(
Router::new()
.route(__tork_route_echo())
.route(__tork_route_guarded()),
),
)
.await;
let request = http::Request::builder()
.method("GET")
.uri(format!("ws://{addr}/ws"))
.header("Host", addr.to_string())
.header("Origin", "https://evil.example.com")
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.body(())
.unwrap();
let (mut socket, _response) = connect_async(request).await.unwrap();
socket.close(None).await.unwrap();
let _ = shutdown.send(());
}