use crate::error::WebSocketError;
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async, tungstenite::Message};
use url::Url;
#[derive(Debug)]
pub struct WebSocketConnection {
url: Url,
stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
}
impl WebSocketConnection {
pub fn new(url: Url) -> Self {
Self { url, stream: None }
}
pub async fn connect(&mut self) -> Result<(), WebSocketError> {
match connect_async(self.url.as_str()).await {
Ok((stream, _response)) => {
self.stream = Some(stream);
Ok(())
}
Err(e) => Err(WebSocketError::ConnectionFailed(format!(
"Failed to connect: {}",
e
))),
}
}
pub async fn disconnect(&mut self) -> Result<(), WebSocketError> {
self.stream = None;
Ok(())
}
pub fn is_connected(&self) -> bool {
self.stream.is_some()
}
pub async fn send(&mut self, message: String) -> Result<(), WebSocketError> {
if let Some(stream) = &mut self.stream {
match stream.send(Message::Text(message.into())).await {
Ok(()) => Ok(()),
Err(e) => {
self.stream = None;
Err(WebSocketError::ConnectionFailed(format!(
"Failed to send message: {}",
e
)))
}
}
} else {
Err(WebSocketError::ConnectionClosed)
}
}
pub async fn receive(&mut self) -> Result<String, WebSocketError> {
if let Some(stream) = &mut self.stream {
loop {
match stream.next().await {
Some(Ok(Message::Text(text))) => return Ok(text.to_string()),
Some(Ok(
Message::Binary(_)
| Message::Ping(_)
| Message::Pong(_)
| Message::Frame(_),
)) => {
continue;
}
Some(Ok(Message::Close(_))) => {
self.stream = None;
return Err(WebSocketError::ConnectionClosed);
}
Some(Err(e)) => {
self.stream = None;
return Err(WebSocketError::ConnectionFailed(format!(
"Failed to receive message: {}",
e
)));
}
None => {
self.stream = None;
return Err(WebSocketError::ConnectionClosed);
}
}
}
} else {
Err(WebSocketError::ConnectionClosed)
}
}
pub fn url(&self) -> &Url {
&self.url
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;
async fn spawn_mock_server<F, Fut>(send_frames: F) -> (SocketAddr, JoinHandle<()>)
where
F: FnOnce(
futures_util::stream::SplitSink<WebSocketStream<tokio::net::TcpStream>, Message>,
) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send,
{
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind localhost ephemeral port");
let addr = listener
.local_addr()
.expect("read local addr of bound listener");
let handle = tokio::spawn(async move {
let (socket, _peer) = match listener.accept().await {
Ok(pair) => pair,
Err(_) => return,
};
let ws = match accept_async(socket).await {
Ok(ws) => ws,
Err(_) => return,
};
let (sink, mut stream) = ws.split();
let drain = tokio::spawn(async move {
while let Some(msg) = stream.next().await {
if msg.is_err() {
break;
}
}
});
send_frames(sink).await;
let _ = drain.await;
});
(addr, handle)
}
fn ws_url(addr: SocketAddr) -> Url {
Url::parse(&format!("ws://{}/", addr)).expect("valid ws url")
}
async fn connect_client(addr: SocketAddr) -> WebSocketConnection {
let mut client = WebSocketConnection::new(ws_url(addr));
client
.connect()
.await
.expect("client connects to mock server");
client
}
#[tokio::test]
async fn test_receive_skips_ping_frames_then_returns_text() {
let (addr, server) = spawn_mock_server(|mut sink| async move {
for _ in 0..10_000 {
if sink.send(Message::Ping(Vec::new().into())).await.is_err() {
return;
}
}
let _ = sink.send(Message::Text("payload".into())).await;
})
.await;
let mut client = connect_client(addr).await;
let received = client.receive().await.expect("receive returns the text");
assert_eq!(received, "payload");
drop(client);
server.await.expect("server task did not panic");
}
#[tokio::test]
async fn test_receive_skips_binary_frames_then_returns_text() {
let (addr, server) = spawn_mock_server(|mut sink| async move {
for _ in 0..100 {
if sink
.send(Message::Binary(vec![1, 2, 3].into()))
.await
.is_err()
{
return;
}
}
let _ = sink.send(Message::Text("payload".into())).await;
})
.await;
let mut client = connect_client(addr).await;
let received = client.receive().await.expect("receive returns the text");
assert_eq!(received, "payload");
drop(client);
server.await.expect("server task did not panic");
}
#[tokio::test]
async fn test_receive_skips_pong_frames_then_returns_text() {
let (addr, server) = spawn_mock_server(|mut sink| async move {
for _ in 0..100 {
if sink.send(Message::Pong(Vec::new().into())).await.is_err() {
return;
}
}
let _ = sink.send(Message::Text("payload".into())).await;
})
.await;
let mut client = connect_client(addr).await;
let received = client.receive().await.expect("receive returns the text");
assert_eq!(received, "payload");
drop(client);
server.await.expect("server task did not panic");
}
#[tokio::test]
async fn test_receive_returns_closed_on_close_frame() {
let (addr, server) = spawn_mock_server(|mut sink| async move {
let _ = sink.send(Message::Close(None)).await;
let _ = sink.close().await;
})
.await;
let mut client = connect_client(addr).await;
let result = client.receive().await;
assert!(
matches!(result, Err(WebSocketError::ConnectionClosed)),
"expected ConnectionClosed, got {:?}",
result
);
assert!(
!client.is_connected(),
"stream should be cleared after close frame"
);
drop(client);
server.await.expect("server task did not panic");
}
#[tokio::test]
async fn test_receive_skips_mixed_non_text_frames() {
let (addr, server) = spawn_mock_server(|mut sink| async move {
for _ in 0..200 {
if sink.send(Message::Ping(Vec::new().into())).await.is_err() {
return;
}
if sink
.send(Message::Binary(vec![9, 9, 9].into()))
.await
.is_err()
{
return;
}
if sink.send(Message::Pong(Vec::new().into())).await.is_err() {
return;
}
}
let _ = sink.send(Message::Text("payload".into())).await;
})
.await;
let mut client = connect_client(addr).await;
let received = client.receive().await.expect("receive returns the text");
assert_eq!(received, "payload");
drop(client);
server.await.expect("server task did not panic");
}
}