use super::protocol::{ServerMessage, Subscription, SubscriptionFilter};
use crate::core::events::EventEnvelope;
use crate::server::host::ServerHost;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, broadcast, mpsc};
use uuid::Uuid;
struct ConnectionHandle {
tx: mpsc::UnboundedSender<ServerMessage>,
subscriptions: Vec<Subscription>,
user_id: Option<String>,
}
pub struct ConnectionManager {
_host: Arc<ServerHost>,
connections: RwLock<HashMap<String, ConnectionHandle>>,
}
impl std::fmt::Debug for ConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionManager").finish_non_exhaustive()
}
}
impl ConnectionManager {
pub fn new(host: Arc<ServerHost>) -> Self {
Self {
_host: host,
connections: RwLock::new(HashMap::new()),
}
}
pub async fn connect(&self) -> (String, mpsc::UnboundedReceiver<ServerMessage>) {
let connection_id = format!("conn_{}", Uuid::new_v4().simple());
let (tx, rx) = mpsc::unbounded_channel();
let handle = ConnectionHandle {
tx,
subscriptions: Vec::new(),
user_id: None,
};
self.connections
.write()
.await
.insert(connection_id.clone(), handle);
tracing::debug!(connection_id = %connection_id, "WebSocket client connected");
(connection_id, rx)
}
pub async fn disconnect(&self, connection_id: &str) {
self.connections.write().await.remove(connection_id);
tracing::debug!(connection_id = %connection_id, "WebSocket client disconnected");
}
pub async fn subscribe(
&self,
connection_id: &str,
filter: SubscriptionFilter,
) -> Result<String, String> {
let mut connections = self.connections.write().await;
let conn = connections
.get_mut(connection_id)
.ok_or_else(|| format!("Connection {} not found", connection_id))?;
let subscription = Subscription::new(filter);
let sub_id = subscription.id.clone();
conn.subscriptions.push(subscription);
tracing::debug!(
connection_id = %connection_id,
subscription_id = %sub_id,
"Subscription added"
);
Ok(sub_id)
}
pub async fn unsubscribe(
&self,
connection_id: &str,
subscription_id: &str,
) -> Result<bool, String> {
let mut connections = self.connections.write().await;
let conn = connections
.get_mut(connection_id)
.ok_or_else(|| format!("Connection {} not found", connection_id))?;
let before = conn.subscriptions.len();
conn.subscriptions.retain(|s| s.id != subscription_id);
let removed = conn.subscriptions.len() < before;
if removed {
tracing::debug!(
connection_id = %connection_id,
subscription_id = %subscription_id,
"Subscription removed"
);
}
Ok(removed)
}
pub async fn send_to(&self, connection_id: &str, message: ServerMessage) {
let connections = self.connections.read().await;
if let Some(conn) = connections.get(connection_id) {
let _ = conn.tx.send(message);
}
}
#[allow(dead_code)] pub async fn associate_user(&self, connection_id: &str, user_id: String) -> Result<(), String> {
let mut connections = self.connections.write().await;
let conn = connections
.get_mut(connection_id)
.ok_or_else(|| format!("Connection {} not found", connection_id))?;
tracing::debug!(
connection_id = %connection_id,
user_id = %user_id,
"User associated with WebSocket connection"
);
conn.user_id = Some(user_id);
Ok(())
}
pub async fn send_to_user(&self, user_id: &str, payload: serde_json::Value) -> usize {
let connections = self.connections.read().await;
let mut count = 0;
for handle in connections.values() {
if handle.user_id.as_deref() == Some(user_id) {
let msg = ServerMessage::Notification {
data: payload.clone(),
};
if handle.tx.send(msg).is_ok() {
count += 1;
}
}
}
if count > 0 {
tracing::debug!(
user_id = %user_id,
connections = count,
"Dispatched notification to user connections"
);
}
count
}
pub async fn broadcast_payload(&self, payload: serde_json::Value) -> usize {
let connections = self.connections.read().await;
let mut count = 0;
for handle in connections.values() {
let msg = ServerMessage::Notification {
data: payload.clone(),
};
if handle.tx.send(msg).is_ok() {
count += 1;
}
}
tracing::debug!(
connections = count,
"Broadcast notification to all connections"
);
count
}
async fn dispatch_event(&self, envelope: &EventEnvelope) {
let connections = self.connections.read().await;
for (connection_id, handle) in connections.iter() {
for subscription in &handle.subscriptions {
if subscription.filter.matches(&envelope.event) {
let message = ServerMessage::Event {
subscription_id: subscription.id.clone(),
data: envelope.clone(),
};
if handle.tx.send(message).is_err() {
tracing::debug!(
connection_id = %connection_id,
"Failed to send event to connection (likely disconnected)"
);
break; }
}
}
}
}
pub async fn run_dispatch_loop(&self, mut rx: broadcast::Receiver<EventEnvelope>) {
tracing::info!("WebSocket dispatch loop started");
loop {
match rx.recv().await {
Ok(envelope) => {
self.dispatch_event(&envelope).await;
}
Err(broadcast::error::RecvError::Lagged(count)) => {
tracing::warn!(
count = count,
"WebSocket dispatch loop lagged, {} events skipped",
count
);
}
Err(broadcast::error::RecvError::Closed) => {
tracing::info!("EventBus closed, stopping WebSocket dispatch loop");
break;
}
}
}
}
#[allow(dead_code)]
pub async fn connection_count(&self) -> usize {
self.connections.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::events::{EntityEvent, EventBus, FrameworkEvent};
use serde_json::json;
fn test_host() -> Arc<ServerHost> {
use crate::config::LinksConfig;
use crate::server::entity_registry::EntityRegistry;
use crate::storage::InMemoryLinkService;
use std::collections::HashMap;
let host = ServerHost::from_builder_components(
Arc::new(InMemoryLinkService::new()),
LinksConfig::default_config(),
EntityRegistry::new(),
HashMap::new(),
HashMap::new(),
)
.unwrap();
Arc::new(host)
}
#[tokio::test]
async fn test_connect_and_disconnect() {
let cm = ConnectionManager::new(test_host());
let (conn_id, _rx) = cm.connect().await;
assert!(conn_id.starts_with("conn_"));
assert_eq!(cm.connection_count().await, 1);
cm.disconnect(&conn_id).await;
assert_eq!(cm.connection_count().await, 0);
}
#[tokio::test]
async fn test_subscribe_and_unsubscribe() {
let cm = ConnectionManager::new(test_host());
let (conn_id, _rx) = cm.connect().await;
let filter = SubscriptionFilter {
entity_type: Some("order".to_string()),
..Default::default()
};
let sub_id = cm.subscribe(&conn_id, filter).await.unwrap();
assert!(sub_id.starts_with("sub_"));
let removed = cm.unsubscribe(&conn_id, &sub_id).await.unwrap();
assert!(removed);
let removed = cm.unsubscribe(&conn_id, &sub_id).await.unwrap();
assert!(!removed);
}
#[tokio::test]
async fn test_subscribe_nonexistent_connection() {
let cm = ConnectionManager::new(test_host());
let result = cm
.subscribe("nonexistent", SubscriptionFilter::default())
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dispatch_event_matches() {
let cm = ConnectionManager::new(test_host());
let (conn_id, mut rx) = cm.connect().await;
let filter = SubscriptionFilter {
entity_type: Some("order".to_string()),
..Default::default()
};
let sub_id = cm.subscribe(&conn_id, filter).await.unwrap();
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "order".to_string(),
entity_id: Uuid::new_v4(),
data: json!({"amount": 100}),
}));
cm.dispatch_event(&envelope).await;
let msg = rx.try_recv().unwrap();
match msg {
ServerMessage::Event {
subscription_id,
data,
} => {
assert_eq!(subscription_id, sub_id);
assert_eq!(data.id, envelope.id);
}
_ => panic!("Expected Event message"),
}
}
#[tokio::test]
async fn test_dispatch_event_no_match() {
let cm = ConnectionManager::new(test_host());
let (conn_id, mut rx) = cm.connect().await;
let filter = SubscriptionFilter {
entity_type: Some("order".to_string()),
..Default::default()
};
cm.subscribe(&conn_id, filter).await.unwrap();
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "invoice".to_string(),
entity_id: Uuid::new_v4(),
data: json!({}),
}));
cm.dispatch_event(&envelope).await;
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_dispatch_with_event_bus() {
let cm = Arc::new(ConnectionManager::new(test_host()));
let (conn_id, mut rx) = cm.connect().await;
cm.subscribe(&conn_id, SubscriptionFilter::default())
.await
.unwrap();
let event_bus = EventBus::new(16);
let bus_rx = event_bus.subscribe();
let cm_clone = cm.clone();
let handle = tokio::spawn(async move {
cm_clone.run_dispatch_loop(bus_rx).await;
});
let entity_id = Uuid::new_v4();
event_bus.publish(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "order".to_string(),
entity_id,
data: json!({"test": true}),
}));
let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("Timeout waiting for event")
.expect("Channel closed");
match msg {
ServerMessage::Event { data, .. } => {
assert_eq!(data.event.entity_id(), Some(entity_id));
}
_ => panic!("Expected Event message"),
}
drop(event_bus);
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), handle).await;
}
#[tokio::test]
async fn test_multiple_subscriptions_same_connection() {
let cm = ConnectionManager::new(test_host());
let (conn_id, mut rx) = cm.connect().await;
cm.subscribe(
&conn_id,
SubscriptionFilter {
entity_type: Some("order".to_string()),
..Default::default()
},
)
.await
.unwrap();
cm.subscribe(
&conn_id,
SubscriptionFilter {
entity_type: Some("invoice".to_string()),
..Default::default()
},
)
.await
.unwrap();
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "order".to_string(),
entity_id: Uuid::new_v4(),
data: json!({}),
}));
cm.dispatch_event(&envelope).await;
assert!(rx.try_recv().is_ok());
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "invoice".to_string(),
entity_id: Uuid::new_v4(),
data: json!({}),
}));
cm.dispatch_event(&envelope).await;
assert!(rx.try_recv().is_ok());
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "user".to_string(),
entity_id: Uuid::new_v4(),
data: json!({}),
}));
cm.dispatch_event(&envelope).await;
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn test_concurrent_subscriptions_same_event_different_connections() {
let cm = ConnectionManager::new(test_host());
let (conn1_id, mut rx1) = cm.connect().await;
let (conn2_id, mut rx2) = cm.connect().await;
let filter = SubscriptionFilter {
entity_type: Some("order".to_string()),
event_type: Some("created".to_string()),
..Default::default()
};
cm.subscribe(&conn1_id, filter.clone())
.await
.expect("conn1 subscribe should succeed");
cm.subscribe(&conn2_id, filter)
.await
.expect("conn2 subscribe should succeed");
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "order".to_string(),
entity_id: Uuid::new_v4(),
data: json!({"total": 50}),
}));
cm.dispatch_event(&envelope).await;
let msg1 = rx1.try_recv().expect("conn1 should receive event");
let msg2 = rx2.try_recv().expect("conn2 should receive event");
match (&msg1, &msg2) {
(ServerMessage::Event { data: d1, .. }, ServerMessage::Event { data: d2, .. }) => {
assert_eq!(d1.id, envelope.id);
assert_eq!(d2.id, envelope.id);
}
_ => panic!("Expected Event messages for both connections"),
}
}
#[tokio::test]
async fn test_send_to_nonexistent_connection() {
let cm = ConnectionManager::new(test_host());
cm.send_to("conn_does_not_exist", ServerMessage::Pong).await;
assert_eq!(cm.connection_count().await, 0);
}
#[tokio::test]
async fn test_dead_connection_handling() {
let cm = ConnectionManager::new(test_host());
let (conn_id, rx) = cm.connect().await;
cm.subscribe(&conn_id, SubscriptionFilter::default())
.await
.expect("subscribe should succeed");
drop(rx);
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "order".to_string(),
entity_id: Uuid::new_v4(),
data: json!({}),
}));
cm.dispatch_event(&envelope).await;
assert_eq!(cm.connection_count().await, 1);
}
#[tokio::test]
async fn test_dispatch_event_with_multiple_matching_subscriptions() {
let cm = ConnectionManager::new(test_host());
let (conn_id, mut rx) = cm.connect().await;
cm.subscribe(
&conn_id,
SubscriptionFilter {
entity_type: Some("order".to_string()),
..Default::default()
},
)
.await
.expect("first subscribe should succeed");
cm.subscribe(
&conn_id,
SubscriptionFilter {
event_type: Some("created".to_string()),
..Default::default()
},
)
.await
.expect("second subscribe should succeed");
let envelope = EventEnvelope::new(FrameworkEvent::Entity(EntityEvent::Created {
entity_type: "order".to_string(),
entity_id: Uuid::new_v4(),
data: json!({}),
}));
cm.dispatch_event(&envelope).await;
let msg1 = rx.try_recv().expect("should receive first matching event");
let msg2 = rx.try_recv().expect("should receive second matching event");
match (&msg1, &msg2) {
(
ServerMessage::Event {
subscription_id: sub1,
data: d1,
},
ServerMessage::Event {
subscription_id: sub2,
data: d2,
},
) => {
assert_ne!(sub1, sub2, "subscription IDs should differ");
assert_eq!(d1.id, d2.id, "both should carry the same event envelope");
}
_ => panic!("Expected two Event messages"),
}
}
#[tokio::test]
async fn test_associate_user() {
let cm = ConnectionManager::new(test_host());
let (conn_id, _rx) = cm.connect().await;
cm.associate_user(&conn_id, "user-A".to_string())
.await
.expect("associate should succeed");
let result = cm
.associate_user("conn_nonexistent", "user-B".to_string())
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_send_to_user() {
let cm = ConnectionManager::new(test_host());
let (conn1_id, mut rx1) = cm.connect().await;
let (conn2_id, mut rx2) = cm.connect().await;
let (conn3_id, mut rx3) = cm.connect().await;
cm.associate_user(&conn1_id, "user-A".to_string())
.await
.unwrap();
cm.associate_user(&conn2_id, "user-A".to_string())
.await
.unwrap();
cm.associate_user(&conn3_id, "user-B".to_string())
.await
.unwrap();
let payload = json!({"title": "Hello user-A"});
let count = cm.send_to_user("user-A", payload.clone()).await;
assert_eq!(count, 2);
let msg1 = rx1.try_recv().expect("conn1 should receive");
let msg2 = rx2.try_recv().expect("conn2 should receive");
assert!(matches!(msg1, ServerMessage::Notification { .. }));
assert!(matches!(msg2, ServerMessage::Notification { .. }));
assert!(rx3.try_recv().is_err());
}
#[tokio::test]
async fn test_send_to_user_no_match() {
let cm = ConnectionManager::new(test_host());
let (_conn_id, _rx) = cm.connect().await;
let count = cm.send_to_user("user-X", json!({})).await;
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_broadcast_payload() {
let cm = ConnectionManager::new(test_host());
let (_conn1_id, mut rx1) = cm.connect().await;
let (_conn2_id, mut rx2) = cm.connect().await;
let payload = json!({"type": "system_announcement", "message": "Server restarting"});
let count = cm.broadcast_payload(payload.clone()).await;
assert_eq!(count, 2);
let msg1 = rx1.try_recv().expect("conn1 should receive broadcast");
let msg2 = rx2.try_recv().expect("conn2 should receive broadcast");
match (&msg1, &msg2) {
(
ServerMessage::Notification { data: d1 },
ServerMessage::Notification { data: d2 },
) => {
assert_eq!(d1["message"], "Server restarting");
assert_eq!(d2["message"], "Server restarting");
}
_ => panic!("Expected Notification messages"),
}
}
#[tokio::test]
async fn test_broadcast_payload_empty() {
let cm = ConnectionManager::new(test_host());
let count = cm.broadcast_payload(json!({})).await;
assert_eq!(count, 0);
}
}