use crate::models::{SubscribeRequest, WebSocketMessage, WebSocketRequest};
use crate::websocket::connection_event::emit_event;
use crate::websocket::protocol::{
frame_request, frame_subscribe, frame_subscribe_futopt, frame_unsubscribe,
};
use crate::websocket::sync::owner_thread::{
do_auth_handshake, do_blocking_connect, run_supervisor, OwnerShared, WRITE_QUEUE_CAPACITY,
};
use crate::websocket::{
ConnectionConfig, ConnectionEvent, ConnectionState, HealthCheckConfig, MessageReceiver,
ReconnectionConfig, ReconnectionManager, SubscriptionManager,
};
use crate::MarketDataError;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{mpsc, Arc, Mutex, RwLock};
use std::thread;
pub struct WebSocketClient {
shared: Arc<OwnerShared>,
event_rx: Arc<Mutex<mpsc::Receiver<ConnectionEvent>>>,
message_rx_slot: Mutex<Option<mpsc::Receiver<WebSocketMessage>>>,
message_receiver: Mutex<Option<Arc<MessageReceiver>>>,
supervisor_handle: Mutex<Option<thread::JoinHandle<()>>>,
}
impl WebSocketClient {
pub fn new(config: ConnectionConfig) -> Self {
Self::with_full_config(config, ReconnectionConfig::default(), HealthCheckConfig::default())
}
pub fn with_reconnection_config(
config: ConnectionConfig,
reconnection_config: ReconnectionConfig,
) -> Self {
Self::with_full_config(config, reconnection_config, HealthCheckConfig::default())
}
pub fn with_health_check_config(
config: ConnectionConfig,
health_check_config: HealthCheckConfig,
) -> Self {
Self::with_full_config(config, ReconnectionConfig::default(), health_check_config)
}
pub fn with_full_config(
config: ConnectionConfig,
reconnection_config: ReconnectionConfig,
health_check_config: HealthCheckConfig,
) -> Self {
let (event_tx, event_rx) = mpsc::sync_channel::<ConnectionEvent>(1024);
let (message_tx, message_rx) = mpsc::sync_channel::<WebSocketMessage>(1024);
let tls_config = crate::tls::build_rustls_config(&config.tls)
.unwrap_or_else(|e| panic!("Failed to build TLS config: {e}"));
let shared = Arc::new(OwnerShared {
config,
tls_config,
health: health_check_config,
reconnection: Mutex::new(ReconnectionManager::new(reconnection_config)),
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
subscriptions: Arc::new(SubscriptionManager::new()),
event_tx,
message_tx,
write_tx_slot: Mutex::new(None),
should_stop: Arc::new(AtomicBool::new(false)),
});
Self {
shared,
event_rx: Arc::new(Mutex::new(event_rx)),
message_rx_slot: Mutex::new(Some(message_rx)),
message_receiver: Mutex::new(None),
supervisor_handle: Mutex::new(None),
}
}
pub fn state(&self) -> ConnectionState {
self.shared.state.read().expect("state lock poisoned").clone()
}
pub fn is_closed(&self) -> bool {
matches!(*self.shared.state.read().expect("state lock poisoned"), ConnectionState::Closed { .. })
}
pub fn is_connected(&self) -> bool {
matches!(*self.shared.state.read().expect("state lock poisoned"), ConnectionState::Connected)
}
pub fn events(&self) -> &Arc<Mutex<mpsc::Receiver<ConnectionEvent>>> {
&self.event_rx
}
pub fn state_events(&self) -> &Arc<Mutex<mpsc::Receiver<ConnectionEvent>>> {
&self.event_rx
}
pub fn messages(&self) -> Arc<MessageReceiver> {
let mut slot = self.message_receiver.lock().expect("message_receiver lock poisoned");
if let Some(rx) = slot.as_ref() {
return Arc::clone(rx);
}
let std_rx = self
.message_rx_slot
.lock()
.expect("message_rx_slot lock poisoned")
.take()
.expect("message receiver already taken");
let receiver = Arc::new(MessageReceiver::new(std_rx));
*slot = Some(Arc::clone(&receiver));
receiver
}
pub fn connect(&self) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
if self.supervisor_handle.lock().expect("supervisor handle lock poisoned").is_some() {
return Ok(());
}
self.set_state(ConnectionState::Connecting);
emit_event(&self.shared.event_tx, ConnectionEvent::Connecting);
let mut ws = match do_blocking_connect(
&self.shared.config,
Arc::clone(&self.shared.tls_config),
) {
Ok(ws) => ws,
Err(e) => {
self.set_state(ConnectionState::Disconnected);
emit_event(&self.shared.event_tx, ConnectionEvent::Error {
message: e.to_string(),
code: e.to_error_code(),
});
return Err(e);
}
};
emit_event(&self.shared.event_tx, ConnectionEvent::Connected);
self.set_state(ConnectionState::Authenticating);
if let Err(e) = do_auth_handshake(&mut ws, &self.shared.config, &self.shared.message_tx) {
self.set_state(ConnectionState::Disconnected);
if let MarketDataError::AuthError { msg } = &e {
emit_event(&self.shared.event_tx, ConnectionEvent::Unauthenticated {
message: msg.clone(),
});
} else {
emit_event(&self.shared.event_tx, ConnectionEvent::Error {
message: e.to_string(),
code: e.to_error_code(),
});
}
return Err(e);
}
let (write_tx, write_rx) = mpsc::sync_channel::<String>(WRITE_QUEUE_CAPACITY);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = Some(write_tx);
self.set_state(ConnectionState::Connected);
emit_event(&self.shared.event_tx, ConnectionEvent::Authenticated);
let shared = Arc::clone(&self.shared);
let handle = thread::Builder::new()
.name("fugle-ws-supervisor".to_string())
.spawn(move || run_supervisor(ws, write_rx, shared))
.map_err(|e| MarketDataError::ConnectionError {
msg: format!("Failed to spawn supervisor thread: {e}"),
})?;
*self.supervisor_handle.lock().expect("supervisor handle lock poisoned") = Some(handle);
Ok(())
}
pub fn disconnect(&self) -> Result<(), MarketDataError> {
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
let _ = handle.join();
}
self.set_state(ConnectionState::Closed {
code: Some(1000),
reason: "Normal closure".to_string(),
});
emit_event(&self.shared.event_tx, ConnectionEvent::Disconnected {
code: Some(1000),
reason: "Normal closure".to_string(),
});
Ok(())
}
pub fn force_close(&self) -> Result<(), MarketDataError> {
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
let _ = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take();
self.set_state(ConnectionState::Closed {
code: Some(1006),
reason: "Force closed".to_string(),
});
emit_event(&self.shared.event_tx, ConnectionEvent::Disconnected {
code: Some(1006),
reason: "Force closed".to_string(),
});
Ok(())
}
pub fn subscribe(
&self,
sub: crate::websocket::channels::StockSubscription,
) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let (json, expanded) = frame_subscribe(sub)?;
for entry in expanded {
self.shared.subscriptions.subscribe(entry);
}
if self.is_connected() {
self.enqueue_write(json)?;
}
Ok(())
}
pub fn subscribe_futopt(
&self,
sub: crate::websocket::channels::FutOptSubscription,
) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let (json, expanded) = frame_subscribe_futopt(sub)?;
for entry in expanded {
self.shared.subscriptions.subscribe(entry);
}
if self.is_connected() {
self.enqueue_write(json)?;
}
Ok(())
}
pub fn unsubscribe(
&self,
ids: impl IntoIterator<Item = impl Into<String>>,
) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let keys: Vec<String> = ids.into_iter().map(Into::into).collect();
if keys.is_empty() {
return Ok(());
}
let mut wire_ids = Vec::with_capacity(keys.len());
for key in &keys {
let id = self
.shared
.subscriptions
.take_server_id(key)
.unwrap_or_else(|| key.clone());
self.shared.subscriptions.unsubscribe(key);
wire_ids.push(id);
}
if !self.is_connected() {
return Ok(());
}
let json = frame_unsubscribe(wire_ids)?;
self.enqueue_write(json)
}
pub fn subscriptions(&self) -> Vec<SubscribeRequest> {
self.shared.subscriptions.get_all()
}
pub fn subscription_keys(&self) -> Vec<String> {
self.shared.subscriptions.keys()
}
pub fn reconnect(&self) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
let _ = handle.join();
}
self.shared.should_stop.store(false, Ordering::SeqCst);
{
let mut mgr = self.shared.reconnection.lock().expect("reconnection lock poisoned");
mgr.reset();
}
self.connect()
}
pub fn send(&self, request: WebSocketRequest) -> Result<(), MarketDataError> {
if self.is_closed() {
return Err(MarketDataError::ClientClosed);
}
let json = frame_request(&request)?;
self.enqueue_write(json)
}
fn enqueue_write(&self, json: String) -> Result<(), MarketDataError> {
let sender_clone = {
let guard = self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned");
guard.clone()
};
match sender_clone {
Some(tx) => tx.send(json).map_err(|_| MarketDataError::ConnectionError {
msg: "Writer queue closed (supervisor exited)".to_string(),
}),
None => Err(MarketDataError::ConnectionError {
msg: "Not connected".to_string(),
}),
}
}
fn set_state(&self, new_state: ConnectionState) {
let mut st = self.shared.state.write().expect("state lock poisoned");
*st = new_state;
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
self.shared.should_stop.store(true, Ordering::SeqCst);
*self.shared.write_tx_slot.lock().expect("write_tx_slot lock poisoned") = None;
if let Some(handle) = self.supervisor_handle.lock().expect("supervisor handle lock poisoned").take() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AuthRequest;
#[test]
fn test_new_starts_disconnected() {
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
assert_eq!(client.state(), ConnectionState::Disconnected);
assert!(!client.is_closed());
assert!(!client.is_connected());
}
#[test]
fn test_subscribe_before_connect_records_subscription() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
let sub = StockSubscription::new(Channel::Trades, "2330");
client.subscribe(sub).unwrap();
assert_eq!(client.subscription_keys().len(), 1);
}
#[test]
fn test_unsubscribe_when_disconnected_removes_state() {
use crate::models::Channel;
use crate::websocket::channels::StockSubscription;
let config = ConnectionConfig::fugle_stock(AuthRequest::with_api_key("test"));
let client = WebSocketClient::new(config);
let sub = StockSubscription::new(Channel::Trades, "2330");
client.subscribe(sub).unwrap();
assert_eq!(client.subscription_keys().len(), 1);
client.unsubscribe(["trades:2330"]).unwrap();
assert_eq!(client.subscription_keys().len(), 0);
}
}