1#![allow(dead_code)]
4
5use crate::binance::Binance;
6use ccxt_core::error::{Error, Result};
7use ccxt_core::types::OrderBook;
8use ccxt_core::types::financial::{Amount, Price};
9use ccxt_core::types::orderbook::{OrderBookDelta, OrderBookEntry};
10use rust_decimal::Decimal;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::Mutex;
16
17struct MessageLoopContext {
18 ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
19 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
20 is_connected: Arc<std::sync::atomic::AtomicBool>,
21 reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
22 request_id: Arc<std::sync::atomic::AtomicU64>,
23 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
24 base_url: String,
25 current_url: String,
26}
27
28pub struct MessageRouter {
30 ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
32
33 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
35
36 router_task: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
38
39 is_connected: Arc<std::sync::atomic::AtomicBool>,
41
42 connection_lock: Arc<Mutex<()>>,
44
45 reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
47
48 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
50
51 ws_url: String,
53
54 request_id: Arc<std::sync::atomic::AtomicU64>,
56}
57
58impl MessageRouter {
59 pub fn new(
61 ws_url: String,
62 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
63 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
64 ) -> Self {
65 Self {
66 ws_client: Arc::new(tokio::sync::RwLock::new(None)),
67 subscription_manager,
68 router_task: Arc::new(Mutex::new(None)),
69 is_connected: Arc::new(std::sync::atomic::AtomicBool::new(false)),
70 connection_lock: Arc::new(Mutex::new(())),
71 reconnect_config: Arc::new(tokio::sync::RwLock::new(
72 super::subscriptions::ReconnectConfig::default(),
73 )),
74 listen_key_manager,
75 ws_url,
76 request_id: Arc::new(std::sync::atomic::AtomicU64::new(1)),
77 }
78 }
79
80 pub async fn start(&self, url_override: Option<String>) -> Result<()> {
82 let _lock = self.connection_lock.lock().await;
83
84 if self.is_connected() {
85 if url_override.is_some() {
86 self.stop().await?;
87 } else {
88 return Ok(());
89 }
90 }
91
92 let url = url_override.unwrap_or_else(|| self.ws_url.clone());
93 let config = ccxt_core::ws_client::WsConfig {
94 url: url.clone(),
95 ..Default::default()
96 };
97 let client = ccxt_core::ws_client::WsClient::new(config);
98 client.connect().await?;
99
100 *self.ws_client.write().await = Some(client);
101
102 self.is_connected
103 .store(true, std::sync::atomic::Ordering::SeqCst);
104
105 let ws_client = self.ws_client.clone();
106 let subscription_manager = self.subscription_manager.clone();
107 let is_connected = self.is_connected.clone();
108 let reconnect_config = self.reconnect_config.clone();
109 let request_id = self.request_id.clone();
110 let listen_key_manager = self.listen_key_manager.clone();
111
112 let ws_url = self.ws_url.clone();
113 let current_url = url;
114
115 let ctx = MessageLoopContext {
116 ws_client,
117 subscription_manager,
118 is_connected,
119 reconnect_config,
120 request_id,
121 listen_key_manager,
122 base_url: ws_url,
123 current_url,
124 };
125
126 let handle = tokio::spawn(async move {
127 Self::message_loop(ctx).await;
128 });
129
130 *self.router_task.lock().await = Some(handle);
131
132 Ok(())
133 }
134
135 pub async fn stop(&self) -> Result<()> {
137 self.is_connected
138 .store(false, std::sync::atomic::Ordering::SeqCst);
139
140 let mut task_opt = self.router_task.lock().await;
141 if let Some(handle) = task_opt.take() {
142 handle.abort();
143 }
144
145 let mut client_opt = self.ws_client.write().await;
146 if let Some(client) = client_opt.take() {
147 let _ = client.disconnect().await;
148 }
149
150 Ok(())
151 }
152
153 pub async fn restart(&self) -> Result<()> {
155 self.stop().await?;
156 tokio::time::sleep(Duration::from_millis(100)).await;
157 self.start(None).await
158 }
159
160 pub fn get_url(&self) -> String {
162 self.ws_url.clone()
163 }
164
165 pub fn is_connected(&self) -> bool {
167 self.is_connected.load(std::sync::atomic::Ordering::SeqCst)
168 }
169
170 pub fn latency(&self) -> Option<i64> {
175 if let Ok(guard) = self.ws_client.try_read() {
178 if let Some(ref client) = *guard {
179 return client.latency();
180 }
181 }
182 None
183 }
184
185 pub fn reconnect_count(&self) -> u32 {
187 if let Ok(guard) = self.ws_client.try_read() {
188 if let Some(ref client) = *guard {
189 return client.reconnect_count();
190 }
191 }
192 0
193 }
194
195 pub async fn set_reconnect_config(&self, config: super::subscriptions::ReconnectConfig) {
197 *self.reconnect_config.write().await = config;
198 }
199
200 pub async fn get_reconnect_config(&self) -> super::subscriptions::ReconnectConfig {
202 self.reconnect_config.read().await.clone()
203 }
204
205 pub async fn subscribe(&self, streams: Vec<String>) -> Result<()> {
207 if streams.is_empty() {
208 return Ok(());
209 }
210
211 let client_opt = self.ws_client.read().await;
212 let client = client_opt
213 .as_ref()
214 .ok_or_else(|| Error::network("WebSocket not connected"))?;
215
216 let id = self
217 .request_id
218 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
219
220 #[allow(clippy::disallowed_methods)]
221 let request = serde_json::json!({
222 "method": "SUBSCRIBE",
223 "params": streams,
224 "id": id
225 });
226
227 client
228 .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
229 request.to_string().into(),
230 ))
231 .await?;
232
233 Ok(())
234 }
235
236 pub async fn unsubscribe(&self, streams: Vec<String>) -> Result<()> {
238 if streams.is_empty() {
239 return Ok(());
240 }
241
242 let client_opt = self.ws_client.read().await;
243 let client = client_opt
244 .as_ref()
245 .ok_or_else(|| Error::network("WebSocket not connected"))?;
246
247 let id = self
248 .request_id
249 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
250
251 #[allow(clippy::disallowed_methods)]
252 let request = serde_json::json!({
253 "method": "UNSUBSCRIBE",
254 "params": streams,
255 "id": id
256 });
257
258 client
259 .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
260 request.to_string().into(),
261 ))
262 .await?;
263
264 Ok(())
265 }
266
267 async fn message_loop(ctx: MessageLoopContext) {
269 let mut reconnect_attempt = 0;
270
271 Self::resubscribe_all(&ctx.ws_client, &ctx.subscription_manager, &ctx.request_id).await;
272
273 loop {
274 if !ctx.is_connected.load(std::sync::atomic::Ordering::SeqCst) {
275 break;
276 }
277
278 let has_client = ctx.ws_client.read().await.is_some();
279
280 if !has_client {
281 let config = ctx.reconnect_config.read().await;
282 if config.should_retry(reconnect_attempt) {
283 let delay = config.calculate_delay(reconnect_attempt);
284 drop(config);
285
286 tokio::time::sleep(Duration::from_millis(delay)).await;
287
288 if let Ok(()) = Self::reconnect(
289 &ctx.base_url,
290 &ctx.current_url,
291 ctx.ws_client.clone(),
292 ctx.listen_key_manager.clone(),
293 )
294 .await
295 {
296 Self::resubscribe_all(
297 &ctx.ws_client,
298 &ctx.subscription_manager,
299 &ctx.request_id,
300 )
301 .await;
302 reconnect_attempt = 0;
303 continue;
304 }
305 reconnect_attempt += 1;
306 continue;
307 }
308 ctx.is_connected
309 .store(false, std::sync::atomic::Ordering::SeqCst);
310 break;
311 }
312
313 let message_opt = {
314 let guard = ctx.ws_client.read().await;
315 if let Some(client) = guard.as_ref() {
316 client.receive().await
317 } else {
318 None
319 }
320 };
321
322 if let Some(value) = message_opt {
323 if let Err(_e) = Self::handle_message(
324 value,
325 ctx.subscription_manager.clone(),
326 ctx.listen_key_manager.clone(),
327 )
328 .await
329 {
330 continue;
331 }
332
333 reconnect_attempt = 0;
334 } else {
335 let config = ctx.reconnect_config.read().await;
336 if config.should_retry(reconnect_attempt) {
337 let delay = config.calculate_delay(reconnect_attempt);
338 drop(config);
339
340 tokio::time::sleep(Duration::from_millis(delay)).await;
341
342 if let Ok(()) = Self::reconnect(
343 &ctx.base_url,
344 &ctx.current_url,
345 ctx.ws_client.clone(),
346 ctx.listen_key_manager.clone(),
347 )
348 .await
349 {
350 Self::resubscribe_all(
351 &ctx.ws_client,
352 &ctx.subscription_manager,
353 &ctx.request_id,
354 )
355 .await;
356 reconnect_attempt = 0;
357 continue;
358 }
359 reconnect_attempt += 1;
360 continue;
361 }
362 ctx.is_connected
363 .store(false, std::sync::atomic::Ordering::SeqCst);
364 break;
365 }
366 }
367 }
368
369 async fn resubscribe_all(
371 ws_client: &Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
372 subscription_manager: &Arc<super::subscriptions::SubscriptionManager>,
373 request_id: &Arc<std::sync::atomic::AtomicU64>,
374 ) {
375 let streams = subscription_manager.get_active_streams().await;
376 if streams.is_empty() {
377 return;
378 }
379
380 let client_opt = ws_client.read().await;
381 if let Some(client) = client_opt.as_ref() {
382 for chunk in streams.chunks(10) {
384 let id = request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
385 #[allow(clippy::disallowed_methods)]
386 let request = serde_json::json!({
387 "method": "SUBSCRIBE",
388 "params": chunk,
389 "id": id
390 });
391
392 if let Err(e) = client
393 .send(tokio_tungstenite::tungstenite::protocol::Message::Text(
394 request.to_string().into(),
395 ))
396 .await
397 {
398 tracing::error!("Failed to resubscribe: {}", e);
399 }
400 }
401 }
402 }
403
404 async fn handle_message(
406 message: Value,
407 subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
408 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
409 ) -> Result<()> {
410 let stream_name = Self::extract_stream_name(&message)?;
411
412 let payload = if message.get("stream").is_some() && message.get("data").is_some() {
414 message.get("data").cloned().unwrap_or(message.clone())
415 } else {
416 message.clone()
417 };
418
419 if stream_name == "!userData" {
421 if let Some(event) = payload.get("e").and_then(|e| e.as_str()) {
422 if event == "listenKeyExpired" {
423 if let Some(manager) = listen_key_manager {
424 tracing::warn!(
425 "Listen key expired, regenerating and triggering reconnect..."
426 );
427 let _ = manager.regenerate().await;
428 return Err(Error::network("Listen key expired, reconnecting"));
431 }
432 }
433 }
434 }
435
436 let sent = subscription_manager
437 .send_to_stream(&stream_name, payload.clone())
438 .await;
439
440 if sent {
441 return Ok(());
442 }
443
444 let symbol_opt = payload.get("s").and_then(|s| s.as_str());
447
448 if let Some(symbol) = symbol_opt {
449 let normalized_symbol = symbol.to_lowercase();
451 let active_streams = subscription_manager
452 .get_subscriptions_by_symbol(&normalized_symbol)
453 .await;
454
455 tracing::debug!(
456 "Routing message for symbol {} (normalized: {}): stream_name={}, active_subscriptions={}",
457 symbol,
458 normalized_symbol,
459 stream_name,
460 active_streams.len()
461 );
462
463 let mut fallback_sent = false;
464
465 for sub in active_streams {
466 tracing::debug!(
470 "Checking subscription: stream={}, expected_starts_with={}",
471 sub.stream,
472 stream_name
473 );
474
475 if sub.stream.starts_with(&stream_name) {
476 if subscription_manager
477 .send_to_stream(&sub.stream, payload.clone())
478 .await
479 {
480 fallback_sent = true;
481 tracing::debug!("Successfully routed to fallback stream: {}", sub.stream);
482 }
483 }
484 }
485
486 if fallback_sent {
487 return Ok(());
488 }
489 }
490
491 Err(Error::generic("No subscribers for stream"))
492 }
493
494 pub fn extract_stream_name(message: &Value) -> Result<String> {
496 if let Some(stream) = message.get("stream").and_then(|s| s.as_str()) {
497 return Ok(stream.to_string());
498 }
499
500 if let Some(arr) = message.as_array() {
502 if let Some(first) = arr.first() {
503 if let Some(event_type) = first.get("e").and_then(|e| e.as_str()) {
504 match event_type {
505 "24hrTicker" => return Ok("!ticker@arr".to_string()),
506 "24hrMiniTicker" => return Ok("!miniTicker@arr".to_string()),
507 _ => {}
508 }
509 }
510 }
511 }
512
513 if let Some(event_type) = message.get("e").and_then(|e| e.as_str()) {
514 match event_type {
515 "outboundAccountPosition"
516 | "balanceUpdate"
517 | "executionReport"
518 | "listStatus"
519 | "ACCOUNT_UPDATE"
520 | "ORDER_TRADE_UPDATE"
521 | "listenKeyExpired" => {
522 return Ok("!userData".to_string());
523 }
524 _ => {}
525 }
526
527 if let Some(symbol) = message.get("s").and_then(|s| s.as_str()) {
528 let stream = match event_type {
529 "24hrTicker" => format!("{}@ticker", symbol.to_lowercase()),
530 "24hrMiniTicker" => format!("{}@miniTicker", symbol.to_lowercase()),
531 "depthUpdate" => format!("{}@depth", symbol.to_lowercase()),
532 "aggTrade" => format!("{}@aggTrade", symbol.to_lowercase()),
533 "trade" => format!("{}@trade", symbol.to_lowercase()),
534 "kline" => {
535 if let Some(kline) = message.get("k") {
536 if let Some(interval) = kline.get("i").and_then(|i| i.as_str()) {
537 format!("{}@kline_{}", symbol.to_lowercase(), interval)
538 } else {
539 return Err(Error::generic("Missing kline interval"));
540 }
541 } else {
542 return Err(Error::generic("Missing kline data"));
543 }
544 }
545 "markPriceUpdate" => format!("{}@markPrice", symbol.to_lowercase()),
546 "bookTicker" => format!("{}@bookTicker", symbol.to_lowercase()),
547 _ => {
548 return Err(Error::generic(format!(
549 "Unknown event type: {}",
550 event_type
551 )));
552 }
553 };
554 return Ok(stream);
555 }
556 }
557
558 if message.get("result").is_some() || message.get("error").is_some() {
559 return Err(Error::generic("Subscription response, skip routing"));
560 }
561
562 Err(Error::generic("Cannot extract stream name from message"))
563 }
564
565 async fn reconnect(
567 base_url: &str,
568 current_url: &str,
569 ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
570 listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
571 ) -> Result<()> {
572 {
573 let mut client_opt = ws_client.write().await;
574 if let Some(client) = client_opt.take() {
575 let _ = client.disconnect().await;
576 }
577 }
578
579 let mut final_url = current_url.to_string();
581
582 if let Some(manager) = listen_key_manager {
583 if current_url != base_url {
586 if let Ok(key) = manager.get_or_create().await {
587 let base = if let Some(stripped) = base_url.strip_suffix('/') {
588 stripped
589 } else {
590 base_url
591 };
592 final_url = format!("{}/{}", base, key);
593 }
594 }
595 }
596
597 let config = ccxt_core::ws_client::WsConfig {
598 url: final_url,
599 ..Default::default()
600 };
601 let new_client = ccxt_core::ws_client::WsClient::new(config);
602
603 new_client.connect().await?;
604
605 *ws_client.write().await = Some(new_client);
606
607 Ok(())
608 }
609}
610
611impl Drop for MessageRouter {
612 fn drop(&mut self) {
613 }
616}
617
618pub async fn handle_orderbook_delta(
620 symbol: &str,
621 delta_message: &Value,
622 is_futures: bool,
623 orderbooks: &Mutex<HashMap<String, OrderBook>>,
624) -> Result<()> {
625 let first_update_id = delta_message["U"]
626 .as_i64()
627 .ok_or_else(|| Error::invalid_request("Missing first update ID in delta message"))?;
628
629 let final_update_id = delta_message["u"]
630 .as_i64()
631 .ok_or_else(|| Error::invalid_request("Missing final update ID in delta message"))?;
632
633 let prev_final_update_id = if is_futures {
634 delta_message["pu"].as_i64()
635 } else {
636 None
637 };
638
639 let timestamp = delta_message["E"]
640 .as_i64()
641 .unwrap_or_else(|| chrono::Utc::now().timestamp_millis());
642
643 let mut bids = Vec::new();
644 if let Some(bids_arr) = delta_message["b"].as_array() {
645 for bid in bids_arr {
646 if let (Some(price_str), Some(amount_str)) = (bid[0].as_str(), bid[1].as_str()) {
647 if let (Ok(price), Ok(amount)) =
648 (price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
649 {
650 bids.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
651 }
652 }
653 }
654 }
655
656 let mut asks = Vec::new();
657 if let Some(asks_arr) = delta_message["a"].as_array() {
658 for ask in asks_arr {
659 if let (Some(price_str), Some(amount_str)) = (ask[0].as_str(), ask[1].as_str()) {
660 if let (Ok(price), Ok(amount)) =
661 (price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
662 {
663 asks.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
664 }
665 }
666 }
667 }
668
669 let delta = OrderBookDelta {
670 symbol: symbol.to_string(),
671 first_update_id,
672 final_update_id,
673 prev_final_update_id,
674 timestamp,
675 bids,
676 asks,
677 };
678
679 let mut orderbooks_map = orderbooks.lock().await;
680 let orderbook = orderbooks_map
681 .entry(symbol.to_string())
682 .or_insert_with(|| OrderBook::new(symbol.to_string(), timestamp));
683
684 if !orderbook.is_synced {
685 orderbook.buffer_delta(delta);
686 return Ok(());
687 }
688
689 if let Err(e) = orderbook.apply_delta(&delta, is_futures) {
690 if orderbook.needs_resync {
691 tracing::warn!("Orderbook {} needs resync due to: {}", symbol, e);
692 orderbook.buffer_delta(delta);
693 return Err(Error::invalid_request(format!("RESYNC_NEEDED: {}", e)));
694 }
695 return Err(Error::invalid_request(e));
696 }
697
698 Ok(())
699}
700
701pub async fn fetch_orderbook_snapshot(
703 exchange: &Binance,
704 symbol: &str,
705 limit: Option<i64>,
706 is_futures: bool,
707 orderbooks: &Mutex<HashMap<String, OrderBook>>,
708) -> Result<OrderBook> {
709 let snapshot = exchange
710 .fetch_order_book(symbol, limit.map(|l| l as u32))
711 .await?;
712
713 let mut orderbooks_map = orderbooks.lock().await;
714 let cached_ob = orderbooks_map
715 .entry(symbol.to_string())
716 .or_insert_with(|| OrderBook::new(symbol.to_string(), snapshot.timestamp));
717
718 cached_ob.reset_from_snapshot(
720 snapshot.bids,
721 snapshot.asks,
722 snapshot.timestamp,
723 snapshot.nonce,
724 );
725
726 if let Ok(processed) = cached_ob.process_buffered_deltas(is_futures) {
728 tracing::debug!("Processed {} buffered deltas for {}", processed, symbol);
729 }
730
731 Ok(cached_ob.clone())
732}
733
734#[cfg(test)]
735mod tests {
736 #![allow(clippy::disallowed_methods)]
737 use super::*;
738 use serde_json::json;
739 use std::sync::Arc;
740
741 #[test]
742 fn test_extract_stream_name_combined() {
743 let message = json!({
744 "stream": "btcusdt@ticker",
745 "data": {
746 "e": "24hrTicker",
747 "s": "BTCUSDT"
748 }
749 });
750
751 let stream = MessageRouter::extract_stream_name(&message).unwrap();
752 assert_eq!(stream, "btcusdt@ticker");
753 }
754
755 #[test]
756 fn test_extract_stream_name_raw() {
757 let message = json!({
758 "e": "24hrTicker",
759 "s": "BTCUSDT"
760 });
761
762 let stream = MessageRouter::extract_stream_name(&message).unwrap();
763 assert_eq!(stream, "btcusdt@ticker");
764 }
765
766 #[tokio::test]
767 async fn test_handle_message_unwrapping() {
768 let manager = Arc::new(crate::binance::ws::subscriptions::SubscriptionManager::new());
769 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
770
771 manager
772 .add_subscription(
773 "btcusdt@ticker".to_string(),
774 "BTCUSDT".to_string(),
775 crate::binance::ws::subscriptions::SubscriptionType::Ticker,
776 tx,
777 )
778 .await
779 .unwrap();
780
781 let message = json!({
782 "stream": "btcusdt@ticker",
783 "data": {
784 "e": "24hrTicker",
785 "s": "BTCUSDT",
786 "c": "50000.00"
787 }
788 });
789
790 MessageRouter::handle_message(message, manager, None)
791 .await
792 .unwrap();
793
794 let received = rx.recv().await.unwrap();
795 assert!(received.get("stream").is_none());
796 assert_eq!(received["e"], "24hrTicker");
797 assert_eq!(received["c"], "50000.00");
798 }
799
800 #[tokio::test]
801 async fn test_handle_message_mark_price_fallback() {
802 let manager = Arc::new(crate::binance::ws::subscriptions::SubscriptionManager::new());
803 let (tx, mut rx) = tokio::sync::mpsc::channel(100);
804
805 manager
807 .add_subscription(
808 "btcusdt@markPrice@1s".to_string(),
809 "btcusdt".to_string(),
810 crate::binance::ws::subscriptions::SubscriptionType::MarkPrice,
811 tx,
812 )
813 .await
814 .unwrap();
815
816 let message = json!({
819 "e": "markPriceUpdate",
820 "s": "BTCUSDT",
821 "p": "50000.00",
822 "E": 123456789
823 });
824
825 MessageRouter::handle_message(message, manager, None)
827 .await
828 .unwrap();
829
830 let received = rx.recv().await.unwrap();
831 assert_eq!(received["e"], "markPriceUpdate");
832 assert_eq!(received["p"], "50000.00");
833 }
834}