use crate::client::{create_client, Client, ClientId, ClientReceiver};
use crate::error::RealtimeError;
use crate::event::{EventKind, RealtimeEvent};
use crate::subscription::Channel;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::broadcast;
use tracing::{debug, info};
const BROADCAST_CAPACITY: usize = 1024;
const MAX_CONNECTIONS: usize = 10000;
#[derive(Debug)]
pub struct EventHub {
clients: RwLock<HashMap<ClientId, Arc<Client>>>,
event_tx: broadcast::Sender<RealtimeEvent>,
stats: RwLock<HubStats>,
}
impl EventHub {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
Self {
clients: RwLock::new(HashMap::new()),
event_tx,
stats: RwLock::new(HubStats::default()),
}
}
pub fn connect(&self) -> Result<(Arc<Client>, ClientReceiver), RealtimeError> {
let clients = self.clients.read();
if clients.len() >= MAX_CONNECTIONS {
return Err(RealtimeError::SendFailed(
"maximum connections reached".to_string(),
));
}
drop(clients);
let client_id = uuid::Uuid::new_v4().to_string();
let (client, receiver) = create_client(client_id.clone());
self.clients
.write()
.insert(client_id.clone(), client.clone());
self.stats.write().total_connections += 1;
info!(client_id = %client_id, "Client connected");
Ok((client, receiver))
}
pub fn disconnect(&self, client_id: &str) {
if let Some(client) = self.clients.write().remove(client_id) {
client.clear_subscriptions();
info!(client_id = %client_id, "Client disconnected");
}
}
pub fn get_client(&self, client_id: &str) -> Option<Arc<Client>> {
self.clients.read().get(client_id).cloned()
}
pub fn handle_command(
&self,
client: &Arc<Client>,
command: ClientCommand,
) -> Result<ServerMessage, RealtimeError> {
match command {
ClientCommand::Subscribe { channel } => {
let parsed = Channel::parse(&channel)?;
let is_new = client.subscribe(parsed)?;
if is_new {
debug!(client_id = %client.id, channel = %channel, "Client subscribed");
self.stats.write().total_subscriptions += 1;
}
Ok(ServerMessage::Subscribed { channel })
}
ClientCommand::Unsubscribe { channel } => {
let parsed = Channel::parse(&channel)?;
let was_subscribed = client.unsubscribe(&parsed);
if was_subscribed {
debug!(client_id = %client.id, channel = %channel, "Client unsubscribed");
}
Ok(ServerMessage::Unsubscribed { channel })
}
ClientCommand::Ping => Ok(ServerMessage::Pong),
}
}
pub fn emit(&self, event: RealtimeEvent) {
let channel = event.channel.clone();
let event_kind = event.event;
let mut recipient_count = 0;
let clients = self.clients.read();
for client in clients.values() {
if client.matches_event(&channel) {
if let Ok(json) = serde_json::to_string(&event) {
if client.send(json).is_ok() {
recipient_count += 1;
}
}
}
}
drop(clients);
let _ = self.event_tx.send(event);
self.stats.write().total_events += 1;
debug!(
channel = %channel,
event = %event_kind,
recipients = recipient_count,
"Event broadcast"
);
}
pub fn emit_event(&self, channel: String, event: EventKind, data: serde_json::Value) {
self.emit(RealtimeEvent::new(channel, event, data));
}
pub fn subscribe_events(&self) -> broadcast::Receiver<RealtimeEvent> {
self.event_tx.subscribe()
}
pub fn connection_count(&self) -> usize {
self.clients.read().len()
}
pub fn stats(&self) -> HubStats {
let mut stats = self.stats.read().clone();
stats.current_connections = self.connection_count();
stats
}
pub fn broadcast_all(&self, message: &str) {
let clients = self.clients.read();
for client in clients.values() {
let _ = client.send(message.to_string());
}
}
}
impl Default for EventHub {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ClientCommand {
Subscribe { channel: String },
Unsubscribe { channel: String },
Ping,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ServerMessage {
Subscribed { channel: String },
Unsubscribed { channel: String },
Pong,
Error { message: String },
}
#[derive(Debug, Clone, Default)]
pub struct HubStats {
pub current_connections: usize,
pub total_connections: u64,
pub total_subscriptions: u64,
pub total_events: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hub_connect() {
let hub = EventHub::new();
let (client, _rx) = hub.connect().unwrap();
assert!(!client.id.is_empty());
assert_eq!(hub.connection_count(), 1);
}
#[tokio::test]
async fn test_hub_disconnect() {
let hub = EventHub::new();
let (client, _rx) = hub.connect().unwrap();
let client_id = client.id.clone();
hub.disconnect(&client_id);
assert_eq!(hub.connection_count(), 0);
}
#[tokio::test]
async fn test_hub_subscribe_command() {
let hub = EventHub::new();
let (client, _rx) = hub.connect().unwrap();
let cmd = ClientCommand::Subscribe {
channel: "repo:alice/myrepo".to_string(),
};
let response = hub.handle_command(&client, cmd).unwrap();
assert!(matches!(response, ServerMessage::Subscribed { .. }));
assert_eq!(client.subscription_count(), 1);
}
#[tokio::test]
async fn test_hub_unsubscribe_command() {
let hub = EventHub::new();
let (client, _rx) = hub.connect().unwrap();
hub.handle_command(
&client,
ClientCommand::Subscribe {
channel: "repo:alice/myrepo".to_string(),
},
)
.unwrap();
let response = hub
.handle_command(
&client,
ClientCommand::Unsubscribe {
channel: "repo:alice/myrepo".to_string(),
},
)
.unwrap();
assert!(matches!(response, ServerMessage::Unsubscribed { .. }));
assert_eq!(client.subscription_count(), 0);
}
#[tokio::test]
async fn test_hub_ping_pong() {
let hub = EventHub::new();
let (client, _rx) = hub.connect().unwrap();
let response = hub.handle_command(&client, ClientCommand::Ping).unwrap();
assert!(matches!(response, ServerMessage::Pong));
}
#[tokio::test]
async fn test_hub_emit_event() {
let hub = EventHub::new();
let (client, mut rx) = hub.connect().unwrap();
hub.handle_command(
&client,
ClientCommand::Subscribe {
channel: "repo:alice/myrepo".to_string(),
},
)
.unwrap();
hub.emit_event(
"repo:alice/myrepo".to_string(),
EventKind::Push,
serde_json::json!({"ref": "refs/heads/main"}),
);
let msg = rx.try_recv().unwrap();
assert!(msg.contains("push"));
assert!(msg.contains("repo:alice/myrepo"));
}
#[tokio::test]
async fn test_hub_emit_filtered() {
let hub = EventHub::new();
let (client1, mut rx1) = hub.connect().unwrap();
let (client2, mut rx2) = hub.connect().unwrap();
hub.handle_command(
&client1,
ClientCommand::Subscribe {
channel: "repo:alice/myrepo".to_string(),
},
)
.unwrap();
hub.handle_command(
&client2,
ClientCommand::Subscribe {
channel: "repo:bob/otherrepo".to_string(),
},
)
.unwrap();
hub.emit_event(
"repo:alice/myrepo".to_string(),
EventKind::Push,
serde_json::json!({}),
);
assert!(rx1.try_recv().is_ok());
assert!(rx2.try_recv().is_err());
}
#[tokio::test]
async fn test_hub_stats() {
let hub = EventHub::new();
let (client, _rx) = hub.connect().unwrap();
hub.handle_command(
&client,
ClientCommand::Subscribe {
channel: "repo:alice/myrepo".to_string(),
},
)
.unwrap();
hub.emit_event(
"repo:alice/myrepo".to_string(),
EventKind::Push,
serde_json::json!({}),
);
let stats = hub.stats();
assert_eq!(stats.current_connections, 1);
assert_eq!(stats.total_connections, 1);
assert_eq!(stats.total_subscriptions, 1);
assert_eq!(stats.total_events, 1);
}
#[test]
fn test_client_command_serialization() {
let cmd = ClientCommand::Subscribe {
channel: "repo:alice/myrepo".to_string(),
};
let json = serde_json::to_string(&cmd).unwrap();
assert!(json.contains("subscribe"));
assert!(json.contains("repo:alice/myrepo"));
let parsed: ClientCommand = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed, ClientCommand::Subscribe { .. }));
}
#[test]
fn test_server_message_serialization() {
let msg = ServerMessage::Subscribed {
channel: "repo:test/repo".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("subscribed"));
let pong = ServerMessage::Pong;
let json = serde_json::to_string(&pong).unwrap();
assert!(json.contains("pong"));
}
}