use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use super::connection::ConnectionId;
#[derive(Debug, Clone)]
pub struct BroadcastMessage {
pub event_type: Option<String>,
pub data: String,
pub id: Option<String>,
}
impl BroadcastMessage {
#[must_use]
pub fn new(data: impl Into<String>) -> Self {
Self {
event_type: None,
data: data.into(),
id: None,
}
}
#[must_use]
pub fn named(event_type: impl Into<String>, data: impl Into<String>) -> Self {
Self {
event_type: Some(event_type.into()),
data: data.into(),
id: None,
}
}
pub fn json<T: serde::Serialize>(data: &T) -> Result<Self, serde_json::Error> {
Ok(Self {
event_type: None,
data: serde_json::to_string(data)?,
id: None,
})
}
pub fn json_named<T: serde::Serialize>(
event_type: impl Into<String>,
data: &T,
) -> Result<Self, serde_json::Error> {
Ok(Self {
event_type: Some(event_type.into()),
data: serde_json::to_string(data)?,
id: None,
})
}
#[must_use]
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
}
#[derive(Debug, Clone)]
pub enum BroadcastTarget {
All,
Connections(Vec<ConnectionId>),
AllExcept(Vec<ConnectionId>),
Channel(String),
}
#[derive(Debug)]
struct ConnectionInfo {
subscribed_channels: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct SseBroadcaster {
sender: broadcast::Sender<BroadcastMessage>,
channels: Arc<RwLock<HashMap<String, broadcast::Sender<BroadcastMessage>>>>,
connections: Arc<RwLock<HashMap<ConnectionId, ConnectionInfo>>>,
capacity: usize,
}
impl SseBroadcaster {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(256)
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self {
sender,
channels: Arc::new(RwLock::new(HashMap::new())),
connections: Arc::new(RwLock::new(HashMap::new())),
capacity,
}
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<BroadcastMessage> {
self.sender.subscribe()
}
pub async fn subscribe_channel(&self, channel: &str) -> broadcast::Receiver<BroadcastMessage> {
let mut channels = self.channels.write().await;
if let Some(sender) = channels.get(channel) {
sender.subscribe()
} else {
let (sender, receiver) = broadcast::channel(self.capacity);
channels.insert(channel.to_string(), sender);
receiver
}
}
pub async fn register(&self, id: ConnectionId) {
self.connections.write().await.insert(
id,
ConnectionInfo {
subscribed_channels: Vec::new(),
},
);
tracing::debug!(connection_id = %id, "SSE connection registered");
}
pub async fn register_with_channels(&self, id: ConnectionId, channels: Vec<String>) {
self.connections.write().await.insert(
id,
ConnectionInfo {
subscribed_channels: channels,
},
);
tracing::debug!(connection_id = %id, "SSE connection registered with channels");
}
pub async fn unregister(&self, id: &ConnectionId) {
self.connections.write().await.remove(id);
tracing::debug!(connection_id = %id, "SSE connection unregistered");
}
pub fn broadcast(
&self,
message: BroadcastMessage,
) -> Result<usize, broadcast::error::SendError<BroadcastMessage>> {
self.sender.send(message)
}
pub async fn broadcast_to_channel(
&self,
channel: &str,
message: BroadcastMessage,
) -> Result<usize, broadcast::error::SendError<BroadcastMessage>> {
let channels = self.channels.read().await;
if let Some(sender) = channels.get(channel) {
sender.send(message)
} else {
Ok(0) }
}
pub async fn connection_count(&self) -> usize {
self.connections.read().await.len()
}
pub async fn channel_count(&self) -> usize {
self.channels.read().await.len()
}
pub async fn has_channel(&self, channel: &str) -> bool {
self.channels.read().await.contains_key(channel)
}
pub async fn connection_channels(&self, id: &ConnectionId) -> Option<Vec<String>> {
self.connections
.read()
.await
.get(id)
.map(|info| info.subscribed_channels.clone())
}
}
impl Default for SseBroadcaster {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_broadcast_message() {
let msg = BroadcastMessage::new("Hello");
assert_eq!(msg.data, "Hello");
assert!(msg.event_type.is_none());
assert!(msg.id.is_none());
}
#[test]
fn test_broadcast_message_named() {
let msg = BroadcastMessage::named("notification", "Hello");
assert_eq!(msg.data, "Hello");
assert_eq!(msg.event_type, Some("notification".to_string()));
}
#[test]
fn test_broadcast_message_with_id() {
let msg = BroadcastMessage::new("Hello").with_id("event-123");
assert_eq!(msg.id, Some("event-123".to_string()));
}
#[tokio::test]
async fn test_broadcaster() {
let broadcaster = SseBroadcaster::new();
let mut receiver = broadcaster.subscribe();
broadcaster
.broadcast(BroadcastMessage::new("Test"))
.unwrap();
let msg = receiver.recv().await.unwrap();
assert_eq!(msg.data, "Test");
}
#[tokio::test]
async fn test_connection_tracking() {
let broadcaster = SseBroadcaster::new();
let id = ConnectionId::new();
assert_eq!(broadcaster.connection_count().await, 0);
broadcaster.register(id).await;
assert_eq!(broadcaster.connection_count().await, 1);
broadcaster.unregister(&id).await;
assert_eq!(broadcaster.connection_count().await, 0);
}
}