use super::error::SignalError;
use super::signal::Signal;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::collections::HashMap;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSocketMessage<T> {
pub event_type: String,
pub payload: T,
pub timestamp: u64,
}
impl<T> WebSocketMessage<T> {
pub fn new(event_type: impl Into<String>, payload: T) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
Self {
event_type: event_type.into(),
payload,
timestamp,
}
}
}
pub trait WebSocketClient: Send + Sync {
fn send_message(&self, message: String) -> Result<(), SignalError>;
fn client_id(&self) -> &str;
fn is_connected(&self) -> bool;
}
pub struct MockWebSocketClient {
id: String,
messages: Arc<RwLock<Vec<String>>>,
connected: Arc<RwLock<bool>>,
}
impl MockWebSocketClient {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
messages: Arc::new(RwLock::new(Vec::new())),
connected: Arc::new(RwLock::new(true)),
}
}
pub fn messages(&self) -> Vec<String> {
self.messages.read().clone()
}
pub fn disconnect(&self) {
*self.connected.write() = false;
}
}
impl WebSocketClient for MockWebSocketClient {
fn send_message(&self, message: String) -> Result<(), SignalError> {
if !self.is_connected() {
return Err(SignalError::new("Client is disconnected"));
}
self.messages.write().push(message);
Ok(())
}
fn client_id(&self) -> &str {
&self.id
}
fn is_connected(&self) -> bool {
*self.connected.read()
}
}
pub struct WebSocketSignalBridge {
clients: Arc<RwLock<HashMap<String, Arc<dyn WebSocketClient>>>>,
}
impl WebSocketSignalBridge {
pub fn new() -> Self {
Self {
clients: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_client(&self, client: Arc<dyn WebSocketClient>) {
self.clients
.write()
.insert(client.client_id().to_string(), client);
}
pub fn remove_client(&self, client_id: &str) {
self.clients.write().remove(client_id);
}
pub fn client_count(&self) -> usize {
self.clients.read().len()
}
pub fn broadcast(&self, message: String) -> Result<(), SignalError> {
let clients = self.clients.read();
let mut errors = Vec::new();
for client in clients.values() {
if client.is_connected()
&& let Err(e) = client.send_message(message.clone())
{
errors.push(e);
}
}
if !errors.is_empty() {
return Err(SignalError::new(format!(
"Failed to send to {} clients",
errors.len()
)));
}
Ok(())
}
pub async fn connect_signal<T>(&self, signal: Signal<T>, event_type: impl Into<String>)
where
T: Serialize + Send + Sync + 'static,
{
let clients = Arc::clone(&self.clients);
let event_type = event_type.into();
signal.connect(move |instance| {
let clients = Arc::clone(&clients);
let event_type = event_type.clone();
async move {
let message = WebSocketMessage::new(&event_type, &*instance);
let json = serde_json::to_string(&message)
.map_err(|e| SignalError::new(format!("Serialization error: {}", e)))?;
let clients_read = clients.read();
for client in clients_read.values() {
if client.is_connected()
&& let Err(e) = client.send_message(json.clone())
{
eprintln!("Failed to send WebSocket message: {}", e);
}
}
Ok(())
}
});
}
pub fn cleanup_disconnected(&self) {
let mut clients = self.clients.write();
clients.retain(|_, client| client.is_connected());
}
}
impl Default for WebSocketSignalBridge {
fn default() -> Self {
Self::new()
}
}
impl Clone for WebSocketSignalBridge {
fn clone(&self) -> Self {
Self {
clients: Arc::clone(&self.clients),
}
}
}
impl fmt::Debug for WebSocketSignalBridge {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketSignalBridge")
.field("client_count", &self.client_count())
.finish()
}
}
pub struct TypedWebSocketBroadcaster<T>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
bridge: WebSocketSignalBridge,
event_type: String,
_phantom: PhantomData<T>,
}
impl<T> TypedWebSocketBroadcaster<T>
where
T: Serialize + DeserializeOwned + Send + Sync + 'static,
{
pub fn new(bridge: WebSocketSignalBridge, event_type: impl Into<String>) -> Self {
Self {
bridge,
event_type: event_type.into(),
_phantom: PhantomData,
}
}
pub fn broadcast(&self, payload: T) -> Result<(), SignalError> {
let message = WebSocketMessage::new(&self.event_type, payload);
let json = serde_json::to_string(&message)
.map_err(|e| SignalError::new(format!("Serialization error: {}", e)))?;
self.bridge.broadcast(json)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct TestPayload {
message: String,
}
#[test]
fn test_websocket_message_creation() {
let msg = WebSocketMessage::new("test", "payload");
assert_eq!(msg.event_type, "test");
assert_eq!(msg.payload, "payload");
assert!(msg.timestamp > 0);
}
#[test]
fn test_mock_websocket_client() {
let client = MockWebSocketClient::new("test-client");
assert_eq!(client.client_id(), "test-client");
assert!(client.is_connected());
client.send_message("Hello".to_string()).unwrap();
let messages = client.messages();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0], "Hello");
}
#[test]
fn test_mock_client_disconnect() {
let client = MockWebSocketClient::new("test");
client.disconnect();
assert!(!client.is_connected());
let result = client.send_message("test".to_string());
assert!(result.is_err());
}
#[test]
fn test_websocket_bridge_add_remove_client() {
let bridge = WebSocketSignalBridge::new();
let client = Arc::new(MockWebSocketClient::new("client-1"));
bridge.add_client(client.clone());
assert_eq!(bridge.client_count(), 1);
bridge.remove_client("client-1");
assert_eq!(bridge.client_count(), 0);
}
#[test]
fn test_websocket_bridge_broadcast() {
let bridge = WebSocketSignalBridge::new();
let client1 = Arc::new(MockWebSocketClient::new("client-1"));
let client2 = Arc::new(MockWebSocketClient::new("client-2"));
bridge.add_client(client1.clone());
bridge.add_client(client2.clone());
bridge.broadcast("Test message".to_string()).unwrap();
assert_eq!(client1.messages().len(), 1);
assert_eq!(client2.messages().len(), 1);
assert_eq!(client1.messages()[0], "Test message");
}
#[tokio::test]
async fn test_websocket_bridge_connect_signal() {
let bridge = WebSocketSignalBridge::new();
let client = Arc::new(MockWebSocketClient::new("client-1"));
bridge.add_client(client.clone());
let signal = Signal::new(crate::signals::SignalName::custom("test_signal"));
bridge.connect_signal(signal.clone(), "test.event").await;
signal.send("test payload".to_string()).await.unwrap();
let messages = client.messages();
assert_eq!(messages.len(), 1);
let parsed: WebSocketMessage<String> = serde_json::from_str(&messages[0]).unwrap();
assert_eq!(parsed.event_type, "test.event");
}
#[test]
fn test_websocket_bridge_cleanup_disconnected() {
let bridge = WebSocketSignalBridge::new();
let client1 = Arc::new(MockWebSocketClient::new("client-1"));
let client2 = Arc::new(MockWebSocketClient::new("client-2"));
bridge.add_client(client1.clone());
bridge.add_client(client2.clone());
assert_eq!(bridge.client_count(), 2);
client1.disconnect();
bridge.cleanup_disconnected();
assert_eq!(bridge.client_count(), 1);
}
#[test]
fn test_typed_websocket_broadcaster() {
let bridge = WebSocketSignalBridge::new();
let client = Arc::new(MockWebSocketClient::new("client-1"));
bridge.add_client(client.clone());
let broadcaster = TypedWebSocketBroadcaster::new(bridge, "typed.event");
let payload = TestPayload {
message: "Hello".to_string(),
};
broadcaster.broadcast(payload.clone()).unwrap();
let messages = client.messages();
assert_eq!(messages.len(), 1);
let parsed: WebSocketMessage<TestPayload> = serde_json::from_str(&messages[0]).unwrap();
assert_eq!(parsed.event_type, "typed.event");
assert_eq!(parsed.payload, payload);
}
}