use crate::domain::entities::Event;
use axum::extract::ws::{Message, WebSocket};
use dashmap::DashMap;
use futures::{sink::SinkExt, stream::StreamExt};
use std::sync::Arc;
use tokio::sync::broadcast;
use uuid::Uuid;
pub struct WebSocketManager {
event_tx: broadcast::Sender<Arc<Event>>,
clients: Arc<DashMap<Uuid, ClientInfo>>,
}
#[derive(Debug, Clone)]
struct ClientInfo {
id: Uuid,
filters: EventFilters,
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EventFilters {
pub entity_id: Option<String>,
pub event_type: Option<String>,
}
impl WebSocketManager {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(1000);
Self {
event_tx,
clients: Arc::new(DashMap::new()),
}
}
pub fn broadcast_event(&self, event: Arc<Event>) {
let _ = self.event_tx.send(event);
}
pub fn subscribe_events(&self) -> broadcast::Receiver<Arc<Event>> {
self.event_tx.subscribe()
}
pub async fn handle_socket(&self, socket: WebSocket) {
let client_id = Uuid::new_v4();
tracing::info!("🔌 WebSocket client connected: {}", client_id);
let mut event_rx = self.event_tx.subscribe();
let (mut sender, mut receiver) = socket.split();
self.clients.insert(
client_id,
ClientInfo {
id: client_id,
filters: EventFilters::default(),
},
);
let clients = Arc::clone(&self.clients);
let send_task = tokio::spawn(async move {
while let Ok(event) = event_rx.recv().await {
let filters = clients
.get(&client_id)
.map(|entry| entry.value().filters.clone())
.unwrap_or_default();
if let Some(ref entity_id) = filters.entity_id
&& event.entity_id_str() != entity_id
{
continue;
}
if let Some(ref event_type) = filters.event_type
&& event.event_type_str() != event_type
{
continue;
}
match serde_json::to_string(&*event) {
Ok(json) => {
if sender.send(Message::Text(json.into())).await.is_err() {
tracing::warn!("Failed to send event to client {}", client_id);
break;
}
}
Err(e) => {
tracing::error!("Failed to serialize event: {}", e);
}
}
}
});
let clients = Arc::clone(&self.clients);
let recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if let Message::Text(text) = msg {
if let Ok(filters) = serde_json::from_str::<EventFilters>(text.as_str()) {
tracing::info!("Setting filters for client {}: {:?}", client_id, filters);
if let Some(mut client) = clients.get_mut(&client_id) {
client.filters = filters;
}
}
}
}
});
tokio::select! {
_ = send_task => {
tracing::info!("Send task ended for client {}", client_id);
}
_ = recv_task => {
tracing::info!("Receive task ended for client {}", client_id);
}
}
self.clients.remove(&client_id);
tracing::info!("🔌 WebSocket client disconnected: {}", client_id);
}
pub fn stats(&self) -> WebSocketStats {
WebSocketStats {
connected_clients: self.clients.len(),
total_capacity: self.event_tx.receiver_count(),
}
}
}
impl Default for WebSocketManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, serde::Serialize)]
pub struct WebSocketStats {
pub connected_clients: usize,
pub total_capacity: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn create_test_event() -> Event {
Event::reconstruct_from_strings(
Uuid::new_v4(),
"test.event".to_string(),
"test-entity".to_string(),
"default".to_string(),
json!({"test": "data"}),
chrono::Utc::now(),
None,
1,
)
}
#[test]
fn test_websocket_manager_creation() {
let manager = WebSocketManager::new();
let stats = manager.stats();
assert_eq!(stats.connected_clients, 0);
}
#[test]
fn test_event_broadcast() {
let manager = WebSocketManager::new();
let event = Arc::new(create_test_event());
manager.broadcast_event(event);
}
}