use super::channel::{ChannelId, ChannelManager};
use super::connection::WebSocketConnection;
use super::types::{ConnectionId, ConnectionState, WebSocketMessage, WebSocketResult};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub enum ConnectionEvent {
Connected(ConnectionId),
Disconnected(ConnectionId, ConnectionState),
Broadcast(WebSocketMessage),
MessageSent(ConnectionId, WebSocketMessage),
}
pub struct ConnectionRegistry {
connections: Arc<RwLock<HashMap<ConnectionId, Arc<WebSocketConnection>>>>,
channel_manager: Arc<ChannelManager>,
event_handlers: Arc<RwLock<Vec<Box<dyn Fn(ConnectionEvent) + Send + Sync>>>>,
}
impl ConnectionRegistry {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
channel_manager: Arc::new(ChannelManager::new()),
event_handlers: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn with_channel_manager(channel_manager: Arc<ChannelManager>) -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
channel_manager,
event_handlers: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn channel_manager(&self) -> &Arc<ChannelManager> {
&self.channel_manager
}
pub async fn add_connection(&self, connection: WebSocketConnection) -> ConnectionId {
let id = connection.id;
let arc_connection = Arc::new(connection);
{
let mut connections = self.connections.write().await;
connections.insert(id, arc_connection);
}
info!("Added connection to registry: {}", id);
self.emit_event(ConnectionEvent::Connected(id)).await;
id
}
pub async fn remove_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
let connection = {
let mut connections = self.connections.write().await;
connections.remove(&id)
};
if let Some(conn) = &connection {
let state = conn.state().await;
self.channel_manager.leave_all_channels(id).await;
info!(
"Removed connection from registry: {} (state: {:?})",
id, state
);
self.emit_event(ConnectionEvent::Disconnected(id, state))
.await;
}
connection
}
pub async fn get_connection(&self, id: ConnectionId) -> Option<Arc<WebSocketConnection>> {
let connections = self.connections.read().await;
connections.get(&id).cloned()
}
pub async fn get_all_connections(&self) -> Vec<Arc<WebSocketConnection>> {
let connections = self.connections.read().await;
connections.values().cloned().collect()
}
pub async fn get_connection_ids(&self) -> Vec<ConnectionId> {
let connections = self.connections.read().await;
connections.keys().copied().collect()
}
pub async fn connection_count(&self) -> usize {
let connections = self.connections.read().await;
connections.len()
}
pub async fn send_to_connection(
&self,
id: ConnectionId,
message: WebSocketMessage,
) -> WebSocketResult<()> {
let connection = self
.get_connection(id)
.await
.ok_or(WebSocketError::ConnectionNotFound(id))?;
let result = connection.send(message.clone()).await;
if result.is_ok() {
self.emit_event(ConnectionEvent::MessageSent(id, message))
.await;
}
result
}
pub async fn send_text_to_connection<T: Into<String>>(
&self,
id: ConnectionId,
text: T,
) -> WebSocketResult<()> {
self.send_to_connection(id, WebSocketMessage::text(text))
.await
}
pub async fn send_binary_to_connection<T: Into<Vec<u8>>>(
&self,
id: ConnectionId,
data: T,
) -> WebSocketResult<()> {
self.send_to_connection(id, WebSocketMessage::binary(data))
.await
}
pub async fn broadcast(&self, message: WebSocketMessage) -> BroadcastResult {
let connections = self.get_all_connections().await;
let mut results = BroadcastResult::new();
for connection in connections {
if connection.is_active().await {
match connection.send(message.clone()).await {
Ok(_) => results.success_count += 1,
Err(e) => {
results.failed_connections.push((connection.id, e));
}
}
} else {
results.inactive_connections.push(connection.id);
}
}
self.emit_event(ConnectionEvent::Broadcast(message)).await;
results
}
pub async fn broadcast_text<T: Into<String>>(&self, text: T) -> BroadcastResult {
self.broadcast(WebSocketMessage::text(text)).await
}
pub async fn broadcast_binary<T: Into<Vec<u8>>>(&self, data: T) -> BroadcastResult {
self.broadcast(WebSocketMessage::binary(data)).await
}
pub async fn send_to_channel(
&self,
channel_id: ChannelId,
sender_id: ConnectionId,
message: WebSocketMessage,
) -> WebSocketResult<BroadcastResult> {
let member_ids = self
.channel_manager
.send_to_channel(channel_id, sender_id, message.clone())
.await?;
let mut results = BroadcastResult::new();
for member_id in member_ids {
if let Some(connection) = self.get_connection(member_id).await {
if connection.is_active().await {
match connection.send(message.clone()).await {
Ok(_) => results.success_count += 1,
Err(e) => {
results.failed_connections.push((member_id, e));
}
}
} else {
results.inactive_connections.push(member_id);
}
} else {
let _ = self
.channel_manager
.leave_channel(channel_id, member_id)
.await;
}
}
Ok(results)
}
pub async fn send_text_to_channel<T: Into<String>>(
&self,
channel_id: ChannelId,
sender_id: ConnectionId,
text: T,
) -> WebSocketResult<BroadcastResult> {
self.send_to_channel(channel_id, sender_id, WebSocketMessage::text(text))
.await
}
pub async fn send_binary_to_channel<T: Into<Vec<u8>>>(
&self,
channel_id: ChannelId,
sender_id: ConnectionId,
data: T,
) -> WebSocketResult<BroadcastResult> {
self.send_to_channel(channel_id, sender_id, WebSocketMessage::binary(data))
.await
}
pub async fn close_connection(&self, id: ConnectionId) -> WebSocketResult<()> {
let connection = self
.get_connection(id)
.await
.ok_or(WebSocketError::ConnectionNotFound(id))?;
connection.close().await?;
self.remove_connection(id).await;
Ok(())
}
pub async fn close_all_connections(&self) -> CloseAllResult {
let connections = self.get_all_connections().await;
let mut results = CloseAllResult::new();
let mut to_remove = Vec::new();
for connection in connections {
match connection.close().await {
Ok(_) => {
to_remove.push(connection.id);
results.closed_count += 1;
}
Err(e) => {
results.failed_connections.push((connection.id, e));
}
}
}
if !to_remove.is_empty() {
let mut connections = self.connections.write().await;
for id in to_remove {
if let Some(conn) = connections.remove(&id) {
let state = conn.state().await;
info!(
"Removed connection from registry: {} (state: {:?})",
id, state
);
}
}
}
results
}
pub async fn cleanup_inactive_connections(&self) -> usize {
let connections = self.get_all_connections().await;
let mut to_remove = Vec::new();
for connection in connections {
if connection.is_closed().await {
to_remove.push((connection.id, connection));
}
}
let cleaned_up = to_remove.len();
if !to_remove.is_empty() {
let mut registry_connections = self.connections.write().await;
for (id, _connection) in to_remove {
if registry_connections.remove(&id).is_some() {
debug!("Cleaned up inactive connection: {}", id);
}
}
}
if cleaned_up > 0 {
info!("Cleaned up {} inactive connections", cleaned_up);
}
cleaned_up
}
pub async fn stats(&self) -> RegistryStats {
let connections = self.get_all_connections().await;
let mut stats = RegistryStats::default();
stats.total_connections = connections.len();
for connection in connections {
match connection.state().await {
ConnectionState::Connected => stats.active_connections += 1,
ConnectionState::Connecting => stats.connecting_connections += 1,
ConnectionState::Closing => stats.closing_connections += 1,
ConnectionState::Closed => stats.closed_connections += 1,
ConnectionState::Failed(_) => stats.failed_connections += 1,
}
let conn_stats = connection.stats().await;
stats.total_messages_sent += conn_stats.messages_sent;
stats.total_messages_received += conn_stats.messages_received;
stats.total_bytes_sent += conn_stats.bytes_sent;
stats.total_bytes_received += conn_stats.bytes_received;
}
stats
}
pub async fn add_event_handler<F>(&self, handler: F)
where
F: Fn(ConnectionEvent) + Send + Sync + 'static,
{
let mut handlers = self.event_handlers.write().await;
handlers.push(Box::new(handler));
}
async fn emit_event(&self, event: ConnectionEvent) {
let handlers = self.event_handlers.read().await;
for handler in handlers.iter() {
handler(event.clone());
}
}
}
impl Default for ConnectionRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct BroadcastResult {
pub success_count: usize,
pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
pub inactive_connections: Vec<ConnectionId>,
}
impl BroadcastResult {
fn new() -> Self {
Self {
success_count: 0,
failed_connections: Vec::new(),
inactive_connections: Vec::new(),
}
}
pub fn total_attempted(&self) -> usize {
self.success_count + self.failed_connections.len() + self.inactive_connections.len()
}
pub fn has_failures(&self) -> bool {
!self.failed_connections.is_empty()
}
}
#[derive(Debug)]
pub struct CloseAllResult {
pub closed_count: usize,
pub failed_connections: Vec<(ConnectionId, WebSocketError)>,
}
impl CloseAllResult {
fn new() -> Self {
Self {
closed_count: 0,
failed_connections: Vec::new(),
}
}
}
#[derive(Debug, Default)]
pub struct RegistryStats {
pub total_connections: usize,
pub active_connections: usize,
pub connecting_connections: usize,
pub closing_connections: usize,
pub closed_connections: usize,
pub failed_connections: usize,
pub total_messages_sent: u64,
pub total_messages_received: u64,
pub total_bytes_sent: u64,
pub total_bytes_received: u64,
}
use super::types::WebSocketError;