use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SubscriptionType {
Ticker,
OrderBook,
Trades,
Kline(String),
Balance,
Orders,
Positions,
MyTrades,
MarkPrice,
BookTicker,
}
impl SubscriptionType {
pub fn from_stream(stream: &str) -> Option<Self> {
if stream.contains("@ticker") {
Some(Self::Ticker)
} else if stream.contains("@depth") {
Some(Self::OrderBook)
} else if stream.contains("@trade") || stream.contains("@aggTrade") {
Some(Self::Trades)
} else if stream.contains("@kline_") {
let parts: Vec<&str> = stream.split("@kline_").collect();
if parts.len() == 2 {
Some(Self::Kline(parts[1].to_string()))
} else {
None
}
} else if stream.contains("@markPrice") {
Some(Self::MarkPrice)
} else if stream.contains("@bookTicker") {
Some(Self::BookTicker)
} else {
None
}
}
}
#[derive(Clone)]
pub struct Subscription {
pub stream: String,
pub symbol: String,
pub sub_type: SubscriptionType,
pub subscribed_at: Instant,
senders: Arc<std::sync::Mutex<Vec<tokio::sync::mpsc::Sender<Value>>>>,
ref_count: Arc<AtomicUsize>,
}
impl Subscription {
pub fn new(
stream: String,
symbol: String,
sub_type: SubscriptionType,
sender: tokio::sync::mpsc::Sender<Value>,
) -> Self {
Self {
stream,
symbol,
sub_type,
subscribed_at: Instant::now(),
senders: Arc::new(std::sync::Mutex::new(vec![sender])),
ref_count: Arc::new(AtomicUsize::new(1)),
}
}
pub fn add_sender(&self, sender: tokio::sync::mpsc::Sender<Value>) {
if let Ok(mut senders) = self.senders.lock() {
senders.push(sender);
}
}
pub fn send(&self, message: Value) -> bool {
if let Ok(mut senders) = self.senders.lock() {
let mut any_sent = false;
senders.retain(|sender| {
match sender.try_send(message.clone()) {
Ok(()) => {
any_sent = true;
true }
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
true
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
false
}
}
});
any_sent || !senders.is_empty()
} else {
false
}
}
pub fn add_ref(&self) -> usize {
self.ref_count.fetch_add(1, Ordering::SeqCst) + 1
}
pub fn remove_ref(&self) -> usize {
let prev = self.ref_count.fetch_sub(1, Ordering::SeqCst);
prev.saturating_sub(1)
}
pub fn ref_count(&self) -> usize {
self.ref_count.load(Ordering::SeqCst)
}
}
pub struct SubscriptionManager {
subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
symbol_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
active_count: Arc<std::sync::atomic::AtomicUsize>,
}
impl SubscriptionManager {
pub fn new() -> Self {
Self {
subscriptions: Arc::new(RwLock::new(HashMap::new())),
symbol_index: Arc::new(RwLock::new(HashMap::new())),
active_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
pub async fn add_subscription(
&self,
stream: String,
symbol: String,
sub_type: SubscriptionType,
sender: tokio::sync::mpsc::Sender<Value>,
) -> ccxt_core::error::Result<bool> {
let mut subs = self.subscriptions.write().await;
if let Some(existing) = subs.get(&stream) {
existing.add_sender(sender);
existing.add_ref();
return Ok(false);
}
let subscription = Subscription::new(stream.clone(), symbol.clone(), sub_type, sender);
subs.insert(stream.clone(), subscription);
let mut index = self.symbol_index.write().await;
index.entry(symbol).or_insert_with(Vec::new).push(stream);
self.active_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(true)
}
pub async fn remove_subscription(&self, stream: &str) -> ccxt_core::error::Result<bool> {
let mut subs = self.subscriptions.write().await;
if let Some(subscription) = subs.get(stream) {
let remaining = subscription.remove_ref();
if remaining > 0 {
return Ok(false);
}
let Some(subscription) = subs.remove(stream) else {
return Ok(false);
};
let mut index = self.symbol_index.write().await;
if let Some(streams) = index.get_mut(&subscription.symbol) {
streams.retain(|s| s != stream);
if streams.is_empty() {
index.remove(&subscription.symbol);
}
}
self.active_count
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
Ok(true)
} else {
Ok(false)
}
}
pub async fn get_subscription(&self, stream: &str) -> Option<Subscription> {
let subs = self.subscriptions.read().await;
subs.get(stream).cloned()
}
pub async fn has_subscription(&self, stream: &str) -> bool {
let subs = self.subscriptions.read().await;
subs.contains_key(stream)
}
pub async fn get_all_subscriptions(&self) -> Vec<Subscription> {
let subs = self.subscriptions.read().await;
subs.values().cloned().collect()
}
pub fn get_all_subscriptions_sync(&self) -> Vec<Subscription> {
if let Ok(subs) = self.subscriptions.try_read() {
subs.values().cloned().collect()
} else {
Vec::new()
}
}
pub async fn get_subscriptions_by_symbol(&self, symbol: &str) -> Vec<Subscription> {
let index = self.symbol_index.read().await;
let subs = self.subscriptions.read().await;
if let Some(streams) = index.get(symbol) {
streams
.iter()
.filter_map(|stream| subs.get(stream).cloned())
.collect()
} else {
Vec::new()
}
}
pub fn active_count(&self) -> usize {
self.active_count.load(std::sync::atomic::Ordering::SeqCst)
}
pub async fn clear(&self) {
let mut subs = self.subscriptions.write().await;
let mut index = self.symbol_index.write().await;
subs.clear();
index.clear();
self.active_count
.store(0, std::sync::atomic::Ordering::SeqCst);
}
pub async fn send_to_stream(&self, stream: &str, message: Value) -> bool {
let subs = self.subscriptions.read().await;
if let Some(subscription) = subs.get(stream) {
if subscription.send(message) {
return true;
}
} else {
return false;
}
drop(subs);
let _ = self.remove_subscription(stream).await;
false
}
pub async fn send_to_symbol(&self, symbol: &str, message: &Value) -> usize {
let index = self.symbol_index.read().await;
let subs = self.subscriptions.read().await;
let mut sent_count = 0;
let mut streams_to_remove = Vec::new();
if let Some(streams) = index.get(symbol) {
for stream in streams {
if let Some(subscription) = subs.get(stream) {
if subscription.send(message.clone()) {
sent_count += 1;
} else {
streams_to_remove.push(stream.clone());
}
}
}
}
drop(subs);
drop(index);
for stream in streams_to_remove {
let _ = self.remove_subscription(&stream).await;
}
sent_count
}
pub async fn get_active_streams(&self) -> Vec<String> {
let subs = self.subscriptions.read().await;
subs.keys().cloned().collect()
}
}
impl Default for SubscriptionManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub enabled: bool,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub max_attempts: usize,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
enabled: true,
initial_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
max_attempts: 0,
}
}
}
impl ReconnectConfig {
pub fn calculate_delay(&self, attempt: usize) -> u64 {
let delay = (self.initial_delay_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
delay.min(self.max_delay_ms as f64) as u64
}
pub fn should_retry(&self, attempt: usize) -> bool {
self.enabled && (self.max_attempts == 0 || attempt < self.max_attempts)
}
}
pub struct SubscriptionHandle {
stream: String,
subscription_manager: Arc<SubscriptionManager>,
message_router: Option<Arc<crate::binance::ws::handlers::MessageRouter>>,
released: bool,
}
impl SubscriptionHandle {
pub fn new(
stream: String,
subscription_manager: Arc<SubscriptionManager>,
message_router: Option<Arc<crate::binance::ws::handlers::MessageRouter>>,
) -> Self {
Self {
stream,
subscription_manager,
message_router,
released: false,
}
}
pub fn stream(&self) -> &str {
&self.stream
}
pub async fn release(mut self) -> ccxt_core::error::Result<()> {
self.released = true;
self.do_release().await
}
async fn do_release(&self) -> ccxt_core::error::Result<()> {
let fully_removed = self
.subscription_manager
.remove_subscription(&self.stream)
.await?;
if fully_removed {
if let Some(router) = &self.message_router {
router.unsubscribe(vec![self.stream.clone()]).await?;
}
}
Ok(())
}
}
impl Drop for SubscriptionHandle {
fn drop(&mut self) {
if self.released {
return;
}
let stream = self.stream.clone();
let subscription_manager = self.subscription_manager.clone();
let message_router = self.message_router.clone();
tokio::spawn(async move {
let fully_removed = subscription_manager
.remove_subscription(&stream)
.await
.unwrap_or(false);
if fully_removed {
if let Some(router) = &message_router {
let _ = router.unsubscribe(vec![stream]).await;
}
}
});
}
}