use crate::models::{Channel, SubscribeRequest, WebSocketMessage, WebSocketRequest};
use crate::websocket::connection_event::emit_event;
use crate::websocket::protocol::{
frame_request, frame_subscribe, frame_subscribe_futopt, frame_unsubscribe,
};
use crate::websocket::sync::owner_thread::{
do_auth_handshake, do_blocking_connect, run_supervisor, OwnerShared, WRITE_QUEUE_CAPACITY,
};
use crate::websocket::{
ConnectionConfig, ConnectionEvent, ConnectionState, DisconnectIntent, HealthCheckConfig,
MessageReceiver, ReconnectionConfig, ReconnectionManager, SubscriptionManager,
};
use crate::MarketDataError;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{mpsc, Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
pub struct WebSocketClient {
shared: Arc<OwnerShared>,
event_rx: Arc<Mutex<mpsc::Receiver<ConnectionEvent>>>,
message_rx_slot: Mutex<Option<mpsc::Receiver<WebSocketMessage>>>,
message_receiver: Mutex<Option<Arc<MessageReceiver>>>,
supervisor_handle: Mutex<Option<thread::JoinHandle<()>>>,
supervisor_exit_rx: Mutex<Option<mpsc::Receiver<()>>>,
}
pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
impl WebSocketClient {
pub fn new(config: ConnectionConfig) -> Self {
Self::with_full_config(config, ReconnectionConfig::default(), HealthCheckConfig::default())
}
pub fn with_reconnection_config(
config: ConnectionConfig,
reconnection_config: ReconnectionConfig,
) -> Self {
Self::with_full_config(config, reconnection_config, HealthCheckConfig::default())
}
pub fn with_health_check_config(
config: ConnectionConfig,
health_check_config: HealthCheckConfig,
) -> Self {
Self::with_full_config(config, ReconnectionConfig::default(), health_check_config)
}
pub fn with_full_config(
config: ConnectionConfig,
reconnection_config: ReconnectionConfig,
health_check_config: HealthCheckConfig,
) -> Self {
let (event_tx, event_rx) = mpsc::sync_channel::<ConnectionEvent>(config.event_buffer);
let (message_tx, message_rx) = mpsc::sync_channel::<WebSocketMessage>(config.message_buffer);
let tls_config = crate::tls::build_rustls_config(&config.tls)
.unwrap_or_else(|e| panic!("Failed to build TLS config: {e}"));
let shared = Arc::new(OwnerShared {
config,
tls_config,
health: health_check_config,
reconnection: Mutex::new(ReconnectionManager::new(reconnection_config)),
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
subscriptions: Arc::new(SubscriptionManager::new()),
event_tx,
message_tx,
write_tx_slot: Mutex::new(None),
should_stop: Arc::new(AtomicBool::new(false)),
messages_dropped: Arc::new(AtomicU64::new(0)),
events_dropped: Arc::new(AtomicU64::new(0)),
});
Self {
shared,
event_rx: Arc::new(Mutex::new(event_rx)),
message_rx_slot: Mutex::new(Some(message_rx)),
message_receiver: Mutex::new(None),
supervisor_handle: Mutex::new(None),
supervisor_exit_rx: Mutex::new(None),
}
}
pub fn state(&self) -> ConnectionState {
self.shared.state.read().expect("state lock poisoned").clone()
}
pub fn is_closed(&self) -> bool {
matches!(*self.shared.state.read().expect("state lock poisoned"), ConnectionState::Closed { .. })
}
pub fn is_connected(&self) -> bool {
matches!(*self.shared.state.read().expect("state lock poisoned"), ConnectionState::Connected)
}
pub fn events(&self) -> &Arc<Mutex<mpsc::Receiver<ConnectionEvent>>> {
&self.event_rx
}
pub fn state_events(&self) -> &Arc<Mutex<mpsc::Receiver<ConnectionEvent>>> {
&self.event_rx
}
pub fn messages(&self) -> Arc<MessageReceiver> {
let mut slot = self.message_receiver.lock().expect("message_receiver lock poisoned");
if let Some(rx) = slot.as_ref() {
return Arc::clone(rx);
}
let std_rx = self
.message_rx_slot
.lock()
.expect("message_rx_slot lock poisoned")
.take()
.expect("message receiver already taken");
let receiver = Arc::new(MessageReceiver::new(std_rx));
*slot = Some(Arc::clone(&receiver));
receiver
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.connect", skip(self))
)]
pub fn connect(&self) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
if self.supervisor_handle.lock().expect("supervisor handle lock poisoned").is_some() {
return Ok(());
}
self.set_state(ConnectionState::Connecting);
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Connecting {
});
let mut ws = match do_blocking_connect(
&self.shared.config,
Arc::clone(&self.shared.tls_config),
) {
Ok(ws) => ws,
Err(e) => {
self.set_state(ConnectionState::Disconnected);
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Error {
message: e.to_string(),
code: e.to_error_code(),
});
return Err(e);
}
};
crate::tracing_compat::info!(target: "fugle_marketdata::ws", "ws connected");
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Connected {
});
self.set_state(ConnectionState::Authenticating);
if let Err(e) = do_auth_handshake(&mut ws, &self.shared.config, &self.shared.message_tx) {
self.set_state(ConnectionState::Disconnected);
if let MarketDataError::AuthError { msg } = &e {
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Unauthenticated {
message: msg.clone(),
});
} else {
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Error {
message: e.to_string(),
code: e.to_error_code(),
});
}
return Err(e);
}
let (write_tx, write_rx) = mpsc::sync_channel::<String>(WRITE_QUEUE_CAPACITY);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = Some(write_tx);
self.set_state(ConnectionState::Connected);
crate::tracing_compat::info!(target: "fugle_marketdata::ws", "ws authenticated");
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Authenticated {
});
let shared = Arc::clone(&self.shared);
let (exit_tx, exit_rx) = mpsc::channel::<()>();
let handle = thread::Builder::new()
.name("fugle-ws-supervisor".to_string())
.spawn(move || {
run_supervisor(ws, write_rx, shared);
let _ = exit_tx.send(());
})
.map_err(|e| MarketDataError::ConnectionError {
msg: format!("Failed to spawn supervisor thread: {e}"),
})?;
*self.supervisor_handle.lock().expect("supervisor handle lock poisoned") = Some(handle);
*self
.supervisor_exit_rx
.lock()
.expect("supervisor_exit_rx lock poisoned") = Some(exit_rx);
Ok(())
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.disconnect", skip(self))
)]
pub fn disconnect(&self) -> Result<(), MarketDataError> {
self.shutdown_with_timeout(DEFAULT_SHUTDOWN_TIMEOUT)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.shutdown_with_timeout", skip(self))
)]
pub fn shutdown_with_timeout(
&self,
timeout_dur: Duration,
) -> Result<(), MarketDataError> {
self.shared.should_stop.store(true, Ordering::SeqCst);
*self
.shared
.write_tx_slot
.lock()
.expect("write_tx_slot lock poisoned") = None;
let exit_rx = self
.supervisor_exit_rx
.lock()
.expect("supervisor_exit_rx lock poisoned")
.take();
let signaled = match exit_rx {
Some(rx) => rx.recv_timeout(timeout_dur).is_ok(),
None => true, };
if let Some(handle) = self
.supervisor_handle
.lock()
.expect("supervisor handle lock poisoned")
.take()
{
if signaled {
let _ = handle.join();
} else {
drop(handle);
}
}
self.set_state(ConnectionState::Closed {
code: Some(1000),
reason: "Normal closure".to_string(),
intent: DisconnectIntent::Client,
});
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Disconnected {
code: Some(1000),
reason: "Normal closure".to_string(),
intent: DisconnectIntent::Client,
});
Ok(())
}
pub fn force_close(&self) -> Result<(), MarketDataError> {
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
let _ = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take();
let _ = self
.supervisor_exit_rx
.lock()
.expect("supervisor_exit_rx lock poisoned")
.take();
self.set_state(ConnectionState::Closed {
code: Some(1006),
reason: "Force closed".to_string(),
intent: DisconnectIntent::Client,
});
emit_event(&self.shared.event_tx, &self.shared.events_dropped, ConnectionEvent::Disconnected {
code: Some(1006),
reason: "Force closed".to_string(),
intent: DisconnectIntent::Client,
});
Ok(())
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.subscribe", skip(self, sub))
)]
pub fn subscribe(
&self,
sub: crate::websocket::channels::StockSubscription,
) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let (json, expanded) = frame_subscribe(sub)?;
for entry in expanded {
self.shared.subscriptions.subscribe(entry);
}
if self.is_connected() {
self.enqueue_write(json)?;
}
Ok(())
}
pub fn subscribe_futopt(
&self,
sub: crate::websocket::channels::FutOptSubscription,
) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let (json, expanded) = frame_subscribe_futopt(sub)?;
for entry in expanded {
self.shared.subscriptions.subscribe(entry);
}
if self.is_connected() {
self.enqueue_write(json)?;
}
Ok(())
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(target = "fugle_marketdata::ws", name = "ws.sync.unsubscribe", skip(self, ids))
)]
pub fn unsubscribe(
&self,
ids: impl IntoIterator<Item = impl Into<String>>,
) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let keys: Vec<String> = ids.into_iter().map(Into::into).collect();
if keys.is_empty() {
return Ok(());
}
let mut wire_ids = Vec::with_capacity(keys.len());
for key in &keys {
let id = self
.shared
.subscriptions
.take_server_id(key)
.unwrap_or_else(|| key.clone());
self.shared.subscriptions.unsubscribe(key);
wire_ids.push(id);
}
if !self.is_connected() {
return Ok(());
}
let json = frame_unsubscribe(wire_ids)?;
self.enqueue_write(json)
}
pub fn subscriptions(&self) -> Vec<SubscribeRequest> {
self.shared.subscriptions.get_all()
}
pub fn subscription_keys(&self) -> Vec<String> {
self.shared.subscriptions.keys()
}
pub fn subscription_count(&self) -> usize {
self.shared.subscriptions.count()
}
pub fn messages_dropped_total(&self) -> u64 {
self.shared.messages_dropped.load(Ordering::Relaxed)
}
#[must_use]
pub fn events_dropped_total(&self) -> u64 {
self.shared.events_dropped.load(Ordering::Relaxed)
}
pub fn is_subscribed(&self, channel: &Channel, symbol: &str) -> bool {
let base = format!("{}:{}", channel.as_str(), symbol);
let modifier_prefix = format!("{}:", base);
self.shared
.subscriptions
.keys()
.iter()
.any(|k| k == &base || k.starts_with(&modifier_prefix))
}
pub fn reconnect(&self) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
let _ = handle.join();
}
self.shared.should_stop.store(false, Ordering::SeqCst);
{
let mut mgr = self.shared.reconnection.lock().expect("reconnection lock poisoned");
mgr.reset();
}
self.connect()
}
pub fn send(&self, request: WebSocketRequest) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let json = frame_request(&request)?;
self.enqueue_write(json)
}
fn enqueue_write(&self, json: String) -> Result<(), MarketDataError> {
let sender_clone = {
let guard = self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned");
guard.clone()
};
match sender_clone {
Some(tx) => tx.send(json).map_err(|_| MarketDataError::ConnectionError {
msg: "Writer queue closed (supervisor exited)".to_string(),
}),
None => Err(MarketDataError::ConnectionError {
msg: "Not connected".to_string(),
}),
}
}
fn set_state(&self, new_state: ConnectionState) {
let mut st = self.shared.state.write().expect("state lock poisoned");
*st = new_state;
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AuthRequest;
#[test]
fn test_new_starts_disconnected() {
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
assert_eq!(client.state(), ConnectionState::Disconnected);
assert!(!client.is_closed());
assert!(!client.is_connected());
}
#[test]
fn events_dropped_total_starts_at_zero() {
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
assert_eq!(client.events_dropped_total(), 0);
}
#[test]
fn events_dropped_increments_on_saturation() {
use crate::websocket::connection_event::emit_event;
let config = ConnectionConfig::builder("wss://example.com", AuthRequest::with_api_key("k"))
.event_buffer(1) .build();
let client = WebSocketClient::new(config);
emit_event(
&client.shared.event_tx,
&client.shared.events_dropped,
ConnectionEvent::Connecting {},
);
emit_event(
&client.shared.event_tx,
&client.shared.events_dropped,
ConnectionEvent::Connecting {},
);
emit_event(
&client.shared.event_tx,
&client.shared.events_dropped,
ConnectionEvent::Connecting {},
);
let dropped = client.events_dropped_total();
assert!(
dropped >= 1,
"expected events_dropped_total >= 1 after saturation, got {dropped}"
);
let observed_again = client.events_dropped_total();
assert!(observed_again >= dropped);
}
#[test]
fn test_subscribe_before_connect_records_subscription() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
let sub = StockSubscription::new(Channel::Trades, "2330");
client.subscribe(sub).unwrap();
assert_eq!(client.subscription_keys().len(), 1);
}
#[test]
fn test_unsubscribe_when_disconnected_removes_state() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
let sub = StockSubscription::new(Channel::Trades, "2330");
client.subscribe(sub).unwrap();
assert_eq!(client.subscription_keys().len(), 1);
client.unsubscribe(["trades:2330"]).unwrap();
assert_eq!(client.subscription_keys().len(), 0);
}
#[test]
fn test_subscription_count_zero_on_fresh_client() {
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
assert_eq!(client.subscription_count(), 0);
}
#[test]
fn test_subscription_count_tracks_subscribe_unsubscribe() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
client.subscribe(StockSubscription::new(Channel::Trades, "2330")).unwrap();
client.subscribe(StockSubscription::new(Channel::Books, "2330")).unwrap();
assert_eq!(client.subscription_count(), 2);
client.unsubscribe(["trades:2330"]).unwrap();
assert_eq!(client.subscription_count(), 1);
}
#[test]
fn test_is_subscribed_positive_match() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
client.subscribe(StockSubscription::new(Channel::Trades, "2330")).unwrap();
assert!(client.is_subscribed(&Channel::Trades, "2330"));
}
#[test]
fn test_is_subscribed_negative_match_other_channel() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
client.subscribe(StockSubscription::new(Channel::Trades, "2330")).unwrap();
assert!(!client.is_subscribed(&Channel::Books, "2330"));
assert!(!client.is_subscribed(&Channel::Trades, "1234"));
}
#[test]
fn test_is_subscribed_false_on_fresh_client() {
use crate::models::Channel;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
assert!(!client.is_subscribed(&Channel::Trades, "2330"));
}
}