use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex as AsyncMutex;
use crate::channel_gateway::{Channel, ChannelType, InboundMessage, OutboundMessage};
use crate::types::Layer4Result;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{
connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream,
};
#[derive(Debug, Clone)]
pub struct WebSocketChannelConfig {
pub url: String,
pub reconnect_attempts: u32,
pub reconnect_interval_ms: u64,
pub ping_interval_ms: u64,
pub connect_timeout_ms: u64,
}
impl Default for WebSocketChannelConfig {
fn default() -> Self {
Self {
url: "ws://localhost:8080/ws".to_string(),
reconnect_attempts: 3,
reconnect_interval_ms: 1000,
ping_interval_ms: 30000,
connect_timeout_ms: 10000,
}
}
}
type WsConnection = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct WebSocketChannel {
channel_id: String,
config: WebSocketChannelConfig,
connected: RwLock<bool>,
message_queue: RwLock<VecDeque<InboundMessage>>,
sessions: RwLock<HashMap<String, String>>, ws_sender: Arc<AsyncMutex<Option<futures::stream::SplitSink<WsConnection, WsMessage>>>>,
}
impl WebSocketChannel {
pub fn new(channel_id: impl Into<String>, config: WebSocketChannelConfig) -> Self {
Self {
channel_id: channel_id.into(),
config,
connected: RwLock::new(false),
message_queue: RwLock::new(VecDeque::new()),
sessions: RwLock::new(HashMap::new()),
ws_sender: Arc::new(AsyncMutex::new(None)),
}
}
pub fn default_channel() -> Self {
Self::new("ws-default", WebSocketChannelConfig::default())
}
pub async fn connect(&self) -> Layer4Result<()> {
let url = self.config.url.clone();
let timeout = Duration::from_millis(self.config.connect_timeout_ms);
let connect_future = async { connect_async(&url).await };
let result = tokio::time::timeout(timeout, connect_future).await;
match result {
Ok(Ok((stream, _))) => {
let (sink, _stream) = stream.split();
*self.ws_sender.lock().await = Some(sink);
*self.connected.write() = true;
tracing::info!("WebSocket connected to {}", url);
Ok(())
}
Ok(Err(e)) => {
tracing::error!("WebSocket connection failed: {}", e);
Err(anyhow::anyhow!("WebSocket connection failed: {}", e))
}
Err(_) => {
tracing::error!("WebSocket connection timeout");
Err(anyhow::anyhow!("WebSocket connection timeout"))
}
}
}
pub async fn connect_with_retry(&self) -> Layer4Result<()> {
let mut attempts = 0;
let max_attempts = self.config.reconnect_attempts;
let interval = Duration::from_millis(self.config.reconnect_interval_ms);
loop {
match self.connect().await {
Ok(_) => return Ok(()),
Err(e) => {
attempts += 1;
if attempts >= max_attempts {
return Err(e);
}
tracing::warn!(
"WebSocket connection attempt {}/{} failed, retrying...",
attempts,
max_attempts
);
tokio::time::sleep(interval).await;
}
}
}
}
pub async fn send_raw(&self, message: WsMessage) -> Layer4Result<()> {
let mut sender = self.ws_sender.lock().await;
if let Some(ref mut sink) = *sender {
sink.send(message).await?;
Ok(())
} else {
Err(anyhow::anyhow!("WebSocket not connected"))
}
}
pub async fn send_text(&self, text: &str) -> Layer4Result<()> {
self.send_raw(WsMessage::Text(text.into())).await
}
pub async fn send_binary(&self, data: Vec<u8>) -> Layer4Result<()> {
self.send_raw(WsMessage::Binary(data.into())).await
}
pub fn register_session(&self, session_id: &str, user_id: &str) {
self.sessions
.write()
.insert(session_id.to_string(), user_id.to_string());
}
pub fn unregister_session(&self, session_id: &str) {
self.sessions.write().remove(session_id);
}
pub fn receive_message(&self, session_id: &str, content: &str) {
let user_id = self
.sessions
.read()
.get(session_id)
.cloned()
.unwrap_or_default();
let message = InboundMessage::new(&self.channel_id, &user_id, content)
.with_session(session_id)
.with_metadata(serde_json::json!({
"source": "websocket",
"session_id": session_id
}));
self.message_queue.write().push_back(message);
}
pub fn active_sessions(&self) -> usize {
self.sessions.read().len()
}
}
#[async_trait]
impl Channel for WebSocketChannel {
fn id(&self) -> &str {
&self.channel_id
}
fn channel_type(&self) -> ChannelType {
ChannelType::WebSocket
}
async fn send(&self, message: &OutboundMessage) -> Layer4Result<()> {
if !*self.connected.read() {
return Err(anyhow::anyhow!("Channel not connected"));
}
let payload = serde_json::json!({
"message_id": message.message_id,
"content": message.content,
"message_type": message.message_type,
"target": message.target,
"metadata": message.metadata,
"timestamp": message.timestamp.to_rfc3339(),
});
self.send_text(&payload.to_string()).await?;
tracing::debug!("WebSocket channel sent message {}", message.message_id);
Ok(())
}
async fn try_receive(&self) -> Layer4Result<Option<InboundMessage>> {
if !*self.connected.read() {
return Err(anyhow::anyhow!("Channel not connected"));
}
Ok(self.message_queue.write().pop_front())
}
fn is_connected(&self) -> bool {
*self.connected.read()
}
async fn close(&self) -> Layer4Result<()> {
let mut sender = self.ws_sender.lock().await;
if let Some(ref mut sink) = *sender {
sink.close().await?;
}
*sender = None;
*self.connected.write() = false;
self.message_queue.write().clear();
self.sessions.write().clear();
tracing::info!("WebSocket channel closed");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_websocket_channel_creation() {
let channel = WebSocketChannel::default_channel();
assert_eq!(channel.id(), "ws-default");
assert!(!channel.is_connected());
}
#[test]
fn test_websocket_config_default() {
let config = WebSocketChannelConfig::default();
assert_eq!(config.reconnect_attempts, 3);
assert_eq!(config.ping_interval_ms, 30000);
assert_eq!(config.connect_timeout_ms, 10000);
}
#[test]
fn test_websocket_session_management() {
let channel = WebSocketChannel::default_channel();
channel.register_session("session-1", "user-1");
assert_eq!(channel.active_sessions(), 1);
channel.unregister_session("session-1");
assert_eq!(channel.active_sessions(), 0);
}
#[test]
fn test_websocket_receive_message() {
let channel = WebSocketChannel::default_channel();
*channel.connected.write() = true;
channel.register_session("session-1", "user-1");
channel.receive_message("session-1", "Hello");
let count = channel.message_queue.read().len();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_websocket_channel_close() {
let channel = WebSocketChannel::default_channel();
*channel.connected.write() = true;
channel.register_session("session-1", "user-1");
channel.close().await.unwrap();
assert!(!channel.is_connected());
assert_eq!(channel.active_sessions(), 0);
}
#[tokio::test]
async fn test_send_without_connection() {
let channel = WebSocketChannel::default_channel();
let msg = OutboundMessage::to_user("test-user", "hello");
let result = channel.send(&msg).await;
assert!(result.is_err());
}
}