1use futures::{Future, SinkExt, Stream, StreamExt};
40use std::collections::BTreeMap;
41use std::pin::Pin;
42use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
44use std::task::{Context, Poll};
45use std::time::{Duration, Instant};
46use tokio::net::TcpStream;
47use tokio::sync::{Mutex, RwLock, mpsc};
48use tokio::time::{interval, sleep, timeout};
49use tokio_tungstenite::{
50 MaybeTlsStream, WebSocketStream as TungsteniteStream, connect_async,
51 tungstenite::{Bytes, Message},
52};
53
54use crate::config::Config;
55use crate::models::OrderBook;
56use crate::models::websocket::{DepthEvent, WebSocketEvent};
57use crate::types::KlineInterval;
58use crate::{Error, Result};
59
60const MAX_RECONNECTS: u32 = 5;
64
65const MAX_RECONNECT_DELAY_SECS: u64 = 60;
67
68const BASE_RECONNECT_DELAY_MS: u64 = 100;
70
71const WS_TIMEOUT_SECS: u64 = 30;
73
74const HEALTH_CHECK_INTERVAL_SECS: u64 = 30;
76
77const USER_STREAM_KEEPALIVE_SECS: u64 = 30 * 60; #[derive(Clone)]
85pub struct WebSocketClient {
86 config: Config,
87}
88
89impl WebSocketClient {
90 pub(crate) fn new(config: Config) -> Self {
92 Self { config }
93 }
94
95 pub fn endpoint(&self) -> &str {
97 &self.config.ws_endpoint
98 }
99
100 pub async fn connect(&self, stream: &str) -> Result<WebSocketConnection> {
114 let url = format!("{}/ws/{}", self.config.ws_endpoint, stream);
115 self.connect_url(&url).await
116 }
117
118 pub async fn connect_combined(&self, streams: &[String]) -> Result<WebSocketConnection> {
135 let streams_param = streams.join("/");
136 let url = format!(
137 "{}/stream?streams={}",
138 self.config.ws_endpoint, streams_param
139 );
140 self.connect_url(&url).await
141 }
142
143 pub async fn connect_user_stream(&self, listen_key: &str) -> Result<WebSocketConnection> {
156 let url = format!("{}/ws/{}", self.config.ws_endpoint, listen_key);
157 self.connect_url(&url).await
158 }
159
160 pub async fn connect_with_reconnect(&self, stream: &str) -> Result<ReconnectingWebSocket> {
182 let url = format!("{}/ws/{}", self.config.ws_endpoint, stream);
183 ReconnectingWebSocket::new(url, ReconnectConfig::default()).await
184 }
185
186 pub async fn connect_combined_with_reconnect(
188 &self,
189 streams: &[String],
190 ) -> Result<ReconnectingWebSocket> {
191 let streams_param = streams.join("/");
192 let url = format!(
193 "{}/stream?streams={}",
194 self.config.ws_endpoint, streams_param
195 );
196 ReconnectingWebSocket::new(url, ReconnectConfig::default()).await
197 }
198
199 async fn connect_url(&self, url: &str) -> Result<WebSocketConnection> {
200 let (ws_stream, _) = connect_async(url).await.map_err(Error::WebSocket)?;
201 Ok(WebSocketConnection::new(ws_stream))
202 }
203
204 pub fn agg_trade_stream(&self, symbol: &str) -> String {
210 format!("{}@aggTrade", symbol.to_lowercase())
211 }
212
213 pub fn trade_stream(&self, symbol: &str) -> String {
217 format!("{}@trade", symbol.to_lowercase())
218 }
219
220 pub fn kline_stream(&self, symbol: &str, interval: KlineInterval) -> String {
224 format!("{}@kline_{}", symbol.to_lowercase(), interval)
225 }
226
227 pub fn mini_ticker_stream(&self, symbol: &str) -> String {
231 format!("{}@miniTicker", symbol.to_lowercase())
232 }
233
234 pub fn all_mini_ticker_stream(&self) -> String {
238 "!miniTicker@arr".to_string()
239 }
240
241 pub fn ticker_stream(&self, symbol: &str) -> String {
245 format!("{}@ticker", symbol.to_lowercase())
246 }
247
248 pub fn all_ticker_stream(&self) -> String {
252 "!ticker@arr".to_string()
253 }
254
255 pub fn book_ticker_stream(&self, symbol: &str) -> String {
259 format!("{}@bookTicker", symbol.to_lowercase())
260 }
261
262 pub fn all_book_ticker_stream(&self) -> String {
266 "!bookTicker".to_string()
267 }
268
269 pub fn partial_depth_stream(&self, symbol: &str, levels: u8, fast: bool) -> String {
279 let base = format!("{}@depth{}", symbol.to_lowercase(), levels);
280 if fast {
281 format!("{}@100ms", base)
282 } else {
283 base
284 }
285 }
286
287 pub fn diff_depth_stream(&self, symbol: &str, fast: bool) -> String {
296 let base = format!("{}@depth", symbol.to_lowercase());
297 if fast {
298 format!("{}@100ms", base)
299 } else {
300 base
301 }
302 }
303}
304
305pub struct WebSocketConnection {
311 inner: TungsteniteStream<MaybeTlsStream<TcpStream>>,
312 last_ping: Instant,
313}
314
315impl WebSocketConnection {
316 fn new(stream: TungsteniteStream<MaybeTlsStream<TcpStream>>) -> Self {
317 Self {
318 inner: stream,
319 last_ping: Instant::now(),
320 }
321 }
322
323 pub async fn next(&mut self) -> Option<Result<WebSocketEvent>> {
327 loop {
328 match self.inner.next().await? {
329 Ok(Message::Text(text)) => {
330 if let Ok(combined) = serde_json::from_str::<CombinedStreamMessage>(&text) {
332 return Some(Ok(combined.data));
333 }
334 return Some(serde_json::from_str(&text).map_err(Error::Serialization));
336 }
337 Ok(Message::Binary(data)) => {
338 if let Ok(combined) = serde_json::from_slice::<CombinedStreamMessage>(&data) {
339 return Some(Ok(combined.data));
340 }
341 return Some(serde_json::from_slice(&data).map_err(Error::Serialization));
342 }
343 Ok(Message::Ping(data)) => {
344 self.last_ping = Instant::now();
345 if let Err(e) = self.inner.send(Message::Pong(data)).await {
347 return Some(Err(Error::WebSocket(e)));
348 }
349 }
350 Ok(Message::Pong(_)) => {
351 continue;
353 }
354 Ok(Message::Close(_)) => {
355 return None;
356 }
357 Ok(Message::Frame(_)) => {
358 continue;
360 }
361 Err(e) => {
362 return Some(Err(Error::WebSocket(e)));
363 }
364 }
365 }
366 }
367
368 pub(crate) async fn next_raw(&mut self) -> Option<Result<serde_json::Value>> {
370 loop {
371 match self.inner.next().await? {
372 Ok(Message::Text(text)) => {
373 return Some(serde_json::from_str(&text).map_err(Error::Serialization));
374 }
375 Ok(Message::Binary(data)) => {
376 return Some(serde_json::from_slice(&data).map_err(Error::Serialization));
377 }
378 Ok(Message::Ping(data)) => {
379 self.last_ping = Instant::now();
380 if let Err(e) = self.inner.send(Message::Pong(data)).await {
381 return Some(Err(Error::WebSocket(e)));
382 }
383 }
384 Ok(Message::Pong(_)) | Ok(Message::Frame(_)) => continue,
385 Ok(Message::Close(_)) => return None,
386 Err(e) => return Some(Err(Error::WebSocket(e))),
387 }
388 }
389 }
390
391 pub async fn ping(&mut self) -> Result<()> {
393 self.inner
394 .send(Message::Ping(Bytes::new()))
395 .await
396 .map_err(Error::WebSocket)
397 }
398
399 pub async fn close(&mut self) -> Result<()> {
401 self.inner.close(None).await.map_err(Error::WebSocket)
402 }
403
404 pub fn time_since_last_ping(&self) -> Duration {
406 self.last_ping.elapsed()
407 }
408
409 pub fn into_stream(self) -> WebSocketEventStream {
411 WebSocketEventStream { inner: self }
412 }
413}
414
415pub struct WebSocketEventStream {
417 inner: WebSocketConnection,
418}
419
420impl Stream for WebSocketEventStream {
421 type Item = Result<WebSocketEvent>;
422
423 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
424 let future = self.inner.next();
425 tokio::pin!(future);
426 future.poll(cx)
427 }
428}
429
430#[derive(Debug, Clone)]
434pub struct ReconnectConfig {
435 pub max_reconnects: u32,
437 pub max_reconnect_delay: Duration,
439 pub base_delay: Duration,
441 pub health_check_enabled: bool,
443 pub health_check_interval: Duration,
445}
446
447impl Default for ReconnectConfig {
448 fn default() -> Self {
449 Self {
450 max_reconnects: MAX_RECONNECTS,
451 max_reconnect_delay: Duration::from_secs(MAX_RECONNECT_DELAY_SECS),
452 base_delay: Duration::from_millis(BASE_RECONNECT_DELAY_MS),
453 health_check_enabled: true,
454 health_check_interval: Duration::from_secs(HEALTH_CHECK_INTERVAL_SECS),
455 }
456 }
457}
458
459#[derive(Debug, Clone, Copy, PartialEq, Eq)]
461pub enum ConnectionState {
462 Connecting,
464 Connected,
466 Reconnecting,
468 Closed,
470}
471
472pub struct ReconnectingWebSocket {
477 connection: Arc<Mutex<Option<WebSocketConnection>>>,
478 state: Arc<RwLock<ConnectionState>>,
479 reconnect_count: Arc<AtomicU64>,
480 is_closed: Arc<AtomicBool>,
481 event_rx: mpsc::Receiver<Result<WebSocketEvent>>,
482}
483
484impl ReconnectingWebSocket {
485 pub async fn new(url: String, config: ReconnectConfig) -> Result<Self> {
487 let (event_tx, event_rx) = mpsc::channel(1000);
488 let connection = Arc::new(Mutex::new(None));
489 let state = Arc::new(RwLock::new(ConnectionState::Connecting));
490 let reconnect_count = Arc::new(AtomicU64::new(0));
491 let is_closed = Arc::new(AtomicBool::new(false));
492
493 let (ws_stream, _) = connect_async(&url).await.map_err(Error::WebSocket)?;
495 {
496 let mut conn = connection.lock().await;
497 *conn = Some(WebSocketConnection::new(ws_stream));
498 }
499 *state.write().await = ConnectionState::Connected;
500
501 let ws = Self {
502 connection: connection.clone(),
503 state: state.clone(),
504 reconnect_count: reconnect_count.clone(),
505 is_closed: is_closed.clone(),
506 event_rx,
507 };
508
509 tokio::spawn(async move {
511 Self::read_loop(
512 url,
513 config,
514 connection,
515 state,
516 reconnect_count,
517 is_closed,
518 event_tx,
519 )
520 .await;
521 });
522
523 Ok(ws)
524 }
525
526 async fn read_loop(
527 url: String,
528 config: ReconnectConfig,
529 connection: Arc<Mutex<Option<WebSocketConnection>>>,
530 state: Arc<RwLock<ConnectionState>>,
531 reconnect_count: Arc<AtomicU64>,
532 is_closed: Arc<AtomicBool>,
533 event_tx: mpsc::Sender<Result<WebSocketEvent>>,
534 ) {
535 loop {
536 if is_closed.load(Ordering::SeqCst) {
537 break;
538 }
539
540 let event = {
542 let mut conn_guard = connection.lock().await;
543 if let Some(ref mut conn) = *conn_guard {
544 match timeout(Duration::from_secs(WS_TIMEOUT_SECS), conn.next()).await {
545 Ok(Some(event)) => Some(event),
546 Ok(None) => None, Err(_) => {
548 None
550 }
551 }
552 } else {
553 None
554 }
555 };
556
557 match event {
558 Some(Ok(ev)) => {
559 if event_tx.send(Ok(ev)).await.is_err() {
560 break;
562 }
563 }
564 Some(Err(e)) => {
565 let _ = event_tx.send(Err(e)).await;
567 Self::attempt_reconnect(
568 &url,
569 &config,
570 &connection,
571 &state,
572 &reconnect_count,
573 &is_closed,
574 )
575 .await;
576 }
577 None => {
578 Self::attempt_reconnect(
580 &url,
581 &config,
582 &connection,
583 &state,
584 &reconnect_count,
585 &is_closed,
586 )
587 .await;
588 }
589 }
590 }
591
592 *state.write().await = ConnectionState::Closed;
593 }
594
595 async fn attempt_reconnect(
596 url: &str,
597 config: &ReconnectConfig,
598 connection: &Arc<Mutex<Option<WebSocketConnection>>>,
599 state: &Arc<RwLock<ConnectionState>>,
600 reconnect_count: &Arc<AtomicU64>,
601 is_closed: &Arc<AtomicBool>,
602 ) {
603 if is_closed.load(Ordering::SeqCst) {
604 return;
605 }
606
607 *state.write().await = ConnectionState::Reconnecting;
608
609 let count = reconnect_count.fetch_add(1, Ordering::SeqCst) + 1;
610
611 if count > config.max_reconnects as u64 {
612 is_closed.store(true, Ordering::SeqCst);
613 *state.write().await = ConnectionState::Closed;
614 return;
615 }
616
617 let delay = Self::calculate_backoff_delay(count, config);
619 sleep(delay).await;
620
621 match connect_async(url).await {
623 Ok((ws_stream, _)) => {
624 let mut conn = connection.lock().await;
625 *conn = Some(WebSocketConnection::new(ws_stream));
626 *state.write().await = ConnectionState::Connected;
627 reconnect_count.store(0, Ordering::SeqCst);
628 }
629 Err(_) => {
630 }
632 }
633 }
634
635 fn calculate_backoff_delay(attempt: u64, config: &ReconnectConfig) -> Duration {
636 let base_ms = config.base_delay.as_millis() as u64;
637 let exp_delay = base_ms.saturating_mul(2u64.saturating_pow(attempt as u32));
638 let max_delay_ms = config.max_reconnect_delay.as_millis() as u64;
639 let delay_ms = exp_delay.min(max_delay_ms);
640
641 let jitter = (delay_ms as f64 * 0.25 * (rand_simple() * 2.0 - 1.0)) as i64;
643 let final_delay = (delay_ms as i64 + jitter).max(0) as u64;
644
645 Duration::from_millis(final_delay)
646 }
647
648 pub async fn next(&mut self) -> Option<Result<WebSocketEvent>> {
650 self.event_rx.recv().await
651 }
652
653 pub async fn state(&self) -> ConnectionState {
655 *self.state.read().await
656 }
657
658 pub fn reconnect_count(&self) -> u64 {
660 self.reconnect_count.load(Ordering::SeqCst)
661 }
662
663 pub fn is_closed(&self) -> bool {
665 self.is_closed.load(Ordering::SeqCst)
666 }
667
668 pub async fn close(&self) {
670 self.is_closed.store(true, Ordering::SeqCst);
671 let mut conn = self.connection.lock().await;
672 if let Some(ref mut c) = *conn {
673 let _ = c.close().await;
674 }
675 *conn = None;
676 *self.state.write().await = ConnectionState::Closed;
677 }
678}
679
680fn rand_simple() -> f64 {
682 use std::time::SystemTime;
683 let nanos = SystemTime::now()
684 .duration_since(SystemTime::UNIX_EPOCH)
685 .unwrap_or_default()
686 .subsec_nanos();
687 nanos as f64 / u32::MAX as f64
688}
689
690#[derive(Debug, Clone)]
697pub struct DepthCache {
698 pub symbol: String,
700 bids: BTreeMap<OrderedFloat, f64>,
702 asks: BTreeMap<OrderedFloat, f64>,
704 pub last_update_id: u64,
706 pub update_time: Option<u64>,
708}
709
710#[derive(Debug, Clone, Copy, PartialEq)]
712struct OrderedFloat(f64);
713
714impl Eq for OrderedFloat {}
715
716impl PartialOrd for OrderedFloat {
717 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
718 Some(self.cmp(other))
719 }
720}
721
722impl Ord for OrderedFloat {
723 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
724 self.0
725 .partial_cmp(&other.0)
726 .unwrap_or(std::cmp::Ordering::Equal)
727 }
728}
729
730impl DepthCache {
731 pub fn new(symbol: &str) -> Self {
733 Self {
734 symbol: symbol.to_string(),
735 bids: BTreeMap::new(),
736 asks: BTreeMap::new(),
737 last_update_id: 0,
738 update_time: None,
739 }
740 }
741
742 pub fn initialize_from_snapshot(&mut self, order_book: &OrderBook) {
744 self.bids.clear();
745 self.asks.clear();
746
747 for bid in &order_book.bids {
748 if bid.quantity > 0.0 {
749 self.bids.insert(OrderedFloat(bid.price), bid.quantity);
750 }
751 }
752
753 for ask in &order_book.asks {
754 if ask.quantity > 0.0 {
755 self.asks.insert(OrderedFloat(ask.price), ask.quantity);
756 }
757 }
758
759 self.last_update_id = order_book.last_update_id;
760 }
761
762 pub fn apply_update(&mut self, event: &DepthEvent) -> bool {
767 if event.final_update_id <= self.last_update_id {
769 return false;
770 }
771
772 if event.first_update_id > self.last_update_id + 1 {
774 return false;
775 }
776
777 for bid in &event.bids {
779 if bid.quantity == 0.0 {
780 self.bids.remove(&OrderedFloat(bid.price));
781 } else {
782 self.bids.insert(OrderedFloat(bid.price), bid.quantity);
783 }
784 }
785
786 for ask in &event.asks {
788 if ask.quantity == 0.0 {
789 self.asks.remove(&OrderedFloat(ask.price));
790 } else {
791 self.asks.insert(OrderedFloat(ask.price), ask.quantity);
792 }
793 }
794
795 self.last_update_id = event.final_update_id;
796 self.update_time = Some(event.event_time);
797
798 true
799 }
800
801 pub fn best_bid(&self) -> Option<(f64, f64)> {
803 self.bids.iter().next_back().map(|(p, q)| (p.0, *q))
804 }
805
806 pub fn best_ask(&self) -> Option<(f64, f64)> {
808 self.asks.iter().next().map(|(p, q)| (p.0, *q))
809 }
810
811 pub fn spread(&self) -> Option<f64> {
813 match (self.best_bid(), self.best_ask()) {
814 (Some((bid, _)), Some((ask, _))) => Some(ask - bid),
815 _ => None,
816 }
817 }
818
819 pub fn mid_price(&self) -> Option<f64> {
821 match (self.best_bid(), self.best_ask()) {
822 (Some((bid, _)), Some((ask, _))) => Some((bid + ask) / 2.0),
823 _ => None,
824 }
825 }
826
827 pub fn get_bids(&self) -> Vec<(f64, f64)> {
829 self.bids.iter().rev().map(|(p, q)| (p.0, *q)).collect()
830 }
831
832 pub fn get_asks(&self) -> Vec<(f64, f64)> {
834 self.asks.iter().map(|(p, q)| (p.0, *q)).collect()
835 }
836
837 pub fn get_top_bids(&self, n: usize) -> Vec<(f64, f64)> {
839 self.bids
840 .iter()
841 .rev()
842 .take(n)
843 .map(|(p, q)| (p.0, *q))
844 .collect()
845 }
846
847 pub fn get_top_asks(&self, n: usize) -> Vec<(f64, f64)> {
849 self.asks.iter().take(n).map(|(p, q)| (p.0, *q)).collect()
850 }
851
852 pub fn total_bid_volume(&self) -> f64 {
854 self.bids.values().sum()
855 }
856
857 pub fn total_ask_volume(&self) -> f64 {
859 self.asks.values().sum()
860 }
861}
862
863#[derive(Debug, Clone)]
867pub struct DepthCacheConfig {
868 pub depth_limit: u32,
870 pub fast_updates: bool,
872 pub refresh_interval: Option<Duration>,
874}
875
876impl Default for DepthCacheConfig {
877 fn default() -> Self {
878 Self {
879 depth_limit: 1000,
880 fast_updates: false,
881 refresh_interval: None,
882 }
883 }
884}
885
886#[derive(Debug, Clone, Copy, PartialEq, Eq)]
888pub enum DepthCacheState {
889 Initializing,
891 Synced,
893 OutOfSync,
895 Stopped,
897}
898
899pub struct DepthCacheManager {
934 symbol: String,
935 cache: Arc<RwLock<DepthCache>>,
936 state: Arc<RwLock<DepthCacheState>>,
937 is_stopped: Arc<AtomicBool>,
938 cache_rx: mpsc::Receiver<DepthCache>,
939}
940
941impl DepthCacheManager {
942 pub async fn new(
946 client: crate::Binance,
947 symbol: &str,
948 config: DepthCacheConfig,
949 ) -> Result<Self> {
950 let symbol = symbol.to_uppercase();
951 let cache = Arc::new(RwLock::new(DepthCache::new(&symbol)));
952 let state = Arc::new(RwLock::new(DepthCacheState::Initializing));
953 let is_stopped = Arc::new(AtomicBool::new(false));
954 let (cache_tx, cache_rx) = mpsc::channel(100);
955
956 let symbol_clone = symbol.clone();
958 let cache_clone = cache.clone();
959 let state_clone = state.clone();
960 let is_stopped_clone = is_stopped.clone();
961
962 tokio::spawn(async move {
964 Self::sync_loop(
965 client,
966 symbol_clone,
967 config,
968 cache_clone,
969 state_clone,
970 is_stopped_clone,
971 cache_tx,
972 )
973 .await;
974 });
975
976 Ok(Self {
977 symbol,
978 cache,
979 state,
980 is_stopped,
981 cache_rx,
982 })
983 }
984
985 async fn sync_loop(
986 client: crate::Binance,
987 symbol: String,
988 config: DepthCacheConfig,
989 cache: Arc<RwLock<DepthCache>>,
990 state: Arc<RwLock<DepthCacheState>>,
991 is_stopped: Arc<AtomicBool>,
992 cache_tx: mpsc::Sender<DepthCache>,
993 ) {
994 let ws = client.websocket();
995 let stream = ws.diff_depth_stream(&symbol, config.fast_updates);
996
997 loop {
998 if is_stopped.load(Ordering::SeqCst) {
999 break;
1000 }
1001
1002 *state.write().await = DepthCacheState::Initializing;
1004
1005 let mut conn = match ws.connect(&stream).await {
1007 Ok(c) => c,
1008 Err(_) => {
1009 sleep(Duration::from_secs(1)).await;
1010 continue;
1011 }
1012 };
1013
1014 let mut initial_events = Vec::new();
1016 let buffer_timeout = Duration::from_secs(2);
1017 let start = Instant::now();
1018
1019 while start.elapsed() < buffer_timeout {
1020 match timeout(Duration::from_millis(500), conn.next_raw()).await {
1021 Ok(Some(Ok(raw))) => {
1022 if let Ok(event) = serde_json::from_value::<DepthEvent>(raw) {
1023 initial_events.push(event);
1024 }
1025 }
1026 _ => break,
1027 }
1028 }
1029
1030 let snapshot = match client
1032 .market()
1033 .depth(&symbol, Some(config.depth_limit as u16))
1034 .await
1035 {
1036 Ok(s) => s,
1037 Err(_) => {
1038 sleep(Duration::from_secs(1)).await;
1039 continue;
1040 }
1041 };
1042
1043 {
1045 let mut cache_guard = cache.write().await;
1046 cache_guard.initialize_from_snapshot(&snapshot);
1047
1048 for event in &initial_events {
1050 cache_guard.apply_update(event);
1051 }
1052 }
1053
1054 *state.write().await = DepthCacheState::Synced;
1055
1056 {
1058 let cache_guard = cache.read().await;
1059 let _ = cache_tx.send(cache_guard.clone()).await;
1060 }
1061
1062 let mut last_refresh = Instant::now();
1064 loop {
1065 if is_stopped.load(Ordering::SeqCst) {
1066 break;
1067 }
1068
1069 if let Some(refresh_interval) = config.refresh_interval {
1071 if last_refresh.elapsed() >= refresh_interval {
1072 if let Ok(snapshot) = client
1074 .market()
1075 .depth(&symbol, Some(config.depth_limit as u16))
1076 .await
1077 {
1078 let mut cache_guard = cache.write().await;
1079 cache_guard.initialize_from_snapshot(&snapshot);
1080 }
1081 last_refresh = Instant::now();
1082 }
1083 }
1084
1085 match timeout(Duration::from_secs(WS_TIMEOUT_SECS), conn.next_raw()).await {
1086 Ok(Some(Ok(raw))) => {
1087 if let Ok(event) = serde_json::from_value::<DepthEvent>(raw) {
1088 let mut cache_guard = cache.write().await;
1089 if cache_guard.apply_update(&event) {
1090 let _ = cache_tx.send(cache_guard.clone()).await;
1092 } else {
1093 drop(cache_guard);
1095 *state.write().await = DepthCacheState::OutOfSync;
1096 break;
1097 }
1098 }
1099 }
1100 Ok(Some(Err(_))) | Ok(None) | Err(_) => {
1101 *state.write().await = DepthCacheState::OutOfSync;
1103 break;
1104 }
1105 }
1106 }
1107
1108 sleep(Duration::from_millis(100)).await;
1110 }
1111
1112 *state.write().await = DepthCacheState::Stopped;
1113 }
1114
1115 pub async fn wait_for_sync(&self) -> Result<()> {
1117 let timeout_duration = Duration::from_secs(30);
1118 let start = Instant::now();
1119
1120 loop {
1121 let state = *self.state.read().await;
1122 match state {
1123 DepthCacheState::Synced => return Ok(()),
1124 DepthCacheState::Stopped => {
1125 return Err(Error::InvalidCredentials(
1126 "Depth cache manager stopped".to_string(),
1127 ));
1128 }
1129 _ => {
1130 if start.elapsed() > timeout_duration {
1131 return Err(Error::InvalidCredentials(
1132 "Timeout waiting for depth cache sync".to_string(),
1133 ));
1134 }
1135 sleep(Duration::from_millis(100)).await;
1136 }
1137 }
1138 }
1139 }
1140
1141 pub async fn get_cache(&self) -> DepthCache {
1143 self.cache.read().await.clone()
1144 }
1145
1146 pub async fn state(&self) -> DepthCacheState {
1148 *self.state.read().await
1149 }
1150
1151 pub async fn next(&mut self) -> Option<DepthCache> {
1153 self.cache_rx.recv().await
1154 }
1155
1156 pub fn stop(&self) {
1158 self.is_stopped.store(true, Ordering::SeqCst);
1159 }
1160
1161 pub fn symbol(&self) -> &str {
1163 &self.symbol
1164 }
1165}
1166
1167pub struct UserDataStreamManager {
1196 listen_key: Arc<RwLock<String>>,
1197 is_stopped: Arc<AtomicBool>,
1198 event_rx: mpsc::Receiver<Result<WebSocketEvent>>,
1199}
1200
1201impl UserDataStreamManager {
1202 pub async fn new(client: crate::Binance) -> Result<Self> {
1206 let listen_key = client.user_stream().start().await?;
1208 let listen_key = Arc::new(RwLock::new(listen_key));
1209 let is_stopped = Arc::new(AtomicBool::new(false));
1210 let (event_tx, event_rx) = mpsc::channel(1000);
1211
1212 let listen_key_clone = listen_key.clone();
1214 let is_stopped_clone = is_stopped.clone();
1215 let client_clone = client.clone();
1216
1217 tokio::spawn(async move {
1219 Self::keepalive_loop(
1220 client_clone.clone(),
1221 listen_key_clone.clone(),
1222 is_stopped_clone.clone(),
1223 )
1224 .await;
1225 });
1226
1227 let listen_key_ws = listen_key.clone();
1229 let is_stopped_ws = is_stopped.clone();
1230
1231 tokio::spawn(async move {
1232 Self::connection_loop(client, listen_key_ws, is_stopped_ws, event_tx).await;
1233 });
1234
1235 Ok(Self {
1236 listen_key,
1237 is_stopped,
1238 event_rx,
1239 })
1240 }
1241
1242 async fn keepalive_loop(
1243 client: crate::Binance,
1244 listen_key: Arc<RwLock<String>>,
1245 is_stopped: Arc<AtomicBool>,
1246 ) {
1247 let mut interval_timer = interval(Duration::from_secs(USER_STREAM_KEEPALIVE_SECS));
1248
1249 loop {
1250 interval_timer.tick().await;
1251
1252 if is_stopped.load(Ordering::SeqCst) {
1253 break;
1254 }
1255
1256 let key = listen_key.read().await.clone();
1257 if client.user_stream().keepalive(&key).await.is_err() {
1258 if let Ok(new_key) = client.user_stream().start().await {
1260 *listen_key.write().await = new_key;
1261 }
1262 }
1263 }
1264
1265 let key = listen_key.read().await.clone();
1267 let _ = client.user_stream().close(&key).await;
1268 }
1269
1270 async fn connection_loop(
1271 client: crate::Binance,
1272 listen_key: Arc<RwLock<String>>,
1273 is_stopped: Arc<AtomicBool>,
1274 event_tx: mpsc::Sender<Result<WebSocketEvent>>,
1275 ) {
1276 let reconnect_config = ReconnectConfig::default();
1277
1278 loop {
1279 if is_stopped.load(Ordering::SeqCst) {
1280 break;
1281 }
1282
1283 let key = listen_key.read().await.clone();
1284 let ws = client.websocket();
1285
1286 match ws.connect_user_stream(&key).await {
1287 Ok(mut conn) => {
1288 loop {
1289 if is_stopped.load(Ordering::SeqCst) {
1290 break;
1291 }
1292
1293 match timeout(Duration::from_secs(WS_TIMEOUT_SECS), conn.next()).await {
1294 Ok(Some(event)) => {
1295 if event_tx.send(event).await.is_err() {
1296 return;
1298 }
1299 }
1300 Ok(None) => {
1301 break;
1303 }
1304 Err(_) => {
1305 continue;
1307 }
1308 }
1309 }
1310 }
1311 Err(_) => {
1312 sleep(reconnect_config.base_delay).await;
1314 }
1315 }
1316
1317 sleep(Duration::from_millis(100)).await;
1319 }
1320 }
1321
1322 pub async fn next(&mut self) -> Option<Result<WebSocketEvent>> {
1324 self.event_rx.recv().await
1325 }
1326
1327 pub async fn listen_key(&self) -> String {
1329 self.listen_key.read().await.clone()
1330 }
1331
1332 pub fn stop(&self) {
1334 self.is_stopped.store(true, Ordering::SeqCst);
1335 }
1336
1337 pub fn is_stopped(&self) -> bool {
1339 self.is_stopped.load(Ordering::SeqCst)
1340 }
1341}
1342
1343pub struct ConnectionHealthMonitor {
1350 last_activity: Arc<RwLock<Instant>>,
1351 is_healthy: Arc<AtomicBool>,
1352 max_idle_duration: Duration,
1353}
1354
1355impl ConnectionHealthMonitor {
1356 pub fn new(max_idle_duration: Duration) -> Self {
1362 Self {
1363 last_activity: Arc::new(RwLock::new(Instant::now())),
1364 is_healthy: Arc::new(AtomicBool::new(true)),
1365 max_idle_duration,
1366 }
1367 }
1368
1369 pub async fn record_activity(&self) {
1371 *self.last_activity.write().await = Instant::now();
1372 self.is_healthy.store(true, Ordering::SeqCst);
1373 }
1374
1375 pub async fn is_healthy(&self) -> bool {
1377 let last = *self.last_activity.read().await;
1378 let healthy = last.elapsed() < self.max_idle_duration;
1379 self.is_healthy.store(healthy, Ordering::SeqCst);
1380 healthy
1381 }
1382
1383 pub async fn time_since_last_activity(&self) -> Duration {
1385 self.last_activity.read().await.elapsed()
1386 }
1387
1388 pub fn start_background_check(
1390 self: Arc<Self>,
1391 check_interval: Duration,
1392 ) -> tokio::task::JoinHandle<()> {
1393 let monitor = self;
1394 tokio::spawn(async move {
1395 let mut interval_timer = interval(check_interval);
1396 loop {
1397 interval_timer.tick().await;
1398 monitor.is_healthy().await;
1399 }
1400 })
1401 }
1402}
1403
1404#[derive(serde::Deserialize)]
1408struct CombinedStreamMessage {
1409 #[allow(dead_code)]
1410 stream: String,
1411 data: WebSocketEvent,
1412}
1413
1414#[cfg(test)]
1417mod tests {
1418 use super::*;
1419
1420 #[test]
1421 fn test_stream_names() {
1422 let config = Config::default();
1423 let ws = WebSocketClient::new(config);
1424
1425 assert_eq!(ws.agg_trade_stream("BTCUSDT"), "btcusdt@aggTrade");
1426 assert_eq!(ws.trade_stream("BTCUSDT"), "btcusdt@trade");
1427 assert_eq!(
1428 ws.kline_stream("BTCUSDT", KlineInterval::Hours1),
1429 "btcusdt@kline_1h"
1430 );
1431 assert_eq!(ws.ticker_stream("BTCUSDT"), "btcusdt@ticker");
1432 assert_eq!(ws.book_ticker_stream("BTCUSDT"), "btcusdt@bookTicker");
1433 assert_eq!(ws.all_mini_ticker_stream(), "!miniTicker@arr");
1434 assert_eq!(ws.all_ticker_stream(), "!ticker@arr");
1435 assert_eq!(ws.all_book_ticker_stream(), "!bookTicker");
1436 }
1437
1438 #[test]
1439 fn test_depth_stream_names() {
1440 let config = Config::default();
1441 let ws = WebSocketClient::new(config);
1442
1443 assert_eq!(
1444 ws.partial_depth_stream("BTCUSDT", 10, false),
1445 "btcusdt@depth10"
1446 );
1447 assert_eq!(
1448 ws.partial_depth_stream("BTCUSDT", 10, true),
1449 "btcusdt@depth10@100ms"
1450 );
1451 assert_eq!(ws.diff_depth_stream("BTCUSDT", false), "btcusdt@depth");
1452 assert_eq!(ws.diff_depth_stream("BTCUSDT", true), "btcusdt@depth@100ms");
1453 }
1454
1455 #[test]
1456 fn test_depth_cache() {
1457 let mut cache = DepthCache::new("BTCUSDT");
1458
1459 cache.bids.insert(OrderedFloat(50000.0), 1.0);
1461 cache.bids.insert(OrderedFloat(49999.0), 2.0);
1462 cache.asks.insert(OrderedFloat(50001.0), 1.5);
1463 cache.asks.insert(OrderedFloat(50002.0), 2.5);
1464
1465 assert_eq!(cache.best_bid(), Some((50000.0, 1.0)));
1466 assert_eq!(cache.best_ask(), Some((50001.0, 1.5)));
1467 assert_eq!(cache.spread(), Some(1.0));
1468 assert_eq!(cache.mid_price(), Some(50000.5));
1469 }
1470
1471 #[test]
1472 fn test_reconnect_config_default() {
1473 let config = ReconnectConfig::default();
1474 assert_eq!(config.max_reconnects, MAX_RECONNECTS);
1475 assert_eq!(
1476 config.max_reconnect_delay,
1477 Duration::from_secs(MAX_RECONNECT_DELAY_SECS)
1478 );
1479 assert!(config.health_check_enabled);
1480 }
1481
1482 #[test]
1483 fn test_depth_cache_config_default() {
1484 let config = DepthCacheConfig::default();
1485 assert_eq!(config.depth_limit, 1000);
1486 assert!(!config.fast_updates);
1487 assert!(config.refresh_interval.is_none());
1488 }
1489
1490 #[test]
1491 fn test_connection_state() {
1492 assert_eq!(ConnectionState::Connecting, ConnectionState::Connecting);
1493 assert_ne!(ConnectionState::Connected, ConnectionState::Closed);
1494 }
1495
1496 #[test]
1497 fn test_depth_cache_state() {
1498 assert_eq!(DepthCacheState::Initializing, DepthCacheState::Initializing);
1499 assert_ne!(DepthCacheState::Synced, DepthCacheState::OutOfSync);
1500 }
1501
1502 #[test]
1503 fn test_ordered_float() {
1504 let a = OrderedFloat(1.0);
1505 let b = OrderedFloat(2.0);
1506 assert!(a < b);
1507 assert_eq!(a, OrderedFloat(1.0));
1508 }
1509
1510 #[test]
1511 fn test_backoff_delay() {
1512 let config = ReconnectConfig::default();
1513
1514 let delay1 = ReconnectingWebSocket::calculate_backoff_delay(1, &config);
1516 assert!(delay1.as_millis() > 0);
1517 assert!(delay1 <= config.max_reconnect_delay);
1518
1519 let delay5 = ReconnectingWebSocket::calculate_backoff_delay(5, &config);
1521 assert!(delay5 <= config.max_reconnect_delay);
1522 }
1523}