use std::{
num::NonZeroUsize,
sync::{Arc, LazyLock},
};
use ahash::AHashSet;
use dashmap::DashMap;
use ustr::Ustr;
pub(crate) static CHANNEL_LEVEL_MARKER: LazyLock<Ustr> = LazyLock::new(|| Ustr::from(""));
#[derive(Clone, Debug)]
pub struct SubscriptionState {
confirmed: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
pending_subscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
pending_unsubscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
reference_counts: Arc<DashMap<Ustr, NonZeroUsize>>,
delimiter: char,
}
impl SubscriptionState {
pub fn new(delimiter: char) -> Self {
Self {
confirmed: Arc::new(DashMap::new()),
pending_subscribe: Arc::new(DashMap::new()),
pending_unsubscribe: Arc::new(DashMap::new()),
reference_counts: Arc::new(DashMap::new()),
delimiter,
}
}
pub fn delimiter(&self) -> char {
self.delimiter
}
pub fn confirmed(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
Arc::clone(&self.confirmed)
}
pub fn pending_subscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
Arc::clone(&self.pending_subscribe)
}
pub fn pending_unsubscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
Arc::clone(&self.pending_unsubscribe)
}
pub fn len(&self) -> usize {
self.confirmed.iter().map(|entry| entry.value().len()).sum()
}
pub fn is_empty(&self) -> bool {
self.confirmed.is_empty()
&& self.pending_subscribe.is_empty()
&& self.pending_unsubscribe.is_empty()
}
pub fn is_subscribed(&self, channel: &Ustr, symbol: &Ustr) -> bool {
if let Some(symbols) = self.confirmed.get(channel)
&& symbols.contains(symbol)
{
return true;
}
if let Some(symbols) = self.pending_subscribe.get(channel)
&& symbols.contains(symbol)
{
return true;
}
false
}
pub fn mark_subscribe(&self, topic: &str) {
let (channel, symbol) = split_topic(topic, self.delimiter);
if is_tracked(&self.confirmed, channel, symbol) {
return;
}
untrack_topic(&self.pending_unsubscribe, channel, symbol);
track_topic(&self.pending_subscribe, channel, symbol);
}
pub fn try_mark_subscribe(&self, topic: &str) -> bool {
let (channel, symbol) = split_topic(topic, self.delimiter);
if is_tracked(&self.confirmed, channel, symbol) {
return false;
}
let channel_ustr = Ustr::from(channel);
let symbol_ustr = symbol.map_or(*CHANNEL_LEVEL_MARKER, Ustr::from);
let mut entry = self.pending_subscribe.entry(channel_ustr).or_default();
let inserted = entry.insert(symbol_ustr);
if inserted {
untrack_topic(&self.pending_unsubscribe, channel, symbol);
}
inserted
}
pub fn mark_unsubscribe(&self, topic: &str) {
let (channel, symbol) = split_topic(topic, self.delimiter);
track_topic(&self.pending_unsubscribe, channel, symbol);
untrack_topic(&self.confirmed, channel, symbol);
untrack_topic(&self.pending_subscribe, channel, symbol);
}
pub fn confirm_subscribe(&self, topic: &str) {
let (channel, symbol) = split_topic(topic, self.delimiter);
if is_tracked(&self.pending_unsubscribe, channel, symbol) {
return;
}
untrack_topic(&self.pending_subscribe, channel, symbol);
track_topic(&self.confirmed, channel, symbol);
}
pub fn confirm_unsubscribe(&self, topic: &str) {
let (channel, symbol) = split_topic(topic, self.delimiter);
if !is_tracked(&self.pending_unsubscribe, channel, symbol) {
return; }
untrack_topic(&self.pending_unsubscribe, channel, symbol);
untrack_topic(&self.confirmed, channel, symbol);
}
pub fn mark_failure(&self, topic: &str) {
let (channel, symbol) = split_topic(topic, self.delimiter);
if is_tracked(&self.pending_unsubscribe, channel, symbol) {
return;
}
untrack_topic(&self.confirmed, channel, symbol);
track_topic(&self.pending_subscribe, channel, symbol);
}
pub fn pending_subscribe_topics(&self) -> Vec<String> {
self.topics_from_map(&self.pending_subscribe)
}
pub fn pending_unsubscribe_topics(&self) -> Vec<String> {
self.topics_from_map(&self.pending_unsubscribe)
}
pub fn all_topics(&self) -> Vec<String> {
let mut topics = Vec::new();
topics.extend(self.topics_from_map(&self.confirmed));
topics.extend(self.topics_from_map(&self.pending_subscribe));
topics
}
fn topics_from_map(&self, map: &DashMap<Ustr, AHashSet<Ustr>>) -> Vec<String> {
let mut topics = Vec::new();
let marker = *CHANNEL_LEVEL_MARKER;
for entry in map {
let channel = entry.key();
let symbols = entry.value();
if symbols.contains(&marker) {
topics.push(channel.to_string());
}
for symbol in symbols {
if *symbol != marker {
topics.push(format!(
"{}{}{}",
channel.as_str(),
self.delimiter,
symbol.as_str()
));
}
}
}
topics
}
pub fn add_reference(&self, topic: &str) -> bool {
let mut should_subscribe = false;
let topic_ustr = Ustr::from(topic);
self.reference_counts
.entry(topic_ustr)
.and_modify(|count| {
*count = NonZeroUsize::new(count.get() + 1).expect("reference count overflow");
})
.or_insert_with(|| {
should_subscribe = true;
NonZeroUsize::new(1).expect("NonZeroUsize::new(1) should never fail")
});
should_subscribe
}
pub fn remove_reference(&self, topic: &str) -> bool {
let topic_ustr = Ustr::from(topic);
if let dashmap::mapref::entry::Entry::Occupied(mut entry) =
self.reference_counts.entry(topic_ustr)
{
let current = entry.get().get();
if current == 1 {
entry.remove();
return true;
}
*entry.get_mut() = NonZeroUsize::new(current - 1)
.expect("reference count should never reach zero here");
}
false
}
pub fn get_reference_count(&self, topic: &str) -> usize {
let topic_ustr = Ustr::from(topic);
self.reference_counts
.get(&topic_ustr)
.map_or(0, |count| count.get())
}
pub fn clear(&self) {
self.confirmed.clear();
self.pending_subscribe.clear();
self.pending_unsubscribe.clear();
self.reference_counts.clear();
}
}
pub fn split_topic(topic: &str, delimiter: char) -> (&str, Option<&str>) {
topic
.split_once(delimiter)
.map_or((topic, None), |(channel, symbol)| (channel, Some(symbol)))
}
fn track_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
let channel_ustr = Ustr::from(channel);
let mut entry = map.entry(channel_ustr).or_default();
if let Some(symbol) = symbol {
entry.insert(Ustr::from(symbol));
} else {
entry.insert(*CHANNEL_LEVEL_MARKER);
}
}
fn untrack_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
let channel_ustr = Ustr::from(channel);
let symbol_to_remove = if let Some(symbol) = symbol {
Ustr::from(symbol)
} else {
*CHANNEL_LEVEL_MARKER
};
if let dashmap::mapref::entry::Entry::Occupied(mut entry) = map.entry(channel_ustr) {
entry.get_mut().remove(&symbol_to_remove);
if entry.get().is_empty() {
entry.remove();
}
}
}
fn is_tracked(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) -> bool {
let channel_ustr = Ustr::from(channel);
let symbol_to_check = if let Some(symbol) = symbol {
Ustr::from(symbol)
} else {
*CHANNEL_LEVEL_MARKER
};
if let Some(entry) = map.get(&channel_ustr) {
entry.contains(&symbol_to_check)
} else {
false
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_split_topic_with_symbol() {
let (channel, symbol) = split_topic("tickers.BTCUSDT", '.');
assert_eq!(channel, "tickers");
assert_eq!(symbol, Some("BTCUSDT"));
let (channel, symbol) = split_topic("orderBookL2:XBTUSD", ':');
assert_eq!(channel, "orderBookL2");
assert_eq!(symbol, Some("XBTUSD"));
}
#[rstest]
fn test_split_topic_without_symbol() {
let (channel, symbol) = split_topic("orderbook", '.');
assert_eq!(channel, "orderbook");
assert_eq!(symbol, None);
}
#[rstest]
fn test_new_state_is_empty() {
let state = SubscriptionState::new('.');
assert!(state.is_empty());
assert_eq!(state.len(), 0);
}
#[rstest]
fn test_mark_subscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
assert_eq!(state.len(), 0); }
#[rstest]
fn test_confirm_subscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert!(state.pending_subscribe_topics().is_empty());
assert_eq!(state.len(), 1);
}
#[rstest]
fn test_is_subscribed_empty_state() {
let state = SubscriptionState::new('.');
let channel = Ustr::from("tickers");
let symbol = Ustr::from("BTCUSDT");
assert!(!state.is_subscribed(&channel, &symbol));
}
#[rstest]
fn test_is_subscribed_pending() {
let state = SubscriptionState::new('.');
let channel = Ustr::from("tickers");
let symbol = Ustr::from("BTCUSDT");
state.mark_subscribe("tickers.BTCUSDT");
assert!(state.is_subscribed(&channel, &symbol));
}
#[rstest]
fn test_is_subscribed_confirmed() {
let state = SubscriptionState::new('.');
let channel = Ustr::from("tickers");
let symbol = Ustr::from("BTCUSDT");
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert!(state.is_subscribed(&channel, &symbol));
}
#[rstest]
fn test_is_subscribed_after_unsubscribe() {
let state = SubscriptionState::new('.');
let channel = Ustr::from("tickers");
let symbol = Ustr::from("BTCUSDT");
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
assert!(!state.is_subscribed(&channel, &symbol));
}
#[rstest]
fn test_is_subscribed_after_confirm_unsubscribe() {
let state = SubscriptionState::new('.');
let channel = Ustr::from("tickers");
let symbol = Ustr::from("BTCUSDT");
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(!state.is_subscribed(&channel, &symbol));
}
#[rstest]
fn test_mark_unsubscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 0); assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
}
#[rstest]
fn test_confirm_unsubscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(state.is_empty());
}
#[rstest]
fn test_resubscribe_before_unsubscribe_ack() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
state.mark_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
state.mark_subscribe("tickers.BTCUSDT");
assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(state.pending_unsubscribe_topics().is_empty());
assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
assert!(state.pending_subscribe_topics().is_empty());
let all = state.all_topics();
assert_eq!(all.len(), 1);
assert!(all.contains(&"tickers.BTCUSDT".to_string()));
}
#[rstest]
fn test_stale_unsubscribe_ack_after_resubscribe_confirmed() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
state.mark_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
state.mark_subscribe("tickers.BTCUSDT");
assert!(state.pending_unsubscribe_topics().is_empty()); assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1); assert!(state.pending_subscribe_topics().is_empty());
state.confirm_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1); assert!(state.pending_unsubscribe_topics().is_empty());
assert!(state.pending_subscribe_topics().is_empty());
let all = state.all_topics();
assert_eq!(all.len(), 1);
assert!(all.contains(&"tickers.BTCUSDT".to_string()));
}
#[rstest]
fn test_mark_failure() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_failure("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
}
#[rstest]
fn test_all_topics_includes_confirmed_and_pending_subscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_subscribe("tickers.ETHUSDT");
let topics = state.all_topics();
assert_eq!(topics.len(), 2);
assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
}
#[rstest]
fn test_all_topics_excludes_pending_unsubscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
let topics = state.all_topics();
assert!(topics.is_empty());
}
#[rstest]
fn test_reference_counting_single_topic() {
let state = SubscriptionState::new('.');
assert!(state.add_reference("tickers.BTCUSDT"));
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
assert!(!state.add_reference("tickers.BTCUSDT"));
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
assert!(!state.remove_reference("tickers.BTCUSDT"));
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
assert!(state.remove_reference("tickers.BTCUSDT"));
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
}
#[rstest]
fn test_reference_counting_multiple_topics() {
let state = SubscriptionState::new('.');
assert!(state.add_reference("tickers.BTCUSDT"));
assert!(state.add_reference("tickers.ETHUSDT"));
assert!(!state.add_reference("tickers.BTCUSDT"));
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
assert_eq!(state.get_reference_count("tickers.ETHUSDT"), 1);
assert!(!state.remove_reference("tickers.BTCUSDT"));
assert!(state.remove_reference("tickers.ETHUSDT"));
}
#[rstest]
fn test_topic_without_symbol() {
let state = SubscriptionState::new('.');
state.mark_subscribe("orderbook");
state.confirm_subscribe("orderbook");
assert_eq!(state.len(), 1);
assert_eq!(state.all_topics(), vec!["orderbook"]);
}
#[rstest]
fn test_different_delimiters() {
let state_dot = SubscriptionState::new('.');
state_dot.mark_subscribe("tickers.BTCUSDT");
assert_eq!(
state_dot.pending_subscribe_topics(),
vec!["tickers.BTCUSDT"]
);
let state_colon = SubscriptionState::new(':');
state_colon.mark_subscribe("orderBookL2:XBTUSD");
assert_eq!(
state_colon.pending_subscribe_topics(),
vec!["orderBookL2:XBTUSD"]
);
}
#[rstest]
fn test_clear() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.add_reference("tickers.BTCUSDT");
state.clear();
assert!(state.is_empty());
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
}
#[rstest]
fn test_multiple_symbols_same_channel() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.mark_subscribe("tickers.ETHUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.ETHUSDT");
assert_eq!(state.len(), 2);
let topics = state.all_topics();
assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
}
#[rstest]
fn test_mixed_channel_and_symbol_subscriptions() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers");
state.confirm_subscribe("tickers");
assert_eq!(state.len(), 1);
assert_eq!(state.all_topics(), vec!["tickers"]);
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 2);
let topics = state.all_topics();
assert_eq!(topics.len(), 2);
assert!(topics.contains(&"tickers".to_string()));
assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
state.mark_subscribe("tickers.ETHUSDT");
state.confirm_subscribe("tickers.ETHUSDT");
assert_eq!(state.len(), 3);
let topics = state.all_topics();
assert_eq!(topics.len(), 3);
assert!(topics.contains(&"tickers".to_string()));
assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
state.mark_unsubscribe("tickers");
state.confirm_unsubscribe("tickers");
assert_eq!(state.len(), 2);
let topics = state.all_topics();
assert_eq!(topics.len(), 2);
assert!(!topics.contains(&"tickers".to_string()));
assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
}
#[rstest]
fn test_symbol_subscription_before_channel() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
state.mark_subscribe("tickers");
state.confirm_subscribe("tickers");
assert_eq!(state.len(), 2);
let topics = state.all_topics();
assert_eq!(topics.len(), 2);
assert!(topics.contains(&"tickers".to_string()));
assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
}
#[rstest]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_concurrent_subscribe_same_topic() {
let state = Arc::new(SubscriptionState::new('.'));
let mut handles = vec![];
for _ in 0..10 {
let state_clone = Arc::clone(&state);
let handle = tokio::spawn(async move {
state_clone.add_reference("tickers.BTCUSDT");
state_clone.mark_subscribe("tickers.BTCUSDT");
state_clone.confirm_subscribe("tickers.BTCUSDT");
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 10);
assert_eq!(state.len(), 1);
}
#[rstest]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_concurrent_subscribe_unsubscribe() {
let state = Arc::new(SubscriptionState::new('.'));
let mut handles = vec![];
for i in 0..20 {
let state_clone = Arc::clone(&state);
let handle = tokio::spawn(async move {
let topic = format!("tickers.SYMBOL{i}");
state_clone.add_reference(&topic);
state_clone.add_reference(&topic);
state_clone.mark_subscribe(&topic);
state_clone.confirm_subscribe(&topic);
state_clone.remove_reference(&topic);
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
for i in 0..20 {
let topic = format!("tickers.SYMBOL{i}");
assert_eq!(state.get_reference_count(&topic), 1);
}
assert_eq!(state.len(), 20);
assert!(!state.is_empty());
}
#[rstest]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_concurrent_reference_counting_same_topic() {
let state = Arc::new(SubscriptionState::new('.'));
let topic = "tickers.BTCUSDT";
let mut handles = vec![];
for _ in 0..10 {
let state_clone = Arc::clone(&state);
let handle = tokio::spawn(async move {
for _ in 0..10 {
state_clone.add_reference(topic);
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(state.get_reference_count(topic), 100);
for _ in 0..50 {
state.remove_reference(topic);
}
assert_eq!(state.get_reference_count(topic), 50);
}
#[rstest]
fn test_reconnection_scenario() {
let state = SubscriptionState::new('.');
state.add_reference("tickers.BTCUSDT");
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.add_reference("tickers.ETHUSDT");
state.mark_subscribe("tickers.ETHUSDT");
state.confirm_subscribe("tickers.ETHUSDT");
state.add_reference("orderbook");
state.mark_subscribe("orderbook");
state.confirm_subscribe("orderbook");
assert_eq!(state.len(), 3);
let topics_to_resubscribe = state.all_topics();
assert_eq!(topics_to_resubscribe.len(), 3);
assert!(topics_to_resubscribe.contains(&"tickers.BTCUSDT".to_string()));
assert!(topics_to_resubscribe.contains(&"tickers.ETHUSDT".to_string()));
assert!(topics_to_resubscribe.contains(&"orderbook".to_string()));
for topic in &topics_to_resubscribe {
state.mark_subscribe(topic);
}
for topic in &topics_to_resubscribe {
state.confirm_subscribe(topic);
}
assert_eq!(state.len(), 3);
assert_eq!(state.all_topics().len(), 3);
}
#[rstest]
fn test_state_machine_invalid_transitions() {
let state = SubscriptionState::new('.');
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
state.confirm_unsubscribe("tickers.ETHUSDT");
assert_eq!(state.len(), 1);
state.mark_subscribe("orderbook");
state.confirm_subscribe("orderbook");
state.confirm_subscribe("orderbook"); assert_eq!(state.len(), 2);
state.mark_unsubscribe("nonexistent");
state.confirm_unsubscribe("nonexistent");
assert_eq!(state.len(), 2); }
#[rstest]
fn test_mark_failure_moves_to_pending() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
assert!(state.pending_subscribe_topics().is_empty());
state.mark_failure("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
assert_eq!(state.all_topics(), vec!["tickers.BTCUSDT"]);
}
#[rstest]
fn test_pending_subscribe_excludes_pending_unsubscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
assert!(state.all_topics().is_empty());
assert_eq!(state.len(), 0);
}
#[rstest]
fn test_remove_reference_nonexistent_topic() {
let state = SubscriptionState::new('.');
let should_unsubscribe = state.remove_reference("nonexistent");
assert!(!should_unsubscribe);
assert_eq!(state.get_reference_count("nonexistent"), 0);
}
#[rstest]
fn test_edge_case_empty_channel_name() {
let state = SubscriptionState::new('.');
state.mark_subscribe("");
state.confirm_subscribe("");
assert_eq!(state.len(), 1);
assert_eq!(state.all_topics(), vec![""]);
}
#[rstest]
fn test_special_characters_in_topics() {
let state = SubscriptionState::new('.');
let special_topics = vec![
"channel.symbol-with-dash",
"channel.SYMBOL_WITH_UNDERSCORE",
"channel.symbol123",
"channel.symbol@special",
];
for topic in &special_topics {
state.mark_subscribe(topic);
state.confirm_subscribe(topic);
}
assert_eq!(state.len(), special_topics.len());
let all_topics = state.all_topics();
for topic in &special_topics {
assert!(
all_topics.contains(&(*topic).to_string()),
"Missing topic: {topic}"
);
}
}
#[rstest]
fn test_clear_resets_all_state() {
let state = SubscriptionState::new('.');
for i in 0..10 {
let topic = format!("channel{i}.SYMBOL");
state.add_reference(&topic);
state.add_reference(&topic); state.mark_subscribe(&topic);
state.confirm_subscribe(&topic);
}
assert_eq!(state.len(), 10);
assert!(!state.is_empty());
state.clear();
assert_eq!(state.len(), 0);
assert!(state.is_empty());
assert!(state.all_topics().is_empty());
assert!(state.pending_subscribe_topics().is_empty());
assert!(state.pending_unsubscribe_topics().is_empty());
for i in 0..10 {
let topic = format!("channel{i}.SYMBOL");
assert_eq!(state.get_reference_count(&topic), 0);
}
}
#[rstest]
fn test_different_delimiter_does_not_affect_storage() {
let state_dot = SubscriptionState::new('.');
let state_colon = SubscriptionState::new(':');
state_dot.mark_subscribe("channel.SYMBOL");
state_colon.mark_subscribe("channel:SYMBOL");
assert_eq!(state_dot.pending_subscribe_topics(), vec!["channel.SYMBOL"]);
assert_eq!(
state_colon.pending_subscribe_topics(),
vec!["channel:SYMBOL"]
);
}
#[rstest]
fn test_unsubscribe_before_subscribe_confirmed() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
state.mark_unsubscribe("tickers.BTCUSDT");
assert!(state.pending_subscribe_topics().is_empty());
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(state.is_empty());
assert!(state.all_topics().is_empty());
assert_eq!(state.len(), 0);
}
#[rstest]
fn test_late_subscribe_confirmation_after_unsubscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.mark_unsubscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert!(state.pending_subscribe_topics().is_empty());
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(state.is_empty());
assert!(state.all_topics().is_empty());
}
#[rstest]
fn test_unsubscribe_clears_all_states() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
state.mark_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
state.confirm_subscribe("tickers.BTCUSDT");
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(state.is_empty());
assert_eq!(state.len(), 0);
assert!(state.pending_subscribe_topics().is_empty());
assert!(state.pending_unsubscribe_topics().is_empty());
assert!(state.all_topics().is_empty());
}
#[rstest]
fn test_mark_failure_respects_pending_unsubscribe() {
let state = SubscriptionState::new('.');
state.mark_subscribe("tickers.BTCUSDT");
state.confirm_subscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 1);
state.mark_unsubscribe("tickers.BTCUSDT");
assert_eq!(state.len(), 0);
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
state.mark_failure("tickers.BTCUSDT");
assert!(state.pending_subscribe_topics().is_empty());
assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
assert!(state.all_topics().is_empty());
state.confirm_unsubscribe("tickers.BTCUSDT");
assert!(state.is_empty());
}
#[rstest]
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_concurrent_stress_mixed_operations() {
let state = Arc::new(SubscriptionState::new('.'));
let mut handles = vec![];
for i in 0..50 {
let state_clone = Arc::clone(&state);
let handle = tokio::spawn(async move {
let topic1 = format!("channel.SYMBOL{i}");
let topic2 = format!("channel.SYMBOL{}", i + 100);
state_clone.add_reference(&topic1);
state_clone.add_reference(&topic2);
state_clone.mark_subscribe(&topic1);
state_clone.confirm_subscribe(&topic1);
state_clone.mark_subscribe(&topic2);
if i % 3 == 0 {
state_clone.mark_unsubscribe(&topic1);
state_clone.confirm_unsubscribe(&topic1);
}
state_clone.add_reference(&topic2);
state_clone.remove_reference(&topic2);
state_clone.confirm_subscribe(&topic2);
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let all = state.all_topics();
let confirmed_count = state.len();
assert!(confirmed_count > 50); assert!(confirmed_count <= 100); assert_eq!(
all.len(),
confirmed_count + state.pending_subscribe_topics().len()
);
}
#[rstest]
fn test_edge_case_malformed_topics() {
let state = SubscriptionState::new('.');
state.mark_subscribe("channel.symbol.extra");
state.confirm_subscribe("channel.symbol.extra");
let topics = state.all_topics();
assert!(topics.contains(&"channel.symbol.extra".to_string()));
state.mark_subscribe(".channel");
state.confirm_subscribe(".channel");
assert_eq!(state.len(), 2);
state.mark_subscribe("channel.");
state.confirm_subscribe("channel.");
assert_eq!(state.len(), 3);
state.mark_subscribe("tickers");
state.confirm_subscribe("tickers");
assert_eq!(state.len(), 4);
let all = state.all_topics();
assert_eq!(all.len(), 4);
assert!(all.contains(&"channel.symbol.extra".to_string()));
assert!(all.contains(&".channel".to_string()));
assert!(all.contains(&"channel".to_string())); assert!(all.contains(&"tickers".to_string()));
}
#[rstest]
fn test_reference_count_underflow_safety() {
let state = SubscriptionState::new('.');
assert!(!state.remove_reference("never.added"));
assert_eq!(state.get_reference_count("never.added"), 0);
state.add_reference("once.added");
assert_eq!(state.get_reference_count("once.added"), 1);
assert!(state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
assert!(!state.remove_reference("once.added")); assert!(!state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
assert!(state.add_reference("once.added"));
assert_eq!(state.get_reference_count("once.added"), 1);
}
#[rstest]
fn test_reconnection_with_partial_state() {
let state = SubscriptionState::new('.');
state.mark_subscribe("confirmed.BTCUSDT");
state.confirm_subscribe("confirmed.BTCUSDT");
state.mark_subscribe("pending.ETHUSDT");
state.mark_subscribe("cancelled.XRPUSDT");
state.confirm_subscribe("cancelled.XRPUSDT");
state.mark_unsubscribe("cancelled.XRPUSDT");
assert_eq!(state.len(), 1); let all = state.all_topics();
assert_eq!(all.len(), 2); assert!(all.contains(&"confirmed.BTCUSDT".to_string()));
assert!(all.contains(&"pending.ETHUSDT".to_string()));
assert!(!all.contains(&"cancelled.XRPUSDT".to_string()));
let topics_to_resubscribe = state.all_topics();
state.confirmed().clear();
for topic in &topics_to_resubscribe {
state.mark_subscribe(topic);
}
for topic in &topics_to_resubscribe {
state.confirm_subscribe(topic);
}
assert_eq!(state.len(), 2); let final_topics = state.all_topics();
assert_eq!(final_topics.len(), 2);
assert!(final_topics.contains(&"confirmed.BTCUSDT".to_string()));
assert!(final_topics.contains(&"pending.ETHUSDT".to_string()));
assert!(!final_topics.contains(&"cancelled.XRPUSDT".to_string()));
}
fn check_invariants(state: &SubscriptionState, label: &str) {
let confirmed_topics: AHashSet<String> = state
.topics_from_map(&state.confirmed)
.into_iter()
.collect();
let pending_sub_topics: AHashSet<String> =
state.pending_subscribe_topics().into_iter().collect();
let pending_unsub_topics: AHashSet<String> =
state.pending_unsubscribe_topics().into_iter().collect();
let confirmed_and_pending_sub: Vec<_> =
confirmed_topics.intersection(&pending_sub_topics).collect();
assert!(
confirmed_and_pending_sub.is_empty(),
"{label}: Topic in both confirmed and pending_subscribe: {confirmed_and_pending_sub:?}"
);
let confirmed_and_pending_unsub: Vec<_> = confirmed_topics
.intersection(&pending_unsub_topics)
.collect();
assert!(
confirmed_and_pending_unsub.is_empty(),
"{label}: Topic in both confirmed and pending_unsubscribe: {confirmed_and_pending_unsub:?}"
);
let pending_sub_and_unsub: Vec<_> = pending_sub_topics
.intersection(&pending_unsub_topics)
.collect();
assert!(
pending_sub_and_unsub.is_empty(),
"{label}: Topic in both pending_subscribe and pending_unsubscribe: {pending_sub_and_unsub:?}"
);
let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
let expected_all: AHashSet<String> = confirmed_topics
.union(&pending_sub_topics)
.cloned()
.collect();
assert_eq!(
all_topics, expected_all,
"{label}: all_topics() doesn't match confirmed ∪ pending_subscribe"
);
for topic in &pending_unsub_topics {
assert!(
!all_topics.contains(topic),
"{label}: pending_unsubscribe topic {topic} incorrectly in all_topics()"
);
}
let expected_len: usize = state
.confirmed
.iter()
.map(|entry| entry.value().len())
.sum();
assert_eq!(
state.len(),
expected_len,
"{label}: len() mismatch. Expected {expected_len}, was {}",
state.len()
);
let should_be_empty = state.confirmed.is_empty()
&& pending_sub_topics.is_empty()
&& pending_unsub_topics.is_empty();
assert_eq!(
state.is_empty(),
should_be_empty,
"{label}: is_empty() inconsistent. Maps empty: {should_be_empty}, is_empty(): {}",
state.is_empty()
);
for entry in state.reference_counts.iter() {
let count = entry.value().get();
assert!(
count > 0,
"{label}: Reference count should be NonZeroUsize (> 0), was {count} for {:?}",
entry.key()
);
}
}
fn check_topic_exclusivity(state: &SubscriptionState, topic: &str, label: &str) {
let (channel, symbol) = split_topic(topic, state.delimiter);
let in_confirmed = is_tracked(&state.confirmed, channel, symbol);
let in_pending_sub = is_tracked(&state.pending_subscribe, channel, symbol);
let in_pending_unsub = is_tracked(&state.pending_unsubscribe, channel, symbol);
let count = [in_confirmed, in_pending_sub, in_pending_unsub]
.iter()
.filter(|&&x| x)
.count();
assert!(
count <= 1,
"{label}: Topic {topic} in {count} states (should be 0 or 1). \
confirmed: {in_confirmed}, pending_sub: {in_pending_sub}, pending_unsub: {in_pending_unsub}"
);
}
#[cfg(test)]
mod property_tests {
use proptest::prelude::*;
use super::*;
#[derive(Debug, Clone)]
enum Operation {
MarkSubscribe(String),
ConfirmSubscribe(String),
MarkUnsubscribe(String),
ConfirmUnsubscribe(String),
MarkFailure(String),
AddReference(String),
RemoveReference(String),
Clear,
}
fn topic_strategy() -> impl Strategy<Value = String> {
prop_oneof![
(any::<u8>(), any::<u8>())
.prop_map(|(ch, sym)| { format!("channel{}.SYMBOL{}", ch % 5, sym % 10) }),
any::<u8>().prop_map(|ch| format!("channel{}", ch % 5)),
]
}
fn operation_strategy() -> impl Strategy<Value = Operation> {
topic_strategy().prop_flat_map(|topic| {
prop_oneof![
Just(Operation::MarkSubscribe(topic.clone())),
Just(Operation::ConfirmSubscribe(topic.clone())),
Just(Operation::MarkUnsubscribe(topic.clone())),
Just(Operation::ConfirmUnsubscribe(topic.clone())),
Just(Operation::MarkFailure(topic.clone())),
Just(Operation::AddReference(topic.clone())),
Just(Operation::RemoveReference(topic)),
Just(Operation::Clear),
]
})
}
fn apply_operation(state: &SubscriptionState, op: &Operation) {
match op {
Operation::MarkSubscribe(topic) => state.mark_subscribe(topic),
Operation::ConfirmSubscribe(topic) => state.confirm_subscribe(topic),
Operation::MarkUnsubscribe(topic) => state.mark_unsubscribe(topic),
Operation::ConfirmUnsubscribe(topic) => state.confirm_unsubscribe(topic),
Operation::MarkFailure(topic) => state.mark_failure(topic),
Operation::AddReference(topic) => {
state.add_reference(topic);
}
Operation::RemoveReference(topic) => {
state.remove_reference(topic);
}
Operation::Clear => state.clear(),
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[rstest]
fn prop_invariants_hold_after_operations(
operations in prop::collection::vec(operation_strategy(), 1..50)
) {
let state = SubscriptionState::new('.');
for (i, op) in operations.iter().enumerate() {
apply_operation(&state, op);
check_invariants(&state, &format!("After op {i}: {op:?}"));
}
check_invariants(&state, "Final state");
}
#[rstest]
fn prop_reference_counting_consistency(
ops in prop::collection::vec(
topic_strategy().prop_flat_map(|t| {
prop_oneof![
Just(Operation::AddReference(t.clone())),
Just(Operation::RemoveReference(t)),
]
}),
1..100
)
) {
let state = SubscriptionState::new('.');
for op in &ops {
apply_operation(&state, op);
for entry in state.reference_counts.iter() {
assert!(entry.value().get() > 0);
}
}
}
#[rstest]
fn prop_all_topics_is_union(
operations in prop::collection::vec(operation_strategy(), 1..50)
) {
let state = SubscriptionState::new('.');
for op in &operations {
apply_operation(&state, op);
let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
let confirmed: AHashSet<String> = state.topics_from_map(&state.confirmed).into_iter().collect();
let pending_sub: AHashSet<String> = state.pending_subscribe_topics().into_iter().collect();
let expected: AHashSet<String> = confirmed.union(&pending_sub).cloned().collect();
assert_eq!(all_topics, expected);
let pending_unsub: AHashSet<String> = state.pending_unsubscribe_topics().into_iter().collect();
for topic in pending_unsub {
assert!(!all_topics.contains(&topic));
}
}
}
#[rstest]
fn prop_clear_resets_completely(
operations in prop::collection::vec(operation_strategy(), 1..30)
) {
let state = SubscriptionState::new('.');
for op in &operations {
apply_operation(&state, op);
}
state.clear();
assert!(state.is_empty());
assert_eq!(state.len(), 0);
assert!(state.all_topics().is_empty());
assert!(state.pending_subscribe_topics().is_empty());
assert!(state.pending_unsubscribe_topics().is_empty());
assert!(state.confirmed.is_empty());
assert!(state.pending_subscribe.is_empty());
assert!(state.pending_unsubscribe.is_empty());
assert!(state.reference_counts.is_empty());
}
#[rstest]
fn prop_topic_mutual_exclusivity(
operations in prop::collection::vec(operation_strategy(), 1..50),
topic in topic_strategy()
) {
let state = SubscriptionState::new('.');
for (i, op) in operations.iter().enumerate() {
apply_operation(&state, op);
check_topic_exclusivity(&state, &topic, &format!("After op {i}: {op:?}"));
}
}
}
}
#[rstest]
fn test_exhaustive_two_step_transitions() {
let operations = [
"mark_subscribe",
"confirm_subscribe",
"mark_unsubscribe",
"confirm_unsubscribe",
"mark_failure",
];
for &op1 in &operations {
for &op2 in &operations {
let state = SubscriptionState::new('.');
let topic = "test.TOPIC";
apply_op(&state, op1, topic);
apply_op(&state, op2, topic);
check_invariants(&state, &format!("{op1} → {op2}"));
check_topic_exclusivity(&state, topic, &format!("{op1} → {op2}"));
}
}
}
fn apply_op(state: &SubscriptionState, op: &str, topic: &str) {
match op {
"mark_subscribe" => state.mark_subscribe(topic),
"confirm_subscribe" => state.confirm_subscribe(topic),
"mark_unsubscribe" => state.mark_unsubscribe(topic),
"confirm_unsubscribe" => state.confirm_unsubscribe(topic),
"mark_failure" => state.mark_failure(topic),
_ => panic!("Unknown operation: {op}"),
}
}
#[rstest]
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_stress_rapid_resubscribe_pattern() {
let state = Arc::new(SubscriptionState::new('.'));
let mut handles = vec![];
for i in 0..100 {
let state_clone = Arc::clone(&state);
let handle = tokio::spawn(async move {
let topic = format!("rapid.SYMBOL{}", i % 10);
state_clone.mark_subscribe(&topic);
state_clone.confirm_subscribe(&topic);
state_clone.mark_unsubscribe(&topic);
state_clone.mark_subscribe(&topic);
state_clone.confirm_unsubscribe(&topic);
state_clone.confirm_subscribe(&topic);
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
check_invariants(&state, "After rapid resubscribe stress test");
}
#[rstest]
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn test_stress_failure_recovery_loop() {
let state = Arc::new(SubscriptionState::new('.'));
let mut handles = vec![];
for i in 0..30 {
let state_clone = Arc::clone(&state);
let handle = tokio::spawn(async move {
let topic = format!("failure.SYMBOL{i}");
state_clone.mark_subscribe(&topic);
state_clone.confirm_subscribe(&topic);
for _ in 0..5 {
state_clone.mark_failure(&topic);
state_clone.confirm_subscribe(&topic); }
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
check_invariants(&state, "After failure recovery loops");
assert_eq!(state.len(), 30);
}
}