1use crate::models::{Channel, SubscribeRequest, WebSocketMessage, WebSocketRequest};
7use crate::websocket::connection_event::emit_event;
8use crate::websocket::protocol::{
9 frame_request, frame_subscribe, frame_subscribe_futopt, frame_unsubscribe,
10};
11use crate::websocket::sync::owner_thread::{
12 do_auth_handshake, do_blocking_connect, run_supervisor, OwnerShared, WRITE_QUEUE_CAPACITY,
13};
14use crate::websocket::{
15 ConnectionConfig, ConnectionEvent, ConnectionState, DisconnectIntent, HealthCheckConfig,
16 MessageReceiver, ReconnectionConfig, ReconnectionManager, SubscriptionManager,
17};
18use crate::MarketDataError;
19use std::sync::atomic::{AtomicBool, Ordering};
20use std::sync::{mpsc, Arc, Mutex, RwLock};
21use std::thread;
22use std::time::Duration;
23
24pub struct WebSocketClient {
29 shared: Arc<OwnerShared>,
30 event_rx: Arc<Mutex<mpsc::Receiver<ConnectionEvent>>>,
32 message_rx_slot: Mutex<Option<mpsc::Receiver<WebSocketMessage>>>,
34 message_receiver: Mutex<Option<Arc<MessageReceiver>>>,
36 supervisor_handle: Mutex<Option<thread::JoinHandle<()>>>,
38 supervisor_exit_rx: Mutex<Option<mpsc::Receiver<()>>>,
42}
43
44pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
47
48impl WebSocketClient {
49 pub fn new(config: ConnectionConfig) -> Self {
51 Self::with_full_config(config, ReconnectionConfig::default(), HealthCheckConfig::default())
52 }
53
54 pub fn with_reconnection_config(
56 config: ConnectionConfig,
57 reconnection_config: ReconnectionConfig,
58 ) -> Self {
59 Self::with_full_config(config, reconnection_config, HealthCheckConfig::default())
60 }
61
62 pub fn with_health_check_config(
64 config: ConnectionConfig,
65 health_check_config: HealthCheckConfig,
66 ) -> Self {
67 Self::with_full_config(config, ReconnectionConfig::default(), health_check_config)
68 }
69
70 pub fn with_full_config(
72 config: ConnectionConfig,
73 reconnection_config: ReconnectionConfig,
74 health_check_config: HealthCheckConfig,
75 ) -> Self {
76 let (event_tx, event_rx) = mpsc::sync_channel::<ConnectionEvent>(config.event_buffer);
77 let (message_tx, message_rx) = mpsc::sync_channel::<WebSocketMessage>(config.message_buffer);
78
79 let tls_config = crate::tls::build_rustls_config(&config.tls)
82 .unwrap_or_else(|e| panic!("Failed to build TLS config: {e}"));
83
84 let (messages_dropped, events_dropped) =
85 crate::metrics_compat::build_drop_counters(&config);
86
87 let shared = Arc::new(OwnerShared {
88 config,
89 tls_config,
90 health: health_check_config,
91 reconnection: Mutex::new(ReconnectionManager::new(reconnection_config)),
92 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
93 subscriptions: Arc::new(SubscriptionManager::new()),
94 event_tx,
95 message_tx,
96 write_tx_slot: Mutex::new(None),
97 should_stop: Arc::new(AtomicBool::new(false)),
98 messages_dropped,
99 events_dropped,
100 });
101
102 Self {
103 shared,
104 event_rx: Arc::new(Mutex::new(event_rx)),
105 message_rx_slot: Mutex::new(Some(message_rx)),
106 message_receiver: Mutex::new(None),
107 supervisor_handle: Mutex::new(None),
108 supervisor_exit_rx: Mutex::new(None),
109 }
110 }
111
112 pub fn state(&self) -> ConnectionState {
114 self.shared.state.read().expect("state lock poisoned").clone()
115 }
116
117 pub fn is_closed(&self) -> bool {
119 matches!(*self.shared.state.read().expect("state lock poisoned"), ConnectionState::Closed { .. })
120 }
121
122 pub fn is_connected(&self) -> bool {
124 matches!(*self.shared.state.read().expect("state lock poisoned"), ConnectionState::Connected)
125 }
126
127 pub fn events(&self) -> &Arc<Mutex<mpsc::Receiver<ConnectionEvent>>> {
130 &self.event_rx
131 }
132
133 pub fn state_events(&self) -> &Arc<Mutex<mpsc::Receiver<ConnectionEvent>>> {
135 &self.event_rx
136 }
137
138 pub fn messages(&self) -> Arc<MessageReceiver> {
141 let mut slot = self.message_receiver.lock().expect("message_receiver lock poisoned");
142 if let Some(rx) = slot.as_ref() {
143 return Arc::clone(rx);
144 }
145 let std_rx = self
146 .message_rx_slot
147 .lock()
148 .expect("message_rx_slot lock poisoned")
149 .take()
150 .expect("message receiver already taken");
151 let receiver = Arc::new(MessageReceiver::new(std_rx));
152 *slot = Some(Arc::clone(&receiver));
153 receiver
154 }
155
156 #[cfg_attr(
163 feature = "tracing",
164 tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.connect", skip(self))
165 )]
166 pub fn connect(&self) -> Result<(), MarketDataError> {
167 if self.is_closed() {
168 return Err(MarketDataError::ClientClosed);
169 }
170 if self.supervisor_handle.lock().expect("supervisor handle lock poisoned").is_some() {
171 return Ok(());
173 }
174
175 self.set_state(ConnectionState::Connecting);
176 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Connecting {
177 });
178
179 let mut ws = match do_blocking_connect(
180 &self.shared.config,
181 Arc::clone(&self.shared.tls_config),
182 ) {
183 Ok(ws) => ws,
184 Err(e) => {
185 self.set_state(ConnectionState::Disconnected);
186 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Error {
187 message: e.to_string(),
188 code: e.to_error_code(),
189 });
190 return Err(e);
191 }
192 };
193 crate::tracing_compat::info!(target: "fugle_marketdata::ws", "ws connected");
194 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Connected {
195 });
196
197 self.set_state(ConnectionState::Authenticating);
198 if let Err(e) = do_auth_handshake(&mut ws, &self.shared.config, &self.shared.message_tx) {
199 self.set_state(ConnectionState::Disconnected);
200 if let MarketDataError::AuthError { msg } = &e {
201 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Unauthenticated {
202 message: msg.clone(),
203 });
204 } else {
205 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Error {
206 message: e.to_string(),
207 code: e.to_error_code(),
208 });
209 }
210 return Err(e);
211 }
212
213 let (write_tx, write_rx) = mpsc::sync_channel::<String>(WRITE_QUEUE_CAPACITY);
215 *self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = Some(write_tx);
216
217 self.set_state(ConnectionState::Connected);
218 crate::tracing_compat::info!(target: "fugle_marketdata::ws", "ws authenticated");
219 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Authenticated {
220 });
221
222 let shared = Arc::clone(&self.shared);
226 let (exit_tx, exit_rx) = mpsc::channel::<()>();
227 let handle = thread::Builder::new()
228 .name("fugle-ws-supervisor".to_string())
229 .spawn(move || {
230 run_supervisor(ws, write_rx, shared);
231 let _ = exit_tx.send(());
232 })
233 .map_err(|e| MarketDataError::ConnectionError {
234 msg: format!("Failed to spawn supervisor thread: {e}"),
235 })?;
236 *self.supervisor_handle.lock().expect("supervisor handle lock poisoned") = Some(handle);
237 *self
238 .supervisor_exit_rx
239 .lock()
240 .expect("supervisor_exit_rx lock poisoned") = Some(exit_rx);
241
242 Ok(())
243 }
244
245 #[cfg_attr(
255 feature = "tracing",
256 tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.disconnect", skip(self))
257 )]
258 pub fn disconnect(&self) -> Result<(), MarketDataError> {
259 self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
260 }
261
262 #[cfg_attr(
286 feature = "tracing",
287 tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.shutdown_with_timeout", skip(self))
288 )]
289 pub fn shutdown_with_timeout(
290 &self,
291 timeout_dur: Duration,
292 ) -> Result<(), MarketDataError> {
293 self.shared.should_stop.store(true, Ordering::SeqCst);
294
295 *self
298 .shared
299 .write_tx_slot
300 .lock()
301 .expect("write_tx_slot lock poisoned") = None;
302
303 let exit_rx = self
305 .supervisor_exit_rx
306 .lock()
307 .expect("supervisor_exit_rx lock poisoned")
308 .take();
309 let signaled = match exit_rx {
310 Some(rx) => rx.recv_timeout(timeout_dur).is_ok(),
311 None => true, };
313
314 if let Some(handle) = self
315 .supervisor_handle
316 .lock()
317 .expect("supervisor handle lock poisoned")
318 .take()
319 {
320 if signaled {
321 let _ = handle.join();
324 } else {
325 drop(handle);
330 }
331 }
332
333 self.set_state(ConnectionState::Closed {
334 code: Some(1000),
335 reason: "Normal closure".to_string(),
336 intent: DisconnectIntent::Client,
337 });
338 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Disconnected {
339 code: Some(1000),
340 reason: "Normal closure".to_string(),
341 intent: DisconnectIntent::Client,
342 });
343
344 Ok(())
345 }
346
347 pub fn force_close(&self) -> Result<(), MarketDataError> {
353 self.shared.should_stop.store(true, Ordering::SeqCst);
354 *self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
355 let _ = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take();
357 let _ = self
358 .supervisor_exit_rx
359 .lock()
360 .expect("supervisor_exit_rx lock poisoned")
361 .take();
362
363 self.set_state(ConnectionState::Closed {
364 code: Some(1006),
365 reason: "Force closed".to_string(),
366 intent: DisconnectIntent::Client,
367 });
368 emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Disconnected {
369 code: Some(1006),
370 reason: "Force closed".to_string(),
371 intent: DisconnectIntent::Client,
372 });
373
374 Ok(())
375 }
376
377 #[cfg_attr(
383 feature = "tracing",
384 tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.subscribe", skip(self, sub))
385 )]
386 pub fn subscribe(
387 &self,
388 sub: crate::websocket::channels::StockSubscription,
389 ) -> Result<(), MarketDataError> {
390 if self.is_closed() {
391 return Err(MarketDataError::ClientClosed);
392 }
393
394 let (json, expanded) = frame_subscribe(sub)?;
395 for entry in expanded {
396 self.shared.subscriptions.subscribe(entry);
397 }
398
399 if self.is_connected() {
400 self.enqueue_write(json)?;
401 }
402 Ok(())
403 }
404
405 pub fn subscribe_futopt(
411 &self,
412 sub: crate::websocket::channels::FutOptSubscription,
413 ) -> Result<(), MarketDataError> {
414 if self.is_closed() {
415 return Err(MarketDataError::ClientClosed);
416 }
417
418 let (json, expanded) = frame_subscribe_futopt(sub)?;
419 for entry in expanded {
420 self.shared.subscriptions.subscribe(entry);
421 }
422
423 if self.is_connected() {
424 self.enqueue_write(json)?;
425 }
426 Ok(())
427 }
428
429 #[cfg_attr(
435 feature = "tracing",
436 tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.unsubscribe", skip(self, ids))
437 )]
438 pub fn unsubscribe(
439 &self,
440 ids: impl IntoIterator<Item = impl Into<String>>,
441 ) -> Result<(), MarketDataError> {
442 if self.is_closed() {
443 return Err(MarketDataError::ClientClosed);
444 }
445
446 let keys: Vec<String> = ids.into_iter().map(Into::into).collect();
447 if keys.is_empty() {
448 return Ok(());
449 }
450
451 let mut wire_ids = Vec::with_capacity(keys.len());
452 for key in &keys {
453 let id = self
454 .shared
455 .subscriptions
456 .take_server_id(key)
457 .unwrap_or_else(|| key.clone());
458 self.shared.subscriptions.unsubscribe(key);
459 wire_ids.push(id);
460 }
461
462 if !self.is_connected() {
463 return Ok(());
464 }
465
466 let json = frame_unsubscribe(wire_ids)?;
467 self.enqueue_write(json)
468 }
469
470 pub fn subscriptions(&self) -> Vec<SubscribeRequest> {
472 self.shared.subscriptions.get_all()
473 }
474
475 pub fn subscription_keys(&self) -> Vec<String> {
477 self.shared.subscriptions.keys()
478 }
479
480 pub fn subscription_count(&self) -> usize {
482 self.shared.subscriptions.count()
483 }
484
485 pub fn messages_dropped_total(&self) -> u64 {
493 self.shared.messages_dropped.load()
494 }
495
496 #[must_use]
505 pub fn events_dropped_total(&self) -> u64 {
506 self.shared.events_dropped.load()
507 }
508
509 pub fn is_subscribed(&self, channel: &Channel, symbol: &str) -> bool {
513 let base = format!("{}:{}", channel.as_str(), symbol);
514 let modifier_prefix = format!("{}:", base);
515 self.shared
516 .subscriptions
517 .keys()
518 .iter()
519 .any(|k| k == &base || k.starts_with(&modifier_prefix))
520 }
521
522 pub fn reconnect(&self) -> Result<(), MarketDataError> {
529 if self.is_closed() {
530 return Err(MarketDataError::ClientClosed);
531 }
532 self.shared.should_stop.store(true, Ordering::SeqCst);
534 *self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
535 if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
536 let _ = handle.join();
537 }
538 self.shared.should_stop.store(false, Ordering::SeqCst);
540 {
541 let mut mgr = self.shared.reconnection.lock().expect("reconnection lock poisoned");
542 mgr.reset();
543 }
544
545 self.connect()
546 }
547
548 pub fn send(&self, request: WebSocketRequest) -> Result<(), MarketDataError> {
554 if self.is_closed() {
555 return Err(MarketDataError::ClientClosed);
556 }
557 let json = frame_request(&request)?;
558 self.enqueue_write(json)
559 }
560
561 fn enqueue_write(&self, json: String) -> Result<(), MarketDataError> {
562 let sender_clone = {
563 let guard = self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned");
564 guard.clone()
565 };
566 match sender_clone {
567 Some(tx) => tx.send(json).map_err(|_| MarketDataError::ConnectionError {
568 msg: "Writer queue closed (supervisor exited)".to_string(),
569 }),
570 None => Err(MarketDataError::ConnectionError {
571 msg: "Not connected".to_string(),
572 }),
573 }
574 }
575
576 fn set_state(&self, new_state: ConnectionState) {
577 let mut st = self.shared.state.write().expect("state lock poisoned");
578 *st = new_state;
579 }
580}
581
582impl Drop for WebSocketClient {
583 fn drop(&mut self) {
584 self.shared.should_stop.store(true, Ordering::SeqCst);
588 *self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
589 if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
590 let _ = handle.join();
591 }
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use crate::AuthRequest;
599
600 #[test]
601 fn test_new_starts_disconnected() {
602 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
603 let client = WebSocketClient::new(config);
604 assert_eq!(client.state(), ConnectionState::Disconnected);
605 assert!(!client.is_closed());
606 assert!(!client.is_connected());
607 }
608
609 #[test]
610 fn events_dropped_total_starts_at_zero() {
611 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
612 let client = WebSocketClient::new(config);
613 assert_eq!(client.events_dropped_total(), 0);
614 }
615
616 #[test]
617 fn events_dropped_increments_on_saturation() {
618 use crate::websocket::connection_event::emit_event;
619 let config = ConnectionConfig::builder("wss://example.com", AuthRequest::with_api_key("k"))
620 .event_buffer(1) .build();
622 let client = WebSocketClient::new(config);
623
624 emit_event(
628 &client.shared.event_tx,
629 &client.shared.events_dropped,
630 ConnectionEvent::Connecting {},
631 );
632 emit_event(
634 &client.shared.event_tx,
635 &client.shared.events_dropped,
636 ConnectionEvent::Connecting {},
637 );
638 emit_event(
639 &client.shared.event_tx,
640 &client.shared.events_dropped,
641 ConnectionEvent::Connecting {},
642 );
643
644 let dropped = client.events_dropped_total();
646 assert!(
647 dropped >= 1,
648 "expected events_dropped_total >= 1 after saturation, got {dropped}"
649 );
650
651 let observed_again = client.events_dropped_total();
653 assert!(observed_again >= dropped);
654 }
655
656 #[test]
657 fn test_subscribe_before_connect_records_subscription() {
658 use crate::models::Channel;
659 use crate::websocket::channels::StockSubscription;
660 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
661 let client = WebSocketClient::new(config);
662
663 let sub = StockSubscription::new(Channel::Trades, "2330");
664 client.subscribe(sub).unwrap();
665
666 assert_eq!(client.subscription_keys().len(), 1);
667 }
668
669 #[test]
670 fn test_unsubscribe_when_disconnected_removes_state() {
671 use crate::models::Channel;
672 use crate::websocket::channels::StockSubscription;
673 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
674 let client = WebSocketClient::new(config);
675
676 let sub = StockSubscription::new(Channel::Trades, "2330");
677 client.subscribe(sub).unwrap();
678 assert_eq!(client.subscription_keys().len(), 1);
679
680 client.unsubscribe(["trades:2330"]).unwrap();
681 assert_eq!(client.subscription_keys().len(), 0);
682 }
683
684 #[test]
685 fn test_subscription_count_zero_on_fresh_client() {
686 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
687 let client = WebSocketClient::new(config);
688 assert_eq!(client.subscription_count(), 0);
689 }
690
691 #[test]
692 fn test_subscription_count_tracks_subscribe_unsubscribe() {
693 use crate::models::Channel;
694 use crate::websocket::channels::StockSubscription;
695 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
696 let client = WebSocketClient::new(config);
697
698 client.subscribe(StockSubscription::new(Channel::Trades, "2330")).unwrap();
699 client.subscribe(StockSubscription::new(Channel::Books, "2330")).unwrap();
700 assert_eq!(client.subscription_count(), 2);
701
702 client.unsubscribe(["trades:2330"]).unwrap();
703 assert_eq!(client.subscription_count(), 1);
704 }
705
706 #[test]
707 fn test_is_subscribed_positive_match() {
708 use crate::models::Channel;
709 use crate::websocket::channels::StockSubscription;
710 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
711 let client = WebSocketClient::new(config);
712
713 client.subscribe(StockSubscription::new(Channel::Trades, "2330")).unwrap();
714 assert!(client.is_subscribed(&Channel::Trades, "2330"));
715 }
716
717 #[test]
718 fn test_is_subscribed_negative_match_other_channel() {
719 use crate::models::Channel;
720 use crate::websocket::channels::StockSubscription;
721 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
722 let client = WebSocketClient::new(config);
723
724 client.subscribe(StockSubscription::new(Channel::Trades, "2330")).unwrap();
725 assert!(!client.is_subscribed(&Channel::Books, "2330"));
726 assert!(!client.is_subscribed(&Channel::Trades, "1234"));
727 }
728
729 #[test]
730 fn test_is_subscribed_false_on_fresh_client() {
731 use crate::models::Channel;
732 let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
733 let client = WebSocketClient::new(config);
734 assert!(!client.is_subscribed(&Channel::Trades, "2330"));
735 }
736}