use futures_util::stream::{ SplitSink, SplitStream };
use futures_util::{ SinkExt, StreamExt };
use log::{ debug, trace };
use core::panic;
use std::error::Error;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{oneshot, Mutex};
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{ connect_async, tungstenite::Message, WebSocketStream };
pub type WsWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
pub type WsRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
#[derive(Debug, Clone)]
pub struct InverseWebSocketClient {
ws_url: String,
write: Option<Arc<Mutex<WsWrite>>>,
read: Option<Arc<Mutex<WsRead>>>,
}
impl InverseWebSocketClient {
pub fn new(ws_url: &str) -> Self {
Self {
ws_url: ws_url.to_string(),
write: None,
read: None,
}
}
pub async fn connect(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
let (ws_stream, _) = connect_async(&self.ws_url).await?;
let (write, read) = ws_stream.split();
self.write = Some(Arc::new(Mutex::new(write)));
self.read = Some(Arc::new(Mutex::new(read)));
debug!("Connected to WebSocket at {}", self.ws_url);
Ok(())
}
pub async fn disconnect(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
let writer_arc_opt = self.write.take();
let reader_arc_opt = self.read.take();
if let (Some(writer_arc), Some(reader_arc)) = (writer_arc_opt, reader_arc_opt) {
if
let (Ok(writer_mutex), Ok(reader_mutex)) = (
Arc::try_unwrap(writer_arc),
Arc::try_unwrap(reader_arc),
)
{
let writer = writer_mutex.into_inner();
let reader = reader_mutex.into_inner();
match writer.reunite(reader) {
Ok(mut ws_stream) => {
ws_stream.close(None).await?;
debug!("WebSocket connection to {} closed (reunited)", self.ws_url);
return Ok(());
}
Err(e) => {
debug!("Failed to reunite WebSocket halves, falling back to best-effort close: {}", e);
}
}
}
}
Ok(())
}
pub async fn listen<F, Fut>(&mut self, callback: F, mut stop_rx: oneshot::Receiver<()>)
where
F: Fn(String) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send
{
if let Some(reader_arc) = &self.read.take() {
let mut reader = reader_arc.lock().await;
loop {
tokio::select! {
biased;
_ = &mut stop_rx => { break; },
msg = reader.next() => {
if let Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) = msg {
let message = {
#[cfg(feature = "tracy")]
let _decode_span = tracy_client::span!("websocket::listen_message_decode");
text.to_string()
};
callback(message).await;
}
}
}
}
} else {
panic!("WebSocket reader not available; ensure connect() was called successfully.");
}
}
pub async fn send_message(&self, message: &str) -> Result<(), Box<dyn Error + Send + Sync>> {
if let Some(writer_arc) = &self.write {
let mut writer = {
#[cfg(feature = "tracy")]
let _lock_span = tracy_client::span!("websocket::send_message_lock_writer");
writer_arc.lock().await
};
{
#[cfg(feature = "tracy")]
let _send_span = tracy_client::span!("websocket::send_message_writer_send");
writer.send(Message::Text(message.to_string().into())).await?;
}
trace!("Sent message: {}", message);
Ok(())
} else {
Err("WebSocket writer not available".into())
}
}
}