1#![allow(dead_code)]
4
5use crate::binance::Binance;
6use ccxt_core::error::{Error, Result};
7use ccxt_core::types::OrderBook;
8use ccxt_core::types::financial::{Amount, Price};
9use ccxt_core::types::orderbook::{OrderBookDelta, OrderBookEntry};
10use rust_decimal::Decimal;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::Mutex;
16
17struct MessageLoopContext {
18 ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
19 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
20 is_connected: Arc<std::sync::atomic::AtomicBool>,
21 reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
22 request_id: Arc<std::sync::atomic::AtomicU64>,
23 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
24 base_url: String,
25 current_url: String,
26}
27
28pub struct MessageRouter {
30 ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
32
33 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
35
36 router_task: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
38
39 is_connected: Arc<std::sync::atomic::AtomicBool>,
41
42 connection_lock: Arc<Mutex<()>>,
44
45 reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
47
48 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
50
51 ws_url: String,
53
54 request_id: Arc<std::sync::atomic::AtomicU64>,
56}
57
58impl MessageRouter {
59 pub fn new(
61 ws_url: String,
62 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
63 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
64 ) -> Self {
65 Self {
66 ws_client: Arc::new(tokio::sync::RwLock::new(None)),
67 subscription_manager,
68 router_task: Arc::new(Mutex::new(None)),
69 is_connected: Arc::new(std::sync::atomic::AtomicBool::new(false)),
70 connection_lock: Arc::new(Mutex::new(())),
71 reconnect_config: Arc::new(tokio::sync::RwLock::new(
72 super::subscriptions::ReconnectConfig::default(),
73 )),
74 listen_key_manager,
75 ws_url,
76 request_id: Arc::new(std::sync::atomic::AtomicU64::new(1)),
77 }
78 }
79
80 pub async fn start(&self, url_override: Option<String>) -> Result<()> {
82 let _lock = self.connection_lock.lock().await;
83
84 if self.is_connected() {
85 if url_override.is_some() {
86 self.stop().await?;
87 } else {
88 return Ok(());
89 }
90 }
91
92 let url = url_override.unwrap_or_else(|| self.ws_url.clone());
93 let config = ccxt_core::ws_client::WsConfig {
94 url: url.clone(),
95 ..Default::default()
96 };
97 let client = ccxt_core::ws_client::WsClient::new(config);
98 client.connect().await?;
99
100 *self.ws_client.write().await = Some(client);
101
102 self.is_connected
103 .store(true, std::sync::atomic::Ordering::SeqCst);
104
105 let ws_client = self.ws_client.clone();
106 let subscription_manager = self.subscription_manager.clone();
107 let is_connected = self.is_connected.clone();
108 let reconnect_config = self.reconnect_config.clone();
109 let request_id = self.request_id.clone();
110 let listen_key_manager = self.listen_key_manager.clone();
111
112 let ws_url = self.ws_url.clone();
113 let current_url = url;
114
115 let ctx = MessageLoopContext {
116 ws_client,
117 subscription_manager,
118 is_connected,
119 reconnect_config,
120 request_id,
121 listen_key_manager,
122 base_url: ws_url,
123 current_url,
124 };
125
126 let handle = tokio::spawn(async move {
127 Self::message_loop(ctx).await;
128 });
129
130 *self.router_task.lock().await = Some(handle);
131
132 Ok(())
133 }
134
135 pub async fn stop(&self) -> Result<()> {
137 self.is_connected
138 .store(false, std::sync::atomic::Ordering::SeqCst);
139
140 let mut task_opt = self.router_task.lock().await;
141 if let Some(handle) = task_opt.take() {
142 handle.abort();
143 }
144
145 let mut client_opt = self.ws_client.write().await;
146 if let Some(client) = client_opt.take() {
147 let _ = client.disconnect().await;
148 }
149
150 Ok(())
151 }
152
153 pub async fn restart(&self) -> Result<()> {
155 self.stop().await?;
156 tokio::time::sleep(Duration::from_millis(100)).await;
157 self.start(None).await
158 }
159
160 pub fn get_url(&self) -> String {
162 self.ws_url.clone()
163 }
164
165 pub fn is_connected(&self) -> bool {
167 self.is_connected.load(std::sync::atomic::Ordering::SeqCst)
168 }
169
170 pub async fn set_reconnect_config(&self, config: super::subscriptions::ReconnectConfig) {
172 *self.reconnect_config.write().await = config;
173 }
174
175 pub async fn get_reconnect_config(&self) -> super::subscriptions::ReconnectConfig {
177 self.reconnect_config.read().await.clone()
178 }
179
180 pub async fn subscribe(&self, streams: Vec<String>) -> Result<()> {
182 if streams.is_empty() {
183 return Ok(());
184 }
185
186 let client_opt = self.ws_client.read().await;
187 let client = client_opt
188 .as_ref()
189 .ok_or_else(|| Error::network("WebSocket not connected"))?;
190
191 let id = self
192 .request_id
193 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
194
195 #[allow(clippy::disallowed_methods)]
196 let request = serde_json::json!({
197 "method": "SUBSCRIBE",
198 "params": streams,
199 "id": id
200 });
201
202 client
203 .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
204 request.to_string().into(),
205 ))
206 .await?;
207
208 Ok(())
209 }
210
211 pub async fn unsubscribe(&self, streams: Vec<String>) -> Result<()> {
213 if streams.is_empty() {
214 return Ok(());
215 }
216
217 let client_opt = self.ws_client.read().await;
218 let client = client_opt
219 .as_ref()
220 .ok_or_else(|| Error::network("WebSocket not connected"))?;
221
222 let id = self
223 .request_id
224 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
225
226 #[allow(clippy::disallowed_methods)]
227 let request = serde_json::json!({
228 "method": "UNSUBSCRIBE",
229 "params": streams,
230 "id": id
231 });
232
233 client
234 .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
235 request.to_string().into(),
236 ))
237 .await?;
238
239 Ok(())
240 }
241
242 async fn message_loop(ctx: MessageLoopContext) {
244 let mut reconnect_attempt = 0;
245
246 Self::resubscribe_all(&ctx.ws_client, &ctx.subscription_manager, &ctx.request_id).await;
247
248 loop {
249 if !ctx.is_connected.load(std::sync::atomic::Ordering::SeqCst) {
250 break;
251 }
252
253 let has_client = ctx.ws_client.read().await.is_some();
254
255 if !has_client {
256 let config = ctx.reconnect_config.read().await;
257 if config.should_retry(reconnect_attempt) {
258 let delay = config.calculate_delay(reconnect_attempt);
259 drop(config);
260
261 tokio::time::sleep(Duration::from_millis(delay)).await;
262
263 if let Ok(()) = Self::reconnect(
264 &ctx.base_url,
265 &ctx.current_url,
266 ctx.ws_client.clone(),
267 ctx.listen_key_manager.clone(),
268 )
269 .await
270 {
271 Self::resubscribe_all(
272 &ctx.ws_client,
273 &ctx.subscription_manager,
274 &ctx.request_id,
275 )
276 .await;
277 reconnect_attempt = 0;
278 continue;
279 }
280 reconnect_attempt += 1;
281 continue;
282 }
283 ctx.is_connected
284 .store(false, std::sync::atomic::Ordering::SeqCst);
285 break;
286 }
287
288 let message_opt = {
289 let guard = ctx.ws_client.read().await;
290 if let Some(client) = guard.as_ref() {
291 client.receive().await
292 } else {
293 None
294 }
295 };
296
297 if let Some(value) = message_opt {
298 if let Err(_e) = Self::handle_message(
299 value,
300 ctx.subscription_manager.clone(),
301 ctx.listen_key_manager.clone(),
302 )
303 .await
304 {
305 continue;
306 }
307
308 reconnect_attempt = 0;
309 } else {
310 let config = ctx.reconnect_config.read().await;
311 if config.should_retry(reconnect_attempt) {
312 let delay = config.calculate_delay(reconnect_attempt);
313 drop(config);
314
315 tokio::time::sleep(Duration::from_millis(delay)).await;
316
317 if let Ok(()) = Self::reconnect(
318 &ctx.base_url,
319 &ctx.current_url,
320 ctx.ws_client.clone(),
321 ctx.listen_key_manager.clone(),
322 )
323 .await
324 {
325 Self::resubscribe_all(
326 &ctx.ws_client,
327 &ctx.subscription_manager,
328 &ctx.request_id,
329 )
330 .await;
331 reconnect_attempt = 0;
332 continue;
333 }
334 reconnect_attempt += 1;
335 continue;
336 }
337 ctx.is_connected
338 .store(false, std::sync::atomic::Ordering::SeqCst);
339 break;
340 }
341 }
342 }
343
344 async fn resubscribe_all(
346 ws_client: &Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
347 subscription_manager: &Arc<super::subscriptions::SubscriptionManager>,
348 request_id: &Arc<std::sync::atomic::AtomicU64>,
349 ) {
350 let streams = subscription_manager.get_active_streams().await;
351 if streams.is_empty() {
352 return;
353 }
354
355 let client_opt = ws_client.read().await;
356 if let Some(client) = client_opt.as_ref() {
357 for chunk in streams.chunks(10) {
359 let id = request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
360 #[allow(clippy::disallowed_methods)]
361 let request = serde_json::json!({
362 "method": "SUBSCRIBE",
363 "params": chunk,
364 "id": id
365 });
366
367 if let Err(e) = client
368 .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
369 request.to_string().into(),
370 ))
371 .await
372 {
373 tracing::error!("Failed to resubscribe: {}", e);
374 }
375 }
376 }
377 }
378
379 async fn handle_message(
381 message: Value,
382 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
383 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
384 ) -> Result<()> {
385 let stream_name = Self::extract_stream_name(&message)?;
386
387 let payload = if message.get("stream").is_some() && message.get("data").is_some() {
389 message.get("data").cloned().unwrap_or(message.clone())
390 } else {
391 message.clone()
392 };
393
394 if stream_name == "!userData" {
396 if let Some(event) = payload.get("e").and_then(|e| e.as_str()) {
397 if event == "listenKeyExpired" {
398 if let Some(manager) = listen_key_manager {
399 tracing::info!("Received listenKeyExpired event, regenerating key...");
400 let _ = manager.regenerate().await;
401 }
402 }
403 }
404 }
405
406 let sent = subscription_manager
407 .send_to_stream(&stream_name, payload.clone())
408 .await;
409
410 if sent {
411 return Ok(());
412 }
413
414 let symbol_opt = payload.get("s").and_then(|s| s.as_str());
417
418 if let Some(symbol) = symbol_opt {
419 let mut active_streams = subscription_manager
420 .get_subscriptions_by_symbol(symbol)
421 .await;
422
423 tracing::debug!(
424 "Routing message for symbol {}: stream_name={}, active_subscriptions={}",
425 symbol,
426 stream_name,
427 active_streams.len()
428 );
429
430 if active_streams.is_empty() {
433 let lower_symbol = symbol.to_lowercase();
434 if lower_symbol != symbol {
435 let lower_streams = subscription_manager
436 .get_subscriptions_by_symbol(&lower_symbol)
437 .await;
438 if !lower_streams.is_empty() {
439 active_streams = lower_streams;
440 tracing::debug!(
441 "Found subscriptions for lowercased symbol {}: count={}",
442 lower_symbol,
443 active_streams.len()
444 );
445 }
446 }
447 }
448
449 let mut fallback_sent = false;
450
451 for sub in active_streams {
452 tracing::debug!(
456 "Checking subscription: stream={}, expected_starts_with={}",
457 sub.stream,
458 stream_name
459 );
460
461 if sub.stream.starts_with(&stream_name) {
462 if subscription_manager
463 .send_to_stream(&sub.stream, payload.clone())
464 .await
465 {
466 fallback_sent = true;
467 tracing::debug!("Successfully routed to fallback stream: {}", sub.stream);
468 }
469 }
470 }
471
472 if fallback_sent {
473 return Ok(());
474 }
475 }
476
477 Err(Error::generic("No subscribers for stream"))
478 }
479
480 pub fn extract_stream_name(message: &Value) -> Result<String> {
482 if let Some(stream) = message.get("stream").and_then(|s| s.as_str()) {
483 return Ok(stream.to_string());
484 }
485
486 if let Some(arr) = message.as_array() {
488 if let Some(first) = arr.first() {
489 if let Some(event_type) = first.get("e").and_then(|e| e.as_str()) {
490 match event_type {
491 "24hrTicker" => return Ok("!ticker@arr".to_string()),
492 "24hrMiniTicker" => return Ok("!miniTicker@arr".to_string()),
493 _ => {}
494 }
495 }
496 }
497 }
498
499 if let Some(event_type) = message.get("e").and_then(|e| e.as_str()) {
500 match event_type {
501 "outboundAccountPosition"
502 | "balanceUpdate"
503 | "executionReport"
504 | "listStatus"
505 | "ACCOUNT_UPDATE"
506 | "ORDER_TRADE_UPDATE"
507 | "listenKeyExpired" => {
508 return Ok("!userData".to_string());
509 }
510 _ => {}
511 }
512
513 if let Some(symbol) = message.get("s").and_then(|s| s.as_str()) {
514 let stream = match event_type {
515 "24hrTicker" => format!("{}@ticker", symbol.to_lowercase()),
516 "24hrMiniTicker" => format!("{}@miniTicker", symbol.to_lowercase()),
517 "depthUpdate" => format!("{}@depth", symbol.to_lowercase()),
518 "aggTrade" => format!("{}@aggTrade", symbol.to_lowercase()),
519 "trade" => format!("{}@trade", symbol.to_lowercase()),
520 "kline" => {
521 if let Some(kline) = message.get("k") {
522 if let Some(interval) = kline.get("i").and_then(|i| i.as_str()) {
523 format!("{}@kline_{}", symbol.to_lowercase(), interval)
524 } else {
525 return Err(Error::generic("Missing kline interval"));
526 }
527 } else {
528 return Err(Error::generic("Missing kline data"));
529 }
530 }
531 "markPriceUpdate" => format!("{}@markPrice", symbol.to_lowercase()),
532 "bookTicker" => format!("{}@bookTicker", symbol.to_lowercase()),
533 _ => {
534 return Err(Error::generic(format!(
535 "Unknown event type: {}",
536 event_type
537 )));
538 }
539 };
540 return Ok(stream);
541 }
542 }
543
544 if message.get("result").is_some() || message.get("error").is_some() {
545 return Err(Error::generic("Subscription response, skip routing"));
546 }
547
548 Err(Error::generic("Cannot extract stream name from message"))
549 }
550
551 async fn reconnect(
553 base_url: &str,
554 current_url: &str,
555 ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
556 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
557 ) -> Result<()> {
558 {
559 let mut client_opt = ws_client.write().await;
560 if let Some(client) = client_opt.take() {
561 let _ = client.disconnect().await;
562 }
563 }
564
565 let mut final_url = current_url.to_string();
567
568 if let Some(manager) = listen_key_manager {
569 if current_url != base_url {
572 if let Ok(key) = manager.get_or_create().await {
573 let base = if let Some(stripped) = base_url.strip_suffix('/') {
574 stripped
575 } else {
576 base_url
577 };
578 final_url = format!("{}/{}", base, key);
579 }
580 }
581 }
582
583 let config = ccxt_core::ws_client::WsConfig {
584 url: final_url,
585 ..Default::default()
586 };
587 let new_client = ccxt_core::ws_client::WsClient::new(config);
588
589 new_client.connect().await?;
590
591 *ws_client.write().await = Some(new_client);
592
593 Ok(())
594 }
595}
596
597impl Drop for MessageRouter {
598 fn drop(&mut self) {
599 }
602}
603
604pub async fn handle_orderbook_delta(
606 symbol: &str,
607 delta_message: &Value,
608 is_futures: bool,
609 orderbooks: &Mutex<HashMap<String, OrderBook>>,
610) -> Result<()> {
611 let first_update_id = delta_message["U"]
612 .as_i64()
613 .ok_or_else(|| Error::invalid_request("Missing first update ID in delta message"))?;
614
615 let final_update_id = delta_message["u"]
616 .as_i64()
617 .ok_or_else(|| Error::invalid_request("Missing final update ID in delta message"))?;
618
619 let prev_final_update_id = if is_futures {
620 delta_message["pu"].as_i64()
621 } else {
622 None
623 };
624
625 let timestamp = delta_message["E"]
626 .as_i64()
627 .unwrap_or_else(|| chrono::Utc::now().timestamp_millis());
628
629 let mut bids = Vec::new();
630 if let Some(bids_arr) = delta_message["b"].as_array() {
631 for bid in bids_arr {
632 if let (Some(price_str), Some(amount_str)) = (bid[0].as_str(), bid[1].as_str()) {
633 if let (Ok(price), Ok(amount)) =
634 (price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
635 {
636 bids.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
637 }
638 }
639 }
640 }
641
642 let mut asks = Vec::new();
643 if let Some(asks_arr) = delta_message["a"].as_array() {
644 for ask in asks_arr {
645 if let (Some(price_str), Some(amount_str)) = (ask[0].as_str(), ask[1].as_str()) {
646 if let (Ok(price), Ok(amount)) =
647 (price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
648 {
649 asks.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
650 }
651 }
652 }
653 }
654
655 let delta = OrderBookDelta {
656 symbol: symbol.to_string(),
657 first_update_id,
658 final_update_id,
659 prev_final_update_id,
660 timestamp,
661 bids,
662 asks,
663 };
664
665 let mut orderbooks_map = orderbooks.lock().await;
666 let orderbook = orderbooks_map
667 .entry(symbol.to_string())
668 .or_insert_with(|| OrderBook::new(symbol.to_string(), timestamp));
669
670 if !orderbook.is_synced {
671 orderbook.buffer_delta(delta);
672 return Ok(());
673 }
674
675 if let Err(e) = orderbook.apply_delta(&delta, is_futures) {
676 if orderbook.needs_resync {
677 tracing::warn!("Orderbook {} needs resync due to: {}", symbol, e);
678 orderbook.buffer_delta(delta);
679 return Err(Error::invalid_request(format!("RESYNC_NEEDED: {}", e)));
680 }
681 return Err(Error::invalid_request(e));
682 }
683
684 Ok(())
685}
686
687pub async fn fetch_orderbook_snapshot(
689 exchange: &Binance,
690 symbol: &str,
691 limit: Option<i64>,
692 is_futures: bool,
693 orderbooks: &Mutex<HashMap<String, OrderBook>>,
694) -> Result<OrderBook> {
695 let mut snapshot = exchange
696 .fetch_order_book(symbol, limit.map(|l| l as u32))
697 .await?;
698
699 snapshot.is_synced = true;
700
701 let mut orderbooks_map = orderbooks.lock().await;
702 if let Some(cached_ob) = orderbooks_map.get_mut(symbol) {
703 snapshot
704 .buffered_deltas
705 .clone_from(&cached_ob.buffered_deltas);
706
707 if let Ok(processed) = snapshot.process_buffered_deltas(is_futures) {
708 tracing::debug!("Processed {} buffered deltas for {}", processed, symbol);
709 }
710 }
711
712 orderbooks_map.insert(symbol.to_string(), snapshot.clone());
713
714 Ok(snapshot)
715}
716
717#[cfg(test)]
718mod tests {
719 #![allow(clippy::disallowed_methods)]
720 use super::*;
721 use serde_json::json;
722 use std::sync::Arc;
723
724 #[test]
725 fn test_extract_stream_name_combined() {
726 let message = json!({
727 "stream": "btcusdt@ticker",
728 "data": {
729 "e": "24hrTicker",
730 "s": "BTCUSDT"
731 }
732 });
733
734 let stream = MessageRouter::extract_stream_name(&message).unwrap();
735 assert_eq!(stream, "btcusdt@ticker");
736 }
737
738 #[test]
739 fn test_extract_stream_name_raw() {
740 let message = json!({
741 "e": "24hrTicker",
742 "s": "BTCUSDT"
743 });
744
745 let stream = MessageRouter::extract_stream_name(&message).unwrap();
746 assert_eq!(stream, "btcusdt@ticker");
747 }
748
749 #[tokio::test]
750 async fn test_handle_message_unwrapping() {
751 let manager = Arc::new(crate::binance::ws::subscriptions::SubscriptionManager::new());
752 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
753
754 manager
755 .add_subscription(
756 "btcusdt@ticker".to_string(),
757 "BTCUSDT".to_string(),
758 crate::binance::ws::subscriptions::SubscriptionType::Ticker,
759 tx,
760 )
761 .await
762 .unwrap();
763
764 let message = json!({
765 "stream": "btcusdt@ticker",
766 "data": {
767 "e": "24hrTicker",
768 "s": "BTCUSDT",
769 "c": "50000.00"
770 }
771 });
772
773 MessageRouter::handle_message(message, manager, None)
774 .await
775 .unwrap();
776
777 let received = rx.recv().await.unwrap();
778 assert!(received.get("stream").is_none());
779 assert_eq!(received["e"], "24hrTicker");
780 assert_eq!(received["c"], "50000.00");
781 }
782
783 #[tokio::test]
784 async fn test_handle_message_mark_price_fallback() {
785 let manager = Arc::new(crate::binance::ws::subscriptions::SubscriptionManager::new());
786 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
787
788 manager
790 .add_subscription(
791 "btcusdt@markPrice@1s".to_string(),
792 "btcusdt".to_string(),
793 crate::binance::ws::subscriptions::SubscriptionType::MarkPrice,
794 tx,
795 )
796 .await
797 .unwrap();
798
799 let message = json!({
802 "e": "markPriceUpdate",
803 "s": "BTCUSDT",
804 "p": "50000.00",
805 "E": 123456789
806 });
807
808 MessageRouter::handle_message(message, manager, None)
810 .await
811 .unwrap();
812
813 let received = rx.recv().await.unwrap();
814 assert_eq!(received["e"], "markPriceUpdate");
815 assert_eq!(received["p"], "50000.00");
816 }
817}