use super::frame::Message;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use uuid::Uuid;
pub struct ChannelManager {
subscriptions: Arc<RwLock<HashMap<String, HashSet<Uuid>>>>,
connections: Arc<RwLock<HashMap<Uuid, mpsc::Sender<Message>>>>,
}
impl ChannelManager {
pub fn new() -> Self {
Self {
subscriptions: Arc::new(RwLock::new(HashMap::new())),
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn subscribe(
&self,
connection_id: Uuid,
topic: &str,
sender: mpsc::Sender<Message>,
) -> Result<(), std::io::Error> {
{
let mut connections = self.connections.write().await;
connections.entry(connection_id).or_insert(sender);
}
{
let mut subscriptions = self.subscriptions.write().await;
subscriptions
.entry(topic.to_string())
.or_insert_with(HashSet::new)
.insert(connection_id);
}
tracing::debug!(
"Connection {} subscribed to topic: {}",
connection_id,
topic
);
Ok(())
}
pub async fn unsubscribe(
&self,
connection_id: Uuid,
topic: &str,
) -> Result<(), std::io::Error> {
let mut subscriptions = self.subscriptions.write().await;
if let Some(subscribers) = subscriptions.get_mut(topic) {
subscribers.remove(&connection_id);
if subscribers.is_empty() {
subscriptions.remove(topic);
}
}
tracing::debug!(
"Connection {} unsubscribed from topic: {}",
connection_id,
topic
);
Ok(())
}
pub async fn publish(&self, topic: &str, message: Message) -> Result<usize, std::io::Error> {
let subscriptions = self.subscriptions.read().await;
let connections = self.connections.read().await;
let subscribers = match subscriptions.get(topic) {
Some(subs) => subs,
None => return Ok(0),
};
let mut sent_count = 0;
let mut failed_connections = Vec::new();
for connection_id in subscribers {
if let Some(sender) = connections.get(connection_id) {
match sender.try_send(message.clone()) {
Ok(_) => sent_count += 1,
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::warn!(
"Connection {} backpressured, skipping message",
connection_id
);
}
Err(mpsc::error::TrySendError::Closed(_)) => {
failed_connections.push(*connection_id);
}
}
}
}
if !failed_connections.is_empty() {
drop(subscriptions);
drop(connections);
for conn_id in failed_connections {
self.disconnect(conn_id).await;
}
}
tracing::debug!("Published to topic '{}': {} subscribers", topic, sent_count);
Ok(sent_count)
}
pub async fn disconnect(&self, connection_id: Uuid) {
{
let mut subscriptions = self.subscriptions.write().await;
let topics_to_clean: Vec<String> = subscriptions
.iter()
.filter_map(|(topic, subscribers)| {
if subscribers.contains(&connection_id) {
Some(topic.clone())
} else {
None
}
})
.collect();
for topic in topics_to_clean {
if let Some(subscribers) = subscriptions.get_mut(&topic) {
subscribers.remove(&connection_id);
if subscribers.is_empty() {
subscriptions.remove(&topic);
}
}
}
}
{
let mut connections = self.connections.write().await;
connections.remove(&connection_id);
}
tracing::debug!("Connection {} disconnected and cleaned up", connection_id);
}
pub async fn connection_count(&self) -> usize {
self.connections.read().await.len()
}
pub async fn topic_count(&self) -> usize {
self.subscriptions.read().await.len()
}
pub async fn subscriber_count(&self, topic: &str) -> usize {
self.subscriptions
.read()
.await
.get(topic)
.map(|s| s.len())
.unwrap_or(0)
}
pub async fn broadcast_all(&self, message: Message) -> usize {
let connections = self.connections.read().await;
let mut count = 0;
for sender in connections.values() {
if sender.try_send(message.clone()).is_ok() {
count += 1;
}
}
count
}
pub async fn all_connection_ids(&self) -> Vec<Uuid> {
self.connections.read().await.keys().copied().collect()
}
}
impl Default for ChannelManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_subscribe_unsubscribe() {
let manager = ChannelManager::new();
let (tx, _rx) = mpsc::channel(100);
let conn_id = Uuid::new_v4();
manager
.subscribe(conn_id, "test-topic", tx.clone())
.await
.unwrap();
assert_eq!(manager.subscriber_count("test-topic").await, 1);
manager.unsubscribe(conn_id, "test-topic").await.unwrap();
assert_eq!(manager.subscriber_count("test-topic").await, 0);
}
#[tokio::test]
async fn test_publish() {
let manager = ChannelManager::new();
let (tx1, mut rx1) = mpsc::channel(100);
let (tx2, mut rx2) = mpsc::channel(100);
let conn1 = Uuid::new_v4();
let conn2 = Uuid::new_v4();
manager.subscribe(conn1, "test-topic", tx1).await.unwrap();
manager.subscribe(conn2, "test-topic", tx2).await.unwrap();
let sent = manager
.publish("test-topic", Message::Text("Hello".to_string()))
.await
.unwrap();
assert_eq!(sent, 2);
let msg1 = rx1.recv().await.unwrap();
let msg2 = rx2.recv().await.unwrap();
match (msg1, msg2) {
(Message::Text(t1), Message::Text(t2)) => {
assert_eq!(t1, "Hello");
assert_eq!(t2, "Hello");
}
_ => panic!("Expected text messages"),
}
}
#[tokio::test]
async fn test_disconnect_cleanup() {
let manager = ChannelManager::new();
let (tx, _rx) = mpsc::channel(100);
let conn_id = Uuid::new_v4();
manager
.subscribe(conn_id, "topic1", tx.clone())
.await
.unwrap();
manager.subscribe(conn_id, "topic2", tx).await.unwrap();
assert_eq!(manager.connection_count().await, 1);
assert_eq!(manager.topic_count().await, 2);
manager.disconnect(conn_id).await;
assert_eq!(manager.connection_count().await, 0);
assert_eq!(manager.topic_count().await, 0);
}
}