Skip to main content

ccxt_exchanges/binance/ws/
subscriptions.rs

1//! Subscription management for Binance WebSocket streams
2
3use serde_json::Value;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Instant;
7use tokio::sync::RwLock;
8
9/// Subscription types supported by the Binance WebSocket API.
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub enum SubscriptionType {
12    /// 24-hour ticker stream
13    Ticker,
14    /// Order book depth stream
15    OrderBook,
16    /// Real-time trade stream
17    Trades,
18    /// Kline (candlestick) stream with interval (e.g. "1m", "5m", "1h")
19    Kline(String),
20    /// Account balance stream
21    Balance,
22    /// Order update stream
23    Orders,
24    /// Position update stream
25    Positions,
26    /// Personal trade execution stream
27    MyTrades,
28    /// Mark price stream
29    MarkPrice,
30    /// Book ticker (best bid/ask) stream
31    BookTicker,
32}
33
34impl SubscriptionType {
35    /// Infers a subscription type from a stream name
36    pub fn from_stream(stream: &str) -> Option<Self> {
37        if stream.contains("@ticker") {
38            Some(Self::Ticker)
39        } else if stream.contains("@depth") {
40            Some(Self::OrderBook)
41        } else if stream.contains("@trade") || stream.contains("@aggTrade") {
42            Some(Self::Trades)
43        } else if stream.contains("@kline_") {
44            let parts: Vec<&str> = stream.split("@kline_").collect();
45            if parts.len() == 2 {
46                Some(Self::Kline(parts[1].to_string()))
47            } else {
48                None
49            }
50        } else if stream.contains("@markPrice") {
51            Some(Self::MarkPrice)
52        } else if stream.contains("@bookTicker") {
53            Some(Self::BookTicker)
54        } else {
55            None
56        }
57    }
58}
59
60/// Subscription metadata
61#[derive(Clone)]
62pub struct Subscription {
63    /// Stream name (e.g. "btcusdt@ticker")
64    pub stream: String,
65    /// Normalized trading symbol (e.g. "BTCUSDT")
66    pub symbol: String,
67    /// Subscription type descriptor
68    pub sub_type: SubscriptionType,
69    /// Timestamp when the subscription was created
70    pub subscribed_at: Instant,
71    /// Sender for forwarding WebSocket messages to consumers
72    pub sender: tokio::sync::mpsc::Sender<Value>,
73}
74
75impl Subscription {
76    /// Creates a new subscription with the provided parameters
77    pub fn new(
78        stream: String,
79        symbol: String,
80        sub_type: SubscriptionType,
81        sender: tokio::sync::mpsc::Sender<Value>,
82    ) -> Self {
83        Self {
84            stream,
85            symbol,
86            sub_type,
87            subscribed_at: Instant::now(),
88            sender,
89        }
90    }
91
92    /// Sends a message to the subscriber
93    pub fn send(&self, message: Value) -> bool {
94        // Use try_send to avoid blocking if channel is full (drop strategy)
95        self.sender.try_send(message).is_ok()
96    }
97}
98
99/// Subscription manager
100pub struct SubscriptionManager {
101    /// Mapping of `stream_name -> Subscription`
102    subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
103    /// Index by symbol: `symbol -> Vec<stream_name>`
104    symbol_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
105    /// Counter of active subscriptions
106    active_count: Arc<std::sync::atomic::AtomicUsize>,
107}
108
109impl SubscriptionManager {
110    /// Creates a new `SubscriptionManager`
111    pub fn new() -> Self {
112        Self {
113            subscriptions: Arc::new(RwLock::new(HashMap::new())),
114            symbol_index: Arc::new(RwLock::new(HashMap::new())),
115            active_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
116        }
117    }
118
119    /// Adds a subscription to the manager
120    pub async fn add_subscription(
121        &self,
122        stream: String,
123        symbol: String,
124        sub_type: SubscriptionType,
125        sender: tokio::sync::mpsc::Sender<Value>,
126    ) -> ccxt_core::error::Result<()> {
127        let subscription = Subscription::new(stream.clone(), symbol.clone(), sub_type, sender);
128
129        let mut subs = self.subscriptions.write().await;
130        subs.insert(stream.clone(), subscription);
131
132        let mut index = self.symbol_index.write().await;
133        index.entry(symbol).or_insert_with(Vec::new).push(stream);
134
135        self.active_count
136            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
137
138        Ok(())
139    }
140
141    /// Removes a subscription by stream name
142    pub async fn remove_subscription(&self, stream: &str) -> ccxt_core::error::Result<()> {
143        let mut subs = self.subscriptions.write().await;
144
145        if let Some(subscription) = subs.remove(stream) {
146            let mut index = self.symbol_index.write().await;
147            if let Some(streams) = index.get_mut(&subscription.symbol) {
148                streams.retain(|s| s != stream);
149                if streams.is_empty() {
150                    index.remove(&subscription.symbol);
151                }
152            }
153
154            self.active_count
155                .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
156        }
157
158        Ok(())
159    }
160
161    /// Retrieves a subscription by stream name
162    pub async fn get_subscription(&self, stream: &str) -> Option<Subscription> {
163        let subs = self.subscriptions.read().await;
164        subs.get(stream).cloned()
165    }
166
167    /// Checks whether a subscription exists for the given stream
168    pub async fn has_subscription(&self, stream: &str) -> bool {
169        let subs = self.subscriptions.read().await;
170        subs.contains_key(stream)
171    }
172
173    /// Returns all registered subscriptions
174    pub async fn get_all_subscriptions(&self) -> Vec<Subscription> {
175        let subs = self.subscriptions.read().await;
176        subs.values().cloned().collect()
177    }
178
179    /// Returns all registered subscriptions synchronously (non-blocking)
180    pub fn get_all_subscriptions_sync(&self) -> Vec<Subscription> {
181        if let Ok(subs) = self.subscriptions.try_read() {
182            subs.values().cloned().collect()
183        } else {
184            Vec::new()
185        }
186    }
187
188    /// Returns all subscriptions associated with a symbol
189    pub async fn get_subscriptions_by_symbol(&self, symbol: &str) -> Vec<Subscription> {
190        let index = self.symbol_index.read().await;
191        let subs = self.subscriptions.read().await;
192
193        if let Some(streams) = index.get(symbol) {
194            streams
195                .iter()
196                .filter_map(|stream| subs.get(stream).cloned())
197                .collect()
198        } else {
199            Vec::new()
200        }
201    }
202
203    /// Returns the number of active subscriptions
204    pub fn active_count(&self) -> usize {
205        self.active_count.load(std::sync::atomic::Ordering::SeqCst)
206    }
207
208    /// Removes all subscriptions and clears indexes
209    pub async fn clear(&self) {
210        let mut subs = self.subscriptions.write().await;
211        let mut index = self.symbol_index.write().await;
212
213        subs.clear();
214        index.clear();
215        self.active_count
216            .store(0, std::sync::atomic::Ordering::SeqCst);
217    }
218
219    /// Sends a message to subscribers of a specific stream
220    pub async fn send_to_stream(&self, stream: &str, message: Value) -> bool {
221        let subs = self.subscriptions.read().await;
222        if let Some(subscription) = subs.get(stream) {
223            if subscription.send(message) {
224                return true;
225            }
226        } else {
227            return false;
228        }
229        drop(subs);
230
231        let _ = self.remove_subscription(stream).await;
232        false
233    }
234
235    /// Sends a message to all subscribers of a symbol
236    pub async fn send_to_symbol(&self, symbol: &str, message: &Value) -> usize {
237        let index = self.symbol_index.read().await;
238        let subs = self.subscriptions.read().await;
239
240        let mut sent_count = 0;
241        let mut streams_to_remove = Vec::new();
242
243        if let Some(streams) = index.get(symbol) {
244            for stream in streams {
245                if let Some(subscription) = subs.get(stream) {
246                    if subscription.send(message.clone()) {
247                        sent_count += 1;
248                    } else {
249                        streams_to_remove.push(stream.clone());
250                    }
251                }
252            }
253        }
254        drop(subs);
255        drop(index);
256
257        for stream in streams_to_remove {
258            let _ = self.remove_subscription(&stream).await;
259        }
260
261        sent_count
262    }
263
264    /// Returns a list of all active stream names for resubscription
265    pub async fn get_active_streams(&self) -> Vec<String> {
266        let subs = self.subscriptions.read().await;
267        subs.keys().cloned().collect()
268    }
269}
270
271impl Default for SubscriptionManager {
272    fn default() -> Self {
273        Self::new()
274    }
275}
276
277/// Reconnect configuration
278#[derive(Debug, Clone)]
279pub struct ReconnectConfig {
280    /// Enables or disables automatic reconnection
281    pub enabled: bool,
282
283    /// Initial reconnection delay in milliseconds
284    pub initial_delay_ms: u64,
285
286    /// Maximum reconnection delay in milliseconds
287    pub max_delay_ms: u64,
288
289    /// Exponential backoff multiplier
290    pub backoff_multiplier: f64,
291
292    /// Maximum number of reconnection attempts (0 means unlimited)
293    pub max_attempts: usize,
294}
295
296impl Default for ReconnectConfig {
297    fn default() -> Self {
298        Self {
299            enabled: true,
300            initial_delay_ms: 1000,
301            max_delay_ms: 30000,
302            backoff_multiplier: 2.0,
303            max_attempts: 0,
304        }
305    }
306}
307
308impl ReconnectConfig {
309    /// Calculates the reconnection delay
310    pub fn calculate_delay(&self, attempt: usize) -> u64 {
311        let delay = (self.initial_delay_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
312        delay.min(self.max_delay_ms as f64) as u64
313    }
314
315    /// Determines whether another reconnection attempt should be made
316    pub fn should_retry(&self, attempt: usize) -> bool {
317        self.enabled && (self.max_attempts == 0 || attempt < self.max_attempts)
318    }
319}