Skip to main content

ccxt_exchanges/binance/ws/
subscriptions.rs

1//! Subscription management for Binance WebSocket streams
2
3use serde_json::Value;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::time::Instant;
8use tokio::sync::RwLock;
9
10/// Subscription types supported by the Binance WebSocket API.
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub enum SubscriptionType {
13    /// 24-hour ticker stream
14    Ticker,
15    /// Order book depth stream
16    OrderBook,
17    /// Real-time trade stream
18    Trades,
19    /// Kline (candlestick) stream with interval (e.g. "1m", "5m", "1h")
20    Kline(String),
21    /// Account balance stream
22    Balance,
23    /// Order update stream
24    Orders,
25    /// Position update stream
26    Positions,
27    /// Personal trade execution stream
28    MyTrades,
29    /// Mark price stream
30    MarkPrice,
31    /// Book ticker (best bid/ask) stream
32    BookTicker,
33}
34
35impl SubscriptionType {
36    /// Infers a subscription type from a stream name
37    pub fn from_stream(stream: &str) -> Option<Self> {
38        if stream.contains("@ticker") {
39            Some(Self::Ticker)
40        } else if stream.contains("@depth") {
41            Some(Self::OrderBook)
42        } else if stream.contains("@trade") || stream.contains("@aggTrade") {
43            Some(Self::Trades)
44        } else if stream.contains("@kline_") {
45            let parts: Vec<&str> = stream.split("@kline_").collect();
46            if parts.len() == 2 {
47                Some(Self::Kline(parts[1].to_string()))
48            } else {
49                None
50            }
51        } else if stream.contains("@markPrice") {
52            Some(Self::MarkPrice)
53        } else if stream.contains("@bookTicker") {
54            Some(Self::BookTicker)
55        } else {
56            None
57        }
58    }
59}
60
61/// Subscription metadata
62#[derive(Clone)]
63pub struct Subscription {
64    /// Stream name (e.g. "btcusdt@ticker")
65    pub stream: String,
66    /// Normalized trading symbol (e.g. "BTCUSDT")
67    pub symbol: String,
68    /// Subscription type descriptor
69    pub sub_type: SubscriptionType,
70    /// Timestamp when the subscription was created
71    pub subscribed_at: Instant,
72    /// Senders for forwarding WebSocket messages to consumers (supports multiple subscribers)
73    senders: Arc<std::sync::Mutex<Vec<tokio::sync::mpsc::Sender<Value>>>>,
74    /// Reference count for this subscription (how many handles are active)
75    ref_count: Arc<AtomicUsize>,
76}
77
78impl Subscription {
79    /// Creates a new subscription with the provided parameters
80    pub fn new(
81        stream: String,
82        symbol: String,
83        sub_type: SubscriptionType,
84        sender: tokio::sync::mpsc::Sender<Value>,
85    ) -> Self {
86        Self {
87            stream,
88            symbol,
89            sub_type,
90            subscribed_at: Instant::now(),
91            senders: Arc::new(std::sync::Mutex::new(vec![sender])),
92            ref_count: Arc::new(AtomicUsize::new(1)),
93        }
94    }
95
96    /// Adds a new sender to this subscription for multi-subscriber support.
97    pub fn add_sender(&self, sender: tokio::sync::mpsc::Sender<Value>) {
98        if let Ok(mut senders) = self.senders.lock() {
99            senders.push(sender);
100        }
101    }
102
103    /// Sends a message to all subscribers, removing closed senders.
104    ///
105    /// Returns `true` if at least one sender successfully received the message.
106    pub fn send(&self, message: Value) -> bool {
107        if let Ok(mut senders) = self.senders.lock() {
108            let mut any_sent = false;
109            senders.retain(|sender| {
110                match sender.try_send(message.clone()) {
111                    Ok(()) => {
112                        any_sent = true;
113                        true // keep this sender
114                    }
115                    Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
116                        // Channel full - keep sender but count as not sent for backpressure
117                        true
118                    }
119                    Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
120                        // Receiver dropped - remove this sender
121                        false
122                    }
123                }
124            });
125            any_sent || !senders.is_empty()
126        } else {
127            false
128        }
129    }
130
131    /// Increments the reference count and returns the new value
132    pub fn add_ref(&self) -> usize {
133        self.ref_count.fetch_add(1, Ordering::SeqCst) + 1
134    }
135
136    /// Decrements the reference count and returns the new value
137    pub fn remove_ref(&self) -> usize {
138        let prev = self.ref_count.fetch_sub(1, Ordering::SeqCst);
139        prev.saturating_sub(1)
140    }
141
142    /// Returns the current reference count
143    pub fn ref_count(&self) -> usize {
144        self.ref_count.load(Ordering::SeqCst)
145    }
146}
147
148/// Subscription manager
149pub struct SubscriptionManager {
150    /// Mapping of `stream_name -> Subscription`
151    subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
152    /// Index by symbol: `symbol -> Vec<stream_name>`
153    symbol_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
154    /// Counter of active subscriptions
155    active_count: Arc<std::sync::atomic::AtomicUsize>,
156}
157
158impl SubscriptionManager {
159    /// Creates a new `SubscriptionManager`
160    pub fn new() -> Self {
161        Self {
162            subscriptions: Arc::new(RwLock::new(HashMap::new())),
163            symbol_index: Arc::new(RwLock::new(HashMap::new())),
164            active_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
165        }
166    }
167
168    /// Adds a subscription to the manager.
169    ///
170    /// If a subscription for the same stream already exists, increments the reference count
171    /// instead of creating a duplicate subscription.
172    ///
173    /// Returns `true` if a new subscription was created, `false` if an existing one was reused.
174    pub async fn add_subscription(
175        &self,
176        stream: String,
177        symbol: String,
178        sub_type: SubscriptionType,
179        sender: tokio::sync::mpsc::Sender<Value>,
180    ) -> ccxt_core::error::Result<bool> {
181        let mut subs = self.subscriptions.write().await;
182
183        // Check if subscription already exists - add sender and increment ref count
184        if let Some(existing) = subs.get(&stream) {
185            existing.add_sender(sender);
186            existing.add_ref();
187            return Ok(false);
188        }
189
190        let subscription = Subscription::new(stream.clone(), symbol.clone(), sub_type, sender);
191
192        subs.insert(stream.clone(), subscription);
193
194        let mut index = self.symbol_index.write().await;
195        index.entry(symbol).or_insert_with(Vec::new).push(stream);
196
197        self.active_count
198            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
199
200        Ok(true)
201    }
202
203    /// Removes a subscription by stream name.
204    ///
205    /// Only actually removes the subscription when the reference count reaches zero.
206    /// Returns `true` if the subscription was fully removed (ref count hit zero).
207    pub async fn remove_subscription(&self, stream: &str) -> ccxt_core::error::Result<bool> {
208        let mut subs = self.subscriptions.write().await;
209
210        if let Some(subscription) = subs.get(stream) {
211            let remaining = subscription.remove_ref();
212            if remaining > 0 {
213                // Still has active references, don't remove
214                return Ok(false);
215            }
216
217            // Ref count is zero, remove the subscription
218            let Some(subscription) = subs.remove(stream) else {
219                return Ok(false);
220            };
221            let mut index = self.symbol_index.write().await;
222            if let Some(streams) = index.get_mut(&subscription.symbol) {
223                streams.retain(|s| s != stream);
224                if streams.is_empty() {
225                    index.remove(&subscription.symbol);
226                }
227            }
228
229            self.active_count
230                .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
231            Ok(true)
232        } else {
233            Ok(false)
234        }
235    }
236
237    /// Retrieves a subscription by stream name
238    pub async fn get_subscription(&self, stream: &str) -> Option<Subscription> {
239        let subs = self.subscriptions.read().await;
240        subs.get(stream).cloned()
241    }
242
243    /// Checks whether a subscription exists for the given stream
244    pub async fn has_subscription(&self, stream: &str) -> bool {
245        let subs = self.subscriptions.read().await;
246        subs.contains_key(stream)
247    }
248
249    /// Returns all registered subscriptions
250    pub async fn get_all_subscriptions(&self) -> Vec<Subscription> {
251        let subs = self.subscriptions.read().await;
252        subs.values().cloned().collect()
253    }
254
255    /// Returns all registered subscriptions synchronously (non-blocking)
256    pub fn get_all_subscriptions_sync(&self) -> Vec<Subscription> {
257        if let Ok(subs) = self.subscriptions.try_read() {
258            subs.values().cloned().collect()
259        } else {
260            Vec::new()
261        }
262    }
263
264    /// Returns all subscriptions associated with a symbol
265    pub async fn get_subscriptions_by_symbol(&self, symbol: &str) -> Vec<Subscription> {
266        let index = self.symbol_index.read().await;
267        let subs = self.subscriptions.read().await;
268
269        if let Some(streams) = index.get(symbol) {
270            streams
271                .iter()
272                .filter_map(|stream| subs.get(stream).cloned())
273                .collect()
274        } else {
275            Vec::new()
276        }
277    }
278
279    /// Returns the number of active subscriptions
280    pub fn active_count(&self) -> usize {
281        self.active_count.load(std::sync::atomic::Ordering::SeqCst)
282    }
283
284    /// Removes all subscriptions and clears indexes
285    pub async fn clear(&self) {
286        let mut subs = self.subscriptions.write().await;
287        let mut index = self.symbol_index.write().await;
288
289        subs.clear();
290        index.clear();
291        self.active_count
292            .store(0, std::sync::atomic::Ordering::SeqCst);
293    }
294
295    /// Sends a message to subscribers of a specific stream
296    pub async fn send_to_stream(&self, stream: &str, message: Value) -> bool {
297        let subs = self.subscriptions.read().await;
298        if let Some(subscription) = subs.get(stream) {
299            if subscription.send(message) {
300                return true;
301            }
302        } else {
303            return false;
304        }
305        drop(subs);
306
307        let _ = self.remove_subscription(stream).await;
308        false
309    }
310
311    /// Sends a message to all subscribers of a symbol
312    pub async fn send_to_symbol(&self, symbol: &str, message: &Value) -> usize {
313        let index = self.symbol_index.read().await;
314        let subs = self.subscriptions.read().await;
315
316        let mut sent_count = 0;
317        let mut streams_to_remove = Vec::new();
318
319        if let Some(streams) = index.get(symbol) {
320            for stream in streams {
321                if let Some(subscription) = subs.get(stream) {
322                    if subscription.send(message.clone()) {
323                        sent_count += 1;
324                    } else {
325                        streams_to_remove.push(stream.clone());
326                    }
327                }
328            }
329        }
330        drop(subs);
331        drop(index);
332
333        for stream in streams_to_remove {
334            let _ = self.remove_subscription(&stream).await;
335        }
336
337        sent_count
338    }
339
340    /// Returns a list of all active stream names for resubscription
341    pub async fn get_active_streams(&self) -> Vec<String> {
342        let subs = self.subscriptions.read().await;
343        subs.keys().cloned().collect()
344    }
345}
346
347impl Default for SubscriptionManager {
348    fn default() -> Self {
349        Self::new()
350    }
351}
352
353/// Reconnect configuration
354#[derive(Debug, Clone)]
355pub struct ReconnectConfig {
356    /// Enables or disables automatic reconnection
357    pub enabled: bool,
358
359    /// Initial reconnection delay in milliseconds
360    pub initial_delay_ms: u64,
361
362    /// Maximum reconnection delay in milliseconds
363    pub max_delay_ms: u64,
364
365    /// Exponential backoff multiplier
366    pub backoff_multiplier: f64,
367
368    /// Maximum number of reconnection attempts (0 means unlimited)
369    pub max_attempts: usize,
370}
371
372impl Default for ReconnectConfig {
373    fn default() -> Self {
374        Self {
375            enabled: true,
376            initial_delay_ms: 1000,
377            max_delay_ms: 30000,
378            backoff_multiplier: 2.0,
379            max_attempts: 0,
380        }
381    }
382}
383
384impl ReconnectConfig {
385    /// Calculates the reconnection delay
386    pub fn calculate_delay(&self, attempt: usize) -> u64 {
387        let delay = (self.initial_delay_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
388        delay.min(self.max_delay_ms as f64) as u64
389    }
390
391    /// Determines whether another reconnection attempt should be made
392    pub fn should_retry(&self, attempt: usize) -> bool {
393        self.enabled && (self.max_attempts == 0 || attempt < self.max_attempts)
394    }
395}
396
397/// A handle to an active subscription that automatically unsubscribes when dropped.
398///
399/// `SubscriptionHandle` implements a RAII pattern for WebSocket subscriptions.
400/// When the handle is dropped, it decrements the reference count for the stream
401/// and triggers an UNSUBSCRIBE command if no more handles reference the stream.
402///
403/// # Example
404///
405/// ```rust,ignore
406/// let handle = subscription_manager.subscribe("btcusdt@ticker", ...).await?;
407/// // ... use the subscription ...
408/// drop(handle); // Automatically unsubscribes when dropped
409/// ```
410pub struct SubscriptionHandle {
411    /// The stream name this handle is associated with
412    stream: String,
413    /// Reference to the subscription manager for cleanup
414    subscription_manager: Arc<SubscriptionManager>,
415    /// Reference to the message router for sending UNSUBSCRIBE
416    message_router: Option<Arc<crate::binance::ws::handlers::MessageRouter>>,
417    /// Whether this handle has already been released
418    released: bool,
419}
420
421impl SubscriptionHandle {
422    /// Creates a new subscription handle.
423    pub fn new(
424        stream: String,
425        subscription_manager: Arc<SubscriptionManager>,
426        message_router: Option<Arc<crate::binance::ws::handlers::MessageRouter>>,
427    ) -> Self {
428        Self {
429            stream,
430            subscription_manager,
431            message_router,
432            released: false,
433        }
434    }
435
436    /// Returns the stream name associated with this handle.
437    pub fn stream(&self) -> &str {
438        &self.stream
439    }
440
441    /// Manually releases the subscription handle.
442    ///
443    /// This is equivalent to dropping the handle, but allows for async cleanup.
444    /// After calling this method, the Drop implementation will be a no-op.
445    pub async fn release(mut self) -> ccxt_core::error::Result<()> {
446        self.released = true;
447        self.do_release().await
448    }
449
450    /// Internal release logic
451    async fn do_release(&self) -> ccxt_core::error::Result<()> {
452        let fully_removed = self
453            .subscription_manager
454            .remove_subscription(&self.stream)
455            .await?;
456
457        if fully_removed {
458            if let Some(router) = &self.message_router {
459                router.unsubscribe(vec![self.stream.clone()]).await?;
460            }
461        }
462
463        Ok(())
464    }
465}
466
467impl Drop for SubscriptionHandle {
468    fn drop(&mut self) {
469        if self.released {
470            return;
471        }
472
473        // We can't do async work in Drop, so we spawn a task to handle cleanup
474        let stream = self.stream.clone();
475        let subscription_manager = self.subscription_manager.clone();
476        let message_router = self.message_router.clone();
477
478        tokio::spawn(async move {
479            let fully_removed = subscription_manager
480                .remove_subscription(&stream)
481                .await
482                .unwrap_or(false);
483
484            if fully_removed {
485                if let Some(router) = &message_router {
486                    let _ = router.unsubscribe(vec![stream]).await;
487                }
488            }
489        });
490    }
491}