1mod config;
7mod error;
8mod event;
9mod message;
10mod reconnect;
11mod state;
12mod subscription;
13
14pub use config::{
15 BackoffConfig, BackoffStrategy, BackpressureStrategy, DEFAULT_MAX_SUBSCRIPTIONS,
16 DEFAULT_MESSAGE_CHANNEL_CAPACITY, DEFAULT_SHUTDOWN_TIMEOUT, DEFAULT_WRITE_CHANNEL_CAPACITY,
17 WsConfig,
18};
19pub use error::{WsError, WsErrorKind};
20pub use event::{WsEvent, WsEventCallback};
21pub use message::WsMessage;
22pub use reconnect::AutoReconnectCoordinator;
23pub use state::{WsConnectionState, WsStats, WsStatsSnapshot};
24pub use subscription::{Subscription, SubscriptionManager};
25
26use crate::error::{Error, Result};
27use derive_more::Debug;
28use futures_util::{SinkExt, StreamExt, stream::SplitSink};
29use serde_json::Value;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
33use tokio::net::TcpStream;
34use tokio::sync::{Mutex, RwLock, mpsc};
35use tokio::time::{Duration, interval};
36use tokio_tungstenite::{
37 MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
38};
39use tokio_util::sync::CancellationToken;
40use tracing::{debug, error, info, instrument, warn};
41
42#[allow(dead_code)]
44type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
45
46#[derive(Debug)]
58pub struct WsClient {
59 config: WsConfig,
60 state: Arc<AtomicU8>,
61 subscription_manager: SubscriptionManager,
62 message_tx: mpsc::Sender<Value>,
63 message_rx: Arc<RwLock<mpsc::Receiver<Value>>>,
64 write_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
65 pub(crate) reconnect_count: AtomicU32,
66 shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
67 stats: Arc<WsStats>,
68 cancel_token: Arc<Mutex<Option<CancellationToken>>>,
69 #[debug(skip)]
70 event_callback: Arc<Mutex<Option<WsEventCallback>>>,
71 dropped_messages: Arc<AtomicU32>,
73}
74
75impl WsClient {
76 pub fn new(config: WsConfig) -> Self {
81 let (message_tx, message_rx) = mpsc::channel(config.message_channel_capacity);
82 let max_subscriptions = config.max_subscriptions;
83
84 Self {
85 config,
86 state: Arc::new(AtomicU8::new(WsConnectionState::Disconnected.as_u8())),
87 subscription_manager: SubscriptionManager::new(max_subscriptions),
88 message_tx,
89 message_rx: Arc::new(RwLock::new(message_rx)),
90 write_tx: Arc::new(RwLock::new(None)),
91 reconnect_count: AtomicU32::new(0),
92 shutdown_tx: Arc::new(Mutex::new(None)),
93 stats: Arc::new(WsStats::new()),
94 cancel_token: Arc::new(Mutex::new(None)),
95 event_callback: Arc::new(Mutex::new(None)),
96 dropped_messages: Arc::new(AtomicU32::new(0)),
97 }
98 }
99
100 pub async fn set_event_callback(&self, callback: WsEventCallback) {
102 *self.event_callback.lock().await = Some(callback);
103 debug!("Event callback set");
104 }
105
106 pub async fn clear_event_callback(&self) {
108 *self.event_callback.lock().await = None;
109 debug!("Event callback cleared");
110 }
111
112 async fn emit_event(&self, event: WsEvent) {
113 let callback = self.event_callback.lock().await;
114 if let Some(ref cb) = *callback {
115 let cb = Arc::clone(cb);
116 drop(callback);
117 tokio::spawn(async move {
118 cb(event);
119 });
120 }
121 }
122
123 pub async fn set_cancel_token(&self, token: CancellationToken) {
125 *self.cancel_token.lock().await = Some(token);
126 debug!("Cancellation token set");
127 }
128
129 pub async fn clear_cancel_token(&self) {
131 *self.cancel_token.lock().await = None;
132 debug!("Cancellation token cleared");
133 }
134
135 pub async fn get_cancel_token(&self) -> Option<CancellationToken> {
137 self.cancel_token.lock().await.clone()
138 }
139
140 #[instrument(
142 name = "ws_connect",
143 skip(self),
144 fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
145 )]
146 pub async fn connect(&self) -> Result<()> {
147 if self.state() == WsConnectionState::Connected {
148 info!("WebSocket already connected");
149 return Ok(());
150 }
151
152 self.set_state(WsConnectionState::Connecting);
153
154 let url = self.config.url.clone();
155 info!("Initiating WebSocket connection");
156
157 match tokio::time::timeout(
158 Duration::from_millis(self.config.connect_timeout),
159 connect_async(&url),
160 )
161 .await
162 {
163 Ok(Ok((ws_stream, response))) => {
164 info!(
165 status = response.status().as_u16(),
166 "WebSocket connection established successfully"
167 );
168
169 self.set_state(WsConnectionState::Connected);
170 self.reconnect_count.store(0, Ordering::Release);
171 self.stats.record_connected();
172 self.start_message_loop(ws_stream).await;
173 self.resubscribe_all().await?;
174
175 Ok(())
176 }
177 Ok(Err(e)) => {
178 error!(error = %e, "WebSocket connection failed");
179 self.set_state(WsConnectionState::Error);
180 Err(Error::network(format!("WebSocket connection failed: {e}")))
181 }
182 Err(_) => {
183 error!(
184 timeout_ms = self.config.connect_timeout,
185 "WebSocket connection timeout"
186 );
187 self.set_state(WsConnectionState::Error);
188 Err(Error::timeout("WebSocket connection timeout"))
189 }
190 }
191 }
192
193 #[instrument(
195 name = "ws_connect_with_cancel",
196 skip(self, cancel_token),
197 fields(url = %self.config.url)
198 )]
199 pub async fn connect_with_cancel(&self, cancel_token: Option<CancellationToken>) -> Result<()> {
200 let token = if let Some(t) = cancel_token {
201 t
202 } else {
203 let internal_token = self.cancel_token.lock().await;
204 internal_token
205 .clone()
206 .unwrap_or_else(CancellationToken::new)
207 };
208
209 if self.state() == WsConnectionState::Connected {
210 info!("WebSocket already connected");
211 return Ok(());
212 }
213
214 self.set_state(WsConnectionState::Connecting);
215 let url = self.config.url.clone();
216
217 tokio::select! {
218 biased;
219 () = token.cancelled() => {
220 warn!("WebSocket connection cancelled");
221 self.set_state(WsConnectionState::Disconnected);
222 Err(Error::cancelled("WebSocket connection cancelled"))
223 }
224 result = tokio::time::timeout(
225 Duration::from_millis(self.config.connect_timeout),
226 connect_async(&url),
227 ) => {
228 match result {
229 Ok(Ok((ws_stream, response))) => {
230 info!(status = response.status().as_u16(), "WebSocket connected");
231 self.set_state(WsConnectionState::Connected);
232 self.reconnect_count.store(0, Ordering::Release);
233 self.stats.record_connected();
234 self.start_message_loop(ws_stream).await;
235 self.resubscribe_all().await?;
236 Ok(())
237 }
238 Ok(Err(e)) => {
239 error!(error = %e, "WebSocket connection failed");
240 self.set_state(WsConnectionState::Error);
241 Err(Error::network(format!("WebSocket connection failed: {e}")))
242 }
243 Err(_) => {
244 error!("WebSocket connection timeout");
245 self.set_state(WsConnectionState::Error);
246 Err(Error::timeout("WebSocket connection timeout"))
247 }
248 }
249 }
250 }
251 }
252
253 #[instrument(name = "ws_disconnect", skip(self))]
255 pub async fn disconnect(&self) -> Result<()> {
256 info!("Initiating WebSocket disconnect");
257
258 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
259 let _ = tx.send(());
260 }
261
262 *self.write_tx.write().await = None;
263 self.set_state(WsConnectionState::Disconnected);
264
265 info!("WebSocket disconnected");
266 Ok(())
267 }
268
269 #[instrument(name = "ws_shutdown", skip(self))]
271 pub async fn shutdown(&self) {
272 info!("Initiating graceful shutdown");
273
274 {
275 let token_guard = self.cancel_token.lock().await;
276 if let Some(ref token) = *token_guard {
277 token.cancel();
278 }
279 }
280
281 self.set_state(WsConnectionState::Disconnected);
282
283 {
284 let write_tx_guard = self.write_tx.read().await;
285 if let Some(ref tx) = *write_tx_guard {
286 drop(tx.send(Message::Close(None)).await);
288 }
289 }
290
291 let shutdown_timeout = Duration::from_millis(self.config.shutdown_timeout);
292 let _ = tokio::time::timeout(shutdown_timeout, async {
293 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
294 let _ = tx.send(());
295 }
296 tokio::time::sleep(Duration::from_millis(100)).await;
297 })
298 .await;
299
300 {
301 *self.write_tx.write().await = None;
302 *self.shutdown_tx.lock().await = None;
303 self.subscription_manager.clear();
304 self.reconnect_count.store(0, Ordering::Release);
305 self.dropped_messages.store(0, Ordering::Relaxed);
306 self.stats.reset();
307 }
308
309 self.emit_event(WsEvent::Shutdown).await;
310 info!("Graceful shutdown completed");
311 }
312
313 #[instrument(name = "ws_reconnect", skip(self))]
315 pub async fn reconnect(&self) -> Result<()> {
316 let count = self.reconnect_count.fetch_add(1, Ordering::AcqRel) + 1;
317
318 if count > self.config.max_reconnect_attempts {
319 error!(attempts = count, "Max reconnect attempts reached");
320 return Err(Error::network("Max reconnect attempts reached"));
321 }
322
323 warn!(attempt = count, "Attempting WebSocket reconnection");
324 self.set_state(WsConnectionState::Reconnecting);
325
326 tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
327 self.connect().await
328 }
329
330 #[instrument(name = "ws_reconnect_with_cancel", skip(self, cancel_token))]
332 pub async fn reconnect_with_cancel(
333 &self,
334 cancel_token: Option<CancellationToken>,
335 ) -> Result<()> {
336 let token = if let Some(t) = cancel_token {
337 t
338 } else {
339 let internal_token = self.cancel_token.lock().await;
340 internal_token
341 .clone()
342 .unwrap_or_else(CancellationToken::new)
343 };
344
345 let backoff = BackoffStrategy::new(self.config.backoff_config.clone());
346 self.set_state(WsConnectionState::Reconnecting);
347
348 loop {
349 if token.is_cancelled() {
350 self.set_state(WsConnectionState::Disconnected);
351 return Err(Error::cancelled("Reconnection cancelled"));
352 }
353
354 let attempt = self.reconnect_count.fetch_add(1, Ordering::AcqRel);
355
356 if attempt >= self.config.max_reconnect_attempts {
357 self.set_state(WsConnectionState::Error);
358 return Err(Error::network(format!(
359 "Max reconnect attempts ({}) reached",
360 self.config.max_reconnect_attempts
361 )));
362 }
363
364 let delay = backoff.calculate_delay(attempt);
365
366 tokio::select! {
367 biased;
368 () = token.cancelled() => {
369 self.set_state(WsConnectionState::Disconnected);
370 return Err(Error::cancelled("Reconnection cancelled during backoff"));
371 }
372 () = tokio::time::sleep(delay) => {}
373 }
374
375 match self.connect_with_cancel(Some(token.clone())).await {
376 Ok(()) => {
377 self.reconnect_count.store(0, Ordering::Release);
378 return Ok(());
379 }
380 Err(e) => {
381 if e.as_cancelled().is_some() {
382 self.set_state(WsConnectionState::Disconnected);
383 return Err(e);
384 }
385
386 let ws_error = WsError::from_error(&e);
387 if ws_error.is_permanent() {
388 self.set_state(WsConnectionState::Error);
389 return Err(e);
390 }
391 }
392 }
393 }
394 }
395
396 #[inline]
398 pub fn reconnect_count(&self) -> u32 {
399 self.reconnect_count.load(Ordering::Acquire)
400 }
401
402 pub fn reset_reconnect_count(&self) {
404 self.reconnect_count.store(0, Ordering::Release);
405 }
406
407 pub(crate) fn increment_reconnect_count(&self) {
409 self.reconnect_count.fetch_add(1, Ordering::AcqRel);
410 }
411
412 pub fn stats(&self) -> WsStatsSnapshot {
414 self.stats.snapshot()
415 }
416
417 pub fn reset_stats(&self) {
419 self.stats.reset();
420 }
421
422 pub fn latency(&self) -> Option<i64> {
424 let last_pong = self.stats.last_pong_time();
425 let last_ping = self.stats.last_ping_time();
426 if last_pong > 0 && last_ping > 0 {
427 Some(last_pong - last_ping)
428 } else {
429 None
430 }
431 }
432
433 pub fn dropped_messages(&self) -> u32 {
438 self.dropped_messages.load(Ordering::Relaxed)
439 }
440
441 pub fn reset_dropped_messages(&self) {
443 self.dropped_messages.store(0, Ordering::Relaxed);
444 }
445
446 pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
448 AutoReconnectCoordinator::new(self)
449 }
450
451 #[instrument(name = "ws_subscribe", skip(self, params), fields(channel = %channel))]
453 pub async fn subscribe(
454 &self,
455 channel: String,
456 symbol: Option<String>,
457 params: Option<HashMap<String, Value>>,
458 ) -> Result<()> {
459 let sub_key = Self::subscription_key(&channel, symbol.as_ref());
460 let subscription = Subscription {
461 channel: channel.clone(),
462 symbol: symbol.clone(),
463 params: params.clone(),
464 };
465
466 self.subscription_manager
467 .try_add(sub_key.clone(), subscription)?;
468
469 if self.state() == WsConnectionState::Connected {
470 self.send_subscribe_message(channel, symbol, params).await?;
471 }
472
473 Ok(())
474 }
475
476 #[instrument(name = "ws_unsubscribe", skip(self), fields(channel = %channel))]
478 pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
479 let sub_key = Self::subscription_key(&channel, symbol.as_ref());
480 self.subscription_manager.remove(&sub_key);
481
482 if self.state() == WsConnectionState::Connected {
483 self.send_unsubscribe_message(channel, symbol).await?;
484 }
485
486 Ok(())
487 }
488
489 pub async fn receive(&self) -> Option<Value> {
491 let mut rx = self.message_rx.write().await;
492 rx.recv().await
493 }
494
495 #[inline]
497 pub fn state(&self) -> WsConnectionState {
498 WsConnectionState::from_u8(self.state.load(Ordering::Acquire))
499 }
500
501 #[inline]
503 pub fn config(&self) -> &WsConfig {
504 &self.config
505 }
506
507 #[inline]
509 pub fn set_state(&self, state: WsConnectionState) {
510 self.state.store(state.as_u8(), Ordering::Release);
511 }
512
513 #[inline]
515 pub fn is_connected(&self) -> bool {
516 self.state() == WsConnectionState::Connected
517 }
518
519 pub fn is_subscribed(&self, channel: &str, symbol: Option<&String>) -> bool {
521 let sub_key = Self::subscription_key(channel, symbol);
522 self.subscription_manager.contains(&sub_key)
523 }
524
525 pub fn subscription_count(&self) -> usize {
527 self.subscription_manager.count()
528 }
529
530 pub fn remaining_capacity(&self) -> usize {
532 self.subscription_manager.remaining_capacity()
533 }
534
535 pub fn subscriptions(&self) -> Vec<String> {
540 self.subscription_manager
541 .iter()
542 .map(|entry| {
543 let sub = entry.value();
544 match &sub.symbol {
545 Some(sym) => format!("{}:{}", sub.channel, sym),
546 None => sub.channel.clone(),
547 }
548 })
549 .collect()
550 }
551
552 #[instrument(name = "ws_send", skip(self, message))]
557 pub async fn send(&self, message: Message) -> Result<()> {
558 let tx = self.write_tx.read().await;
559
560 if let Some(sender) = tx.as_ref() {
561 sender
562 .send(message)
563 .await
564 .map_err(|e| Error::network(format!("Failed to send message: {e}")))?;
565 Ok(())
566 } else {
567 Err(Error::network("WebSocket not connected"))
568 }
569 }
570
571 #[instrument(name = "ws_try_send", skip(self, message))]
575 pub fn try_send(&self, message: Message) -> Result<()> {
576 if let Ok(tx) = self.write_tx.try_read() {
579 if let Some(sender) = tx.as_ref() {
580 sender.try_send(message).map_err(|e| match e {
581 mpsc::error::TrySendError::Full(_) => {
582 Error::network("Write channel full (backpressure)")
583 }
584 mpsc::error::TrySendError::Closed(_) => {
585 Error::network("WebSocket channel closed")
586 }
587 })?;
588 Ok(())
589 } else {
590 Err(Error::network("WebSocket not connected"))
591 }
592 } else {
593 Err(Error::network("Write channel busy"))
594 }
595 }
596
597 #[instrument(name = "ws_send_text", skip(self, text))]
599 pub async fn send_text(&self, text: String) -> Result<()> {
600 self.send(Message::Text(text.into())).await
601 }
602
603 #[instrument(name = "ws_send_json", skip(self, json))]
605 pub async fn send_json(&self, json: &Value) -> Result<()> {
606 let text = serde_json::to_string(json).map_err(Error::from)?;
607 self.send_text(text).await
608 }
609
610 fn subscription_key(channel: &str, symbol: Option<&String>) -> String {
611 match symbol {
612 Some(s) => format!("{channel}:{s}"),
613 None => channel.to_string(),
614 }
615 }
616
617 async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
618 let (write, mut read) = ws_stream.split();
619
620 let (write_tx, mut write_rx) = mpsc::channel::<Message>(self.config.write_channel_capacity);
622 *self.write_tx.write().await = Some(write_tx.clone());
623
624 let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
626 *self.shutdown_tx.lock().await = Some(shutdown_tx);
627
628 let state = Arc::clone(&self.state);
629 let message_tx = self.message_tx.clone();
630 let ping_interval_ms = self.config.ping_interval;
631 let backpressure_strategy = self.config.backpressure_strategy;
632 let dropped_messages = Arc::clone(&self.dropped_messages);
633
634 let write_handle = tokio::spawn(async move {
635 let mut write = write;
636 loop {
637 tokio::select! {
638 Some(msg) = write_rx.recv() => {
639 if let Err(e) = write.send(msg).await {
640 error!(error = %e, "Failed to write message");
641 break;
642 }
643 }
644 _ = shutdown_rx.recv() => {
645 let _ = write.send(Message::Close(None)).await;
646 break;
647 }
648 }
649 }
650 });
651
652 let state_clone = Arc::clone(&state);
653 let ws_stats = Arc::clone(&self.stats);
654 let read_handle = tokio::spawn(async move {
655 while let Some(msg_result) = read.next().await {
656 match msg_result {
657 Ok(Message::Text(text)) => {
658 ws_stats.record_received(text.len() as u64);
659 if let Ok(json) = serde_json::from_str::<Value>(&text) {
660 Self::send_with_backpressure(
661 &message_tx,
662 json,
663 backpressure_strategy,
664 &dropped_messages,
665 )
666 .await;
667 }
668 }
669 Ok(Message::Binary(data)) => {
670 ws_stats.record_received(data.len() as u64);
671 if let Some(json) = String::from_utf8(data.to_vec())
672 .ok()
673 .and_then(|text| serde_json::from_str::<Value>(&text).ok())
674 {
675 Self::send_with_backpressure(
676 &message_tx,
677 json,
678 backpressure_strategy,
679 &dropped_messages,
680 )
681 .await;
682 }
683 }
684 Ok(Message::Pong(_)) => {
685 ws_stats.record_pong();
686 }
687 Ok(Message::Close(_)) => {
688 state_clone
689 .store(WsConnectionState::Disconnected.as_u8(), Ordering::Release);
690 break;
691 }
692 Err(_) => {
693 state_clone.store(WsConnectionState::Error.as_u8(), Ordering::Release);
694 break;
695 }
696 _ => {}
697 }
698 }
699 });
700
701 if ping_interval_ms > 0 {
702 let write_tx_clone = write_tx.clone();
703 let ping_stats = Arc::clone(&self.stats);
704 let ping_state = Arc::clone(&state);
705 let pong_timeout_ms = self.config.pong_timeout;
706
707 tokio::spawn(async move {
708 let mut interval = interval(Duration::from_millis(ping_interval_ms));
709
710 loop {
711 interval.tick().await;
712
713 let now = chrono::Utc::now().timestamp_millis();
714 let last_pong = ping_stats.last_pong_time();
715
716 if last_pong > 0 {
717 let elapsed = now - last_pong;
718 #[allow(clippy::cast_possible_wrap)]
719 if elapsed > pong_timeout_ms as i64 {
720 error!(
722 pong_timeout_ms = pong_timeout_ms,
723 elapsed_ms = elapsed,
724 last_pong_time = last_pong,
725 current_time = now,
726 "WebSocket pong timeout detected - connection appears unresponsive (zombie connection)"
727 );
728 ping_state.store(WsConnectionState::Error.as_u8(), Ordering::Release);
729 debug!(
730 "WebSocket state set to Error due to pong timeout - AutoReconnectCoordinator will trigger reconnection if enabled"
731 );
732 break;
733 }
734 }
735
736 ping_stats.record_ping();
737
738 if write_tx_clone
740 .try_send(Message::Ping(vec![].into()))
741 .is_err()
742 {
743 debug!("WebSocket write channel closed, stopping ping loop");
744 break;
745 }
746 }
747 });
748 }
749
750 tokio::spawn(async move {
751 let _ = tokio::join!(write_handle, read_handle);
752 });
753 }
754
755 async fn send_with_backpressure(
760 tx: &mpsc::Sender<Value>,
761 message: Value,
762 strategy: BackpressureStrategy,
763 dropped_counter: &Arc<AtomicU32>,
764 ) {
765 match strategy {
766 BackpressureStrategy::Block => {
767 if tx.send(message).await.is_err() {
769 warn!("Message channel closed");
770 }
771 }
772 BackpressureStrategy::DropNewest => {
773 match tx.try_send(message) {
775 Ok(()) => {}
776 Err(mpsc::error::TrySendError::Full(_)) => {
777 let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
778 if count % 100 == 1 {
779 warn!(
781 dropped_count = count,
782 "Message channel full, dropping newest message (backpressure)"
783 );
784 }
785 }
786 Err(mpsc::error::TrySendError::Closed(_)) => {
787 warn!("Message channel closed");
788 }
789 }
790 }
791 BackpressureStrategy::DropOldest => {
792 match tx.try_send(message) {
794 Ok(()) => {}
795 Err(mpsc::error::TrySendError::Full(msg)) => {
796 let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
800 if count % 100 == 1 {
801 warn!(
802 dropped_count = count,
803 "Message channel full, dropping oldest message (backpressure)"
804 );
805 }
806 drop(msg);
811 }
812 Err(mpsc::error::TrySendError::Closed(_)) => {
813 warn!("Message channel closed");
814 }
815 }
816 }
817 }
818 }
819
820 async fn send_subscribe_message(
821 &self,
822 channel: String,
823 symbol: Option<String>,
824 params: Option<HashMap<String, Value>>,
825 ) -> Result<()> {
826 let msg = WsMessage::Subscribe {
827 channel,
828 symbol,
829 params,
830 };
831 let json = serde_json::to_value(&msg).map_err(Error::from)?;
832 self.send_json(&json).await
833 }
834
835 async fn send_unsubscribe_message(
836 &self,
837 channel: String,
838 symbol: Option<String>,
839 ) -> Result<()> {
840 let msg = WsMessage::Unsubscribe { channel, symbol };
841 let json = serde_json::to_value(&msg).map_err(Error::from)?;
842 self.send_json(&json).await
843 }
844
845 pub(crate) async fn resubscribe_all(&self) -> Result<()> {
846 let subs = self.subscription_manager.collect_subscriptions();
847 for subscription in subs {
848 self.send_subscribe_message(
849 subscription.channel.clone(),
850 subscription.symbol.clone(),
851 subscription.params.clone(),
852 )
853 .await?;
854 }
855 Ok(())
856 }
857}
858
859#[cfg(test)]
860mod tests {
861 use super::*;
862
863 #[test]
864 fn test_backoff_config_default() {
865 let config = BackoffConfig::default();
866 assert_eq!(config.base_delay, Duration::from_secs(1));
867 assert_eq!(config.max_delay, Duration::from_secs(60));
868 }
869
870 #[test]
871 fn test_backoff_strategy_exponential_growth_no_jitter() {
872 let config = BackoffConfig {
873 base_delay: Duration::from_secs(1),
874 max_delay: Duration::from_secs(60),
875 jitter_factor: 0.0,
876 multiplier: 2.0,
877 };
878 let strategy = BackoffStrategy::new(config);
879
880 assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
881 assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2));
882 assert_eq!(strategy.calculate_delay(2), Duration::from_secs(4));
883 assert_eq!(strategy.calculate_delay(6), Duration::from_secs(60));
884 }
885
886 #[test]
887 fn test_ws_config_default() {
888 let config = WsConfig::default();
889 assert_eq!(config.connect_timeout, 10000);
890 assert_eq!(config.max_subscriptions, DEFAULT_MAX_SUBSCRIPTIONS);
891 }
892
893 #[test]
894 fn test_subscription_key() {
895 let key1 = WsClient::subscription_key("ticker", Some(&"BTC/USDT".to_string()));
896 assert_eq!(key1, "ticker:BTC/USDT");
897
898 let key2 = WsClient::subscription_key("trades", None);
899 assert_eq!(key2, "trades");
900 }
901
902 #[tokio::test]
903 async fn test_ws_client_creation() {
904 let config = WsConfig {
905 url: "wss://example.com/ws".to_string(),
906 ..Default::default()
907 };
908
909 let client = WsClient::new(config);
910 assert_eq!(client.state(), WsConnectionState::Disconnected);
911 assert!(!client.is_connected());
912 }
913
914 #[tokio::test]
915 async fn test_subscribe_adds_subscription() {
916 let config = WsConfig {
917 url: "wss://example.com/ws".to_string(),
918 ..Default::default()
919 };
920
921 let client = WsClient::new(config);
922 let result = client
923 .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
924 .await;
925 assert!(result.is_ok());
926 assert_eq!(client.subscription_count(), 1);
927 assert!(client.is_subscribed("ticker", Some(&"BTC/USDT".to_string())));
928 }
929
930 #[test]
931 fn test_ws_connection_state_from_u8() {
932 assert_eq!(
933 WsConnectionState::from_u8(0),
934 WsConnectionState::Disconnected
935 );
936 assert_eq!(WsConnectionState::from_u8(1), WsConnectionState::Connecting);
937 assert_eq!(WsConnectionState::from_u8(2), WsConnectionState::Connected);
938 assert_eq!(WsConnectionState::from_u8(255), WsConnectionState::Error);
939 }
940
941 #[test]
942 fn test_ws_error_kind() {
943 assert!(WsErrorKind::Transient.is_transient());
944 assert!(WsErrorKind::Permanent.is_permanent());
945 }
946
947 #[test]
948 fn test_ws_error_creation() {
949 let err = WsError::transient("Connection timeout");
950 assert!(err.is_transient());
951 assert_eq!(err.message(), "Connection timeout");
952
953 let err = WsError::permanent("Invalid API key");
954 assert!(err.is_permanent());
955 }
956
957 #[test]
958 fn test_subscription_manager() {
959 let manager = SubscriptionManager::new(2);
960 assert_eq!(manager.max_subscriptions(), 2);
961 assert_eq!(manager.count(), 0);
962 assert!(!manager.is_full());
963
964 let sub = Subscription {
965 channel: "ticker".to_string(),
966 symbol: Some("BTC/USDT".to_string()),
967 params: None,
968 };
969 assert!(manager.try_add("ticker:BTC/USDT".to_string(), sub).is_ok());
970 assert_eq!(manager.count(), 1);
971 }
972}