use std::sync::Arc;
use tokio::sync::broadcast;
use super::{config::WsConfig, protocol::WsMessage, types::Topic};
struct SubscriptionEntry {
sender: broadcast::Sender<WsMessage>,
ref_count: usize,
}
pub struct SubscriptionStore {
subscriptions: scc::HashMap<Topic, SubscriptionEntry>,
config: Arc<WsConfig>,
}
impl SubscriptionStore {
pub fn new(config: Arc<WsConfig>) -> Self {
Self {
subscriptions: scc::HashMap::new(),
config,
}
}
pub fn subscribe(&self, topic: Topic) -> (broadcast::Receiver<WsMessage>, bool) {
if let Some(receiver) = self.subscriptions.update_sync(&topic, |_, entry| {
entry.ref_count += 1;
entry.sender.subscribe()
}) {
return (receiver, false);
}
let (sender, receiver) = broadcast::channel(self.config.subscription_channel_capacity);
let entry = SubscriptionEntry {
sender,
ref_count: 1,
};
if let Err((_, _entry)) = self.subscriptions.insert_sync(topic.clone(), entry) {
if let Some(receiver) = self.subscriptions.update_sync(&topic, |_, entry| {
entry.ref_count += 1;
entry.sender.subscribe()
}) {
return (receiver, false);
}
}
(receiver, true)
}
pub fn add_subscriber(&self, topic: &Topic) -> Option<broadcast::Receiver<WsMessage>> {
self.subscriptions.update_sync(topic, |_, entry| {
entry.ref_count += 1;
entry.sender.subscribe()
})
}
pub fn unsubscribe(&self, topic: &Topic) -> bool {
self.subscriptions
.remove_if_sync(topic, |entry| {
entry.ref_count = entry.ref_count.saturating_sub(1);
entry.ref_count == 0
})
.is_some()
}
pub fn decrement_ref(&self, topic: &Topic) -> Option<usize> {
let mut remaining = None;
let removed = self.subscriptions.remove_if_sync(topic, |entry| {
entry.ref_count = entry.ref_count.saturating_sub(1);
remaining = Some(entry.ref_count);
entry.ref_count == 0
});
if removed.is_some() || remaining.is_some() {
remaining
} else {
None
}
}
pub fn publish(&self, topic: &Topic, message: WsMessage) -> bool {
self.subscriptions
.update_sync(topic, |_, entry| {
let _ = entry.sender.send(message.clone());
})
.is_some()
}
pub fn get_all_topics(&self) -> Vec<Topic> {
let mut topics = Vec::new();
self.subscriptions.retain_sync(|topic, _| {
topics.push(topic.clone());
true
});
topics
}
pub fn subscriber_count(&self, topic: &Topic) -> usize {
self.subscriptions
.update_sync(topic, |_, entry| entry.ref_count)
.unwrap_or(0)
}
pub fn len(&self) -> usize {
self.subscriptions.len()
}
pub fn is_empty(&self) -> bool {
self.subscriptions.is_empty()
}
pub fn clear(&self) {
self.subscriptions.clear_sync();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> Arc<WsConfig> {
Arc::new(WsConfig::new("wss://test.com"))
}
#[test]
fn test_subscribe_new_topic() {
let store = SubscriptionStore::new(test_config());
let topic = Topic::new("orderbook.BTC");
let (_, is_new) = store.subscribe(topic.clone());
assert!(is_new);
assert_eq!(store.len(), 1);
assert_eq!(store.subscriber_count(&topic), 1);
}
#[test]
fn test_subscribe_existing_topic() {
let store = SubscriptionStore::new(test_config());
let topic = Topic::new("orderbook.BTC");
let (_, is_new1) = store.subscribe(topic.clone());
assert!(is_new1);
let (_, is_new2) = store.subscribe(topic.clone());
assert!(!is_new2);
assert_eq!(store.len(), 1);
assert_eq!(store.subscriber_count(&topic), 2);
}
#[test]
fn test_unsubscribe_decrements_count() {
let store = SubscriptionStore::new(test_config());
let topic = Topic::new("orderbook.BTC");
store.subscribe(topic.clone());
store.subscribe(topic.clone());
let removed = store.unsubscribe(&topic);
assert!(!removed);
assert_eq!(store.subscriber_count(&topic), 1);
let removed = store.unsubscribe(&topic);
assert!(removed);
assert_eq!(store.len(), 0);
}
#[test]
fn test_decrement_ref_removes_on_zero() {
let store = SubscriptionStore::new(test_config());
let topic = Topic::new("orderbook.BTC");
store.subscribe(topic.clone());
store.subscribe(topic.clone());
assert_eq!(store.decrement_ref(&topic), Some(1));
assert_eq!(store.subscriber_count(&topic), 1);
assert_eq!(store.decrement_ref(&topic), Some(0));
assert_eq!(store.subscriber_count(&topic), 0);
assert!(store.is_empty());
}
#[test]
fn test_decrement_ref_missing_returns_none() {
let store = SubscriptionStore::new(test_config());
let topic = Topic::new("orderbook.BTC");
assert_eq!(store.decrement_ref(&topic), None);
}
#[test]
fn test_publish() {
let store = SubscriptionStore::new(test_config());
let topic = Topic::new("trades.ETH");
let (mut rx, _) = store.subscribe(topic.clone());
let published = store.publish(&topic, WsMessage::text("test message"));
assert!(published);
let received = rx.try_recv();
assert!(received.is_ok());
}
#[test]
fn test_get_all_topics() {
let store = SubscriptionStore::new(test_config());
store.subscribe(Topic::new("topic1"));
store.subscribe(Topic::new("topic2"));
store.subscribe(Topic::new("topic3"));
let topics = store.get_all_topics();
assert_eq!(topics.len(), 3);
}
}