Skip to main content

ccxt_exchanges/binance/
ws_exchange_impl.rs

1//! WsExchange trait implementation for Binance
2//!
3//! This module implements the unified `WsExchange` trait from `ccxt-core` for Binance,
4//! providing real-time WebSocket data streaming capabilities.
5
6use async_trait::async_trait;
7use ccxt_core::{
8    error::{Error, Result},
9    types::{
10        Balance, MarketType, Ohlcv, Order, OrderBook, Ticker, Timeframe, Trade, financial::Amount,
11        financial::Price,
12    },
13    ws_client::WsConnectionState,
14    ws_exchange::{MessageStream, WsExchange},
15};
16
17use rust_decimal::Decimal;
18use std::collections::HashMap;
19use std::sync::Arc;
20use tokio::sync::mpsc;
21
22use super::Binance;
23
24/// A simple stream wrapper that converts an mpsc receiver into a Stream
25struct ReceiverStream<T> {
26    receiver: mpsc::Receiver<T>,
27}
28
29impl<T> ReceiverStream<T> {
30    fn new(receiver: mpsc::Receiver<T>) -> Self {
31        Self { receiver }
32    }
33}
34
35impl<T> futures::Stream for ReceiverStream<T> {
36    type Item = T;
37
38    fn poll_next(
39        mut self: std::pin::Pin<&mut Self>,
40        cx: &mut std::task::Context<'_>,
41    ) -> std::task::Poll<Option<Self::Item>> {
42        self.receiver.poll_recv(cx)
43    }
44}
45
46#[async_trait]
47impl WsExchange for Binance {
48    // ==================== Connection Management ====================
49
50    async fn ws_connect(&self) -> Result<()> {
51        // Ensure at least one public connection is active for the default market type
52        let default_market_type = MarketType::from(self.options.default_type);
53        let _ = self
54            .connection_manager
55            .get_public_connection(default_market_type)
56            .await?;
57        Ok(())
58    }
59
60    async fn ws_disconnect(&self) -> Result<()> {
61        self.connection_manager.disconnect_all().await
62    }
63
64    fn ws_is_connected(&self) -> bool {
65        self.connection_manager.is_connected()
66    }
67
68    fn ws_state(&self) -> WsConnectionState {
69        if self.connection_manager.is_connected() {
70            WsConnectionState::Connected
71        } else {
72            WsConnectionState::Disconnected
73        }
74    }
75
76    // ==================== Public Data Streams ====================
77
78    async fn watch_ticker(&self, symbol: &str) -> Result<MessageStream<Ticker>> {
79        // Load markets to validate symbol
80        self.load_markets(false).await?;
81
82        // Get market info
83        let market = self.base.market(symbol).await?;
84        let binance_symbol = market.id.to_lowercase();
85        let stream = format!("{}@ticker", binance_symbol);
86
87        // Get shared public connection for this market type
88        let ws = self
89            .connection_manager
90            .get_public_connection(market.market_type)
91            .await?;
92
93        // Create subscription channel
94        let (tx, mut rx) = mpsc::channel(1024);
95
96        // Register with manager
97        let is_new = ws
98            .subscription_manager
99            .add_subscription(
100                stream.clone(),
101                symbol.to_string(),
102                super::ws::SubscriptionType::Ticker,
103                tx,
104            )
105            .await?;
106
107        // Only send subscribe command if this is a new subscription
108        if is_new {
109            ws.message_router.subscribe(vec![stream]).await?;
110        }
111
112        // Create user channel
113        let (user_tx, user_rx) = mpsc::channel::<Result<Ticker>>(1024);
114
115        // Spawn parser task
116        let market_clone = market.clone();
117        tokio::spawn(async move {
118            while let Some(msg) = rx.recv().await {
119                match super::parser::parse_ws_ticker(&msg, Some(&market_clone)) {
120                    Ok(ticker) => {
121                        if user_tx.send(Ok(ticker)).await.is_err() {
122                            break;
123                        }
124                    }
125                    Err(e) => {
126                        let _ = user_tx.send(Err(e)).await;
127                    }
128                }
129            }
130        });
131
132        // Convert receiver to stream
133        let stream = ReceiverStream::new(user_rx);
134        Ok(Box::pin(stream))
135    }
136
137    async fn watch_tickers(&self, symbols: &[String]) -> Result<MessageStream<Vec<Ticker>>> {
138        // Load markets
139        self.load_markets(false).await?;
140
141        // Create aggregator channel
142        let (agg_tx, mut agg_rx) = mpsc::channel::<Ticker>(1024);
143
144        // Subscribe to all requested symbols (potentially across shards)
145        let mut markets = HashMap::new();
146        for symbol in symbols {
147            let market = self.base.market(symbol).await?;
148            let binance_symbol = market.id.to_lowercase();
149            let stream = format!("{}@ticker", binance_symbol);
150
151            markets.insert(binance_symbol.clone(), market);
152
153            // Get connection (might be different for each symbol if sharding active)
154            let Some(market_ref) = markets.get(&binance_symbol) else {
155                continue;
156            };
157            let ws = self
158                .connection_manager
159                .get_public_connection(market_ref.market_type)
160                .await?;
161            let (tx, mut rx) = mpsc::channel(1024);
162
163            let is_new = ws
164                .subscription_manager
165                .add_subscription(
166                    stream.clone(),
167                    symbol.clone(),
168                    super::ws::SubscriptionType::Ticker,
169                    tx,
170                )
171                .await?;
172
173            if is_new {
174                ws.message_router.subscribe(vec![stream]).await?;
175            }
176
177            // Spawn parser task for this symbol
178            let agg_tx_clone = agg_tx.clone();
179            let market_clone = self.base.market(symbol).await?;
180
181            tokio::spawn(async move {
182                while let Some(msg) = rx.recv().await {
183                    if let Ok(ticker) = super::parser::parse_ws_ticker(&msg, Some(&market_clone)) {
184                        let _ = agg_tx_clone.send(ticker).await;
185                    }
186                }
187            });
188        }
189
190        drop(agg_tx);
191
192        // Create user channel
193        let (user_tx, user_rx) = mpsc::channel::<Result<Vec<Ticker>>>(1024);
194
195        // Spawn aggregator task
196        tokio::spawn(async move {
197            let mut tickers: HashMap<String, Ticker> = HashMap::new();
198
199            while let Some(ticker) = agg_rx.recv().await {
200                tickers.insert(ticker.symbol.clone(), ticker);
201
202                let ticker_vec: Vec<Ticker> = tickers.values().cloned().collect();
203                if user_tx.send(Ok(ticker_vec)).await.is_err() {
204                    break;
205                }
206            }
207        });
208
209        let stream = ReceiverStream::new(user_rx);
210        Ok(Box::pin(stream))
211    }
212
213    async fn watch_order_book(
214        &self,
215        symbol: &str,
216        limit: Option<u32>,
217    ) -> Result<MessageStream<OrderBook>> {
218        // Validate depth limit if provided
219        const VALID_WS_DEPTH_LIMITS: &[u32] = &[5, 10, 20];
220        if let Some(l) = limit {
221            if !VALID_WS_DEPTH_LIMITS.contains(&l) {
222                return Err(Error::invalid_request(format!(
223                    "Invalid WebSocket depth limit: {}. Valid values: {:?}",
224                    l, VALID_WS_DEPTH_LIMITS
225                )));
226            }
227        }
228
229        // Load markets
230        self.load_markets(false).await?;
231
232        // Get market info
233        let market = self.base.market(symbol).await?;
234        let binance_symbol = market.id.to_lowercase();
235
236        // Choose stream based on limit:
237        // - With limit (5/10/20): use @depth{N}@100ms for partial book snapshots
238        // - Without limit: use @depth@100ms for incremental diff updates
239        // When no limit is specified, default to @depth20@100ms (partial book snapshots)
240        // instead of @depth@100ms (diff stream) which requires snapshot+delta management.
241        // This is consistent with CCXT Python behavior.
242        let stream = if let Some(l) = limit {
243            format!("{}@depth{}@100ms", binance_symbol, l)
244        } else {
245            format!("{}@depth20@100ms", binance_symbol)
246        };
247
248        // Get shared public connection for this market type
249        let ws = self
250            .connection_manager
251            .get_public_connection(market.market_type)
252            .await?;
253
254        // Create subscription channel
255        let (tx, mut rx) = mpsc::channel(1024);
256
257        // Register with manager
258        let is_new = ws
259            .subscription_manager
260            .add_subscription(
261                stream.clone(),
262                symbol.to_string(),
263                super::ws::SubscriptionType::OrderBook,
264                tx,
265            )
266            .await?;
267
268        // Only send subscribe command if this is a new subscription
269        if is_new {
270            ws.message_router.subscribe(vec![stream]).await?;
271        }
272
273        // Create user channel
274        let (user_tx, user_rx) = mpsc::channel::<Result<OrderBook>>(1024);
275
276        // Spawn parser task
277        let symbol_clone = symbol.to_string();
278        tokio::spawn(async move {
279            while let Some(msg) = rx.recv().await {
280                match super::parser::parse_ws_orderbook(&msg, symbol_clone.clone()) {
281                    Ok(orderbook) => {
282                        if user_tx.send(Ok(orderbook)).await.is_err() {
283                            break;
284                        }
285                    }
286                    Err(e) => {
287                        let _ = user_tx.send(Err(e)).await;
288                    }
289                }
290            }
291        });
292
293        let stream = ReceiverStream::new(user_rx);
294        Ok(Box::pin(stream))
295    }
296
297    async fn watch_trades(&self, symbol: &str) -> Result<MessageStream<Vec<Trade>>> {
298        // Load markets
299        self.load_markets(false).await?;
300
301        // Get market info
302        let market = self.base.market(symbol).await?;
303        let binance_symbol = market.id.to_lowercase();
304        let stream = format!("{}@trade", binance_symbol);
305
306        // Get shared public connection for this market type
307        let ws = self
308            .connection_manager
309            .get_public_connection(market.market_type)
310            .await?;
311
312        // Create subscription channel
313        let (tx, mut rx) = mpsc::channel(1024);
314
315        // Register with manager
316        let is_new = ws
317            .subscription_manager
318            .add_subscription(
319                stream.clone(),
320                symbol.to_string(),
321                super::ws::SubscriptionType::Trades,
322                tx,
323            )
324            .await?;
325
326        // Only send subscribe command if this is a new subscription
327        if is_new {
328            ws.message_router.subscribe(vec![stream]).await?;
329        }
330
331        // Create user channel
332        let (user_tx, user_rx) = mpsc::channel::<Result<Vec<Trade>>>(1024);
333
334        // Spawn parser task
335        let market_clone = market.clone();
336        tokio::spawn(async move {
337            while let Some(msg) = rx.recv().await {
338                match super::parser::parse_ws_trade(&msg, Some(&market_clone)) {
339                    Ok(trade) => {
340                        if user_tx.send(Ok(vec![trade])).await.is_err() {
341                            break;
342                        }
343                    }
344                    Err(e) => {
345                        let _ = user_tx.send(Err(e)).await;
346                    }
347                }
348            }
349        });
350
351        // Convert receiver to stream
352        let stream = ReceiverStream::new(user_rx);
353        Ok(Box::pin(stream))
354    }
355
356    async fn watch_ohlcv(
357        &self,
358        symbol: &str,
359        timeframe: Timeframe,
360    ) -> Result<MessageStream<Ohlcv>> {
361        // Load markets
362        self.load_markets(false).await?;
363
364        // Get market info
365        let market = self.base.market(symbol).await?;
366        let binance_symbol = market.id.to_lowercase();
367
368        // Convert timeframe to Binance format
369        let interval = timeframe.to_string();
370        let stream = format!("{}@kline_{}", binance_symbol, interval);
371
372        // Get shared public connection for this market type
373        let ws = self
374            .connection_manager
375            .get_public_connection(market.market_type)
376            .await?;
377
378        // Create subscription channel
379        let (tx, mut rx) = mpsc::channel(1024);
380
381        // Register with manager
382        let is_new = ws
383            .subscription_manager
384            .add_subscription(
385                stream.clone(),
386                symbol.to_string(),
387                super::ws::SubscriptionType::Kline(interval),
388                tx,
389            )
390            .await?;
391
392        // Only send subscribe command if this is a new subscription
393        if is_new {
394            ws.message_router.subscribe(vec![stream]).await?;
395        }
396
397        // Create user channel
398        let (user_tx, user_rx) = mpsc::channel::<Result<Ohlcv>>(1024);
399
400        // Spawn parser task
401        tokio::spawn(async move {
402            while let Some(msg) = rx.recv().await {
403                // Parse OHLCV from kline message
404                match super::parser::parse_ws_ohlcv(&msg) {
405                    Ok(ohlcv_f64) => {
406                        // Convert OHLCV (f64) to Ohlcv (Decimal)
407                        let ohlcv = Ohlcv {
408                            timestamp: ohlcv_f64.timestamp,
409                            open: Price::from(
410                                Decimal::try_from(ohlcv_f64.open).unwrap_or_default(),
411                            ),
412                            high: Price::from(
413                                Decimal::try_from(ohlcv_f64.high).unwrap_or_default(),
414                            ),
415                            low: Price::from(Decimal::try_from(ohlcv_f64.low).unwrap_or_default()),
416                            close: Price::from(
417                                Decimal::try_from(ohlcv_f64.close).unwrap_or_default(),
418                            ),
419                            volume: Amount::from(
420                                Decimal::try_from(ohlcv_f64.volume).unwrap_or_default(),
421                            ),
422                        };
423                        if user_tx.send(Ok(ohlcv)).await.is_err() {
424                            break;
425                        }
426                    }
427                    Err(e) => {
428                        let _ = user_tx.send(Err(e)).await;
429                    }
430                }
431            }
432        });
433
434        let stream = ReceiverStream::new(user_rx);
435        Ok(Box::pin(stream))
436    }
437
438    // ==================== Private Data Streams ====================
439
440    async fn watch_balance(&self) -> Result<MessageStream<Balance>> {
441        self.base
442            .check_required_credentials()
443            .map_err(|_| Error::authentication("API credentials required for watch_balance"))?;
444
445        let binance_arc = Arc::new(self.clone());
446        let default_market_type = MarketType::from(self.options.default_type);
447        let ws = self
448            .connection_manager
449            .get_private_connection(default_market_type, &binance_arc)
450            .await?;
451
452        let (tx, mut rx) = mpsc::channel(1024);
453
454        ws.subscription_manager
455            .add_subscription(
456                "!userData".to_string(),
457                "user".to_string(),
458                super::ws::SubscriptionType::Balance,
459                tx,
460            )
461            .await?;
462
463        let (user_tx, user_rx) = mpsc::channel::<Result<Balance>>(1024);
464        let account_type = self.options.default_type.to_string();
465        let balances_cache = ws.balances.clone();
466
467        tokio::spawn(async move {
468            while let Some(msg) = rx.recv().await {
469                if let Some(event_type) = msg.get("e").and_then(|e| e.as_str()) {
470                    if matches!(
471                        event_type,
472                        "balanceUpdate" | "outboundAccountPosition" | "ACCOUNT_UPDATE"
473                    ) {
474                        if let Ok(()) = super::ws::user_data::handle_balance_message(
475                            &msg,
476                            &account_type,
477                            &balances_cache,
478                        )
479                        .await
480                        {
481                            let balances = balances_cache.read().await;
482                            if let Some(balance) = balances.get(&account_type) {
483                                if user_tx.send(Ok(balance.clone())).await.is_err() {
484                                    break;
485                                }
486                            }
487                        }
488                    }
489                }
490            }
491        });
492
493        let stream = ReceiverStream::new(user_rx);
494        Ok(Box::pin(stream))
495    }
496
497    async fn watch_orders(&self, symbol: Option<&str>) -> Result<MessageStream<Order>> {
498        self.base
499            .check_required_credentials()
500            .map_err(|_| Error::authentication("API credentials required for watch_orders"))?;
501
502        let binance_arc = Arc::new(self.clone());
503        let default_market_type = MarketType::from(self.options.default_type);
504        let ws = self
505            .connection_manager
506            .get_private_connection(default_market_type, &binance_arc)
507            .await?;
508
509        let (tx, mut rx) = mpsc::channel(1024);
510
511        ws.subscription_manager
512            .add_subscription(
513                "!userData".to_string(),
514                "user".to_string(),
515                super::ws::SubscriptionType::Orders,
516                tx,
517            )
518            .await?;
519
520        let (user_tx, user_rx) = mpsc::channel::<Result<Order>>(1024);
521        let symbol_filter = symbol.map(ToString::to_string);
522        let orders_cache = ws.orders.clone();
523
524        tokio::spawn(async move {
525            while let Some(msg) = rx.recv().await {
526                if let Some(data) = msg.as_object() {
527                    if let Some(event_type) = data.get("e").and_then(|e| e.as_str()) {
528                        let order = match event_type {
529                            // Spot market order updates
530                            "executionReport" => Some(super::ws::user_data::parse_ws_order(data)),
531                            // Futures market order updates
532                            "ORDER_TRADE_UPDATE" => {
533                                super::ws::user_data::parse_order_trade_update_to_order(&msg).ok()
534                            }
535                            _ => None,
536                        };
537
538                        if let Some(order) = order {
539                            {
540                                let mut orders = orders_cache.write().await;
541                                let symbol_orders = orders
542                                    .entry(order.symbol.clone())
543                                    .or_insert_with(HashMap::new);
544                                symbol_orders.insert(order.id.clone(), order.clone());
545                            }
546
547                            // Filter and send
548                            if let Some(s) = &symbol_filter {
549                                if &order.symbol != s {
550                                    continue;
551                                }
552                            }
553
554                            if user_tx.send(Ok(order)).await.is_err() {
555                                break;
556                            }
557                        }
558                    }
559                }
560            }
561        });
562
563        let stream = ReceiverStream::new(user_rx);
564        Ok(Box::pin(stream))
565    }
566
567    async fn watch_my_trades(&self, symbol: Option<&str>) -> Result<MessageStream<Trade>> {
568        self.base
569            .check_required_credentials()
570            .map_err(|_| Error::authentication("API credentials required for watch_my_trades"))?;
571
572        let binance_arc = Arc::new(self.clone());
573        let default_market_type = MarketType::from(self.options.default_type);
574        let ws = self
575            .connection_manager
576            .get_private_connection(default_market_type, &binance_arc)
577            .await?;
578
579        let (tx, mut rx) = mpsc::channel(1024);
580
581        ws.subscription_manager
582            .add_subscription(
583                "!userData".to_string(),
584                "user".to_string(),
585                super::ws::SubscriptionType::MyTrades,
586                tx,
587            )
588            .await?;
589
590        let (user_tx, user_rx) = mpsc::channel::<Result<Trade>>(1024);
591        let symbol_filter = symbol.map(ToString::to_string);
592        let trades_cache = ws.my_trades.clone();
593
594        tokio::spawn(async move {
595            while let Some(msg) = rx.recv().await {
596                if let Some(event_type) = msg.get("e").and_then(|e| e.as_str()) {
597                    let trade_result = match event_type {
598                        // Spot market trade executions
599                        "executionReport" => super::ws::user_data::parse_ws_trade(&msg).ok(),
600                        // Futures market trade executions
601                        "ORDER_TRADE_UPDATE" => {
602                            super::ws::user_data::parse_order_trade_update_to_trade(&msg).ok()
603                        }
604                        _ => None,
605                    };
606
607                    if let Some(trade) = trade_result {
608                        {
609                            let mut trades = trades_cache.write().await;
610
611                            let symbol_trades = trades
612                                .entry(trade.symbol.clone())
613                                .or_insert_with(std::collections::VecDeque::new);
614                            symbol_trades.push_front(trade.clone());
615                            if symbol_trades.len() > 1000 {
616                                symbol_trades.pop_back();
617                            }
618                        }
619
620                        // Filter and send
621                        if let Some(s) = &symbol_filter {
622                            if &trade.symbol != s {
623                                continue;
624                            }
625                        }
626
627                        if user_tx.send(Ok(trade)).await.is_err() {
628                            break;
629                        }
630                    }
631                }
632            }
633        });
634
635        let stream = ReceiverStream::new(user_rx);
636        Ok(Box::pin(stream))
637    }
638
639    // ==================== Subscription Management ====================
640
641    async fn subscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
642        // Determine market type from symbol or use default
643        let market_type = if let Some(sym) = symbol {
644            let market = self.base.market(sym).await?;
645            market.market_type
646        } else {
647            MarketType::from(self.options.default_type)
648        };
649
650        // Create WebSocket and connect
651        let ws = self
652            .connection_manager
653            .get_public_connection(market_type)
654            .await?;
655
656        // Use the appropriate subscription method based on channel
657        // Note: The receiver is intentionally dropped here as this is a fire-and-forget subscription.
658        // For proper message handling, use the specific watch_* methods instead.
659        match channel {
660            "ticker" => {
661                if let Some(sym) = symbol {
662                    let market = self.base.market(sym).await?;
663                    ws.subscribe_ticker(&market.id.to_lowercase())
664                        .await
665                        .map(|_| ())
666                } else {
667                    ws.subscribe_all_tickers().await.map(|_| ())
668                }
669            }
670            "trade" | "trades" => {
671                if let Some(sym) = symbol {
672                    let market = self.base.market(sym).await?;
673                    ws.subscribe_trades(&market.id.to_lowercase())
674                        .await
675                        .map(|_| ())
676                } else {
677                    Err(Error::invalid_request(
678                        "Symbol required for trades subscription",
679                    ))
680                }
681            }
682            _ => {
683                // For other channels, try generic subscription
684                Err(Error::invalid_request(format!(
685                    "Unknown channel: {}. Use specific watch_* methods instead.",
686                    channel
687                )))
688            }
689        }
690    }
691
692    async fn unsubscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
693        // Build stream name
694        let stream_name = if let Some(sym) = symbol {
695            // Load markets to get proper symbol format
696            self.load_markets(false).await?;
697            let market = self.base.market(sym).await?;
698            let binance_symbol = market.id.to_lowercase();
699            format!("{}@{}", binance_symbol, channel)
700        } else {
701            channel.to_string()
702        };
703
704        // Determine market type from symbol or use default
705        let market_type = if let Some(sym) = symbol {
706            let market = self.base.market(sym).await?;
707            market.market_type
708        } else {
709            MarketType::from(self.options.default_type)
710        };
711
712        // Create WS to unsubscribe
713        let ws = self
714            .connection_manager
715            .get_public_connection(market_type)
716            .await?;
717        ws.unsubscribe(stream_name).await
718    }
719
720    fn subscriptions(&self) -> Vec<String> {
721        self.connection_manager.get_all_subscriptions()
722    }
723}
724
725#[cfg(test)]
726mod tests {
727    #![allow(clippy::disallowed_methods)]
728    use super::*;
729    use ccxt_core::ExchangeConfig;
730
731    #[test]
732    fn test_ws_exchange_trait_object_safety() {
733        let config = ExchangeConfig::default();
734        let binance = Binance::new(config).unwrap();
735
736        // Test that we can create a WsExchange trait object
737        let _ws_exchange: &dyn WsExchange = &binance;
738
739        // Test connection state methods
740        assert!(!binance.ws_is_connected());
741        assert_eq!(binance.ws_state(), WsConnectionState::Disconnected);
742    }
743
744    #[test]
745    fn test_subscriptions_empty_by_default() {
746        let config = ExchangeConfig::default();
747        let binance = Binance::new(config).unwrap();
748
749        let subs = binance.subscriptions();
750        assert!(subs.is_empty());
751    }
752}