1mod config;
7mod error;
8mod event;
9mod message;
10mod reconnect;
11mod state;
12mod subscription;
13
14pub use config::{
15 BackoffConfig, BackoffStrategy, BackpressureStrategy, DEFAULT_MAX_SUBSCRIPTIONS,
16 DEFAULT_MESSAGE_CHANNEL_CAPACITY, DEFAULT_SHUTDOWN_TIMEOUT, DEFAULT_WRITE_CHANNEL_CAPACITY,
17 WsConfig,
18};
19pub use error::{WsError, WsErrorKind};
20pub use event::{WsEvent, WsEventCallback};
21pub use message::WsMessage;
22pub use reconnect::AutoReconnectCoordinator;
23pub use state::{WsConnectionState, WsStats, WsStatsSnapshot};
24pub use subscription::{Subscription, SubscriptionManager};
25
26use crate::error::{Error, Result};
27use derive_more::Debug;
28use futures_util::{SinkExt, StreamExt, stream::SplitSink};
29use serde_json::Value;
30use std::collections::HashMap;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU8, AtomicU32, Ordering};
33use tokio::net::TcpStream;
34use tokio::sync::{Mutex, RwLock, mpsc};
35use tokio::time::{Duration, interval};
36use tokio_tungstenite::{
37 MaybeTlsStream, WebSocketStream, connect_async_with_config,
38 tungstenite::protocol::{Message, WebSocketConfig},
39};
40use tokio_util::sync::CancellationToken;
41use tracing::{debug, error, info, instrument, warn};
42
43#[allow(dead_code)]
45type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
46
47#[derive(Debug)]
59pub struct WsClient {
60 config: WsConfig,
61 state: Arc<AtomicU8>,
62 subscription_manager: SubscriptionManager,
63 message_tx: mpsc::Sender<Value>,
64 message_rx: Arc<RwLock<mpsc::Receiver<Value>>>,
65 write_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
66 pub(crate) reconnect_count: AtomicU32,
67 shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
68 stats: Arc<WsStats>,
69 cancel_token: Arc<Mutex<Option<CancellationToken>>>,
70 #[debug(skip)]
71 event_callback: Arc<Mutex<Option<WsEventCallback>>>,
72 dropped_messages: Arc<AtomicU32>,
74}
75
76impl WsClient {
77 pub fn new(config: WsConfig) -> Self {
82 let (message_tx, message_rx) = mpsc::channel(config.message_channel_capacity);
83 let max_subscriptions = config.max_subscriptions;
84
85 Self {
86 config,
87 state: Arc::new(AtomicU8::new(WsConnectionState::Disconnected.as_u8())),
88 subscription_manager: SubscriptionManager::new(max_subscriptions),
89 message_tx,
90 message_rx: Arc::new(RwLock::new(message_rx)),
91 write_tx: Arc::new(RwLock::new(None)),
92 reconnect_count: AtomicU32::new(0),
93 shutdown_tx: Arc::new(Mutex::new(None)),
94 stats: Arc::new(WsStats::new()),
95 cancel_token: Arc::new(Mutex::new(None)),
96 event_callback: Arc::new(Mutex::new(None)),
97 dropped_messages: Arc::new(AtomicU32::new(0)),
98 }
99 }
100
101 pub async fn set_event_callback(&self, callback: WsEventCallback) {
103 *self.event_callback.lock().await = Some(callback);
104 debug!("Event callback set");
105 }
106
107 pub async fn clear_event_callback(&self) {
109 *self.event_callback.lock().await = None;
110 debug!("Event callback cleared");
111 }
112
113 async fn emit_event(&self, event: WsEvent) {
114 let callback = self.event_callback.lock().await;
115 if let Some(ref cb) = *callback {
116 let cb = Arc::clone(cb);
117 drop(callback);
118 tokio::spawn(async move {
119 cb(event);
120 });
121 }
122 }
123
124 pub async fn set_cancel_token(&self, token: CancellationToken) {
126 *self.cancel_token.lock().await = Some(token);
127 debug!("Cancellation token set");
128 }
129
130 pub async fn clear_cancel_token(&self) {
132 *self.cancel_token.lock().await = None;
133 debug!("Cancellation token cleared");
134 }
135
136 pub async fn get_cancel_token(&self) -> Option<CancellationToken> {
138 self.cancel_token.lock().await.clone()
139 }
140
141 #[instrument(
143 name = "ws_connect",
144 skip(self),
145 fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
146 )]
147 pub async fn connect(&self) -> Result<()> {
148 if self.state() == WsConnectionState::Connected {
149 info!("WebSocket already connected");
150 return Ok(());
151 }
152
153 self.set_state(WsConnectionState::Connecting);
154
155 let url = self.config.url.clone();
156 info!("Initiating WebSocket connection");
157
158 let mut ws_config = WebSocketConfig::default();
159 if let Some(limit) = self.config.max_message_size {
160 ws_config.max_message_size = Some(limit);
161 }
162 if let Some(limit) = self.config.max_frame_size {
163 ws_config.max_frame_size = Some(limit);
164 }
165
166 match tokio::time::timeout(
167 Duration::from_millis(self.config.connect_timeout),
168 connect_async_with_config(&url, Some(ws_config), false),
169 )
170 .await
171 {
172 Ok(Ok((ws_stream, response))) => {
173 info!(
174 status = response.status().as_u16(),
175 "WebSocket connection established successfully"
176 );
177
178 self.set_state(WsConnectionState::Connected);
179 self.reconnect_count.store(0, Ordering::Release);
180 self.stats.record_connected();
181 self.start_message_loop(ws_stream).await;
182 self.resubscribe_all().await?;
183
184 Ok(())
185 }
186 Ok(Err(e)) => {
187 error!(error = %e, "WebSocket connection failed");
188 self.set_state(WsConnectionState::Error);
189 Err(Error::network(format!("WebSocket connection failed: {e}")))
190 }
191 Err(_) => {
192 error!(
193 timeout_ms = self.config.connect_timeout,
194 "WebSocket connection timeout"
195 );
196 self.set_state(WsConnectionState::Error);
197 Err(Error::timeout("WebSocket connection timeout"))
198 }
199 }
200 }
201
202 #[instrument(
204 name = "ws_connect_with_cancel",
205 skip(self, cancel_token),
206 fields(url = %self.config.url)
207 )]
208 pub async fn connect_with_cancel(&self, cancel_token: Option<CancellationToken>) -> Result<()> {
209 let token = if let Some(t) = cancel_token {
210 t
211 } else {
212 let internal_token = self.cancel_token.lock().await;
213 internal_token
214 .clone()
215 .unwrap_or_else(CancellationToken::new)
216 };
217
218 if self.state() == WsConnectionState::Connected {
219 info!("WebSocket already connected");
220 return Ok(());
221 }
222
223 self.set_state(WsConnectionState::Connecting);
224 let url = self.config.url.clone();
225
226 let mut ws_config = WebSocketConfig::default();
227 if let Some(limit) = self.config.max_message_size {
228 ws_config.max_message_size = Some(limit);
229 }
230 if let Some(limit) = self.config.max_frame_size {
231 ws_config.max_frame_size = Some(limit);
232 }
233
234 tokio::select! {
235 biased;
236 () = token.cancelled() => {
237 warn!("WebSocket connection cancelled");
238 self.set_state(WsConnectionState::Disconnected);
239 Err(Error::cancelled("WebSocket connection cancelled"))
240 }
241 result = tokio::time::timeout(
242 Duration::from_millis(self.config.connect_timeout),
243 connect_async_with_config(&url, Some(ws_config), false),
244 ) => {
245 match result {
246 Ok(Ok((ws_stream, response))) => {
247 info!(status = response.status().as_u16(), "WebSocket connected");
248 self.set_state(WsConnectionState::Connected);
249 self.reconnect_count.store(0, Ordering::Release);
250 self.stats.record_connected();
251 self.start_message_loop(ws_stream).await;
252 self.resubscribe_all().await?;
253 Ok(())
254 }
255 Ok(Err(e)) => {
256 error!(error = %e, "WebSocket connection failed");
257 self.set_state(WsConnectionState::Error);
258 Err(Error::network(format!("WebSocket connection failed: {e}")))
259 }
260 Err(_) => {
261 error!("WebSocket connection timeout");
262 self.set_state(WsConnectionState::Error);
263 Err(Error::timeout("WebSocket connection timeout"))
264 }
265 }
266 }
267 }
268 }
269
270 #[instrument(name = "ws_disconnect", skip(self))]
272 pub async fn disconnect(&self) -> Result<()> {
273 info!("Initiating WebSocket disconnect");
274
275 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
276 let _ = tx.send(());
277 }
278
279 *self.write_tx.write().await = None;
280 self.set_state(WsConnectionState::Disconnected);
281
282 info!("WebSocket disconnected");
283 Ok(())
284 }
285
286 #[instrument(name = "ws_shutdown", skip(self))]
288 pub async fn shutdown(&self) {
289 info!("Initiating graceful shutdown");
290
291 {
292 let token_guard = self.cancel_token.lock().await;
293 if let Some(ref token) = *token_guard {
294 token.cancel();
295 }
296 }
297
298 self.set_state(WsConnectionState::Disconnected);
299
300 {
301 let write_tx_guard = self.write_tx.read().await;
302 if let Some(ref tx) = *write_tx_guard {
303 drop(tx.send(Message::Close(None)).await);
305 }
306 }
307
308 let shutdown_timeout = Duration::from_millis(self.config.shutdown_timeout);
309 let _ = tokio::time::timeout(shutdown_timeout, async {
310 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
311 let _ = tx.send(());
312 }
313 tokio::time::sleep(Duration::from_millis(100)).await;
314 })
315 .await;
316
317 {
318 *self.write_tx.write().await = None;
319 *self.shutdown_tx.lock().await = None;
320 self.subscription_manager.clear();
321 self.reconnect_count.store(0, Ordering::Release);
322 self.dropped_messages.store(0, Ordering::Relaxed);
323 self.stats.reset();
324 }
325
326 self.emit_event(WsEvent::Shutdown).await;
327 info!("Graceful shutdown completed");
328 }
329
330 #[instrument(name = "ws_reconnect", skip(self))]
332 pub async fn reconnect(&self) -> Result<()> {
333 let count = self.reconnect_count.fetch_add(1, Ordering::AcqRel) + 1;
334
335 if count > self.config.max_reconnect_attempts {
336 error!(attempts = count, "Max reconnect attempts reached");
337 return Err(Error::network("Max reconnect attempts reached"));
338 }
339
340 warn!(attempt = count, "Attempting WebSocket reconnection");
341 self.set_state(WsConnectionState::Reconnecting);
342
343 tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
344 self.connect().await
345 }
346
347 #[instrument(name = "ws_reconnect_with_cancel", skip(self, cancel_token))]
349 pub async fn reconnect_with_cancel(
350 &self,
351 cancel_token: Option<CancellationToken>,
352 ) -> Result<()> {
353 let token = if let Some(t) = cancel_token {
354 t
355 } else {
356 let internal_token = self.cancel_token.lock().await;
357 internal_token
358 .clone()
359 .unwrap_or_else(CancellationToken::new)
360 };
361
362 let backoff = BackoffStrategy::new(self.config.backoff_config.clone());
363 self.set_state(WsConnectionState::Reconnecting);
364
365 loop {
366 if token.is_cancelled() {
367 self.set_state(WsConnectionState::Disconnected);
368 return Err(Error::cancelled("Reconnection cancelled"));
369 }
370
371 let attempt = self.reconnect_count.fetch_add(1, Ordering::AcqRel);
372
373 if attempt >= self.config.max_reconnect_attempts {
374 self.set_state(WsConnectionState::Error);
375 return Err(Error::network(format!(
376 "Max reconnect attempts ({}) reached",
377 self.config.max_reconnect_attempts
378 )));
379 }
380
381 let delay = backoff.calculate_delay(attempt);
382
383 tokio::select! {
384 biased;
385 () = token.cancelled() => {
386 self.set_state(WsConnectionState::Disconnected);
387 return Err(Error::cancelled("Reconnection cancelled during backoff"));
388 }
389 () = tokio::time::sleep(delay) => {}
390 }
391
392 match self.connect_with_cancel(Some(token.clone())).await {
393 Ok(()) => {
394 self.reconnect_count.store(0, Ordering::Release);
395 return Ok(());
396 }
397 Err(e) => {
398 if e.as_cancelled().is_some() {
399 self.set_state(WsConnectionState::Disconnected);
400 return Err(e);
401 }
402
403 let ws_error = WsError::from_error(&e);
404 if ws_error.is_permanent() {
405 self.set_state(WsConnectionState::Error);
406 return Err(e);
407 }
408 }
409 }
410 }
411 }
412
413 #[inline]
415 pub fn reconnect_count(&self) -> u32 {
416 self.reconnect_count.load(Ordering::Acquire)
417 }
418
419 pub fn reset_reconnect_count(&self) {
421 self.reconnect_count.store(0, Ordering::Release);
422 }
423
424 pub(crate) fn increment_reconnect_count(&self) {
426 self.reconnect_count.fetch_add(1, Ordering::AcqRel);
427 }
428
429 pub fn stats(&self) -> WsStatsSnapshot {
431 self.stats.snapshot()
432 }
433
434 pub fn reset_stats(&self) {
436 self.stats.reset();
437 }
438
439 pub fn latency(&self) -> Option<i64> {
441 let last_pong = self.stats.last_pong_time();
442 let last_ping = self.stats.last_ping_time();
443 if last_pong > 0 && last_ping > 0 {
444 Some(last_pong - last_ping)
445 } else {
446 None
447 }
448 }
449
450 pub fn dropped_messages(&self) -> u32 {
455 self.dropped_messages.load(Ordering::Relaxed)
456 }
457
458 pub fn reset_dropped_messages(&self) {
460 self.dropped_messages.store(0, Ordering::Relaxed);
461 }
462
463 pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
465 AutoReconnectCoordinator::new(self)
466 }
467
468 #[instrument(name = "ws_subscribe", skip(self, params), fields(channel = %channel))]
470 pub async fn subscribe(
471 &self,
472 channel: String,
473 symbol: Option<String>,
474 params: Option<HashMap<String, Value>>,
475 ) -> Result<()> {
476 let sub_key = Self::subscription_key(&channel, symbol.as_ref());
477 let subscription = Subscription {
478 channel: channel.clone(),
479 symbol: symbol.clone(),
480 params: params.clone(),
481 };
482
483 self.subscription_manager
484 .try_add(sub_key.clone(), subscription)?;
485
486 if self.state() == WsConnectionState::Connected {
487 self.send_subscribe_message(channel, symbol, params).await?;
488 }
489
490 Ok(())
491 }
492
493 #[instrument(name = "ws_unsubscribe", skip(self), fields(channel = %channel))]
495 pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
496 let sub_key = Self::subscription_key(&channel, symbol.as_ref());
497 self.subscription_manager.remove(&sub_key);
498
499 if self.state() == WsConnectionState::Connected {
500 self.send_unsubscribe_message(channel, symbol).await?;
501 }
502
503 Ok(())
504 }
505
506 pub async fn receive(&self) -> Option<Value> {
508 let mut rx = self.message_rx.write().await;
509 rx.recv().await
510 }
511
512 #[inline]
514 pub fn state(&self) -> WsConnectionState {
515 WsConnectionState::from_u8(self.state.load(Ordering::Acquire))
516 }
517
518 #[inline]
520 pub fn config(&self) -> &WsConfig {
521 &self.config
522 }
523
524 #[inline]
526 pub fn set_state(&self, state: WsConnectionState) {
527 self.state.store(state.as_u8(), Ordering::Release);
528 }
529
530 #[inline]
532 pub fn is_connected(&self) -> bool {
533 self.state() == WsConnectionState::Connected
534 }
535
536 pub fn is_subscribed(&self, channel: &str, symbol: Option<&String>) -> bool {
538 let sub_key = Self::subscription_key(channel, symbol);
539 self.subscription_manager.contains(&sub_key)
540 }
541
542 pub fn subscription_count(&self) -> usize {
544 self.subscription_manager.count()
545 }
546
547 pub fn remaining_capacity(&self) -> usize {
549 self.subscription_manager.remaining_capacity()
550 }
551
552 pub fn subscriptions(&self) -> Vec<String> {
557 self.subscription_manager
558 .iter()
559 .map(|entry| {
560 let sub = entry.value();
561 match &sub.symbol {
562 Some(sym) => format!("{}:{}", sub.channel, sym),
563 None => sub.channel.clone(),
564 }
565 })
566 .collect()
567 }
568
569 #[instrument(name = "ws_send", skip(self, message))]
574 pub async fn send(&self, message: Message) -> Result<()> {
575 let tx = self.write_tx.read().await;
576
577 if let Some(sender) = tx.as_ref() {
578 sender
579 .send(message)
580 .await
581 .map_err(|e| Error::network(format!("Failed to send message: {e}")))?;
582 Ok(())
583 } else {
584 Err(Error::network("WebSocket not connected"))
585 }
586 }
587
588 #[instrument(name = "ws_try_send", skip(self, message))]
592 pub fn try_send(&self, message: Message) -> Result<()> {
593 if let Ok(tx) = self.write_tx.try_read() {
596 if let Some(sender) = tx.as_ref() {
597 sender.try_send(message).map_err(|e| match e {
598 mpsc::error::TrySendError::Full(_) => {
599 Error::network("Write channel full (backpressure)")
600 }
601 mpsc::error::TrySendError::Closed(_) => {
602 Error::network("WebSocket channel closed")
603 }
604 })?;
605 Ok(())
606 } else {
607 Err(Error::network("WebSocket not connected"))
608 }
609 } else {
610 Err(Error::network("Write channel busy"))
611 }
612 }
613
614 #[instrument(name = "ws_send_text", skip(self, text))]
616 pub async fn send_text(&self, text: String) -> Result<()> {
617 self.send(Message::Text(text.into())).await
618 }
619
620 #[instrument(name = "ws_send_json", skip(self, json))]
622 pub async fn send_json(&self, json: &Value) -> Result<()> {
623 let text = serde_json::to_string(json).map_err(Error::from)?;
624 self.send_text(text).await
625 }
626
627 fn subscription_key(channel: &str, symbol: Option<&String>) -> String {
628 match symbol {
629 Some(s) => format!("{channel}:{s}"),
630 None => channel.to_string(),
631 }
632 }
633
634 async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
635 let (write, mut read) = ws_stream.split();
636
637 let (write_tx, mut write_rx) = mpsc::channel::<Message>(self.config.write_channel_capacity);
639 *self.write_tx.write().await = Some(write_tx.clone());
640
641 let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
643 *self.shutdown_tx.lock().await = Some(shutdown_tx);
644
645 let state = Arc::clone(&self.state);
646 let message_tx = self.message_tx.clone();
647 let ping_interval_ms = self.config.ping_interval;
648 let backpressure_strategy = self.config.backpressure_strategy;
649 let dropped_messages = Arc::clone(&self.dropped_messages);
650
651 let write_handle = tokio::spawn(async move {
652 let mut write = write;
653 loop {
654 tokio::select! {
655 Some(msg) = write_rx.recv() => {
656 if let Err(e) = write.send(msg).await {
657 error!(error = %e, "Failed to write message");
658 break;
659 }
660 }
661 _ = shutdown_rx.recv() => {
662 let _ = write.send(Message::Close(None)).await;
663 break;
664 }
665 }
666 }
667 });
668
669 let state_clone = Arc::clone(&state);
670 let ws_stats = Arc::clone(&self.stats);
671 let read_handle = tokio::spawn(async move {
672 while let Some(msg_result) = read.next().await {
673 match msg_result {
674 Ok(Message::Text(text)) => {
675 ws_stats.record_received(text.len() as u64);
676 if let Ok(json) = serde_json::from_str::<Value>(&text) {
677 Self::send_with_backpressure(
678 &message_tx,
679 json,
680 backpressure_strategy,
681 &dropped_messages,
682 )
683 .await;
684 }
685 }
686 Ok(Message::Binary(data)) => {
687 ws_stats.record_received(data.len() as u64);
688 if let Some(json) = String::from_utf8(data.to_vec())
689 .ok()
690 .and_then(|text| serde_json::from_str::<Value>(&text).ok())
691 {
692 Self::send_with_backpressure(
693 &message_tx,
694 json,
695 backpressure_strategy,
696 &dropped_messages,
697 )
698 .await;
699 }
700 }
701 Ok(Message::Pong(_)) => {
702 ws_stats.record_pong();
703 }
704 Ok(Message::Close(_)) => {
705 state_clone
706 .store(WsConnectionState::Disconnected.as_u8(), Ordering::Release);
707 break;
708 }
709 Err(_) => {
710 state_clone.store(WsConnectionState::Error.as_u8(), Ordering::Release);
711 break;
712 }
713 _ => {}
714 }
715 }
716 });
717
718 if ping_interval_ms > 0 {
719 let write_tx_clone = write_tx.clone();
720 let ping_stats = Arc::clone(&self.stats);
721 let ping_state = Arc::clone(&state);
722 let pong_timeout_ms = self.config.pong_timeout;
723
724 tokio::spawn(async move {
725 let mut interval = interval(Duration::from_millis(ping_interval_ms));
726
727 loop {
728 interval.tick().await;
729
730 let now = chrono::Utc::now().timestamp_millis();
731 let last_pong = ping_stats.last_pong_time();
732
733 if last_pong > 0 {
734 let elapsed = now - last_pong;
735 #[allow(clippy::cast_possible_wrap)]
736 if elapsed > pong_timeout_ms as i64 {
737 error!(
739 pong_timeout_ms = pong_timeout_ms,
740 elapsed_ms = elapsed,
741 last_pong_time = last_pong,
742 current_time = now,
743 "WebSocket pong timeout detected - connection appears unresponsive (zombie connection)"
744 );
745 ping_state.store(WsConnectionState::Error.as_u8(), Ordering::Release);
746 debug!(
747 "WebSocket state set to Error due to pong timeout - AutoReconnectCoordinator will trigger reconnection if enabled"
748 );
749 break;
750 }
751 }
752
753 ping_stats.record_ping();
754
755 if write_tx_clone
757 .try_send(Message::Ping(vec![].into()))
758 .is_err()
759 {
760 debug!("WebSocket write channel closed, stopping ping loop");
761 break;
762 }
763 }
764 });
765 }
766
767 tokio::spawn(async move {
768 let _ = tokio::join!(write_handle, read_handle);
769 });
770 }
771
772 async fn send_with_backpressure(
777 tx: &mpsc::Sender<Value>,
778 message: Value,
779 strategy: BackpressureStrategy,
780 dropped_counter: &Arc<AtomicU32>,
781 ) {
782 match strategy {
783 BackpressureStrategy::Block => {
784 if tx.send(message).await.is_err() {
786 warn!("Message channel closed");
787 }
788 }
789 BackpressureStrategy::DropNewest => {
790 match tx.try_send(message) {
792 Ok(()) => {}
793 Err(mpsc::error::TrySendError::Full(_)) => {
794 let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
795 if count % 100 == 1 {
796 warn!(
798 dropped_count = count,
799 "Message channel full, dropping newest message (backpressure)"
800 );
801 }
802 }
803 Err(mpsc::error::TrySendError::Closed(_)) => {
804 warn!("Message channel closed");
805 }
806 }
807 }
808 BackpressureStrategy::DropOldest => {
809 match tx.try_send(message) {
811 Ok(()) => {}
812 Err(mpsc::error::TrySendError::Full(msg)) => {
813 let count = dropped_counter.fetch_add(1, Ordering::Relaxed) + 1;
817 if count % 100 == 1 {
818 warn!(
819 dropped_count = count,
820 "Message channel full, dropping oldest message (backpressure)"
821 );
822 }
823 drop(msg);
828 }
829 Err(mpsc::error::TrySendError::Closed(_)) => {
830 warn!("Message channel closed");
831 }
832 }
833 }
834 }
835 }
836
837 async fn send_subscribe_message(
838 &self,
839 channel: String,
840 symbol: Option<String>,
841 params: Option<HashMap<String, Value>>,
842 ) -> Result<()> {
843 let msg = WsMessage::Subscribe {
844 channel,
845 symbol,
846 params,
847 };
848 let json = serde_json::to_value(&msg).map_err(Error::from)?;
849 self.send_json(&json).await
850 }
851
852 async fn send_unsubscribe_message(
853 &self,
854 channel: String,
855 symbol: Option<String>,
856 ) -> Result<()> {
857 let msg = WsMessage::Unsubscribe { channel, symbol };
858 let json = serde_json::to_value(&msg).map_err(Error::from)?;
859 self.send_json(&json).await
860 }
861
862 pub(crate) async fn resubscribe_all(&self) -> Result<()> {
863 let subs = self.subscription_manager.collect_subscriptions();
864 for subscription in subs {
865 self.send_subscribe_message(
866 subscription.channel.clone(),
867 subscription.symbol.clone(),
868 subscription.params.clone(),
869 )
870 .await?;
871 }
872 Ok(())
873 }
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879
880 #[test]
881 fn test_backoff_config_default() {
882 let config = BackoffConfig::default();
883 assert_eq!(config.base_delay, Duration::from_secs(1));
884 assert_eq!(config.max_delay, Duration::from_secs(60));
885 }
886
887 #[test]
888 fn test_backoff_strategy_exponential_growth_no_jitter() {
889 let config = BackoffConfig {
890 base_delay: Duration::from_secs(1),
891 max_delay: Duration::from_secs(60),
892 jitter_factor: 0.0,
893 multiplier: 2.0,
894 };
895 let strategy = BackoffStrategy::new(config);
896
897 assert_eq!(strategy.calculate_delay(0), Duration::from_secs(1));
898 assert_eq!(strategy.calculate_delay(1), Duration::from_secs(2));
899 assert_eq!(strategy.calculate_delay(2), Duration::from_secs(4));
900 assert_eq!(strategy.calculate_delay(6), Duration::from_secs(60));
901 }
902
903 #[test]
904 fn test_ws_config_default() {
905 let config = WsConfig::default();
906 assert_eq!(config.connect_timeout, 10000);
907 assert_eq!(config.max_subscriptions, DEFAULT_MAX_SUBSCRIPTIONS);
908 }
909
910 #[test]
911 fn test_subscription_key() {
912 let key1 = WsClient::subscription_key("ticker", Some(&"BTC/USDT".to_string()));
913 assert_eq!(key1, "ticker:BTC/USDT");
914
915 let key2 = WsClient::subscription_key("trades", None);
916 assert_eq!(key2, "trades");
917 }
918
919 #[tokio::test]
920 async fn test_ws_client_creation() {
921 let config = WsConfig {
922 url: "wss://example.com/ws".to_string(),
923 ..Default::default()
924 };
925
926 let client = WsClient::new(config);
927 assert_eq!(client.state(), WsConnectionState::Disconnected);
928 assert!(!client.is_connected());
929 }
930
931 #[tokio::test]
932 async fn test_subscribe_adds_subscription() {
933 let config = WsConfig {
934 url: "wss://example.com/ws".to_string(),
935 ..Default::default()
936 };
937
938 let client = WsClient::new(config);
939 let result = client
940 .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
941 .await;
942 assert!(result.is_ok());
943 assert_eq!(client.subscription_count(), 1);
944 assert!(client.is_subscribed("ticker", Some(&"BTC/USDT".to_string())));
945 }
946
947 #[test]
948 fn test_ws_connection_state_from_u8() {
949 assert_eq!(
950 WsConnectionState::from_u8(0),
951 WsConnectionState::Disconnected
952 );
953 assert_eq!(WsConnectionState::from_u8(1), WsConnectionState::Connecting);
954 assert_eq!(WsConnectionState::from_u8(2), WsConnectionState::Connected);
955 assert_eq!(WsConnectionState::from_u8(255), WsConnectionState::Error);
956 }
957
958 #[test]
959 fn test_ws_error_kind() {
960 assert!(WsErrorKind::Transient.is_transient());
961 assert!(WsErrorKind::Permanent.is_permanent());
962 }
963
964 #[test]
965 fn test_ws_error_creation() {
966 let err = WsError::transient("Connection timeout");
967 assert!(err.is_transient());
968 assert_eq!(err.message(), "Connection timeout");
969
970 let err = WsError::permanent("Invalid API key");
971 assert!(err.is_permanent());
972 }
973
974 #[test]
975 fn test_subscription_manager() {
976 let manager = SubscriptionManager::new(2);
977 assert_eq!(manager.max_subscriptions(), 2);
978 assert_eq!(manager.count(), 0);
979 assert!(!manager.is_full());
980
981 let sub = Subscription {
982 channel: "ticker".to_string(),
983 symbol: Some("BTC/USDT".to_string()),
984 params: None,
985 };
986 assert!(manager.try_add("ticker:BTC/USDT".to_string(), sub).is_ok());
987 assert_eq!(manager.count(), 1);
988 }
989}