use std::hash::{DefaultHasher, Hash, Hasher};
use futures::stream::{BoxStream, SplitSink};
use futures::{SinkExt as _, StreamExt as _, stream};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
use super::{SubscriptionId, SubscriptionSource};
#[derive(Debug, Clone)]
pub enum WebSocketCommand {
SendText(String),
SendBinary(Vec<u8>),
Close(Option<CloseFrame>),
}
#[derive(Debug, Clone)]
pub enum WebSocketMessage {
Connected {
sender: mpsc::UnboundedSender<WebSocketCommand>,
},
Disconnected,
Received(Message),
Error { error: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WebSocket {
url: String,
}
impl WebSocket {
#[must_use]
pub fn new(url: impl Into<String>) -> Self {
Self { url: url.into() }
}
}
impl WebSocket {
async fn handle_command(
cmd: WebSocketCommand,
write: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
msg_tx: &mpsc::UnboundedSender<WebSocketMessage>,
) {
let result = match cmd {
WebSocketCommand::SendText(text) => write.send(Message::Text(text.into())).await,
WebSocketCommand::SendBinary(data) => write.send(Message::Binary(data.into())).await,
WebSocketCommand::Close(frame) => write.send(Message::Close(frame)).await,
};
if let Err(e) = result {
let _ = msg_tx.send(WebSocketMessage::Error {
error: e.to_string(),
});
}
}
async fn run_subscription_loop(
url: String,
msg_tx: mpsc::UnboundedSender<WebSocketMessage>,
mut cmd_rx: mpsc::UnboundedReceiver<WebSocketCommand>,
cmd_tx: mpsc::UnboundedSender<WebSocketCommand>,
) {
let ws_stream = match connect_async(&url).await {
Ok((stream, _)) => stream,
Err(e) => {
let _ = msg_tx.send(WebSocketMessage::Error {
error: format!("Connection failed: {e}"),
});
return;
}
};
if msg_tx
.send(WebSocketMessage::Connected { sender: cmd_tx })
.is_err()
{
return;
}
let (mut write, mut read) = ws_stream.split();
loop {
tokio::select! {
msg = read.next() => {
match msg {
Some(Ok(Message::Close(_))) => {
let _ = msg_tx.send(WebSocketMessage::Disconnected);
break;
}
Some(Ok(message)) => {
if msg_tx.send(WebSocketMessage::Received(message)).is_err() {
break;
}
}
Some(Err(e)) => {
let _ = msg_tx.send(WebSocketMessage::Error {
error: e.to_string(),
});
break;
}
None => {
let _ = msg_tx.send(WebSocketMessage::Disconnected);
break;
}
}
}
cmd = cmd_rx.recv() => {
match cmd {
Some(WebSocketCommand::Close(frame)) => {
let _ = write.send(Message::Close(frame)).await;
let _ = msg_tx.send(WebSocketMessage::Disconnected);
break;
}
Some(cmd) => {
Self::handle_command(cmd, &mut write, &msg_tx).await;
}
None => {
let _ = msg_tx.send(WebSocketMessage::Disconnected);
break;
}
}
}
}
}
let _ = write.close().await;
}
}
impl SubscriptionSource for WebSocket {
type Output = WebSocketMessage;
fn stream(&self) -> BoxStream<'static, WebSocketMessage> {
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let url = self.url.clone();
tokio::spawn(async move {
Self::run_subscription_loop(url, msg_tx, cmd_rx, cmd_tx).await;
});
stream::unfold(msg_rx, |mut rx| async move {
let msg = rx.recv().await?;
Some((msg, rx))
})
.boxed()
}
fn id(&self) -> SubscriptionId {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
SubscriptionId::of::<Self>(hasher.finish())
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[test]
fn test_ws_new() {
let ws = WebSocket::new("wss://example.com");
assert_eq!(ws.url, "wss://example.com");
}
#[test]
fn test_ws_id_consistency() {
let ws1 = WebSocket::new("wss://example.com");
let ws2 = WebSocket::new("wss://example.com");
assert_eq!(ws1.id(), ws2.id());
}
#[test]
fn test_ws_id_different_urls() {
let ws1 = WebSocket::new("wss://example.com");
let ws2 = WebSocket::new("wss://different.com");
assert_ne!(ws1.id(), ws2.id());
}
#[tokio::test]
async fn test_stream_emits_error_on_connection_failure() {
let ws = WebSocket::new("ws://localhost:1");
let mut stream = ws.stream();
assert!(matches!(
stream.next().await,
Some(WebSocketMessage::Error { .. }),
));
}
#[test]
fn test_message_variants() {
let (tx, _rx) = mpsc::unbounded_channel();
matches!(
WebSocketMessage::Connected { sender: tx },
WebSocketMessage::Connected { .. }
);
matches!(
WebSocketMessage::Disconnected,
WebSocketMessage::Disconnected
);
matches!(
WebSocketMessage::Received(Message::Text("test".into())),
WebSocketMessage::Received(_)
);
matches!(
WebSocketMessage::Error {
error: "test".to_string()
},
WebSocketMessage::Error { .. }
);
}
}