use super::error::{DXLinkError, DXLinkResult};
use futures_util::{SinkExt, StreamExt};
use serde::Serialize;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::timeout;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{WebSocketStream, connect_async};
use tracing::{debug, error};
#[derive(Debug)]
pub struct WebSocketConnection {
write: Arc<
Mutex<futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
>,
read: Arc<Mutex<futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
}
impl WebSocketConnection {
pub async fn connect(url: &str) -> DXLinkResult<Self> {
debug!("Connecting to WebSocket at: {}", url);
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| DXLinkError::Connection(format!("Failed to connect: {}", e)))?;
debug!("WebSocket connection established");
let (write, read) = ws_stream.split();
Ok(Self {
write: Arc::new(Mutex::new(write)),
read: Arc::new(Mutex::new(read)),
})
}
pub async fn send<T: Serialize>(&self, message: &T) -> DXLinkResult<()> {
let json = serde_json::to_string(message)?;
debug!("Sending message: {}", json);
let mut write = self.write.lock().await;
write.send(Message::Text(json.into())).await?;
Ok(())
}
pub async fn receive(&self) -> DXLinkResult<String> {
let mut read = self.read.lock().await;
match read.next().await {
Some(Ok(Message::Text(text))) => {
debug!("Received message: {}", text);
Ok(text.to_string())
}
Some(Ok(message)) => {
debug!("Received non-text message: {:?}", message);
Err(DXLinkError::UnexpectedMessage(
"Expected text message".to_string(),
))
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
Err(DXLinkError::WebSocket(Box::new(e)))
}
None => {
error!("WebSocket connection closed unexpectedly");
Err(DXLinkError::Connection(
"Connection closed unexpectedly".to_string(),
))
}
}
}
pub async fn receive_with_timeout(&self, duration: Duration) -> DXLinkResult<Option<String>> {
let read_future = self.receive();
match timeout(duration, read_future).await {
Ok(result) => result.map(Some),
Err(_) => Ok(None), }
}
pub fn create_keepalive_sender(&self) -> KeepAliveSender {
KeepAliveSender {
connection: self.clone(),
}
}
}
impl Clone for WebSocketConnection {
fn clone(&self) -> Self {
Self {
write: Arc::clone(&self.write),
read: Arc::clone(&self.read),
}
}
}
#[derive(Clone)]
pub struct KeepAliveSender {
connection: WebSocketConnection,
}
impl KeepAliveSender {
pub async fn send_keepalive(&self, channel: u32) -> DXLinkResult<()> {
use crate::messages::KeepaliveMessage;
let keepalive_msg = KeepaliveMessage {
channel,
message_type: "KEEPALIVE".to_string(),
};
self.connection.send(&keepalive_msg).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use warp::Filter;
use warp::ws::{Message as WarpMessage, WebSocket as WarpWebSocket};
async fn setup_test_server() -> (SocketAddr, mpsc::Receiver<String>, mpsc::Sender<String>) {
let (client_tx, client_rx) = mpsc::channel::<String>(10);
let (server_tx, server_rx) = mpsc::channel::<String>(10);
let client_tx = Arc::new(Mutex::new(client_tx));
let server_rx = Arc::new(Mutex::new(server_rx));
let websocket = warp::path("websocket")
.and(warp::ws())
.map(move |ws: warp::ws::Ws| {
let client_tx = client_tx.clone();
let server_rx = server_rx.clone();
ws.on_upgrade(move |websocket| handle_websocket(websocket, client_tx, server_rx))
});
let addr = ([127, 0, 0, 1], 3030).into();
let server = warp::serve(websocket).run(addr);
tokio::spawn(server);
(addr, client_rx, server_tx)
}
async fn handle_websocket(
websocket: WarpWebSocket,
client_tx: Arc<Mutex<mpsc::Sender<String>>>,
server_rx: Arc<Mutex<mpsc::Receiver<String>>>,
) {
let (mut ws_tx, mut ws_rx) = websocket.split();
let server_to_client = tokio::spawn(async move {
let mut rx = server_rx.lock().await;
while let Some(msg) = rx.recv().await {
ws_tx
.send(WarpMessage::text(msg))
.await
.expect("Failed to send message");
}
});
let client_to_server = tokio::spawn(async move {
let tx = client_tx.lock().await;
while let Some(result) = ws_rx.next().await {
match result {
Ok(msg) if msg.is_text() => {
if let Ok(text) = msg.to_str() {
tx.send(text.to_string())
.await
.expect("Failed to send to channel");
}
}
_ => break,
}
}
});
let _ = tokio::join!(server_to_client, client_to_server);
}
#[tokio::test]
#[ignore] async fn test_websocket_connection() {
let (addr, mut client_rx, server_tx) = setup_test_server().await;
let ws_url = format!("ws://{}/websocket", addr);
let connection = WebSocketConnection::connect(&ws_url)
.await
.expect("Failed to connect");
#[derive(Serialize)]
struct TestMessage {
channel: u32,
#[serde(rename = "type")]
message_type: String,
data: String,
}
let test_msg = TestMessage {
channel: 1,
message_type: "TEST".to_string(),
data: "Hello, World!".to_string(),
};
connection
.send(&test_msg)
.await
.expect("Failed to send message");
if let Some(received) = client_rx.recv().await {
let parsed: serde_json::Value = serde_json::from_str(&received).unwrap();
assert_eq!(parsed["channel"], 1);
assert_eq!(parsed["type"], "TEST");
assert_eq!(parsed["data"], "Hello, World!");
} else {
panic!("No message received");
}
server_tx
.send("test_response".to_string())
.await
.expect("Failed to send from server");
let received = connection
.receive()
.await
.expect("Failed to receive message");
assert_eq!(received, "test_response");
}
}
#[cfg(test)]
mod additional_tests {
use super::*;
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::sleep;
use warp::Filter;
use warp::ws::{Message as WarpMessage, WebSocket as WarpWebSocket};
async fn setup_test_server() -> (
SocketAddr,
mpsc::Receiver<String>,
mpsc::Sender<String>,
mpsc::Sender<bool>,
) {
let (client_tx, client_rx) = mpsc::channel::<String>(10);
let (server_tx, server_rx) = mpsc::channel::<String>(10);
let (binary_tx, binary_rx) = mpsc::channel::<bool>(10);
let client_tx = Arc::new(tokio::sync::Mutex::new(client_tx));
let server_rx = Arc::new(tokio::sync::Mutex::new(server_rx));
let binary_rx = Arc::new(tokio::sync::Mutex::new(binary_rx));
let websocket = warp::path("websocket")
.and(warp::ws())
.map(move |ws: warp::ws::Ws| {
let client_tx = client_tx.clone();
let server_rx = server_rx.clone();
let binary_rx = binary_rx.clone();
ws.on_upgrade(move |websocket| {
handle_websocket(websocket, client_tx, server_rx, binary_rx)
})
});
let addr = ([127, 0, 0, 1], 3031).into();
let server = warp::serve(websocket).run(addr);
tokio::spawn(server);
(addr, client_rx, server_tx, binary_tx)
}
async fn handle_websocket(
websocket: WarpWebSocket,
client_tx: Arc<tokio::sync::Mutex<mpsc::Sender<String>>>,
server_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<String>>>,
binary_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<bool>>>,
) {
let (ws_tx, mut ws_rx) = websocket.split();
let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx));
let ws_tx_clone = ws_tx.clone();
let server_to_client = tokio::spawn(async move {
let mut rx = server_rx.lock().await;
while let Some(msg) = rx.recv().await {
let mut tx = ws_tx_clone.lock().await;
tx.send(WarpMessage::text(msg))
.await
.expect("Failed to send message");
}
});
let binary_to_client = tokio::spawn(async move {
let mut rx = binary_rx.lock().await;
while rx.recv().await.is_some() {
let mut tx = ws_tx.lock().await;
tx.send(WarpMessage::binary(vec![1, 2, 3]))
.await
.expect("Failed to send binary message");
}
});
let client_to_server = tokio::spawn(async move {
let tx = client_tx.lock().await;
while let Some(result) = ws_rx.next().await {
match result {
Ok(msg) if msg.is_text() => {
if let Ok(text) = msg.to_str() {
tx.send(text.to_string())
.await
.expect("Failed to send to channel");
}
}
_ => break,
}
}
});
let _ = tokio::join!(server_to_client, binary_to_client, client_to_server);
}
#[tokio::test]
#[ignore] async fn test_receive_with_timeout_success() {
let (addr, _client_rx, server_tx, _binary_tx) = setup_test_server().await;
let ws_url = format!("ws://{}/websocket", addr);
let connection = WebSocketConnection::connect(&ws_url)
.await
.expect("Failed to connect");
let server_tx_clone = server_tx.clone();
tokio::spawn(async move {
sleep(Duration::from_millis(50)).await;
server_tx_clone
.send("test_response".to_string())
.await
.expect("Failed to send from server");
});
let result = connection
.receive_with_timeout(Duration::from_millis(500))
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some("test_response".to_string()));
}
#[tokio::test]
#[ignore] async fn test_receive_with_timeout_timeout() {
let (addr, _client_rx, _server_tx, _binary_tx) = setup_test_server().await;
let ws_url = format!("ws://{}/websocket", addr);
let connection = WebSocketConnection::connect(&ws_url)
.await
.expect("Failed to connect");
let result = connection
.receive_with_timeout(Duration::from_millis(100))
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), None);
}
#[tokio::test]
#[ignore] async fn test_receive_non_text_message() {
let (addr, _client_rx, _server_tx, binary_tx) = setup_test_server().await;
let ws_url = format!("ws://{}/websocket", addr);
let connection = WebSocketConnection::connect(&ws_url)
.await
.expect("Failed to connect");
binary_tx
.send(true)
.await
.expect("Failed to trigger binary message");
let result = connection.receive().await;
assert!(result.is_err());
match result {
Err(DXLinkError::UnexpectedMessage(msg)) => {
assert!(msg.contains("Expected text message"));
}
_ => panic!("Expected UnexpectedMessage error, got: {:?}", result),
}
}
#[tokio::test]
#[ignore] async fn test_clone() {
let (addr, _client_rx, server_tx, _binary_tx) = setup_test_server().await;
let ws_url = format!("ws://{}/websocket", addr);
let connection = WebSocketConnection::connect(&ws_url)
.await
.expect("Failed to connect");
let connection_clone = connection.clone();
server_tx
.send("test_message".to_string())
.await
.expect("Failed to send from server");
let result = connection.receive().await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test_message");
server_tx
.send("clone_message".to_string())
.await
.expect("Failed to send from server");
let result = connection_clone.receive().await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "clone_message");
}
#[tokio::test]
#[ignore] async fn test_keepalive_sender_with_clone() {
let (addr, mut client_rx, _server_tx, _binary_tx) = setup_test_server().await;
let ws_url = format!("ws://{}/websocket", addr);
let connection = WebSocketConnection::connect(&ws_url)
.await
.expect("Failed to connect");
let keepalive_sender = connection.create_keepalive_sender();
keepalive_sender
.send_keepalive(5)
.await
.expect("Failed to send keepalive");
if let Some(received) = client_rx.recv().await {
let parsed: serde_json::Value = serde_json::from_str(&received).unwrap();
assert_eq!(parsed["channel"], 5);
assert_eq!(parsed["type"], "KEEPALIVE");
} else {
panic!("No keepalive message received");
}
let connection_clone = connection.clone();
let keepalive_sender2 = connection_clone.create_keepalive_sender();
keepalive_sender2
.send_keepalive(10)
.await
.expect("Failed to send keepalive from clone");
if let Some(received) = client_rx.recv().await {
let parsed: serde_json::Value = serde_json::from_str(&received).unwrap();
assert_eq!(parsed["channel"], 10);
assert_eq!(parsed["type"], "KEEPALIVE");
} else {
panic!("No keepalive message received from clone");
}
}
}