use super::{Transport, TransportReceiver, TransportSender};
use crate::{Error, Result};
use futures_util::{SinkExt, StreamExt};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use url::Url;
pub struct WebSocketTransport {
message_tx: mpsc::UnboundedSender<JsonValue>,
sender: futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>,
receiver: Option<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
}
pub struct WebSocketTransportReceiver {
receiver: futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
message_tx: mpsc::UnboundedSender<JsonValue>,
}
impl WebSocketTransport {
pub async fn connect(
url: &str,
headers: Option<HashMap<String, String>>,
) -> Result<(Self, mpsc::UnboundedReceiver<JsonValue>)> {
let (message_tx, message_rx) = mpsc::unbounded_channel();
let _parsed_url =
Url::parse(url).map_err(|e| Error::TransportError(format!("Invalid URL: {}", e)))?;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
let mut request = url
.into_client_request()
.map_err(|e| Error::TransportError(format!("Failed to build request: {}", e)))?;
if let Some(headers_map) = headers {
use std::str::FromStr;
use tokio_tungstenite::tungstenite::http::header::{HeaderName, HeaderValue};
let headers = request.headers_mut();
for (k, v) in headers_map {
let name = HeaderName::from_str(&k)
.map_err(|e| Error::TransportError(format!("Invalid header name: {}", e)))?;
let value = HeaderValue::from_str(&v)
.map_err(|e| Error::TransportError(format!("Invalid header value: {}", e)))?;
headers.insert(name, value);
}
}
let (ws_stream, _) = tokio_tungstenite::connect_async(request)
.await
.map_err(|e| Error::TransportError(format!("WebSocket connection failed: {}", e)))?;
let (sender, receiver) = ws_stream.split();
Ok((
Self {
message_tx,
sender,
receiver: Some(receiver),
},
message_rx,
))
}
pub fn into_parts(mut self) -> (WebSocketTransportSender, WebSocketTransportReceiver) {
let receiver = self.receiver.take().expect("Receiver already taken");
let sender = WebSocketTransportSender {
sender: self.sender,
};
let receiver = WebSocketTransportReceiver {
receiver,
message_tx: self.message_tx,
};
(sender, receiver)
}
}
pub struct WebSocketTransportSender {
sender: futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>,
}
impl TransportSender for WebSocketTransportSender {
fn send(
&mut self,
message: JsonValue,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move {
let json_str = serde_json::to_string(&message)
.map_err(|e| Error::TransportError(format!("Failed to serialize JSON: {}", e)))?;
self.sender
.send(WsMessage::Text(json_str.into()))
.await
.map_err(|e| {
Error::TransportError(format!("Failed to send WebSocket message: {}", e))
})
})
}
}
impl TransportReceiver for WebSocketTransportReceiver {
fn run(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move {
while let Some(msg_result) = self.receiver.next().await {
match msg_result {
Ok(msg) => {
match msg {
WsMessage::Text(text) => {
let message: JsonValue =
serde_json::from_str(&text).map_err(|e| {
Error::ProtocolError(format!("Failed to parse JSON: {}", e))
})?;
if self.message_tx.send(message).is_err() {
break;
}
}
WsMessage::Binary(_) => {
}
WsMessage::Close(_) => break,
_ => {}
}
}
Err(e) => {
return Err(Error::TransportError(format!(
"WebSocket read error: {}",
e
)));
}
}
}
Ok(())
})
}
}
impl Transport for WebSocketTransport {
async fn send(&mut self, message: JsonValue) -> Result<()> {
let json_str = serde_json::to_string(&message)
.map_err(|e| Error::TransportError(format!("Failed to serialize JSON: {}", e)))?;
self.sender
.send(WsMessage::Text(json_str.into()))
.await
.map_err(|e| Error::TransportError(format!("Failed to send WebSocket message: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_url_parsing_in_connect() {
let url = Url::parse("ws://127.0.0.1:9000").unwrap();
assert_eq!(url.port(), Some(9000));
}
}