Skip to main content

ccxt_exchanges/binance/
ws_exchange_impl.rs

1//! WsExchange trait implementation for Binance
2//!
3//! This module implements the unified `WsExchange` trait from `ccxt-core` for Binance,
4//! providing real-time WebSocket data streaming capabilities.
5
6use async_trait::async_trait;
7use ccxt_core::{
8    error::{Error, Result},
9    types::{
10        Balance, Ohlcv, Order, OrderBook, Ticker, Timeframe, Trade, financial::Amount,
11        financial::Price,
12    },
13    ws_client::WsConnectionState,
14    ws_exchange::{MessageStream, WsExchange},
15};
16
17use rust_decimal::Decimal;
18use std::collections::HashMap;
19use std::sync::Arc;
20use tokio::sync::mpsc;
21
22use super::Binance;
23
24/// A simple stream wrapper that converts an mpsc receiver into a Stream
25struct ReceiverStream<T> {
26    receiver: mpsc::Receiver<T>,
27}
28
29impl<T> ReceiverStream<T> {
30    fn new(receiver: mpsc::Receiver<T>) -> Self {
31        Self { receiver }
32    }
33}
34
35impl<T> futures::Stream for ReceiverStream<T> {
36    type Item = T;
37
38    fn poll_next(
39        mut self: std::pin::Pin<&mut Self>,
40        cx: &mut std::task::Context<'_>,
41    ) -> std::task::Poll<Option<Self::Item>> {
42        self.receiver.poll_recv(cx)
43    }
44}
45
46#[async_trait]
47impl WsExchange for Binance {
48    // ==================== Connection Management ====================
49
50    async fn ws_connect(&self) -> Result<()> {
51        // Ensure at least one public connection is active
52        let _ = self.connection_manager.get_public_connection().await?;
53        Ok(())
54    }
55
56    async fn ws_disconnect(&self) -> Result<()> {
57        self.connection_manager.disconnect_all().await
58    }
59
60    fn ws_is_connected(&self) -> bool {
61        self.connection_manager.is_connected()
62    }
63
64    fn ws_state(&self) -> WsConnectionState {
65        if self.connection_manager.is_connected() {
66            WsConnectionState::Connected
67        } else {
68            WsConnectionState::Disconnected
69        }
70    }
71
72    // ==================== Public Data Streams ====================
73
74    async fn watch_ticker(&self, symbol: &str) -> Result<MessageStream<Ticker>> {
75        // Load markets to validate symbol
76        self.load_markets(false).await?;
77
78        // Get market info
79        let market = self.base.market(symbol).await?;
80        let binance_symbol = market.id.to_lowercase();
81        let stream = format!("{}@ticker", binance_symbol);
82
83        // Get shared public connection
84        let ws = self.connection_manager.get_public_connection().await?;
85
86        // Create subscription channel
87        let (tx, mut rx) = mpsc::channel(1024);
88
89        // Register with manager
90        ws.subscription_manager
91            .add_subscription(
92                stream.clone(),
93                symbol.to_string(),
94                super::ws::SubscriptionType::Ticker,
95                tx,
96            )
97            .await?;
98
99        // Send subscribe command
100        ws.message_router.subscribe(vec![stream]).await?;
101
102        // Create user channel
103        let (user_tx, user_rx) = mpsc::channel::<Result<Ticker>>(1024);
104
105        // Spawn parser task
106        let market_clone = market.clone();
107        tokio::spawn(async move {
108            while let Some(msg) = rx.recv().await {
109                match super::parser::parse_ws_ticker(&msg, Some(&market_clone)) {
110                    Ok(ticker) => {
111                        if user_tx.send(Ok(ticker)).await.is_err() {
112                            break;
113                        }
114                    }
115                    Err(e) => {
116                        let _ = user_tx.send(Err(e)).await;
117                    }
118                }
119            }
120        });
121
122        // Convert receiver to stream
123        let stream = ReceiverStream::new(user_rx);
124        Ok(Box::pin(stream))
125    }
126
127    async fn watch_tickers(&self, symbols: &[String]) -> Result<MessageStream<Vec<Ticker>>> {
128        // Load markets
129        self.load_markets(false).await?;
130
131        // Create aggregator channel
132        let (agg_tx, mut agg_rx) = mpsc::channel::<Ticker>(1024);
133
134        // Subscribe to all requested symbols (potentially across shards)
135        let mut markets = HashMap::new();
136        for symbol in symbols {
137            let market = self.base.market(symbol).await?;
138            let binance_symbol = market.id.to_lowercase();
139            let stream = format!("{}@ticker", binance_symbol);
140
141            markets.insert(binance_symbol.clone(), market);
142
143            // Get connection (might be different for each symbol if sharding active)
144            let ws = self.connection_manager.get_public_connection().await?;
145            let (tx, mut rx) = mpsc::channel(1024);
146
147            ws.subscription_manager
148                .add_subscription(
149                    stream.clone(),
150                    symbol.clone(),
151                    super::ws::SubscriptionType::Ticker,
152                    tx,
153                )
154                .await?;
155
156            ws.message_router.subscribe(vec![stream]).await?;
157
158            // Spawn parser task for this symbol
159            let agg_tx_clone = agg_tx.clone();
160            let market_clone = self.base.market(symbol).await?;
161
162            tokio::spawn(async move {
163                while let Some(msg) = rx.recv().await {
164                    if let Ok(ticker) = super::parser::parse_ws_ticker(&msg, Some(&market_clone)) {
165                        let _ = agg_tx_clone.send(ticker).await;
166                    }
167                }
168            });
169        }
170
171        drop(agg_tx);
172
173        // Create user channel
174        let (user_tx, user_rx) = mpsc::channel::<Result<Vec<Ticker>>>(1024);
175
176        // Spawn aggregator task
177        tokio::spawn(async move {
178            let mut tickers: HashMap<String, Ticker> = HashMap::new();
179
180            while let Some(ticker) = agg_rx.recv().await {
181                tickers.insert(ticker.symbol.clone(), ticker);
182
183                let ticker_vec: Vec<Ticker> = tickers.values().cloned().collect();
184                if user_tx.send(Ok(ticker_vec)).await.is_err() {
185                    break;
186                }
187            }
188        });
189
190        let stream = ReceiverStream::new(user_rx);
191        Ok(Box::pin(stream))
192    }
193
194    async fn watch_order_book(
195        &self,
196        symbol: &str,
197        limit: Option<u32>,
198    ) -> Result<MessageStream<OrderBook>> {
199        // Load markets
200        self.load_markets(false).await?;
201
202        // Get market info
203        let market = self.base.market(symbol).await?;
204        let binance_symbol = market.id.to_lowercase();
205
206        // Subscribe to orderbook stream
207        let levels = limit.unwrap_or(20);
208        let stream = format!("{}@depth{}@100ms", binance_symbol, levels);
209
210        // Get shared public connection
211        let ws = self.connection_manager.get_public_connection().await?;
212
213        // Create subscription channel
214        let (tx, mut rx) = mpsc::channel(1024);
215
216        // Register with manager
217        ws.subscription_manager
218            .add_subscription(
219                stream.clone(),
220                symbol.to_string(),
221                super::ws::SubscriptionType::OrderBook,
222                tx,
223            )
224            .await?;
225
226        // Send subscribe command
227        ws.message_router.subscribe(vec![stream]).await?;
228
229        // Create user channel
230        let (user_tx, user_rx) = mpsc::channel::<Result<OrderBook>>(1024);
231
232        // Spawn parser task
233        let symbol_clone = symbol.to_string();
234        tokio::spawn(async move {
235            while let Some(msg) = rx.recv().await {
236                match super::parser::parse_ws_orderbook(&msg, symbol_clone.clone()) {
237                    Ok(orderbook) => {
238                        if user_tx.send(Ok(orderbook)).await.is_err() {
239                            break;
240                        }
241                    }
242                    Err(e) => {
243                        let _ = user_tx.send(Err(e)).await;
244                    }
245                }
246            }
247        });
248
249        let stream = ReceiverStream::new(user_rx);
250        Ok(Box::pin(stream))
251    }
252
253    async fn watch_trades(&self, symbol: &str) -> Result<MessageStream<Vec<Trade>>> {
254        // Load markets
255        self.load_markets(false).await?;
256
257        // Get market info
258        let market = self.base.market(symbol).await?;
259        let binance_symbol = market.id.to_lowercase();
260        let stream = format!("{}@trade", binance_symbol);
261
262        // Get shared public connection
263        let ws = self.connection_manager.get_public_connection().await?;
264
265        // Create subscription channel
266        let (tx, mut rx) = mpsc::channel(1024);
267
268        // Register with manager
269        ws.subscription_manager
270            .add_subscription(
271                stream.clone(),
272                symbol.to_string(),
273                super::ws::SubscriptionType::Trades,
274                tx,
275            )
276            .await?;
277
278        // Send subscribe command
279        ws.message_router.subscribe(vec![stream]).await?;
280
281        // Create user channel
282        let (user_tx, user_rx) = mpsc::channel::<Result<Vec<Trade>>>(1024);
283
284        // Spawn parser task
285        let market_clone = market.clone();
286        tokio::spawn(async move {
287            while let Some(msg) = rx.recv().await {
288                match super::parser::parse_ws_trade(&msg, Some(&market_clone)) {
289                    Ok(trade) => {
290                        if user_tx.send(Ok(vec![trade])).await.is_err() {
291                            break;
292                        }
293                    }
294                    Err(e) => {
295                        let _ = user_tx.send(Err(e)).await;
296                    }
297                }
298            }
299        });
300
301        // Convert receiver to stream
302        let stream = ReceiverStream::new(user_rx);
303        Ok(Box::pin(stream))
304    }
305
306    async fn watch_ohlcv(
307        &self,
308        symbol: &str,
309        timeframe: Timeframe,
310    ) -> Result<MessageStream<Ohlcv>> {
311        // Load markets
312        self.load_markets(false).await?;
313
314        // Get market info
315        let market = self.base.market(symbol).await?;
316        let binance_symbol = market.id.to_lowercase();
317
318        // Convert timeframe to Binance format
319        let interval = timeframe.to_string();
320        let stream = format!("{}@kline_{}", binance_symbol, interval);
321
322        // Get shared public connection
323        let ws = self.connection_manager.get_public_connection().await?;
324
325        // Create subscription channel
326        let (tx, mut rx) = mpsc::channel(1024);
327
328        // Register with manager
329        ws.subscription_manager
330            .add_subscription(
331                stream.clone(),
332                symbol.to_string(),
333                super::ws::SubscriptionType::Kline(interval),
334                tx,
335            )
336            .await?;
337
338        // Send subscribe command
339        ws.message_router.subscribe(vec![stream]).await?;
340
341        // Create user channel
342        let (user_tx, user_rx) = mpsc::channel::<Result<Ohlcv>>(1024);
343
344        // Spawn parser task
345        tokio::spawn(async move {
346            while let Some(msg) = rx.recv().await {
347                // Parse OHLCV from kline message
348                match super::parser::parse_ws_ohlcv(&msg) {
349                    Ok(ohlcv_f64) => {
350                        // Convert OHLCV (f64) to Ohlcv (Decimal)
351                        let ohlcv = Ohlcv {
352                            timestamp: ohlcv_f64.timestamp,
353                            open: Price::from(
354                                Decimal::try_from(ohlcv_f64.open).unwrap_or_default(),
355                            ),
356                            high: Price::from(
357                                Decimal::try_from(ohlcv_f64.high).unwrap_or_default(),
358                            ),
359                            low: Price::from(Decimal::try_from(ohlcv_f64.low).unwrap_or_default()),
360                            close: Price::from(
361                                Decimal::try_from(ohlcv_f64.close).unwrap_or_default(),
362                            ),
363                            volume: Amount::from(
364                                Decimal::try_from(ohlcv_f64.volume).unwrap_or_default(),
365                            ),
366                        };
367                        if user_tx.send(Ok(ohlcv)).await.is_err() {
368                            break;
369                        }
370                    }
371                    Err(e) => {
372                        let _ = user_tx.send(Err(e)).await;
373                    }
374                }
375            }
376        });
377
378        let stream = ReceiverStream::new(user_rx);
379        Ok(Box::pin(stream))
380    }
381
382    // ==================== Private Data Streams ====================
383
384    async fn watch_balance(&self) -> Result<MessageStream<Balance>> {
385        self.base
386            .check_required_credentials()
387            .map_err(|_| Error::authentication("API credentials required for watch_balance"))?;
388
389        let binance_arc = Arc::new(self.clone());
390        let ws = self
391            .connection_manager
392            .get_private_connection(&binance_arc)
393            .await?;
394
395        let (tx, mut rx) = mpsc::channel(1024);
396
397        ws.subscription_manager
398            .add_subscription(
399                "!userData".to_string(),
400                "user".to_string(),
401                super::ws::SubscriptionType::Balance,
402                tx,
403            )
404            .await?;
405
406        let (user_tx, user_rx) = mpsc::channel::<Result<Balance>>(1024);
407        let account_type = self.options.default_type.to_string();
408        let balances_cache = ws.balances.clone();
409
410        tokio::spawn(async move {
411            while let Some(msg) = rx.recv().await {
412                if let Some(event_type) = msg.get("e").and_then(|e| e.as_str()) {
413                    if matches!(
414                        event_type,
415                        "balanceUpdate" | "outboundAccountPosition" | "ACCOUNT_UPDATE"
416                    ) {
417                        if let Ok(()) = super::ws::user_data::handle_balance_message(
418                            &msg,
419                            &account_type,
420                            &balances_cache,
421                        )
422                        .await
423                        {
424                            let balances = balances_cache.read().await;
425                            if let Some(balance) = balances.get(&account_type) {
426                                if user_tx.send(Ok(balance.clone())).await.is_err() {
427                                    break;
428                                }
429                            }
430                        }
431                    }
432                }
433            }
434        });
435
436        let stream = ReceiverStream::new(user_rx);
437        Ok(Box::pin(stream))
438    }
439
440    async fn watch_orders(&self, symbol: Option<&str>) -> Result<MessageStream<Order>> {
441        self.base
442            .check_required_credentials()
443            .map_err(|_| Error::authentication("API credentials required for watch_orders"))?;
444
445        let binance_arc = Arc::new(self.clone());
446        let ws = self
447            .connection_manager
448            .get_private_connection(&binance_arc)
449            .await?;
450
451        let (tx, mut rx) = mpsc::channel(1024);
452
453        ws.subscription_manager
454            .add_subscription(
455                "!userData".to_string(),
456                "user".to_string(),
457                super::ws::SubscriptionType::Orders,
458                tx,
459            )
460            .await?;
461
462        let (user_tx, user_rx) = mpsc::channel::<Result<Order>>(1024);
463        let symbol_filter = symbol.map(ToString::to_string);
464        let orders_cache = ws.orders.clone();
465
466        tokio::spawn(async move {
467            while let Some(msg) = rx.recv().await {
468                if let Some(data) = msg.as_object() {
469                    if let Some(event_type) = data.get("e").and_then(|e| e.as_str()) {
470                        if event_type == "executionReport" {
471                            let order = super::ws::user_data::parse_ws_order(data);
472
473                            {
474                                let mut orders = orders_cache.write().await;
475                                let symbol_orders = orders
476                                    .entry(order.symbol.clone())
477                                    .or_insert_with(HashMap::new);
478                                symbol_orders.insert(order.id.clone(), order.clone());
479                            }
480
481                            // Filter and send
482                            if let Some(s) = &symbol_filter {
483                                if &order.symbol != s {
484                                    continue;
485                                }
486                            }
487
488                            if user_tx.send(Ok(order)).await.is_err() {
489                                break;
490                            }
491                        }
492                    }
493                }
494            }
495        });
496
497        let stream = ReceiverStream::new(user_rx);
498        Ok(Box::pin(stream))
499    }
500
501    async fn watch_my_trades(&self, symbol: Option<&str>) -> Result<MessageStream<Trade>> {
502        self.base
503            .check_required_credentials()
504            .map_err(|_| Error::authentication("API credentials required for watch_my_trades"))?;
505
506        let binance_arc = Arc::new(self.clone());
507        let ws = self
508            .connection_manager
509            .get_private_connection(&binance_arc)
510            .await?;
511
512        let (tx, mut rx) = mpsc::channel(1024);
513
514        ws.subscription_manager
515            .add_subscription(
516                "!userData".to_string(),
517                "user".to_string(),
518                super::ws::SubscriptionType::MyTrades,
519                tx,
520            )
521            .await?;
522
523        let (user_tx, user_rx) = mpsc::channel::<Result<Trade>>(1024);
524        let symbol_filter = symbol.map(ToString::to_string);
525        let trades_cache = ws.my_trades.clone();
526
527        tokio::spawn(async move {
528            while let Some(msg) = rx.recv().await {
529                if let Some(event_type) = msg.get("e").and_then(|e| e.as_str()) {
530                    if event_type == "executionReport" {
531                        if let Ok(trade) = super::ws::user_data::parse_ws_trade(&msg) {
532                            {
533                                let mut trades = trades_cache.write().await;
534
535                                let symbol_trades = trades
536                                    .entry(trade.symbol.clone())
537                                    .or_insert_with(std::collections::VecDeque::new);
538                                symbol_trades.push_front(trade.clone());
539                                if symbol_trades.len() > 1000 {
540                                    symbol_trades.pop_back();
541                                }
542                            }
543
544                            // Filter and send
545                            if let Some(s) = &symbol_filter {
546                                if &trade.symbol != s {
547                                    continue;
548                                }
549                            }
550
551                            if user_tx.send(Ok(trade)).await.is_err() {
552                                break;
553                            }
554                        }
555                    }
556                }
557            }
558        });
559
560        let stream = ReceiverStream::new(user_rx);
561        Ok(Box::pin(stream))
562    }
563
564    // ==================== Subscription Management ====================
565
566    async fn subscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
567        // Create WebSocket and connect
568        let ws = self.connection_manager.get_public_connection().await?;
569
570        // Use the appropriate subscription method based on channel
571        match channel {
572            "ticker" => {
573                if let Some(sym) = symbol {
574                    let market = self.base.market(sym).await?;
575                    ws.subscribe_ticker(&market.id.to_lowercase()).await
576                } else {
577                    ws.subscribe_all_tickers().await
578                }
579            }
580            "trade" | "trades" => {
581                if let Some(sym) = symbol {
582                    let market = self.base.market(sym).await?;
583                    ws.subscribe_trades(&market.id.to_lowercase()).await
584                } else {
585                    Err(Error::invalid_request(
586                        "Symbol required for trades subscription",
587                    ))
588                }
589            }
590            _ => {
591                // For other channels, try generic subscription
592                Err(Error::invalid_request(format!(
593                    "Unknown channel: {}. Use specific watch_* methods instead.",
594                    channel
595                )))
596            }
597        }
598    }
599
600    async fn unsubscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
601        // Build stream name
602        let stream_name = if let Some(sym) = symbol {
603            // Load markets to get proper symbol format
604            self.load_markets(false).await?;
605            let market = self.base.market(sym).await?;
606            let binance_symbol = market.id.to_lowercase();
607            format!("{}@{}", binance_symbol, channel)
608        } else {
609            channel.to_string()
610        };
611
612        // Create WS to unsubscribe
613        let ws = self.connection_manager.get_public_connection().await?;
614        ws.unsubscribe(stream_name).await
615    }
616
617    fn subscriptions(&self) -> Vec<String> {
618        self.connection_manager.get_all_subscriptions()
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    #![allow(clippy::disallowed_methods)]
625    use super::*;
626    use ccxt_core::ExchangeConfig;
627
628    #[test]
629    fn test_ws_exchange_trait_object_safety() {
630        let config = ExchangeConfig::default();
631        let binance = Binance::new(config).unwrap();
632
633        // Test that we can create a WsExchange trait object
634        let _ws_exchange: &dyn WsExchange = &binance;
635
636        // Test connection state methods
637        assert!(!binance.ws_is_connected());
638        assert_eq!(binance.ws_state(), WsConnectionState::Disconnected);
639    }
640
641    #[test]
642    fn test_subscriptions_empty_by_default() {
643        let config = ExchangeConfig::default();
644        let binance = Binance::new(config).unwrap();
645
646        let subs = binance.subscriptions();
647        assert!(subs.is_empty());
648    }
649}