use crate::auth::types::AuthMessage;
use crate::{Error, Result};
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tokio_tungstenite::tungstenite::Message;
use super::http::{Transport, TransportCallback};
#[derive(Debug, Clone)]
pub struct WebSocketTransportOptions {
pub base_url: String,
pub read_deadline_secs: Option<u64>,
}
type WsSink = futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
Message,
>;
pub struct WebSocketTransport {
base_url: String,
sink: Arc<Mutex<Option<WsSink>>>,
on_data_callbacks: Arc<RwLock<Vec<Box<TransportCallback>>>>,
read_deadline_secs: u64,
}
impl std::fmt::Debug for WebSocketTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketTransport")
.field("base_url", &self.base_url)
.field("read_deadline_secs", &self.read_deadline_secs)
.finish()
}
}
impl WebSocketTransport {
pub fn new(options: WebSocketTransportOptions) -> Result<Self> {
if options.base_url.is_empty() {
return Err(Error::TransportError(
"base_url is required for WebSocket transport".into(),
));
}
if !options.base_url.starts_with("ws://") && !options.base_url.starts_with("wss://") {
return Err(Error::TransportError(
"WebSocket URL must start with ws:// or wss://".into(),
));
}
url_parse_check(&options.base_url)?;
let read_deadline_secs = match options.read_deadline_secs {
Some(s) if s > 0 => s,
_ => 30,
};
Ok(Self {
base_url: options.base_url,
sink: Arc::new(Mutex::new(None)),
on_data_callbacks: Arc::new(RwLock::new(Vec::new())),
read_deadline_secs,
})
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn read_deadline_secs(&self) -> u64 {
self.read_deadline_secs
}
async fn connect(&self) -> Result<()> {
let (ws_stream, _response) = tokio_tungstenite::connect_async(&self.base_url)
.await
.map_err(|e| Error::TransportError(format!("failed to connect to WebSocket: {}", e)))?;
let (sink, stream) = ws_stream.split();
{
let mut sink_guard = self.sink.lock().await;
*sink_guard = Some(sink);
}
let callbacks = self.on_data_callbacks.clone();
let sink_ref = self.sink.clone();
let deadline = self.read_deadline_secs;
tokio::spawn(async move {
receive_loop(stream, callbacks, sink_ref, deadline).await;
});
Ok(())
}
}
async fn receive_loop(
mut stream: futures_util::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
callbacks: Arc<RwLock<Vec<Box<TransportCallback>>>>,
sink: Arc<Mutex<Option<WsSink>>>,
_deadline: u64,
) {
loop {
match stream.next().await {
Some(Ok(msg)) => {
let text = match msg {
Message::Text(t) => t,
Message::Binary(b) => match String::from_utf8(b.to_vec()) {
Ok(s) => s,
Err(_) => continue,
},
Message::Close(_) => {
let mut sink_guard = sink.lock().await;
*sink_guard = None;
return;
}
_ => continue,
};
let auth_message: AuthMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(_) => continue, };
let cbs = callbacks.read().await;
for cb in cbs.iter() {
let _ = cb(auth_message.clone()).await;
}
}
Some(Err(_)) => {
let mut sink_guard = sink.lock().await;
*sink_guard = None;
return;
}
None => {
let mut sink_guard = sink.lock().await;
*sink_guard = None;
return;
}
}
}
}
fn url_parse_check(url: &str) -> Result<()> {
let after_scheme = if let Some(rest) = url.strip_prefix("wss://") {
rest
} else if let Some(rest) = url.strip_prefix("ws://") {
rest
} else {
return Err(Error::TransportError("invalid WebSocket URL".into()));
};
if after_scheme.is_empty() {
return Err(Error::TransportError(
"WebSocket URL must include a host".into(),
));
}
Ok(())
}
#[async_trait]
impl Transport for WebSocketTransport {
async fn send(&self, message: &AuthMessage) -> Result<()> {
{
let cbs = self.on_data_callbacks.read().await;
if cbs.is_empty() {
return Err(Error::TransportError("no handler registered".into()));
}
}
{
let sink_guard = self.sink.lock().await;
if sink_guard.is_none() {
drop(sink_guard);
self.connect().await?;
}
}
let json_data = serde_json::to_string(message)
.map_err(|e| Error::TransportError(format!("failed to marshal auth message: {}", e)))?;
let mut sink_guard = self.sink.lock().await;
if let Some(ref mut sink) = *sink_guard {
if let Err(e) = sink.send(Message::Text(json_data)).await {
*sink_guard = None;
return Err(Error::TransportError(format!(
"failed to send WebSocket message: {}",
e
)));
}
Ok(())
} else {
Err(Error::TransportError(
"WebSocket connection not available".into(),
))
}
}
fn set_callback(&self, callback: Box<TransportCallback>) {
let callbacks = self.on_data_callbacks.clone();
tokio::spawn(async move {
let mut cbs = callbacks.write().await;
cbs.push(callback);
});
}
fn clear_callback(&self) {
let callbacks = self.on_data_callbacks.clone();
tokio::spawn(async move {
let mut cbs = callbacks.write().await;
cbs.clear();
});
}
}
#[cfg(all(test, feature = "websocket"))]
mod tests {
use super::*;
#[test]
fn test_new_with_valid_ws_url() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080".to_string(),
read_deadline_secs: None,
});
assert!(transport.is_ok());
let t = transport.unwrap();
assert_eq!(t.base_url(), "ws://localhost:8080");
}
#[test]
fn test_new_with_valid_wss_url() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "wss://example.com/ws".to_string(),
read_deadline_secs: Some(60),
});
assert!(transport.is_ok());
let t = transport.unwrap();
assert_eq!(t.base_url(), "wss://example.com/ws");
assert_eq!(t.read_deadline_secs(), 60);
}
#[test]
fn test_new_with_empty_url_returns_error() {
let result = WebSocketTransport::new(WebSocketTransportOptions {
base_url: String::new(),
read_deadline_secs: None,
});
assert!(result.is_err());
match result.unwrap_err() {
Error::TransportError(msg) => {
assert!(msg.contains("base_url is required"));
}
other => panic!("expected TransportError, got: {:?}", other),
}
}
#[test]
fn test_new_with_http_url_returns_error() {
let result = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "http://example.com".to_string(),
read_deadline_secs: None,
});
assert!(result.is_err());
match result.unwrap_err() {
Error::TransportError(msg) => {
assert!(msg.contains("ws://") || msg.contains("wss://"));
}
other => panic!("expected TransportError, got: {:?}", other),
}
}
#[test]
fn test_new_with_https_url_returns_error() {
let result = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "https://example.com".to_string(),
read_deadline_secs: None,
});
assert!(result.is_err());
}
#[test]
fn test_new_with_invalid_scheme_returns_error() {
let result = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ftp://example.com".to_string(),
read_deadline_secs: None,
});
assert!(result.is_err());
}
#[test]
fn test_default_read_deadline() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080".to_string(),
read_deadline_secs: None,
})
.unwrap();
assert_eq!(transport.read_deadline_secs(), 30);
}
#[test]
fn test_zero_read_deadline_defaults_to_30() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080".to_string(),
read_deadline_secs: Some(0),
})
.unwrap();
assert_eq!(transport.read_deadline_secs(), 30);
}
#[test]
fn test_custom_read_deadline() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080".to_string(),
read_deadline_secs: Some(120),
})
.unwrap();
assert_eq!(transport.read_deadline_secs(), 120);
}
#[test]
fn test_ws_url_without_host_returns_error() {
let result = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://".to_string(),
read_deadline_secs: None,
});
assert!(result.is_err());
}
#[tokio::test]
async fn test_on_data_callback_registration() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:9999".to_string(),
read_deadline_secs: None,
})
.unwrap();
transport.set_callback(Box::new(|_msg| Box::pin(async move { Ok(()) })));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let cbs = transport.on_data_callbacks.read().await;
assert_eq!(cbs.len(), 1);
}
#[tokio::test]
async fn test_clear_callback() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:9999".to_string(),
read_deadline_secs: None,
})
.unwrap();
transport.set_callback(Box::new(|_msg| Box::pin(async move { Ok(()) })));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
transport.clear_callback();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let cbs = transport.on_data_callbacks.read().await;
assert_eq!(cbs.len(), 0);
}
#[tokio::test]
async fn test_send_without_callback_returns_error() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:9999".to_string(),
read_deadline_secs: None,
})
.unwrap();
let msg = AuthMessage::new(
crate::auth::types::MessageType::InitialRequest,
crate::primitives::PrivateKey::random().public_key(),
);
let result = transport.send(&msg).await;
assert!(result.is_err());
match result.unwrap_err() {
Error::TransportError(msg) => {
assert!(msg.contains("no handler registered"));
}
other => panic!("expected TransportError, got: {:?}", other),
}
}
#[test]
fn test_debug_format() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080".to_string(),
read_deadline_secs: Some(45),
})
.unwrap();
let debug = format!("{:?}", transport);
assert!(debug.contains("WebSocketTransport"));
assert!(debug.contains("ws://localhost:8080"));
assert!(debug.contains("45"));
}
#[test]
fn test_url_parse_check_valid() {
assert!(url_parse_check("ws://localhost:8080").is_ok());
assert!(url_parse_check("wss://example.com/path").is_ok());
assert!(url_parse_check("ws://192.168.1.1:3000/ws").is_ok());
}
#[test]
fn test_url_parse_check_empty_host() {
assert!(url_parse_check("ws://").is_err());
assert!(url_parse_check("wss://").is_err());
}
#[test]
fn test_url_parse_check_invalid_scheme() {
assert!(url_parse_check("http://example.com").is_err());
}
#[test]
fn test_url_with_trailing_slash() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080/".to_string(),
read_deadline_secs: None,
})
.unwrap();
assert_eq!(transport.base_url(), "ws://localhost:8080/");
}
#[test]
fn test_url_with_path_and_trailing_slash() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "wss://example.com/ws/auth/".to_string(),
read_deadline_secs: Some(45),
})
.unwrap();
assert_eq!(transport.base_url(), "wss://example.com/ws/auth/");
assert_eq!(transport.read_deadline_secs(), 45);
}
#[tokio::test]
async fn test_send_with_callback_but_no_connection_triggers_connect_error() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://127.0.0.1:1".to_string(), read_deadline_secs: Some(5),
})
.unwrap();
transport.set_callback(Box::new(|_msg| Box::pin(async move { Ok(()) })));
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let msg = AuthMessage::new(
crate::auth::types::MessageType::InitialRequest,
crate::primitives::PrivateKey::random().public_key(),
);
let result = transport.send(&msg).await;
assert!(result.is_err(), "Sending to unreachable server should fail");
match result.unwrap_err() {
Error::TransportError(msg) => {
assert!(
msg.contains("failed to connect"),
"Error should mention connection failure, got: {}",
msg
);
}
other => panic!("Expected TransportError, got: {:?}", other),
}
}
#[test]
fn test_read_deadline_one_second() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:8080".to_string(),
read_deadline_secs: Some(1),
})
.unwrap();
assert_eq!(transport.read_deadline_secs(), 1);
}
#[test]
fn test_constructor_with_port_and_path() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://192.168.1.100:3000/v1/ws".to_string(),
read_deadline_secs: Some(60),
})
.unwrap();
assert_eq!(transport.base_url(), "ws://192.168.1.100:3000/v1/ws");
assert_eq!(transport.read_deadline_secs(), 60);
}
#[tokio::test]
async fn test_multiple_callbacks_registered() {
let transport = WebSocketTransport::new(WebSocketTransportOptions {
base_url: "ws://localhost:9999".to_string(),
read_deadline_secs: None,
})
.unwrap();
transport.set_callback(Box::new(|_msg| Box::pin(async move { Ok(()) })));
transport.set_callback(Box::new(|_msg| Box::pin(async move { Ok(()) })));
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let cbs = transport.on_data_callbacks.read().await;
assert_eq!(cbs.len(), 2, "Should support multiple concurrent callbacks");
}
}