use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::ReceiverStream;
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async,
tungstenite::protocol::Message as WsMessage,
WebSocketStream,
MaybeTlsStream,
};
use url::Url;
use log::{debug, error, info, warn};
use crate::error::Error;
use crate::protocol::JSONRPCMessage;
use super::Transport;
pub struct WebSocketTransport {
url: String,
connected: Arc<AtomicBool>,
sender: Arc<RwLock<Option<mpsc::Sender<WsMessage>>>>,
receiver: Arc<RwLock<Option<mpsc::Receiver<Result<JSONRPCMessage, Error>>>>>,
}
impl WebSocketTransport {
pub fn new(url: &str) -> Result<Self, Error> {
let _ = Url::parse(url).map_err(Error::from)?;
Ok(Self {
url: url.to_string(),
connected: Arc::new(AtomicBool::new(false)),
sender: Arc::new(RwLock::new(None)),
receiver: Arc::new(RwLock::new(None)),
})
}
}
#[async_trait]
impl Transport for WebSocketTransport {
async fn connect(&self) -> Result<(), Error> {
if self.connected.load(Ordering::SeqCst) {
return Ok(());
}
let url = Url::parse(&self.url).map_err(Error::from)?;
let (ws_stream, _) = connect_async(url.as_str())
.await
.map_err(|e| Error::TransportError(format!("WebSocket connection failed: {}", e)))?;
info!("WebSocket connected");
let (outgoing_tx, outgoing_rx) = mpsc::channel::<WsMessage>(100);
let (incoming_tx, incoming_rx) = mpsc::channel::<Result<JSONRPCMessage, Error>>(100);
{
let mut sender = self.sender.write().await;
*sender = Some(outgoing_tx);
}
{
let mut receiver = self.receiver.write().await;
*receiver = Some(incoming_rx);
}
self.connected.store(true, Ordering::SeqCst);
let connected = self.connected.clone();
tokio::spawn(async move {
if let Err(e) = handle_websocket(ws_stream, outgoing_rx, incoming_tx, connected).await {
error!("WebSocket error: {}", e);
}
});
Ok(())
}
async fn disconnect(&self) -> Result<(), Error> {
if !self.connected.load(Ordering::SeqCst) {
return Ok(());
}
self.connected.store(false, Ordering::SeqCst);
let close_frame = WsMessage::Close(None);
if let Some(sender) = &*self.sender.read().await {
let _ = sender.send(close_frame).await;
}
{
let mut sender = self.sender.write().await;
*sender = None;
}
{
let mut receiver = self.receiver.write().await;
*receiver = None;
}
Ok(())
}
async fn send(&self, message: JSONRPCMessage) -> Result<(), Error> {
if !self.connected.load(Ordering::SeqCst) {
return Err(Error::ConnectionClosed("Not connected".to_string()));
}
let json = serde_json::to_string(&message)
.map_err(|e| Error::JsonError(e.to_string()))?;
let sender = self.sender.read().await;
let sender = sender.as_ref().ok_or_else(|| {
Error::ConnectionClosed("Connection not initialized".to_string())
})?;
sender
.send(WsMessage::Text(json.into()))
.await
.map_err(|_| Error::ConnectionClosed("Failed to send message".to_string()))?;
Ok(())
}
async fn receive(&self) -> Option<Result<JSONRPCMessage, Error>> {
if !self.connected.load(Ordering::SeqCst) {
return Some(Err(Error::ConnectionClosed("Not connected".to_string())));
}
let mut receiver = self.receiver.write().await;
let receiver = receiver.as_mut()?;
receiver.recv().await
}
async fn is_connected(&self) -> bool {
self.connected.load(Ordering::SeqCst)
}
}
async fn handle_websocket(
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
outgoing_rx: mpsc::Receiver<WsMessage>,
incoming_tx: mpsc::Sender<Result<JSONRPCMessage, Error>>,
connected: Arc<AtomicBool>,
) -> Result<(), Error> {
let (ws_sender, ws_receiver) = ws_stream.split();
let outgoing_stream = ReceiverStream::new(outgoing_rx);
let outgoing_task = tokio::spawn(async move {
let mut outgoing_stream = outgoing_stream;
let mut ws_sender = ws_sender;
while let Some(msg) = outgoing_stream.next().await {
if let Err(e) = ws_sender.send(msg).await {
error!("WebSocket send error: {}", e);
break;
}
}
let _ = ws_sender.close().await;
});
let connected_clone = connected.clone();
let incoming_task = tokio::spawn(async move {
let mut ws_receiver = ws_receiver;
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
let json_msg: Result<JSONRPCMessage, _> = serde_json::from_str(&text);
match json_msg {
Ok(msg) => {
if let Err(_) = incoming_tx.send(Ok(msg)).await {
break;
}
}
Err(e) => {
warn!("Failed to parse JSON message: {}", e);
if let Err(_) = incoming_tx.send(Err(Error::ParseError(e.to_string()))).await {
break;
}
}
}
}
Ok(WsMessage::Binary(_)) => {
warn!("Received binary WebSocket message, ignoring");
}
Ok(WsMessage::Ping(_)) => {
}
Ok(WsMessage::Pong(_)) => {
}
Ok(WsMessage::Close(_)) => {
debug!("WebSocket connection closed by server");
break;
}
Ok(WsMessage::Frame(_)) => {
warn!("Received frame WebSocket message, ignoring");
}
Err(e) => {
error!("WebSocket receive error: {}", e);
if let Err(_) = incoming_tx
.send(Err(Error::TransportError(format!("WebSocket error: {}", e))))
.await
{
break;
}
break;
}
}
}
connected_clone.store(false, Ordering::SeqCst);
});
tokio::select! {
_ = outgoing_task => {
debug!("Outgoing WebSocket task completed");
}
_ = incoming_task => {
debug!("Incoming WebSocket task completed");
}
}
connected.store(false, Ordering::SeqCst);
Ok(())
}