ccxt_exchanges/binance/
ws_exchange_impl.rs

1//! WsExchange trait implementation for Binance
2//!
3//! This module implements the unified `WsExchange` trait from `ccxt-core` for Binance,
4//! providing real-time WebSocket data streaming capabilities.
5
6use async_trait::async_trait;
7use ccxt_core::{
8    error::{Error, Result},
9    types::{
10        Balance, Ohlcv, Order, OrderBook, Ticker, Timeframe, Trade, financial::Amount,
11        financial::Price,
12    },
13    ws_client::WsConnectionState,
14    ws_exchange::{MessageStream, WsExchange},
15};
16
17use rust_decimal::Decimal;
18use std::collections::HashMap;
19use tokio::sync::mpsc;
20
21use super::Binance;
22
23/// A simple stream wrapper that converts an mpsc receiver into a Stream
24struct ReceiverStream<T> {
25    receiver: mpsc::UnboundedReceiver<T>,
26}
27
28impl<T> ReceiverStream<T> {
29    fn new(receiver: mpsc::UnboundedReceiver<T>) -> Self {
30        Self { receiver }
31    }
32}
33
34impl<T> futures::Stream for ReceiverStream<T> {
35    type Item = T;
36
37    fn poll_next(
38        mut self: std::pin::Pin<&mut Self>,
39        cx: &mut std::task::Context<'_>,
40    ) -> std::task::Poll<Option<Self::Item>> {
41        self.receiver.poll_recv(cx)
42    }
43}
44
45#[async_trait]
46impl WsExchange for Binance {
47    // ==================== Connection Management ====================
48
49    async fn ws_connect(&self) -> Result<()> {
50        // Create a public WebSocket connection
51        let ws = self.create_ws();
52        ws.connect().await?;
53        Ok(())
54    }
55
56    async fn ws_disconnect(&self) -> Result<()> {
57        // Create a temporary WS to disconnect
58        let ws = self.create_ws();
59        ws.disconnect().await
60    }
61
62    fn ws_is_connected(&self) -> bool {
63        // Without persistent state tracking, return false
64        // In a full implementation, we'd track the active connection
65        false
66    }
67
68    fn ws_state(&self) -> WsConnectionState {
69        // Return disconnected since we don't have persistent state
70        WsConnectionState::Disconnected
71    }
72
73    // ==================== Public Data Streams ====================
74
75    async fn watch_ticker(&self, symbol: &str) -> Result<MessageStream<Ticker>> {
76        // Load markets to validate symbol
77        self.load_markets(false).await?;
78
79        // Get market info
80        let market = self.base.market(symbol).await?;
81        let binance_symbol = market.id.to_lowercase();
82
83        // Create WebSocket and connect
84        let ws = self.create_ws();
85        ws.connect().await?;
86
87        // Subscribe to ticker stream
88        ws.subscribe_ticker(&binance_symbol).await?;
89
90        // Create a channel for streaming data
91        let (tx, rx) = mpsc::unbounded_channel::<Result<Ticker>>();
92
93        // Spawn a task to receive messages and forward them
94        let market_clone = market.clone();
95        tokio::spawn(async move {
96            loop {
97                if let Some(msg) = ws.receive().await {
98                    // Skip subscription confirmations
99                    if msg.get("result").is_some() || msg.get("id").is_some() {
100                        continue;
101                    }
102
103                    // Parse ticker
104                    match super::parser::parse_ws_ticker(&msg, Some(&market_clone)) {
105                        Ok(ticker) => {
106                            if tx.send(Ok(ticker)).is_err() {
107                                break; // Receiver dropped
108                            }
109                        }
110                        Err(e) => {
111                            let _ = tx.send(Err(e));
112                        }
113                    }
114                }
115            }
116        });
117
118        // Convert receiver to stream
119        let stream = ReceiverStream::new(rx);
120        Ok(Box::pin(stream))
121    }
122
123    async fn watch_tickers(&self, symbols: &[String]) -> Result<MessageStream<Vec<Ticker>>> {
124        // Load markets
125        self.load_markets(false).await?;
126
127        // Create WebSocket and connect
128        let ws = self.create_ws();
129        ws.connect().await?;
130
131        // Subscribe to all requested symbols
132        let mut markets = HashMap::new();
133        for symbol in symbols {
134            let market = self.base.market(symbol).await?;
135            let binance_symbol = market.id.to_lowercase();
136            ws.subscribe_ticker(&binance_symbol).await?;
137            markets.insert(binance_symbol, market);
138        }
139
140        // Create channel for streaming
141        let (tx, rx) = mpsc::unbounded_channel::<Result<Vec<Ticker>>>();
142
143        // Spawn receiver task
144        let markets_clone = markets;
145        tokio::spawn(async move {
146            let mut tickers: HashMap<String, Ticker> = HashMap::new();
147
148            loop {
149                if let Some(msg) = ws.receive().await {
150                    // Skip subscription confirmations
151                    if msg.get("result").is_some() || msg.get("id").is_some() {
152                        continue;
153                    }
154
155                    // Get symbol from message
156                    if let Some(symbol_str) = msg.get("s").and_then(|s| s.as_str()) {
157                        let binance_symbol = symbol_str.to_lowercase();
158                        if let Some(market) = markets_clone.get(&binance_symbol) {
159                            if let Ok(ticker) = super::parser::parse_ws_ticker(&msg, Some(market)) {
160                                tickers.insert(ticker.symbol.clone(), ticker);
161
162                                // Send current state of all tickers
163                                let ticker_vec: Vec<Ticker> = tickers.values().cloned().collect();
164                                if tx.send(Ok(ticker_vec)).is_err() {
165                                    break;
166                                }
167                            }
168                        }
169                    }
170                }
171            }
172        });
173
174        let stream = ReceiverStream::new(rx);
175        Ok(Box::pin(stream))
176    }
177
178    async fn watch_order_book(
179        &self,
180        symbol: &str,
181        limit: Option<u32>,
182    ) -> Result<MessageStream<OrderBook>> {
183        // Load markets
184        self.load_markets(false).await?;
185
186        // Get market info
187        let market = self.base.market(symbol).await?;
188        let binance_symbol = market.id.to_lowercase();
189
190        // Create WebSocket and connect
191        let ws = self.create_ws();
192        ws.connect().await?;
193
194        // Subscribe to orderbook stream
195        let levels = limit.unwrap_or(20);
196        ws.subscribe_orderbook(&binance_symbol, levels, "100ms")
197            .await?;
198
199        // Create channel
200        let (tx, rx) = mpsc::unbounded_channel::<Result<OrderBook>>();
201
202        // Spawn receiver task
203        let symbol_clone = symbol.to_string();
204        tokio::spawn(async move {
205            loop {
206                if let Some(msg) = ws.receive().await {
207                    // Skip subscription confirmations
208                    if msg.get("result").is_some() || msg.get("id").is_some() {
209                        continue;
210                    }
211
212                    // Parse orderbook - pass String as required by the parser
213                    match super::parser::parse_ws_orderbook(&msg, symbol_clone.clone()) {
214                        Ok(orderbook) => {
215                            if tx.send(Ok(orderbook)).is_err() {
216                                break;
217                            }
218                        }
219                        Err(e) => {
220                            let _ = tx.send(Err(e));
221                        }
222                    }
223                }
224            }
225        });
226
227        let stream = ReceiverStream::new(rx);
228        Ok(Box::pin(stream))
229    }
230
231    async fn watch_trades(&self, symbol: &str) -> Result<MessageStream<Vec<Trade>>> {
232        // Load markets
233        self.load_markets(false).await?;
234
235        // Get market info
236        let market = self.base.market(symbol).await?;
237        let binance_symbol = market.id.to_lowercase();
238
239        // Create WebSocket and connect
240        let ws = self.create_ws();
241        ws.connect().await?;
242
243        // Subscribe to trades stream
244        ws.subscribe_trades(&binance_symbol).await?;
245
246        // Create channel
247        let (tx, rx) = mpsc::unbounded_channel::<Result<Vec<Trade>>>();
248
249        // Spawn receiver task
250        let market_clone = market.clone();
251        tokio::spawn(async move {
252            loop {
253                if let Some(msg) = ws.receive().await {
254                    // Skip subscription confirmations
255                    if msg.get("result").is_some() || msg.get("id").is_some() {
256                        continue;
257                    }
258
259                    // Parse trade
260                    match super::parser::parse_ws_trade(&msg, Some(&market_clone)) {
261                        Ok(trade) => {
262                            if tx.send(Ok(vec![trade])).is_err() {
263                                break;
264                            }
265                        }
266                        Err(e) => {
267                            let _ = tx.send(Err(e));
268                        }
269                    }
270                }
271            }
272        });
273
274        let stream = ReceiverStream::new(rx);
275        Ok(Box::pin(stream))
276    }
277
278    async fn watch_ohlcv(
279        &self,
280        symbol: &str,
281        timeframe: Timeframe,
282    ) -> Result<MessageStream<Ohlcv>> {
283        // Load markets
284        self.load_markets(false).await?;
285
286        // Get market info
287        let market = self.base.market(symbol).await?;
288        let binance_symbol = market.id.to_lowercase();
289
290        // Convert timeframe to Binance format
291        let interval = timeframe.to_string();
292
293        // Create WebSocket and connect
294        let ws = self.create_ws();
295        ws.connect().await?;
296
297        // Subscribe to kline stream
298        ws.subscribe_kline(&binance_symbol, &interval).await?;
299
300        // Create channel
301        let (tx, rx) = mpsc::unbounded_channel::<Result<Ohlcv>>();
302
303        // Spawn receiver task
304        tokio::spawn(async move {
305            loop {
306                if let Some(msg) = ws.receive().await {
307                    // Skip subscription confirmations
308                    if msg.get("result").is_some() || msg.get("id").is_some() {
309                        continue;
310                    }
311
312                    // Parse OHLCV from kline message
313                    match super::parser::parse_ws_ohlcv(&msg) {
314                        Ok(ohlcv_f64) => {
315                            // Convert OHLCV (f64) to Ohlcv (Decimal)
316                            let ohlcv = Ohlcv {
317                                timestamp: ohlcv_f64.timestamp,
318                                open: Price::from(
319                                    Decimal::try_from(ohlcv_f64.open).unwrap_or_default(),
320                                ),
321                                high: Price::from(
322                                    Decimal::try_from(ohlcv_f64.high).unwrap_or_default(),
323                                ),
324                                low: Price::from(
325                                    Decimal::try_from(ohlcv_f64.low).unwrap_or_default(),
326                                ),
327                                close: Price::from(
328                                    Decimal::try_from(ohlcv_f64.close).unwrap_or_default(),
329                                ),
330                                volume: Amount::from(
331                                    Decimal::try_from(ohlcv_f64.volume).unwrap_or_default(),
332                                ),
333                            };
334                            if tx.send(Ok(ohlcv)).is_err() {
335                                break;
336                            }
337                        }
338                        Err(e) => {
339                            let _ = tx.send(Err(e));
340                        }
341                    }
342                }
343            }
344        });
345
346        let stream = ReceiverStream::new(rx);
347        Ok(Box::pin(stream))
348    }
349
350    // ==================== Private Data Streams ====================
351
352    async fn watch_balance(&self) -> Result<MessageStream<Balance>> {
353        // Check credentials
354        self.base
355            .check_required_credentials()
356            .map_err(|_| Error::authentication("API credentials required for watch_balance"))?;
357
358        // For user data streams, we need an Arc<Binance> to create authenticated WS
359        // This is a limitation of the current design
360        Err(Error::not_implemented(
361            "watch_balance requires Arc<Binance> for authenticated WebSocket. \
362             Use create_authenticated_ws() directly for now.",
363        ))
364    }
365
366    async fn watch_orders(&self, _symbol: Option<&str>) -> Result<MessageStream<Order>> {
367        // Check credentials
368        self.base
369            .check_required_credentials()
370            .map_err(|_| Error::authentication("API credentials required for watch_orders"))?;
371
372        Err(Error::not_implemented(
373            "watch_orders requires Arc<Binance> for authenticated WebSocket. \
374             Use create_authenticated_ws() directly for now.",
375        ))
376    }
377
378    async fn watch_my_trades(&self, _symbol: Option<&str>) -> Result<MessageStream<Trade>> {
379        // Check credentials
380        self.base
381            .check_required_credentials()
382            .map_err(|_| Error::authentication("API credentials required for watch_my_trades"))?;
383
384        Err(Error::not_implemented(
385            "watch_my_trades requires Arc<Binance> for authenticated WebSocket. \
386             Use create_authenticated_ws() directly for now.",
387        ))
388    }
389
390    // ==================== Subscription Management ====================
391
392    async fn subscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
393        // Create WebSocket and connect
394        let ws = self.create_ws();
395        ws.connect().await?;
396
397        // Use the appropriate subscription method based on channel
398        match channel {
399            "ticker" => {
400                if let Some(sym) = symbol {
401                    let market = self.base.market(sym).await?;
402                    ws.subscribe_ticker(&market.id.to_lowercase()).await
403                } else {
404                    ws.subscribe_all_tickers().await
405                }
406            }
407            "trade" | "trades" => {
408                if let Some(sym) = symbol {
409                    let market = self.base.market(sym).await?;
410                    ws.subscribe_trades(&market.id.to_lowercase()).await
411                } else {
412                    Err(Error::invalid_request(
413                        "Symbol required for trades subscription",
414                    ))
415                }
416            }
417            _ => {
418                // For other channels, try generic subscription
419                Err(Error::invalid_request(format!(
420                    "Unknown channel: {}. Use specific watch_* methods instead.",
421                    channel
422                )))
423            }
424        }
425    }
426
427    async fn unsubscribe(&self, channel: &str, symbol: Option<&str>) -> Result<()> {
428        // Build stream name
429        let stream_name = if let Some(sym) = symbol {
430            // Load markets to get proper symbol format
431            self.load_markets(false).await?;
432            let market = self.base.market(sym).await?;
433            let binance_symbol = market.id.to_lowercase();
434            format!("{}@{}", binance_symbol, channel)
435        } else {
436            channel.to_string()
437        };
438
439        // Create WS to unsubscribe
440        let ws = self.create_ws();
441        ws.unsubscribe(stream_name).await
442    }
443
444    fn subscriptions(&self) -> Vec<String> {
445        // Without persistent state, return empty
446        // In a full implementation, we'd track active subscriptions
447        Vec::new()
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use ccxt_core::ExchangeConfig;
455
456    #[test]
457    fn test_ws_exchange_trait_object_safety() {
458        let config = ExchangeConfig::default();
459        let binance = Binance::new(config).unwrap();
460
461        // Test that we can create a WsExchange trait object
462        let _ws_exchange: &dyn WsExchange = &binance;
463
464        // Test connection state methods
465        assert!(!binance.ws_is_connected());
466        assert_eq!(binance.ws_state(), WsConnectionState::Disconnected);
467    }
468
469    #[test]
470    fn test_subscriptions_empty_by_default() {
471        let config = ExchangeConfig::default();
472        let binance = Binance::new(config).unwrap();
473
474        let subs = binance.subscriptions();
475        assert!(subs.is_empty());
476    }
477}