1use crate::error::{Error, Result};
16use futures_util::{SinkExt, StreamExt, stream::SplitSink};
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::sync::atomic::{AtomicBool, Ordering};
22use tokio::net::TcpStream;
23use tokio::sync::{Mutex, RwLock, mpsc};
24use tokio::task::JoinHandle;
25use tokio::time::{Duration, interval};
26use tokio_tungstenite::{
27 MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
28};
29use tracing::{debug, error, info, instrument, warn};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum WsConnectionState {
34 Disconnected,
36 Connecting,
38 Connected,
40 Reconnecting,
42 Error,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(tag = "type", rename_all = "lowercase")]
49pub enum WsMessage {
50 Subscribe {
52 channel: String,
54 symbol: Option<String>,
56 params: Option<HashMap<String, Value>>,
58 },
59 Unsubscribe {
61 channel: String,
63 symbol: Option<String>,
65 },
66 Ping {
68 timestamp: i64,
70 },
71 Pong {
73 timestamp: i64,
75 },
76 Auth {
78 api_key: String,
80 signature: String,
82 timestamp: i64,
84 },
85 Custom(Value),
87}
88
89#[derive(Debug, Clone)]
91pub struct WsConfig {
92 pub url: String,
94 pub connect_timeout: u64,
96 pub ping_interval: u64,
98 pub reconnect_interval: u64,
100 pub max_reconnect_attempts: u32,
102 pub auto_reconnect: bool,
104 pub enable_compression: bool,
106 pub pong_timeout: u64,
110}
111
112impl Default for WsConfig {
113 fn default() -> Self {
114 Self {
115 url: String::new(),
116 connect_timeout: 10000,
117 ping_interval: 30000,
118 reconnect_interval: 5000,
119 max_reconnect_attempts: 5,
120 auto_reconnect: true,
121 enable_compression: false,
122 pong_timeout: 90000,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct Subscription {
130 channel: String,
131 symbol: Option<String>,
132 params: Option<HashMap<String, Value>>,
133}
134
135#[allow(dead_code)]
137type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
138
139pub struct WsClient {
141 config: WsConfig,
142 state: Arc<RwLock<WsConnectionState>>,
143 subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
144
145 message_tx: mpsc::UnboundedSender<Value>,
146 message_rx: Arc<RwLock<mpsc::UnboundedReceiver<Value>>>,
147
148 write_tx: Arc<Mutex<Option<mpsc::UnboundedSender<Message>>>>,
149
150 reconnect_count: Arc<RwLock<u32>>,
151
152 shutdown_tx: Arc<Mutex<Option<mpsc::UnboundedSender<()>>>>,
153
154 stats: Arc<RwLock<WsStats>>,
155}
156
157#[derive(Debug, Clone, Default)]
159pub struct WsStats {
160 pub messages_received: u64,
162 pub messages_sent: u64,
164 pub bytes_received: u64,
166 pub bytes_sent: u64,
168 pub last_message_time: i64,
170 pub last_ping_time: i64,
172 pub last_pong_time: i64,
174 pub connected_at: i64,
176 pub reconnect_attempts: u32,
178}
179
180impl WsClient {
181 pub fn new(config: WsConfig) -> Self {
191 let (message_tx, message_rx) = mpsc::unbounded_channel();
192
193 Self {
194 config,
195 state: Arc::new(RwLock::new(WsConnectionState::Disconnected)),
196 subscriptions: Arc::new(RwLock::new(HashMap::new())),
197 message_tx,
198 message_rx: Arc::new(RwLock::new(message_rx)),
199 write_tx: Arc::new(Mutex::new(None)),
200 reconnect_count: Arc::new(RwLock::new(0)),
201 shutdown_tx: Arc::new(Mutex::new(None)),
202 stats: Arc::new(RwLock::new(WsStats::default())),
203 }
204 }
205
206 #[instrument(
218 name = "ws_connect",
219 skip(self),
220 fields(url = %self.config.url, timeout_ms = self.config.connect_timeout)
221 )]
222 pub async fn connect(&self) -> Result<()> {
223 {
224 let state = self.state.read().await;
225 if *state == WsConnectionState::Connected {
226 info!("WebSocket already connected");
227 return Ok(());
228 }
229 }
230
231 {
232 let mut state = self.state.write().await;
233 *state = WsConnectionState::Connecting;
234 }
235
236 let url = self.config.url.clone();
237 info!("Initiating WebSocket connection");
238
239 match tokio::time::timeout(
240 Duration::from_millis(self.config.connect_timeout),
241 connect_async(&url),
242 )
243 .await
244 {
245 Ok(Ok((ws_stream, response))) => {
246 info!(
247 status = response.status().as_u16(),
248 "WebSocket connection established successfully"
249 );
250
251 *self.state.write().await = WsConnectionState::Connected;
252 *self.reconnect_count.write().await = 0;
253
254 {
255 let mut stats = self.stats.write().await;
256 stats.connected_at = chrono::Utc::now().timestamp_millis();
257 }
258
259 self.start_message_loop(ws_stream).await;
260
261 self.resubscribe_all().await?;
262
263 Ok(())
264 }
265 Ok(Err(e)) => {
266 error!(
267 error = %e,
268 error_debug = ?e,
269 "WebSocket connection failed"
270 );
271 *self.state.write().await = WsConnectionState::Error;
272 Err(Error::network(format!(
273 "WebSocket connection failed: {}",
274 e
275 )))
276 }
277 Err(_) => {
278 error!(
279 timeout_ms = self.config.connect_timeout,
280 "WebSocket connection timeout exceeded"
281 );
282 *self.state.write().await = WsConnectionState::Error;
283 Err(Error::timeout("WebSocket connection timeout"))
284 }
285 }
286 }
287
288 #[instrument(name = "ws_disconnect", skip(self))]
292 pub async fn disconnect(&self) -> Result<()> {
293 info!("Initiating WebSocket disconnect");
294
295 if let Some(tx) = self.shutdown_tx.lock().await.as_ref() {
296 let _ = tx.send(());
297 debug!("Shutdown signal sent to background tasks");
298 }
299
300 *self.write_tx.lock().await = None;
301
302 let mut state = self.state.write().await;
303 *state = WsConnectionState::Disconnected;
304
305 info!("WebSocket disconnected successfully");
306 Ok(())
307 }
308
309 #[instrument(
318 name = "ws_reconnect",
319 skip(self),
320 fields(
321 max_attempts = self.config.max_reconnect_attempts,
322 reconnect_interval_ms = self.config.reconnect_interval
323 )
324 )]
325 pub async fn reconnect(&self) -> Result<()> {
326 let mut count = self.reconnect_count.write().await;
327
328 if *count >= self.config.max_reconnect_attempts {
329 error!(
330 attempts = *count,
331 max = self.config.max_reconnect_attempts,
332 "Max reconnect attempts reached, giving up"
333 );
334 return Err(Error::network("Max reconnect attempts reached"));
335 }
336
337 *count += 1;
338
339 warn!(
340 attempt = *count,
341 max = self.config.max_reconnect_attempts,
342 delay_ms = self.config.reconnect_interval,
343 "Attempting WebSocket reconnection"
344 );
345
346 *self.state.write().await = WsConnectionState::Reconnecting;
347
348 tokio::time::sleep(Duration::from_millis(self.config.reconnect_interval)).await;
349
350 self.connect().await
351 }
352
353 pub async fn reconnect_count(&self) -> u32 {
355 *self.reconnect_count.read().await
356 }
357
358 pub async fn reset_reconnect_count(&self) {
360 *self.reconnect_count.write().await = 0;
361 debug!("Reconnect count reset");
362 }
363
364 pub async fn stats(&self) -> WsStats {
366 self.stats.read().await.clone()
367 }
368
369 pub async fn reset_stats(&self) {
371 *self.stats.write().await = WsStats::default();
372 debug!("Stats reset");
373 }
374
375 pub async fn latency(&self) -> Option<i64> {
381 let stats = self.stats.read().await;
382 if stats.last_pong_time > 0 && stats.last_ping_time > 0 {
383 Some(stats.last_pong_time - stats.last_ping_time)
384 } else {
385 None
386 }
387 }
388
389 pub fn create_auto_reconnect_coordinator(self: Arc<Self>) -> AutoReconnectCoordinator {
395 AutoReconnectCoordinator::new(self)
396 }
397
398 #[instrument(
412 name = "ws_subscribe",
413 skip(self, params),
414 fields(channel = %channel, symbol = ?symbol)
415 )]
416 pub async fn subscribe(
417 &self,
418 channel: String,
419 symbol: Option<String>,
420 params: Option<HashMap<String, Value>>,
421 ) -> Result<()> {
422 let sub_key = Self::subscription_key(&channel, &symbol);
423 let subscription = Subscription {
424 channel: channel.clone(),
425 symbol: symbol.clone(),
426 params: params.clone(),
427 };
428
429 {
430 let mut subs = self.subscriptions.write().await;
431 subs.insert(sub_key.clone(), subscription);
432 }
433
434 info!(subscription_key = %sub_key, "Subscription registered");
435
436 let state = *self.state.read().await;
437 if state == WsConnectionState::Connected {
438 self.send_subscribe_message(channel, symbol, params).await?;
439 info!(subscription_key = %sub_key, "Subscription message sent");
440 } else {
441 debug!(
442 subscription_key = %sub_key,
443 state = ?state,
444 "Subscription queued (not connected)"
445 );
446 }
447
448 Ok(())
449 }
450
451 #[instrument(
464 name = "ws_unsubscribe",
465 skip(self),
466 fields(channel = %channel, symbol = ?symbol)
467 )]
468 pub async fn unsubscribe(&self, channel: String, symbol: Option<String>) -> Result<()> {
469 let sub_key = Self::subscription_key(&channel, &symbol);
470
471 {
472 let mut subs = self.subscriptions.write().await;
473 subs.remove(&sub_key);
474 }
475
476 info!(subscription_key = %sub_key, "Subscription removed");
477
478 let state = *self.state.read().await;
479 if state == WsConnectionState::Connected {
480 self.send_unsubscribe_message(channel, symbol).await?;
481 info!(subscription_key = %sub_key, "Unsubscribe message sent");
482 }
483
484 Ok(())
485 }
486
487 pub async fn receive(&self) -> Option<Value> {
493 let mut rx = self.message_rx.write().await;
494 rx.recv().await
495 }
496
497 pub async fn state(&self) -> WsConnectionState {
499 *self.state.read().await
500 }
501
502 pub async fn is_connected(&self) -> bool {
504 *self.state.read().await == WsConnectionState::Connected
505 }
506
507 #[instrument(name = "ws_send", skip(self, message))]
517 pub async fn send(&self, message: Message) -> Result<()> {
518 let tx = self.write_tx.lock().await;
519
520 if let Some(sender) = tx.as_ref() {
521 sender.send(message).map_err(|e| {
522 error!(
523 error = %e,
524 "Failed to send WebSocket message"
525 );
526 Error::network(format!("Failed to send message: {}", e))
527 })?;
528 debug!("WebSocket message sent successfully");
529 Ok(())
530 } else {
531 warn!("WebSocket not connected, cannot send message");
532 Err(Error::network("WebSocket not connected"))
533 }
534 }
535
536 #[instrument(name = "ws_send_text", skip(self, text), fields(text_len = text.len()))]
546 pub async fn send_text(&self, text: String) -> Result<()> {
547 self.send(Message::Text(text.into())).await
548 }
549
550 #[instrument(name = "ws_send_json", skip(self, json))]
560 pub async fn send_json(&self, json: &Value) -> Result<()> {
561 let text = serde_json::to_string(json).map_err(|e| {
562 error!(error = %e, "Failed to serialize JSON for WebSocket");
563 Error::from(e)
564 })?;
565 self.send_text(text).await
566 }
567
568 fn subscription_key(channel: &str, symbol: &Option<String>) -> String {
570 match symbol {
571 Some(s) => format!("{}:{}", channel, s),
572 None => channel.to_string(),
573 }
574 }
575
576 async fn start_message_loop(&self, ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>) {
580 let (write, mut read) = ws_stream.split();
581
582 let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Message>();
583 *self.write_tx.lock().await = Some(write_tx.clone());
584
585 let (shutdown_tx, mut shutdown_rx) = mpsc::unbounded_channel::<()>();
586 *self.shutdown_tx.lock().await = Some(shutdown_tx);
587
588 let state = Arc::clone(&self.state);
589 let message_tx = self.message_tx.clone();
590 let ping_interval_ms = self.config.ping_interval;
591
592 info!("Starting WebSocket message loop");
593
594 let write_handle = tokio::spawn(async move {
595 let mut write = write;
596 loop {
597 tokio::select! {
598 Some(msg) = write_rx.recv() => {
599 if let Err(e) = write.send(msg).await {
600 error!(error = %e, "Failed to write message");
601 break;
602 }
603 }
604 _ = shutdown_rx.recv() => {
605 debug!("Write task received shutdown signal");
606 let _ = write.send(Message::Close(None)).await;
607 break;
608 }
609 }
610 }
611 debug!("Write task terminated");
612 });
613
614 let state_clone = Arc::clone(&state);
615 let ws_stats = Arc::clone(&self.stats);
616 let read_handle = tokio::spawn(async move {
617 debug!("Starting WebSocket read task");
618 while let Some(msg_result) = read.next().await {
619 match msg_result {
620 Ok(Message::Text(text)) => {
621 debug!(len = text.len(), "Received text message");
622
623 {
624 let mut stats_guard = ws_stats.write().await;
625 stats_guard.messages_received += 1;
626 stats_guard.bytes_received += text.len() as u64;
627 stats_guard.last_message_time = chrono::Utc::now().timestamp_millis();
628 }
629
630 match serde_json::from_str::<Value>(&text) {
631 Ok(json) => {
632 let _ = message_tx.send(json);
633 }
634 Err(e) => {
635 let raw_preview: String = text.chars().take(200).collect();
637 warn!(
638 error = %e,
639 raw_message_preview = %raw_preview,
640 raw_message_len = text.len(),
641 "Failed to parse WebSocket text message as JSON"
642 );
643 }
644 }
645 }
646 Ok(Message::Binary(data)) => {
647 debug!(len = data.len(), "Received binary message");
648
649 {
650 let mut stats_guard = ws_stats.write().await;
651 stats_guard.messages_received += 1;
652 stats_guard.bytes_received += data.len() as u64;
653 stats_guard.last_message_time = chrono::Utc::now().timestamp_millis();
654 }
655
656 match String::from_utf8(data.to_vec()) {
657 Ok(text) => {
658 match serde_json::from_str::<Value>(&text) {
659 Ok(json) => {
660 let _ = message_tx.send(json);
661 }
662 Err(e) => {
663 let raw_preview: String = text.chars().take(200).collect();
665 warn!(
666 error = %e,
667 raw_message_preview = %raw_preview,
668 raw_message_len = text.len(),
669 "Failed to parse WebSocket binary message as JSON"
670 );
671 }
672 }
673 }
674 Err(e) => {
675 let hex_preview: String = data
677 .iter()
678 .take(50)
679 .map(|b| format!("{:02x}", b))
680 .collect::<Vec<_>>()
681 .join(" ");
682 warn!(
683 error = %e,
684 hex_preview = %hex_preview,
685 data_len = data.len(),
686 "Failed to decode WebSocket binary message as UTF-8"
687 );
688 }
689 }
690 }
691 Ok(Message::Ping(_)) => {
692 debug!("Received ping, auto-responding with pong");
693 }
694 Ok(Message::Pong(_)) => {
695 debug!("Received pong");
696
697 {
698 let mut stats_guard = ws_stats.write().await;
699 stats_guard.last_pong_time = chrono::Utc::now().timestamp_millis();
700 }
701 }
702 Ok(Message::Close(frame)) => {
703 info!(
704 close_frame = ?frame,
705 "Received WebSocket close frame"
706 );
707 *state_clone.write().await = WsConnectionState::Disconnected;
708 break;
709 }
710 Err(e) => {
711 error!(
712 error = %e,
713 error_debug = ?e,
714 "WebSocket read error"
715 );
716 *state_clone.write().await = WsConnectionState::Error;
717 break;
718 }
719 _ => {
720 debug!("Received other WebSocket message type");
721 }
722 }
723 }
724 debug!("WebSocket read task terminated");
725 });
726
727 if ping_interval_ms > 0 {
728 let write_tx_clone = write_tx.clone();
729 let ping_stats = Arc::clone(&self.stats);
730 let ping_state = Arc::clone(&state);
731 let pong_timeout_ms = self.config.pong_timeout;
732
733 tokio::spawn(async move {
734 let mut interval = interval(Duration::from_millis(ping_interval_ms));
735 debug!(
736 interval_ms = ping_interval_ms,
737 timeout_ms = pong_timeout_ms,
738 "Starting ping task with timeout detection"
739 );
740
741 loop {
742 interval.tick().await;
743
744 let now = chrono::Utc::now().timestamp_millis();
745 let last_pong = {
746 let stats_guard = ping_stats.read().await;
747 stats_guard.last_pong_time
748 };
749
750 if last_pong > 0 {
751 let elapsed = now - last_pong;
752 #[allow(clippy::cast_possible_wrap)]
753 if elapsed > pong_timeout_ms as i64 {
754 warn!(
755 elapsed_ms = elapsed,
756 timeout_ms = pong_timeout_ms,
757 "Pong timeout detected, marking connection as error"
758 );
759 *ping_state.write().await = WsConnectionState::Error;
760 break;
761 }
762 }
763
764 {
765 let mut stats_guard = ping_stats.write().await;
766 stats_guard.last_ping_time = now;
767 }
768
769 if write_tx_clone.send(Message::Ping(vec![].into())).is_err() {
770 debug!("Ping task: write channel closed");
771 break;
772 }
773 debug!("Sent ping");
774 }
775 debug!("Ping task terminated");
776 });
777 }
778
779 tokio::spawn(async move {
780 let _ = tokio::join!(write_handle, read_handle);
781 info!("All WebSocket tasks completed");
782 });
783 }
784
785 #[instrument(
787 name = "ws_send_subscribe",
788 skip(self, params),
789 fields(channel = %channel, symbol = ?symbol)
790 )]
791 async fn send_subscribe_message(
792 &self,
793 channel: String,
794 symbol: Option<String>,
795 params: Option<HashMap<String, Value>>,
796 ) -> Result<()> {
797 let msg = WsMessage::Subscribe {
798 channel: channel.clone(),
799 symbol: symbol.clone(),
800 params,
801 };
802
803 let json = serde_json::to_value(&msg).map_err(|e| {
804 error!(error = %e, "Failed to serialize subscribe message");
805 Error::from(e)
806 })?;
807
808 debug!("Sending subscribe message to server");
809
810 self.send_json(&json).await?;
811 info!("Subscribe message sent successfully");
812 Ok(())
813 }
814
815 #[instrument(
817 name = "ws_send_unsubscribe",
818 skip(self),
819 fields(channel = %channel, symbol = ?symbol)
820 )]
821 async fn send_unsubscribe_message(
822 &self,
823 channel: String,
824 symbol: Option<String>,
825 ) -> Result<()> {
826 let msg = WsMessage::Unsubscribe {
827 channel: channel.clone(),
828 symbol: symbol.clone(),
829 };
830
831 let json = serde_json::to_value(&msg).map_err(|e| {
832 error!(error = %e, "Failed to serialize unsubscribe message");
833 Error::from(e)
834 })?;
835
836 debug!("Sending unsubscribe message to server");
837
838 self.send_json(&json).await?;
839 info!("Unsubscribe message sent successfully");
840 Ok(())
841 }
842
843 async fn resubscribe_all(&self) -> Result<()> {
845 let subs = self.subscriptions.read().await;
846 for subscription in subs.values() {
847 self.send_subscribe_message(
848 subscription.channel.clone(),
849 subscription.symbol.clone(),
850 subscription.params.clone(),
851 )
852 .await?;
853 }
854 Ok(())
855 }
856}
857#[derive(Debug, Clone)]
859pub enum WsEvent {
860 Connected,
862 Disconnected,
864 Reconnecting {
866 attempt: u32,
868 },
869 ReconnectSuccess,
871 ReconnectFailed {
873 error: String,
875 },
876 SubscriptionRestored,
878}
879
880pub type WsEventCallback = Arc<dyn Fn(WsEvent) + Send + Sync>;
882
883pub struct AutoReconnectCoordinator {
887 client: Arc<WsClient>,
888 enabled: Arc<AtomicBool>,
889 reconnect_task: Arc<Mutex<Option<JoinHandle<()>>>>,
890 event_callback: Option<WsEventCallback>,
891}
892
893impl AutoReconnectCoordinator {
894 pub fn new(client: Arc<WsClient>) -> Self {
900 Self {
901 client,
902 enabled: Arc::new(AtomicBool::new(false)),
903 reconnect_task: Arc::new(Mutex::new(None)),
904 event_callback: None,
905 }
906 }
907
908 pub fn with_callback(mut self, callback: WsEventCallback) -> Self {
916 self.event_callback = Some(callback);
917 self
918 }
919
920 pub async fn start(&self) {
924 if self.enabled.swap(true, Ordering::SeqCst) {
925 info!("Auto-reconnect already started");
926 return;
927 }
928
929 info!("Starting auto-reconnect coordinator");
930
931 let client = Arc::clone(&self.client);
932 let enabled = Arc::clone(&self.enabled);
933 let callback = self.event_callback.clone();
934
935 let handle = tokio::spawn(async move {
936 Self::reconnect_loop(client, enabled, callback).await;
937 });
938
939 *self.reconnect_task.lock().await = Some(handle);
940 }
941
942 pub async fn stop(&self) {
946 if !self.enabled.swap(false, Ordering::SeqCst) {
947 info!("Auto-reconnect already stopped");
948 return;
949 }
950
951 info!("Stopping auto-reconnect coordinator");
952
953 let mut task = self.reconnect_task.lock().await;
954 if let Some(handle) = task.take() {
955 handle.abort();
956 }
957 }
958
959 async fn reconnect_loop(
964 client: Arc<WsClient>,
965 enabled: Arc<AtomicBool>,
966 callback: Option<WsEventCallback>,
967 ) {
968 let mut check_interval = interval(Duration::from_secs(1));
969
970 loop {
971 check_interval.tick().await;
972
973 if !enabled.load(Ordering::SeqCst) {
974 debug!("Auto-reconnect disabled, exiting loop");
975 break;
976 }
977
978 let state = client.state().await;
979
980 if matches!(
981 state,
982 WsConnectionState::Disconnected | WsConnectionState::Error
983 ) {
984 let attempt = client.reconnect_count().await;
985
986 info!(
987 attempt = attempt + 1,
988 state = ?state,
989 "Connection lost, attempting reconnect"
990 );
991
992 if let Some(ref cb) = callback {
993 cb(WsEvent::Reconnecting {
994 attempt: attempt + 1,
995 });
996 }
997
998 match client.reconnect().await {
999 Ok(_) => {
1000 info!("Reconnection successful");
1001
1002 if let Some(ref cb) = callback {
1003 cb(WsEvent::ReconnectSuccess);
1004 }
1005
1006 match client.resubscribe_all().await {
1007 Ok(_) => {
1008 info!("Subscriptions restored");
1009 if let Some(ref cb) = callback {
1010 cb(WsEvent::SubscriptionRestored);
1011 }
1012 }
1013 Err(e) => {
1014 error!(error = %e, "Failed to restore subscriptions");
1015 }
1016 }
1017 }
1018 Err(e) => {
1019 error!(error = %e, "Reconnection failed");
1020
1021 if let Some(ref cb) = callback {
1022 cb(WsEvent::ReconnectFailed {
1023 error: e.to_string(),
1024 });
1025 }
1026
1027 tokio::time::sleep(Duration::from_secs(5)).await;
1028 }
1029 }
1030 }
1031 }
1032
1033 info!("Auto-reconnect loop terminated");
1034 }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 use super::*;
1040
1041 #[test]
1042 fn test_ws_config_default() {
1043 let config = WsConfig::default();
1044 assert_eq!(config.connect_timeout, 10000);
1045 assert_eq!(config.ping_interval, 30000);
1046 assert_eq!(config.reconnect_interval, 5000);
1047 assert_eq!(config.max_reconnect_attempts, 5);
1048 assert!(config.auto_reconnect);
1049 assert!(!config.enable_compression);
1050 assert_eq!(config.pong_timeout, 90000);
1051 }
1052
1053 #[test]
1054 fn test_subscription_key() {
1055 let key1 = WsClient::subscription_key("ticker", &Some("BTC/USDT".to_string()));
1056 assert_eq!(key1, "ticker:BTC/USDT");
1057
1058 let key2 = WsClient::subscription_key("trades", &None);
1059 assert_eq!(key2, "trades");
1060 }
1061
1062 #[tokio::test]
1063 async fn test_ws_client_creation() {
1064 let config = WsConfig {
1065 url: "wss://example.com/ws".to_string(),
1066 ..Default::default()
1067 };
1068
1069 let client = WsClient::new(config);
1070 assert_eq!(client.state().await, WsConnectionState::Disconnected);
1071 assert!(!client.is_connected().await);
1072 }
1073
1074 #[tokio::test]
1075 async fn test_subscribe_adds_subscription() {
1076 let config = WsConfig {
1077 url: "wss://example.com/ws".to_string(),
1078 ..Default::default()
1079 };
1080
1081 let client = WsClient::new(config);
1082
1083 let result = client
1084 .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
1085 .await;
1086 assert!(result.is_ok());
1087
1088 let subs = client.subscriptions.read().await;
1089 assert_eq!(subs.len(), 1);
1090 assert!(subs.contains_key("ticker:BTC/USDT"));
1091 }
1092
1093 #[tokio::test]
1094 async fn test_unsubscribe_removes_subscription() {
1095 let config = WsConfig {
1096 url: "wss://example.com/ws".to_string(),
1097 ..Default::default()
1098 };
1099
1100 let client = WsClient::new(config);
1101
1102 client
1103 .subscribe("ticker".to_string(), Some("BTC/USDT".to_string()), None)
1104 .await
1105 .unwrap();
1106
1107 let result = client
1108 .unsubscribe("ticker".to_string(), Some("BTC/USDT".to_string()))
1109 .await;
1110 assert!(result.is_ok());
1111
1112 let subs = client.subscriptions.read().await;
1113 assert_eq!(subs.len(), 0);
1114 }
1115
1116 #[test]
1117 fn test_ws_message_serialization() {
1118 let msg = WsMessage::Subscribe {
1119 channel: "ticker".to_string(),
1120 symbol: Some("BTC/USDT".to_string()),
1121 params: None,
1122 };
1123
1124 let json = serde_json::to_string(&msg).unwrap();
1125 assert!(json.contains("\"type\":\"subscribe\""));
1126 assert!(json.contains("\"channel\":\"ticker\""));
1127 }
1128}