use crate::models::SubscribeRequest;
use indexmap::IndexMap;
use std::collections::HashMap;
use std::sync::RwLock;
pub struct SubscriptionManager {
subscriptions: RwLock<IndexMap<String, SubscribeRequest>>,
server_ids: RwLock<HashMap<String, String>>,
}
impl SubscriptionManager {
pub fn new() -> Self {
Self {
subscriptions: RwLock::new(IndexMap::new()),
server_ids: RwLock::new(HashMap::new()),
}
}
pub fn subscribe(&self, req: SubscribeRequest) {
let key = req.key();
let mut subs = self.subscriptions.write().unwrap();
subs.insert(key, req);
}
pub fn unsubscribe(&self, key: &str) {
let mut subs = self.subscriptions.write().unwrap();
subs.shift_remove(key);
drop(subs);
self.server_ids.write().unwrap().remove(key);
}
pub fn record_server_id(&self, key: String, server_id: String) {
self.server_ids.write().unwrap().insert(key, server_id);
}
pub fn take_server_id(&self, key: &str) -> Option<String> {
self.server_ids.write().unwrap().remove(key)
}
pub fn clear_server_ids(&self) {
self.server_ids.write().unwrap().clear();
}
pub fn unsubscribe_by_channel_symbol(&self, channel: &str, symbol: &str) {
let key = format!("{}:{}", channel, symbol);
self.unsubscribe(&key);
}
pub fn get_all(&self) -> Vec<SubscribeRequest> {
let subs = self.subscriptions.read().unwrap();
subs.values().cloned().collect()
}
pub fn contains(&self, key: &str) -> bool {
let subs = self.subscriptions.read().unwrap();
subs.contains_key(key)
}
pub fn count(&self) -> usize {
let subs = self.subscriptions.read().unwrap();
subs.len()
}
pub fn clear(&self) {
let mut subs = self.subscriptions.write().unwrap();
subs.clear();
drop(subs);
self.server_ids.write().unwrap().clear();
}
pub fn keys(&self) -> Vec<String> {
let subs = self.subscriptions.read().unwrap();
subs.keys().cloned().collect()
}
}
impl Default for SubscriptionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Channel;
#[test]
fn test_subscribe_adds_to_state() {
let manager = SubscriptionManager::new();
let req = SubscribeRequest::new(Channel::Trades, "2330");
manager.subscribe(req.clone());
assert_eq!(manager.count(), 1);
assert!(manager.contains("trades:2330"));
let all = manager.get_all();
assert_eq!(all.len(), 1);
assert_eq!(all[0], req);
}
#[test]
fn test_unsubscribe_removes_from_state() {
let manager = SubscriptionManager::new();
let req = SubscribeRequest::new(Channel::Trades, "2330");
manager.subscribe(req.clone());
assert_eq!(manager.count(), 1);
manager.unsubscribe("trades:2330");
assert_eq!(manager.count(), 0);
assert!(!manager.contains("trades:2330"));
}
#[test]
fn test_insertion_order_preserved() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
manager.subscribe(SubscribeRequest::new(Channel::Candles, "2317"));
manager.subscribe(SubscribeRequest::new(Channel::Books, "2454"));
let all = manager.get_all();
assert_eq!(all.len(), 3);
assert_eq!(all[0].key(), "trades:2330");
assert_eq!(all[1].key(), "candles:2317");
assert_eq!(all[2].key(), "books:2454");
}
#[test]
fn test_unsubscribe_during_disconnect_removes() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
manager.subscribe(SubscribeRequest::new(Channel::Candles, "2317"));
assert_eq!(manager.count(), 2);
manager.unsubscribe("trades:2330");
assert_eq!(manager.count(), 1);
assert!(!manager.contains("trades:2330"));
assert!(manager.contains("candles:2317"));
let all = manager.get_all();
assert_eq!(all.len(), 1);
assert_eq!(all[0].key(), "candles:2317");
}
#[test]
fn test_get_all_returns_in_order() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Aggregates, "2330"));
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2317"));
manager.subscribe(SubscribeRequest::new(Channel::Books, "2454"));
manager.subscribe(SubscribeRequest::new(Channel::Candles, "2886"));
let all = manager.get_all();
assert_eq!(all.len(), 4);
assert_eq!(all[0].key(), "aggregates:2330");
assert_eq!(all[1].key(), "trades:2317");
assert_eq!(all[2].key(), "books:2454");
assert_eq!(all[3].key(), "candles:2886");
}
#[test]
fn test_unsubscribe_by_channel_symbol() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
assert!(manager.contains("trades:2330"));
manager.unsubscribe_by_channel_symbol("trades", "2330");
assert!(!manager.contains("trades:2330"));
assert_eq!(manager.count(), 0);
}
#[test]
fn test_clear_removes_all() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
manager.subscribe(SubscribeRequest::new(Channel::Candles, "2317"));
manager.subscribe(SubscribeRequest::new(Channel::Books, "2454"));
assert_eq!(manager.count(), 3);
manager.clear();
assert_eq!(manager.count(), 0);
assert!(manager.get_all().is_empty());
}
#[test]
fn test_subscribe_updates_existing() {
let manager = SubscriptionManager::new();
let req1 = SubscribeRequest::new(Channel::Trades, "2330");
manager.subscribe(req1);
assert_eq!(manager.count(), 1);
let req2 = SubscribeRequest::new(Channel::Trades, "2330");
manager.subscribe(req2);
assert_eq!(manager.count(), 1);
}
#[test]
fn test_server_id_record_and_take() {
let manager = SubscriptionManager::new();
assert!(manager.take_server_id("trades:2330").is_none());
manager.record_server_id("trades:2330".into(), "sub-xyz".into());
assert_eq!(manager.take_server_id("trades:2330"), Some("sub-xyz".into()));
assert!(manager.take_server_id("trades:2330").is_none());
}
#[test]
fn test_server_id_overwrites_on_reconnect() {
let manager = SubscriptionManager::new();
manager.record_server_id("trades:2330".into(), "sub-old".into());
manager.record_server_id("trades:2330".into(), "sub-new".into());
assert_eq!(manager.take_server_id("trades:2330"), Some("sub-new".into()));
}
#[test]
fn test_unsubscribe_drops_server_id() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
manager.record_server_id("trades:2330".into(), "sub-xyz".into());
manager.unsubscribe("trades:2330");
assert!(manager.take_server_id("trades:2330").is_none());
}
#[test]
fn test_clear_server_ids() {
let manager = SubscriptionManager::new();
manager.record_server_id("trades:2330".into(), "sub-a".into());
manager.record_server_id("books:2317".into(), "sub-b".into());
manager.clear_server_ids();
assert!(manager.take_server_id("trades:2330").is_none());
assert!(manager.take_server_id("books:2317").is_none());
}
#[test]
fn test_clear_also_clears_server_ids() {
let manager = SubscriptionManager::new();
manager.subscribe(SubscribeRequest::new(Channel::Trades, "2330"));
manager.record_server_id("trades:2330".into(), "sub-xyz".into());
manager.clear();
assert_eq!(manager.count(), 0);
assert!(manager.take_server_id("trades:2330").is_none());
}
}