ccxt_exchanges/binance/
ws_exchange_impl.rs

1//! WsExchange trait implementation for Binance
2//!
3//! This module implements the unified `WsExchange` trait from `ccxt-core` for Binance,
4//! providing real-time WebSocket data streaming capabilities.
5
6use async_trait::async_trait;
7use ccxt_core::{
8    error::{Error, Result},
9    types::{
10        Balance, Ohlcv, Order, OrderBook, Ticker, Timeframe, Trade, financial::Amount,
11        financial::Price,
12    },
13    ws_client::WsConnectionState,
14    ws_exchange::{MessageStream, WsExchange},
15};
16
17use rust_decimal::Decimal;
18use std::collections::HashMap;
19use tokio::sync::mpsc;
20
21use super::Binance;
22
23/// A simple stream wrapper that converts an mpsc receiver into a Stream
24struct ReceiverStream<T> {
25    receiver: mpsc::UnboundedReceiver<T>,
26}
27
28impl<T> ReceiverStream<T> {
29    fn new(receiver: mpsc::UnboundedReceiver<T>) -> Self {
30        Self { receiver }
31    }
32}
33
34impl<T> futures::Stream for ReceiverStream<T> {
35    type Item = T;
36
37    fn poll_next(
38        mut self: std::pin::Pin<&mut Self>,
39        cx: &mut std::task::Context<'_>,
40    ) -> std::task::Poll<Option<Self::Item>> {
41        self.receiver.poll_recv(cx)
42    }
43}
44
45#[async_trait]
46impl WsExchange for Binance {
47    // ==================== Connection Management ====================
48
49    async fn ws_connect(&self) -> Result<()> {
50        // Create a public WebSocket connection
51        let ws = self.create_ws();
52        ws.connect().await?;
53
54        // Store the connected WebSocket instance for state tracking
55        let mut conn = self.ws_connection.write().await;
56        *conn = Some(ws);
57
58        Ok(())
59    }
60
61    async fn ws_disconnect(&self) -> Result<()> {
62        // Get and clear the stored WebSocket connection
63        let mut conn = self.ws_connection.write().await;
64        if let Some(ws) = conn.take() {
65            ws.disconnect().await?;
66        }
67        Ok(())
68    }
69
70    fn ws_is_connected(&self) -> bool {
71        // Use try_read to avoid blocking
72        if let Ok(conn) = self.ws_connection.try_read() {
73            if let Some(ref ws) = *conn {
74                return ws.is_connected();
75            }
76        }
77        false
78    }
79
80    fn ws_state(&self) -> WsConnectionState {
81        // Use try_read to avoid blocking
82        if let Ok(conn) = self.ws_connection.try_read() {
83            if let Some(ref ws) = *conn {
84                return ws.state();
85            }
86        }
87        WsConnectionState::Disconnected
88    }
89
90    // ==================== Public Data Streams ====================
91
92    async fn watch_ticker(&self, symbol: &str) -> Result<MessageStream<Ticker>> {
93        // Load markets to validate symbol
94        self.load_markets(false).await?;
95
96        // Get market info
97        let market = self.base.market(symbol).await?;
98        let binance_symbol = market.id.to_lowercase();
99
100        // Create WebSocket and connect
101        let ws = self.create_ws();
102        ws.connect().await?;
103
104        // Subscribe to ticker stream
105        ws.subscribe_ticker(&binance_symbol).await?;
106
107        // Create a channel for streaming data
108        let (tx, rx) = mpsc::unbounded_channel::<Result<Ticker>>();
109
110        // Spawn a task to receive messages and forward them
111        let market_clone = market.clone();
112        tokio::spawn(async move {
113            loop {
114                if let Some(msg) = ws.receive().await {
115                    // Skip subscription confirmations
116                    if msg.get("result").is_some() || msg.get("id").is_some() {
117                        continue;
118                    }
119
120                    // Parse ticker
121                    match super::parser::parse_ws_ticker(&msg, Some(&market_clone)) {
122                        Ok(ticker) => {
123                            if tx.send(Ok(ticker)).is_err() {
124                                break; // Receiver dropped
125                            }
126                        }
127                        Err(e) => {
128                            let _ = tx.send(Err(e));
129                        }
130                    }
131                }
132            }
133        });
134
135        // Convert receiver to stream
136        let stream = ReceiverStream::new(rx);
137        Ok(Box::pin(stream))
138    }
139
140    async fn watch_tickers(&self, symbols: &[String]) -> Result<MessageStream<Vec<Ticker>>> {
141        // Load markets
142        self.load_markets(false).await?;
143
144        // Create WebSocket and connect
145        let ws = self.create_ws();
146        ws.connect().await?;
147
148        // Subscribe to all requested symbols
149        let mut markets = HashMap::new();
150        for symbol in symbols {
151            let market = self.base.market(symbol).await?;
152            let binance_symbol = market.id.to_lowercase();
153            ws.subscribe_ticker(&binance_symbol).await?;
154            markets.insert(binance_symbol, market);
155        }
156
157        // Create channel for streaming
158        let (tx, rx) = mpsc::unbounded_channel::<Result<Vec<Ticker>>>();
159
160        // Spawn receiver task
161        let markets_clone = markets;
162        tokio::spawn(async move {
163            let mut tickers: HashMap<String, Ticker> = HashMap::new();
164
165            loop {
166                if let Some(msg) = ws.receive().await {
167                    // Skip subscription confirmations
168                    if msg.get("result").is_some() || msg.get("id").is_some() {
169                        continue;
170                    }
171
172                    // Get symbol from message
173                    if let Some(symbol_str) = msg.get("s").and_then(|s| s.as_str()) {
174                        let binance_symbol = symbol_str.to_lowercase();
175                        if let Some(market) = markets_clone.get(&binance_symbol) {
176                            if let Ok(ticker) = super::parser::parse_ws_ticker(&msg, Some(market)) {
177                                tickers.insert(ticker.symbol.clone(), ticker);
178
179                                // Send current state of all tickers
180                                let ticker_vec: Vec<Ticker> = tickers.values().cloned().collect();
181                                if tx.send(Ok(ticker_vec)).is_err() {
182                                    break;
183                                }
184                            }
185                        }
186                    }
187                }
188            }
189        });
190
191        let stream = ReceiverStream::new(rx);
192        Ok(Box::pin(stream))
193    }
194
195    async fn watch_order_book(
196        &self,
197        symbol: &str,
198        limit: Option<u32>,
199    ) -> Result<MessageStream<OrderBook>> {
200        // Load markets
201        self.load_markets(false).await?;
202
203        // Get market info
204        let market = self.base.market(symbol).await?;
205        let binance_symbol = market.id.to_lowercase();
206
207        // Create WebSocket and connect
208        let ws = self.create_ws();
209        ws.connect().await?;
210
211        // Subscribe to orderbook stream
212        let levels = limit.unwrap_or(20);
213        ws.subscribe_orderbook(&binance_symbol, levels, "100ms")
214            .await?;
215
216        // Create channel
217        let (tx, rx) = mpsc::unbounded_channel::<Result<OrderBook>>();
218
219        // Spawn receiver task
220        let symbol_clone = symbol.to_string();
221        tokio::spawn(async move {
222            loop {
223                if let Some(msg) = ws.receive().await {
224                    // Skip subscription confirmations
225                    if msg.get("result").is_some() || msg.get("id").is_some() {
226                        continue;
227                    }
228
229                    // Parse orderbook - pass String as required by the parser
230                    match super::parser::parse_ws_orderbook(&msg, symbol_clone.clone()) {
231                        Ok(orderbook) => {
232                            if tx.send(Ok(orderbook)).is_err() {
233                                break;
234                            }
235                        }
236                        Err(e) => {
237                            let _ = tx.send(Err(e));
238                        }
239                    }
240                }
241            }
242        });
243
244        let stream = ReceiverStream::new(rx);
245        Ok(Box::pin(stream))
246    }
247
248    async fn watch_trades(&self, symbol: &str) -> Result<MessageStream<Vec<Trade>>> {
249        // Load markets
250        self.load_markets(false).await?;
251
252        // Get market info
253        let market = self.base.market(symbol).await?;
254        let binance_symbol = market.id.to_lowercase();
255
256        // Create WebSocket and connect
257        let ws = self.create_ws();
258        ws.connect().await?;
259
260        // Subscribe to trades stream
261        ws.subscribe_trades(&binance_symbol).await?;
262
263        // Create channel
264        let (tx, rx) = mpsc::unbounded_channel::<Result<Vec<Trade>>>();
265
266        // Spawn receiver task
267        let market_clone = market.clone();
268        tokio::spawn(async move {
269            loop {
270                if let Some(msg) = ws.receive().await {
271                    // Skip subscription confirmations
272                    if msg.get("result").is_some() || msg.get("id").is_some() {
273                        continue;
274                    }
275
276                    // Parse trade
277                    match super::parser::parse_ws_trade(&msg, Some(&market_clone)) {
278                        Ok(trade) => {
279                            if tx.send(Ok(vec![trade])).is_err() {
280                                break;
281                            }
282                        }
283                        Err(e) => {
284                            let _ = tx.send(Err(e));
285                        }
286                    }
287                }
288            }
289        });
290
291        let stream = ReceiverStream::new(rx);
292        Ok(Box::pin(stream))
293    }
294
295    async fn watch_ohlcv(
296        &self,
297        symbol: &str,
298        timeframe: Timeframe,
299    ) -> Result<MessageStream<Ohlcv>> {
300        // Load markets
301        self.load_markets(false).await?;
302
303        // Get market info
304        let market = self.base.market(symbol).await?;
305        let binance_symbol = market.id.to_lowercase();
306
307        // Convert timeframe to Binance format
308        let interval = timeframe.to_string();
309
310        // Create WebSocket and connect
311        let ws = self.create_ws();
312        ws.connect().await?;
313
314        // Subscribe to kline stream
315        ws.subscribe_kline(&binance_symbol, &interval).await?;
316
317        // Create channel
318        let (tx, rx) = mpsc::unbounded_channel::<Result<Ohlcv>>();
319
320        // Spawn receiver task
321        tokio::spawn(async move {
322            loop {
323                if let Some(msg) = ws.receive().await {
324                    // Skip subscription confirmations
325                    if msg.get("result").is_some() || msg.get("id").is_some() {
326                        continue;
327                    }
328
329                    // Parse OHLCV from kline message
330                    match super::parser::parse_ws_ohlcv(&msg) {
331                        Ok(ohlcv_f64) => {
332                            // Convert OHLCV (f64) to Ohlcv (Decimal)
333                            let ohlcv = Ohlcv {
334                                timestamp: ohlcv_f64.timestamp,
335                                open: Price::from(
336                                    Decimal::try_from(ohlcv_f64.open).unwrap_or_default(),
337                                ),
338                                high: Price::from(
339                                    Decimal::try_from(ohlcv_f64.high).unwrap_or_default(),
340                                ),
341                                low: Price::from(
342                                    Decimal::try_from(ohlcv_f64.low).unwrap_or_default(),
343                                ),
344                                close: Price::from(
345                                    Decimal::try_from(ohlcv_f64.close).unwrap_or_default(),
346                                ),
347                                volume: Amount::from(
348                                    Decimal::try_from(ohlcv_f64.volume).unwrap_or_default(),
349                                ),
350                            };
351                            if tx.send(Ok(ohlcv)).is_err() {
352                                break;
353                            }
354                        }
355                        Err(e) => {
356                            let _ = tx.send(Err(e));
357                        }
358                    }
359                }
360            }
361        });
362
363        let stream = ReceiverStream::new(rx);
364        Ok(Box::pin(stream))
365    }
366
367    // ==================== Private Data Streams ====================
368
369    async fn watch_balance(&self) -> Result<MessageStream<Balance>> {
370        // Check credentials
371        self.base
372            .check_required_credentials()
373            .map_err(|_| Error::authentication("API credentials required for watch_balance"))?;
374
375        // For user data streams, we need an Arc<Binance> to create authenticated WS
376        // This is a limitation of the current design
377        Err(Error::not_implemented(
378            "watch_balance requires Arc<Binance> for authenticated WebSocket. \
379             Use create_authenticated_ws() directly for now.",
380        ))
381    }
382
383    async fn watch_orders(&self, _symbol: Option<&str>) -> Result<MessageStream<Order>> {
384        // Check credentials
385        self.base
386            .check_required_credentials()
387            .map_err(|_| Error::authentication("API credentials required for watch_orders"))?;
388
389        Err(Error::not_implemented(
390            "watch_orders requires Arc<Binance> for authenticated WebSocket. \
391             Use create_authenticated_ws() directly for now.",
392        ))
393    }
394
395    async fn watch_my_trades(&self, _symbol: Option<&str>) -> Result<MessageStream<Trade>> {
396        // Check credentials
397        self.base
398            .check_required_credentials()
399            .map_err(|_| Error::authentication("API credentials required for watch_my_trades"))?;
400
401        Err(Error::not_implemented(
402            "watch_my_trades requires Arc<Binance> for authenticated WebSocket. \
403             Use create_authenticated_ws() directly for now.",
404        ))
405    }
406
407    // ==================== Subscription Management ====================
408
409    async fn subscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
410        // Create WebSocket and connect
411        let ws = self.create_ws();
412        ws.connect().await?;
413
414        // Use the appropriate subscription method based on channel
415        match channel {
416            "ticker" => {
417                if let Some(sym) = symbol {
418                    let market = self.base.market(sym).await?;
419                    ws.subscribe_ticker(&market.id.to_lowercase()).await
420                } else {
421                    ws.subscribe_all_tickers().await
422                }
423            }
424            "trade" | "trades" => {
425                if let Some(sym) = symbol {
426                    let market = self.base.market(sym).await?;
427                    ws.subscribe_trades(&market.id.to_lowercase()).await
428                } else {
429                    Err(Error::invalid_request(
430                        "Symbol required for trades subscription",
431                    ))
432                }
433            }
434            _ => {
435                // For other channels, try generic subscription
436                Err(Error::invalid_request(format!(
437                    "Unknown channel: {}. Use specific watch_* methods instead.",
438                    channel
439                )))
440            }
441        }
442    }
443
444    async fn unsubscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
445        // Build stream name
446        let stream_name = if let Some(sym) = symbol {
447            // Load markets to get proper symbol format
448            self.load_markets(false).await?;
449            let market = self.base.market(sym).await?;
450            let binance_symbol = market.id.to_lowercase();
451            format!("{}@{}", binance_symbol, channel)
452        } else {
453            channel.to_string()
454        };
455
456        // Create WS to unsubscribe
457        let ws = self.create_ws();
458        ws.unsubscribe(stream_name).await
459    }
460
461    fn subscriptions(&self) -> Vec<String> {
462        // Use try_read to avoid blocking
463        if let Ok(conn) = self.ws_connection.try_read() {
464            if let Some(ref ws) = *conn {
465                return ws.subscriptions();
466            }
467        }
468        Vec::new()
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use ccxt_core::ExchangeConfig;
476
477    #[test]
478    fn test_ws_exchange_trait_object_safety() {
479        let config = ExchangeConfig::default();
480        let binance = Binance::new(config).unwrap();
481
482        // Test that we can create a WsExchange trait object
483        let _ws_exchange: &dyn WsExchange = &binance;
484
485        // Test connection state methods
486        assert!(!binance.ws_is_connected());
487        assert_eq!(binance.ws_state(), WsConnectionState::Disconnected);
488    }
489
490    #[test]
491    fn test_subscriptions_empty_by_default() {
492        let config = ExchangeConfig::default();
493        let binance = Binance::new(config).unwrap();
494
495        let subs = binance.subscriptions();
496        assert!(subs.is_empty());
497    }
498}