use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_tungstenite::{
connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
};
use tracing::{debug, instrument};
#[derive(Error, Debug)]
pub enum TransportError {
#[error("connection failed: {0}")]
ConnectionFailed(String),
#[error("connection closed")]
ConnectionClosed,
#[error("send failed: {0}")]
SendFailed(String),
#[error("receive failed: {0}")]
ReceiveFailed(String),
#[error("connection timeout after {0:?}")]
Timeout(Duration),
#[error("not connected")]
NotConnected,
#[error("protocol error: {0}")]
Protocol(String),
}
#[async_trait]
pub trait Transport: Send + Sync {
async fn connect(&mut self) -> Result<(), TransportError>;
async fn send(&mut self, message: &str) -> Result<(), TransportError>;
async fn recv(&mut self) -> Result<Option<String>, TransportError>;
async fn close(&mut self) -> Result<(), TransportError>;
fn is_connected(&self) -> bool;
fn endpoint(&self) -> &str;
}
pub struct WsTransport {
url: String,
stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
connect_timeout: Duration,
}
impl WsTransport {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
stream: None,
connect_timeout: Duration::from_secs(10),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
}
#[async_trait]
impl Transport for WsTransport {
#[instrument(skip(self), fields(url = %self.url))]
async fn connect(&mut self) -> Result<(), TransportError> {
debug!("Connecting to WebSocket");
let connect_future = connect_async(&self.url);
let (ws_stream, _response) = timeout(self.connect_timeout, connect_future)
.await
.map_err(|_| TransportError::Timeout(self.connect_timeout))?
.map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
self.stream = Some(ws_stream);
debug!("WebSocket connected");
Ok(())
}
#[instrument(skip(self, message), fields(len = message.len()))]
async fn send(&mut self, message: &str) -> Result<(), TransportError> {
let stream = self.stream.as_mut().ok_or(TransportError::NotConnected)?;
stream
.send(Message::Text(message.to_string()))
.await
.map_err(|e| TransportError::SendFailed(e.to_string()))?;
Ok(())
}
#[instrument(skip(self))]
async fn recv(&mut self) -> Result<Option<String>, TransportError> {
let stream = self.stream.as_mut().ok_or(TransportError::NotConnected)?;
match stream.next().await {
Some(Ok(Message::Text(text))) => Ok(Some(text)),
Some(Ok(Message::Binary(data))) => {
String::from_utf8(data)
.map(Some)
.map_err(|e| TransportError::Protocol(e.to_string()))
}
Some(Ok(Message::Close(_))) => {
self.stream = None;
Ok(None)
}
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {
Box::pin(self.recv()).await
}
Some(Ok(Message::Frame(_))) => {
Box::pin(self.recv()).await
}
Some(Err(e)) => Err(TransportError::ReceiveFailed(e.to_string())),
None => {
self.stream = None;
Err(TransportError::ConnectionClosed)
}
}
}
#[instrument(skip(self))]
async fn close(&mut self) -> Result<(), TransportError> {
if let Some(mut stream) = self.stream.take() {
stream
.close(None)
.await
.map_err(|e| TransportError::SendFailed(e.to_string()))?;
}
Ok(())
}
fn is_connected(&self) -> bool {
self.stream.is_some()
}
fn endpoint(&self) -> &str {
&self.url
}
}
#[cfg(any(test, feature = "test-utils"))]
pub struct MockTransport {
url: String,
connected: bool,
pub responses: std::collections::VecDeque<Result<Option<String>, TransportError>>,
pub sent_messages: Vec<String>,
pub fail_connect: bool,
pub fail_send: bool,
}
#[cfg(any(test, feature = "test-utils"))]
impl MockTransport {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
connected: false,
responses: std::collections::VecDeque::new(),
sent_messages: Vec::new(),
fail_connect: false,
fail_send: false,
}
}
pub fn push_response(&mut self, msg: impl Into<String>) {
self.responses.push_back(Ok(Some(msg.into())));
}
pub fn push_responses(&mut self, msgs: impl IntoIterator<Item = impl Into<String>>) {
for msg in msgs {
self.push_response(msg);
}
}
pub fn push_close(&mut self) {
self.responses.push_back(Ok(None));
}
pub fn push_error(&mut self, error: TransportError) {
self.responses.push_back(Err(error));
}
pub fn take_sent(&mut self) -> Vec<String> {
std::mem::take(&mut self.sent_messages)
}
}
#[cfg(any(test, feature = "test-utils"))]
#[async_trait]
impl Transport for MockTransport {
async fn connect(&mut self) -> Result<(), TransportError> {
if self.fail_connect {
return Err(TransportError::ConnectionFailed("mock connection failure".into()));
}
self.connected = true;
Ok(())
}
async fn send(&mut self, message: &str) -> Result<(), TransportError> {
if !self.connected {
return Err(TransportError::NotConnected);
}
if self.fail_send {
return Err(TransportError::SendFailed("mock send failure".into()));
}
self.sent_messages.push(message.to_string());
Ok(())
}
async fn recv(&mut self) -> Result<Option<String>, TransportError> {
if !self.connected {
return Err(TransportError::NotConnected);
}
self.responses
.pop_front()
.unwrap_or(Err(TransportError::ConnectionClosed))
}
async fn close(&mut self) -> Result<(), TransportError> {
self.connected = false;
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
fn endpoint(&self) -> &str {
&self.url
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_transport_send_recv() {
let mut transport = MockTransport::new("wss://mock.test");
transport.push_response(r#"{"type":"pong"}"#);
transport.connect().await.unwrap();
assert!(transport.is_connected());
transport.send(r#"{"type":"ping"}"#).await.unwrap();
assert_eq!(transport.sent_messages.len(), 1);
assert!(transport.sent_messages[0].contains("ping"));
let response = transport.recv().await.unwrap();
assert!(response.unwrap().contains("pong"));
}
#[tokio::test]
async fn test_mock_transport_connection_failure() {
let mut transport = MockTransport::new("wss://mock.test");
transport.fail_connect = true;
let result = transport.connect().await;
assert!(result.is_err());
assert!(!transport.is_connected());
}
#[tokio::test]
async fn test_mock_transport_close() {
let mut transport = MockTransport::new("wss://mock.test");
transport.push_close();
transport.connect().await.unwrap();
let response = transport.recv().await.unwrap();
assert!(response.is_none()); }
}