ccxt_exchanges/binance/ws/
handlers.rs

1//! WebSocket message handlers
2
3#![allow(dead_code)]
4
5use crate::binance::Binance;
6use ccxt_core::error::{Error, Result};
7use ccxt_core::types::financial::{Amount, Price};
8use ccxt_core::types::orderbook::{OrderBookDelta, OrderBookEntry};
9use ccxt_core::types::{OrderBook, Ticker};
10use rust_decimal::Decimal;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::Mutex;
16
17/// Message router for handling WebSocket connections and message routing
18pub struct MessageRouter {
19    /// WebSocket client instance
20    ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
21
22    /// Subscription manager registry
23    subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
24
25    /// Handle to the background routing task
26    router_task: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
27
28    /// Connection state flag
29    is_connected: Arc<std::sync::atomic::AtomicBool>,
30
31    /// Configuration for reconnection behavior
32    reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
33
34    /// WebSocket endpoint URL
35    ws_url: String,
36
37    /// Request ID counter (used for subscribe/unsubscribe)
38    request_id: Arc<std::sync::atomic::AtomicU64>,
39}
40
41impl MessageRouter {
42    /// Creates a new message router
43    pub fn new(
44        ws_url: String,
45        subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
46    ) -> Self {
47        Self {
48            ws_client: Arc::new(tokio::sync::RwLock::new(None)),
49            subscription_manager,
50            router_task: Arc::new(Mutex::new(None)),
51            is_connected: Arc::new(std::sync::atomic::AtomicBool::new(false)),
52            reconnect_config: Arc::new(tokio::sync::RwLock::new(
53                super::subscriptions::ReconnectConfig::default(),
54            )),
55            ws_url,
56            request_id: Arc::new(std::sync::atomic::AtomicU64::new(1)),
57        }
58    }
59
60    /// Starts the message router
61    pub async fn start(&self) -> Result<()> {
62        if self.is_connected() {
63            self.stop().await?;
64        }
65
66        let config = ccxt_core::ws_client::WsConfig {
67            url: self.ws_url.clone(),
68            ..Default::default()
69        };
70        let client = ccxt_core::ws_client::WsClient::new(config);
71        client.connect().await?;
72
73        *self.ws_client.write().await = Some(client);
74
75        self.is_connected
76            .store(true, std::sync::atomic::Ordering::SeqCst);
77
78        let ws_client = self.ws_client.clone();
79        let subscription_manager = self.subscription_manager.clone();
80        let is_connected = self.is_connected.clone();
81        let reconnect_config = self.reconnect_config.clone();
82        let ws_url = self.ws_url.clone();
83
84        let handle = tokio::spawn(async move {
85            Self::message_loop(
86                ws_client,
87                subscription_manager,
88                is_connected,
89                reconnect_config,
90                ws_url,
91            )
92            .await;
93        });
94
95        *self.router_task.lock().await = Some(handle);
96
97        Ok(())
98    }
99
100    /// Stops the message router
101    pub async fn stop(&self) -> Result<()> {
102        self.is_connected
103            .store(false, std::sync::atomic::Ordering::SeqCst);
104
105        let mut task_opt = self.router_task.lock().await;
106        if let Some(handle) = task_opt.take() {
107            handle.abort();
108        }
109
110        let mut client_opt = self.ws_client.write().await;
111        if let Some(client) = client_opt.take() {
112            let _ = client.disconnect().await;
113        }
114
115        Ok(())
116    }
117
118    /// Restarts the message router
119    pub async fn restart(&self) -> Result<()> {
120        self.stop().await?;
121        tokio::time::sleep(Duration::from_millis(100)).await;
122        self.start().await
123    }
124
125    /// Returns the current connection state
126    pub fn is_connected(&self) -> bool {
127        self.is_connected.load(std::sync::atomic::Ordering::SeqCst)
128    }
129
130    /// Applies a new reconnection configuration
131    pub async fn set_reconnect_config(&self, config: super::subscriptions::ReconnectConfig) {
132        *self.reconnect_config.write().await = config;
133    }
134
135    /// Retrieves the current reconnection configuration
136    pub async fn get_reconnect_config(&self) -> super::subscriptions::ReconnectConfig {
137        self.reconnect_config.read().await.clone()
138    }
139
140    /// Subscribes to the provided streams
141    pub async fn subscribe(&self, streams: Vec<String>) -> Result<()> {
142        if streams.is_empty() {
143            return Ok(());
144        }
145
146        let client_opt = self.ws_client.read().await;
147        let client = client_opt
148            .as_ref()
149            .ok_or_else(|| Error::network("WebSocket not connected"))?;
150
151        let id = self
152            .request_id
153            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
154
155        #[allow(clippy::disallowed_methods)]
156        let request = serde_json::json!({
157            "method": "SUBSCRIBE",
158            "params": streams,
159            "id": id
160        });
161
162        client
163            .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
164                request.to_string().into(),
165            ))
166            .await?;
167
168        Ok(())
169    }
170
171    /// Unsubscribes from the provided streams
172    pub async fn unsubscribe(&self, streams: Vec<String>) -> Result<()> {
173        if streams.is_empty() {
174            return Ok(());
175        }
176
177        let client_opt = self.ws_client.read().await;
178        let client = client_opt
179            .as_ref()
180            .ok_or_else(|| Error::network("WebSocket not connected"))?;
181
182        let id = self
183            .request_id
184            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
185
186        #[allow(clippy::disallowed_methods)]
187        let request = serde_json::json!({
188            "method": "UNSUBSCRIBE",
189            "params": streams,
190            "id": id
191        });
192
193        client
194            .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
195                request.to_string().into(),
196            ))
197            .await?;
198
199        Ok(())
200    }
201
202    /// Message reception loop
203    async fn message_loop(
204        ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
205        subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
206        is_connected: Arc<std::sync::atomic::AtomicBool>,
207        reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
208        ws_url: String,
209    ) {
210        let mut reconnect_attempt = 0;
211
212        loop {
213            if !is_connected.load(std::sync::atomic::Ordering::SeqCst) {
214                break;
215            }
216
217            let has_client = ws_client.read().await.is_some();
218
219            if !has_client {
220                let config = reconnect_config.read().await;
221                if config.should_retry(reconnect_attempt) {
222                    let delay = config.calculate_delay(reconnect_attempt);
223                    drop(config);
224
225                    tokio::time::sleep(Duration::from_millis(delay)).await;
226
227                    if let Ok(()) = Self::reconnect(&ws_url, ws_client.clone()).await {
228                        reconnect_attempt = 0;
229                        continue;
230                    }
231                    reconnect_attempt += 1;
232                    continue;
233                }
234                is_connected.store(false, std::sync::atomic::Ordering::SeqCst);
235                break;
236            }
237
238            let message_opt = {
239                let guard = ws_client.read().await;
240                if let Some(client) = guard.as_ref() {
241                    client.receive().await
242                } else {
243                    None
244                }
245            };
246
247            if let Some(value) = message_opt {
248                if let Err(_e) = Self::handle_message(value, subscription_manager.clone()).await {
249                    continue;
250                }
251
252                reconnect_attempt = 0;
253            } else {
254                let config = reconnect_config.read().await;
255                if config.should_retry(reconnect_attempt) {
256                    let delay = config.calculate_delay(reconnect_attempt);
257                    drop(config);
258
259                    tokio::time::sleep(Duration::from_millis(delay)).await;
260
261                    if let Ok(()) = Self::reconnect(&ws_url, ws_client.clone()).await {
262                        reconnect_attempt = 0;
263                        continue;
264                    }
265                    reconnect_attempt += 1;
266                    continue;
267                }
268                is_connected.store(false, std::sync::atomic::Ordering::SeqCst);
269                break;
270            }
271        }
272    }
273
274    /// Processes a WebSocket message
275    async fn handle_message(
276        message: Value,
277        subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
278    ) -> Result<()> {
279        let stream_name = Self::extract_stream_name(&message)?;
280
281        let sent = subscription_manager
282            .send_to_stream(&stream_name, message)
283            .await;
284
285        if sent {
286            Ok(())
287        } else {
288            Err(Error::generic("No subscribers for stream"))
289        }
290    }
291
292    /// Extracts the stream name from an incoming message
293    pub fn extract_stream_name(message: &Value) -> Result<String> {
294        if let Some(stream) = message.get("stream").and_then(|s| s.as_str()) {
295            return Ok(stream.to_string());
296        }
297
298        if let Some(event_type) = message.get("e").and_then(|e| e.as_str()) {
299            if let Some(symbol) = message.get("s").and_then(|s| s.as_str()) {
300                let stream = match event_type {
301                    "24hrTicker" => format!("{}@ticker", symbol.to_lowercase()),
302                    "depthUpdate" => format!("{}@depth", symbol.to_lowercase()),
303                    "aggTrade" => format!("{}@aggTrade", symbol.to_lowercase()),
304                    "trade" => format!("{}@trade", symbol.to_lowercase()),
305                    "kline" => {
306                        if let Some(kline) = message.get("k") {
307                            if let Some(interval) = kline.get("i").and_then(|i| i.as_str()) {
308                                format!("{}@kline_{}", symbol.to_lowercase(), interval)
309                            } else {
310                                return Err(Error::generic("Missing kline interval"));
311                            }
312                        } else {
313                            return Err(Error::generic("Missing kline data"));
314                        }
315                    }
316                    "markPriceUpdate" => format!("{}@markPrice", symbol.to_lowercase()),
317                    "bookTicker" => format!("{}@bookTicker", symbol.to_lowercase()),
318                    _ => {
319                        return Err(Error::generic(format!(
320                            "Unknown event type: {}",
321                            event_type
322                        )));
323                    }
324                };
325                return Ok(stream);
326            }
327        }
328
329        if message.get("result").is_some() || message.get("error").is_some() {
330            return Err(Error::generic("Subscription response, skip routing"));
331        }
332
333        Err(Error::generic("Cannot extract stream name from message"))
334    }
335
336    /// Reconnects the WebSocket client
337    async fn reconnect(
338        ws_url: &str,
339        ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
340    ) -> Result<()> {
341        {
342            let mut client_opt = ws_client.write().await;
343            if let Some(client) = client_opt.take() {
344                let _ = client.disconnect().await;
345            }
346        }
347
348        let config = ccxt_core::ws_client::WsConfig {
349            url: ws_url.to_string(),
350            ..Default::default()
351        };
352        let new_client = ccxt_core::ws_client::WsClient::new(config);
353
354        new_client.connect().await?;
355
356        *ws_client.write().await = Some(new_client);
357
358        Ok(())
359    }
360}
361
362impl Drop for MessageRouter {
363    fn drop(&mut self) {
364        // Note: Drop is synchronous, so we cannot await asynchronous operations here.
365        // Callers should explicitly invoke `stop()` to release resources.
366    }
367}
368
369/// Processes an order book delta update
370pub async fn handle_orderbook_delta(
371    symbol: &str,
372    delta_message: &Value,
373    is_futures: bool,
374    orderbooks: &Mutex<HashMap<String, OrderBook>>,
375) -> Result<()> {
376    let first_update_id = delta_message["U"]
377        .as_i64()
378        .ok_or_else(|| Error::invalid_request("Missing first update ID in delta message"))?;
379
380    let final_update_id = delta_message["u"]
381        .as_i64()
382        .ok_or_else(|| Error::invalid_request("Missing final update ID in delta message"))?;
383
384    let prev_final_update_id = if is_futures {
385        delta_message["pu"].as_i64()
386    } else {
387        None
388    };
389
390    let timestamp = delta_message["E"]
391        .as_i64()
392        .unwrap_or_else(|| chrono::Utc::now().timestamp_millis());
393
394    let mut bids = Vec::new();
395    if let Some(bids_arr) = delta_message["b"].as_array() {
396        for bid in bids_arr {
397            if let (Some(price_str), Some(amount_str)) = (bid[0].as_str(), bid[1].as_str()) {
398                if let (Ok(price), Ok(amount)) =
399                    (price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
400                {
401                    bids.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
402                }
403            }
404        }
405    }
406
407    let mut asks = Vec::new();
408    if let Some(asks_arr) = delta_message["a"].as_array() {
409        for ask in asks_arr {
410            if let (Some(price_str), Some(amount_str)) = (ask[0].as_str(), ask[1].as_str()) {
411                if let (Ok(price), Ok(amount)) =
412                    (price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
413                {
414                    asks.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
415                }
416            }
417        }
418    }
419
420    let delta = OrderBookDelta {
421        symbol: symbol.to_string(),
422        first_update_id,
423        final_update_id,
424        prev_final_update_id,
425        timestamp,
426        bids,
427        asks,
428    };
429
430    let mut orderbooks_map = orderbooks.lock().await;
431    let orderbook = orderbooks_map
432        .entry(symbol.to_string())
433        .or_insert_with(|| OrderBook::new(symbol.to_string(), timestamp));
434
435    if !orderbook.is_synced {
436        orderbook.buffer_delta(delta);
437        return Ok(());
438    }
439
440    if let Err(e) = orderbook.apply_delta(&delta, is_futures) {
441        if orderbook.needs_resync {
442            tracing::warn!("Orderbook {} needs resync due to: {}", symbol, e);
443            orderbook.buffer_delta(delta);
444            return Err(Error::invalid_request(format!("RESYNC_NEEDED: {}", e)));
445        }
446        return Err(Error::invalid_request(e));
447    }
448
449    Ok(())
450}
451
452/// Retrieves an order book snapshot and initializes cached state
453pub async fn fetch_orderbook_snapshot(
454    exchange: &Binance,
455    symbol: &str,
456    limit: Option<i64>,
457    is_futures: bool,
458    orderbooks: &Mutex<HashMap<String, OrderBook>>,
459) -> Result<OrderBook> {
460    let mut params = HashMap::new();
461    if let Some(l) = limit {
462        #[allow(clippy::disallowed_methods)]
463        let limit_value = serde_json::json!(l);
464        params.insert("limit".to_string(), limit_value);
465    }
466
467    let mut snapshot = exchange.fetch_order_book(symbol, None).await?;
468
469    snapshot.is_synced = true;
470
471    let mut orderbooks_map = orderbooks.lock().await;
472    if let Some(cached_ob) = orderbooks_map.get_mut(symbol) {
473        snapshot
474            .buffered_deltas
475            .clone_from(&cached_ob.buffered_deltas);
476
477        if let Ok(processed) = snapshot.process_buffered_deltas(is_futures) {
478            tracing::debug!("Processed {} buffered deltas for {}", processed, symbol);
479        }
480    }
481
482    orderbooks_map.insert(symbol.to_string(), snapshot.clone());
483
484    Ok(snapshot)
485}
486
487/// Watches a single ticker stream
488pub async fn watch_ticker_internal(
489    ws_client: &ccxt_core::ws_client::WsClient,
490    symbol: &str,
491    channel_name: &str,
492    tickers: &Mutex<HashMap<String, Ticker>>,
493    parser: &dyn Fn(&Value, Option<&ccxt_core::types::Market>) -> Result<Ticker>,
494) -> Result<Ticker> {
495    let stream = format!("{}@{}", symbol.to_lowercase(), channel_name);
496
497    ws_client
498        .subscribe(stream.clone(), Some(symbol.to_string()), None)
499        .await?;
500
501    loop {
502        if let Some(message) = ws_client.receive().await {
503            if message.get("result").is_some() {
504                continue;
505            }
506
507            if let Ok(ticker) = parser(&message, None) {
508                let mut tickers_map = tickers.lock().await;
509                tickers_map.insert(ticker.symbol.clone(), ticker.clone());
510
511                return Ok(ticker);
512            }
513        }
514    }
515}