use crate::error::Result;
use crate::types::{protocol::ResourceUpdatedParams, ServerNotification};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct SubscriptionManager {
subscriptions: Arc<RwLock<HashMap<String, HashSet<String>>>>,
notification_sender: Option<Arc<dyn Fn(ServerNotification) + Send + Sync>>,
}
impl Default for SubscriptionManager {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for SubscriptionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SubscriptionManager")
.field(
"subscriptions",
&self.subscriptions.try_read().map_or(0, |s| s.len()),
)
.finish()
}
}
impl SubscriptionManager {
pub fn new() -> Self {
Self {
subscriptions: Arc::new(RwLock::new(HashMap::new())),
notification_sender: None,
}
}
pub fn set_notification_sender<F>(&mut self, sender: F)
where
F: Fn(ServerNotification) + Send + Sync + 'static,
{
self.notification_sender = Some(Arc::new(sender));
}
pub async fn subscribe(&self, uri: String, subscriber_id: String) -> Result<()> {
self.subscriptions
.write()
.await
.entry(uri)
.or_default()
.insert(subscriber_id);
Ok(())
}
pub async fn unsubscribe(&self, uri: String, subscriber_id: String) -> Result<()> {
let mut subs = self.subscriptions.write().await;
if let Some(subscribers) = subs.get_mut(&uri) {
subscribers.remove(&subscriber_id);
if subscribers.is_empty() {
subs.remove(&uri);
drop(subs);
}
}
Ok(())
}
pub async fn unsubscribe_all(&self, subscriber_id: &str) -> Result<()> {
let mut subs = self.subscriptions.write().await;
let mut empty_uris = Vec::new();
for (uri, subscribers) in subs.iter_mut() {
subscribers.remove(subscriber_id);
if subscribers.is_empty() {
empty_uris.push(uri.clone());
}
}
for uri in empty_uris {
subs.remove(&uri);
}
drop(subs);
Ok(())
}
pub async fn has_subscribers(&self, uri: &str) -> bool {
let subs = self.subscriptions.read().await;
subs.get(uri).is_some_and(|s| !s.is_empty())
}
pub async fn get_subscriptions(&self, subscriber_id: &str) -> Vec<String> {
let subs = self.subscriptions.read().await;
subs.iter()
.filter_map(|(uri, subscribers)| {
if subscribers.contains(subscriber_id) {
Some(uri.clone())
} else {
None
}
})
.collect()
}
pub async fn get_subscribers(&self, uri: &str) -> Vec<String> {
let subs = self.subscriptions.read().await;
subs.get(uri)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default()
}
pub async fn notify_resource_updated(&self, uri: String) -> Result<usize> {
let subs = self.subscriptions.read().await;
if let Some(subscribers) = subs.get(&uri) {
let subscriber_count = subscribers.len();
drop(subs);
if subscriber_count > 0 {
if let Some(sender) = &self.notification_sender {
let notification =
ServerNotification::ResourceUpdated(ResourceUpdatedParams::new(&*uri));
sender(notification);
}
return Ok(subscriber_count);
}
}
Ok(0)
}
pub async fn get_stats(&self) -> SubscriptionStats {
let subs = self.subscriptions.read().await;
let total_resources = subs.len();
let total_subscriptions = subs.values().map(std::collections::HashSet::len).sum();
let mut subscriber_counts = HashMap::new();
for subscribers in subs.values() {
for subscriber in subscribers {
*subscriber_counts.entry(subscriber.clone()).or_insert(0) += 1;
}
}
drop(subs);
SubscriptionStats {
total_resources,
total_subscriptions,
unique_subscribers: subscriber_counts.len(),
subscriptions_per_resource: if total_resources > 0 {
#[allow(clippy::cast_precision_loss)]
{
total_subscriptions as f64 / total_resources as f64
}
} else {
0.0
},
}
}
}
#[derive(Debug, Clone)]
pub struct SubscriptionStats {
pub total_resources: usize,
pub total_subscriptions: usize,
pub unique_subscribers: usize,
pub subscriptions_per_resource: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_subscribe_unsubscribe() {
let manager = SubscriptionManager::new();
manager
.subscribe("file://test.txt".to_string(), "client1".to_string())
.await
.unwrap();
assert!(manager.has_subscribers("file://test.txt").await);
let subs = manager.get_subscriptions("client1").await;
assert_eq!(subs.len(), 1);
assert_eq!(subs[0], "file://test.txt");
manager
.unsubscribe("file://test.txt".to_string(), "client1".to_string())
.await
.unwrap();
assert!(!manager.has_subscribers("file://test.txt").await);
let subs = manager.get_subscriptions("client1").await;
assert_eq!(subs.len(), 0);
}
#[tokio::test]
async fn test_multiple_subscribers() {
let manager = SubscriptionManager::new();
manager
.subscribe("file://shared.txt".to_string(), "client1".to_string())
.await
.unwrap();
manager
.subscribe("file://shared.txt".to_string(), "client2".to_string())
.await
.unwrap();
let subscribers = manager.get_subscribers("file://shared.txt").await;
assert_eq!(subscribers.len(), 2);
assert!(subscribers.contains(&"client1".to_string()));
assert!(subscribers.contains(&"client2".to_string()));
manager
.unsubscribe("file://shared.txt".to_string(), "client1".to_string())
.await
.unwrap();
assert!(manager.has_subscribers("file://shared.txt").await);
let subscribers = manager.get_subscribers("file://shared.txt").await;
assert_eq!(subscribers.len(), 1);
assert_eq!(subscribers[0], "client2");
}
#[tokio::test]
async fn test_unsubscribe_all() {
let manager = SubscriptionManager::new();
manager
.subscribe("file://test1.txt".to_string(), "client1".to_string())
.await
.unwrap();
manager
.subscribe("file://test2.txt".to_string(), "client1".to_string())
.await
.unwrap();
manager
.subscribe("file://test3.txt".to_string(), "client1".to_string())
.await
.unwrap();
manager
.subscribe("file://test2.txt".to_string(), "client2".to_string())
.await
.unwrap();
manager.unsubscribe_all("client1").await.unwrap();
let subs = manager.get_subscriptions("client1").await;
assert_eq!(subs.len(), 0);
assert!(manager.has_subscribers("file://test2.txt").await);
assert!(!manager.has_subscribers("file://test1.txt").await);
assert!(!manager.has_subscribers("file://test3.txt").await);
}
#[tokio::test]
async fn test_stats() {
let manager = SubscriptionManager::new();
manager
.subscribe("file://test1.txt".to_string(), "client1".to_string())
.await
.unwrap();
manager
.subscribe("file://test1.txt".to_string(), "client2".to_string())
.await
.unwrap();
manager
.subscribe("file://test2.txt".to_string(), "client1".to_string())
.await
.unwrap();
manager
.subscribe("file://test3.txt".to_string(), "client3".to_string())
.await
.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_resources, 3);
assert_eq!(stats.total_subscriptions, 4);
assert_eq!(stats.unique_subscribers, 3);
assert!((stats.subscriptions_per_resource - 1.33).abs() < 0.01);
}
#[tokio::test]
async fn test_notify_resource_updated() {
use std::sync::Mutex;
let manager = SubscriptionManager::new();
let notifications = Arc::new(Mutex::new(Vec::new()));
let notifications_clone = notifications.clone();
let mut manager_mut = manager.clone();
manager_mut.set_notification_sender(move |notif| {
notifications_clone.lock().unwrap().push(notif);
});
manager_mut
.subscribe("file://test.txt".to_string(), "client1".to_string())
.await
.unwrap();
let count = manager_mut
.notify_resource_updated("file://test.txt".to_string())
.await
.unwrap();
assert_eq!(count, 1);
let notifs = notifications.lock().unwrap();
assert_eq!(notifs.len(), 1);
match ¬ifs[0] {
ServerNotification::ResourceUpdated(n) => assert_eq!(n.uri, "file://test.txt"),
_ => panic!("Wrong notification type"),
}
}
}