use std::{
fmt::Debug,
sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
},
};
use futures_util::Stream;
use nautilus_common::live::get_runtime;
use nautilus_core::{AtomicMap, string::REDACTED};
use nautilus_model::instruments::{Instrument, InstrumentAny};
use nautilus_network::{
mode::ConnectionMode,
websocket::{
PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
},
};
use tokio_util::sync::CancellationToken;
use ustr::Ustr;
use super::{
super::error::{BinanceWsError, BinanceWsResult},
handler::BinanceSpotWsFeedHandler,
messages::{BinanceSpotWsMessage, BinanceSpotWsStreamsCommand},
subscription::{MAX_CONNECTIONS, MAX_STREAMS_PER_CONNECTION},
};
use crate::common::{
consts::{
BINANCE_API_KEY_HEADER, BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION, BINANCE_SPOT_SBE_WS_URL,
BINANCE_WS_CONNECTION_QUOTA, BINANCE_WS_SUBSCRIPTION_QUOTA,
},
credential::Ed25519Credential,
};
struct ConnectionSlot {
cmd_tx: tokio::sync::mpsc::UnboundedSender<BinanceSpotWsStreamsCommand>,
streams: Vec<String>,
subscriptions_state: SubscriptionState,
task_handle: tokio::task::JoinHandle<()>,
cancellation_token: CancellationToken,
connection_mode: Arc<AtomicU8>,
}
#[derive(Clone)]
pub struct BinanceSpotWebSocketClient {
url: String,
credential: Option<Arc<Ed25519Credential>>,
heartbeat: Option<u64>,
signal: Arc<AtomicBool>,
slots: Arc<Mutex<Vec<ConnectionSlot>>>,
out_tx: Arc<Mutex<Option<tokio::sync::mpsc::UnboundedSender<BinanceSpotWsMessage>>>>,
out_rx: Arc<Mutex<Option<tokio::sync::mpsc::UnboundedReceiver<BinanceSpotWsMessage>>>>,
request_id_counter: Arc<AtomicU64>,
instruments_cache: Arc<AtomicMap<Ustr, InstrumentAny>>,
}
impl Debug for BinanceSpotWebSocketClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(stringify!(BinanceSpotWebSocketClient))
.field("url", &self.url)
.field("credential", &self.credential.as_ref().map(|_| REDACTED))
.field("heartbeat", &self.heartbeat)
.finish_non_exhaustive()
}
}
impl Default for BinanceSpotWebSocketClient {
fn default() -> Self {
Self::new(None, None, None, None).unwrap()
}
}
impl BinanceSpotWebSocketClient {
pub fn new(
url: Option<String>,
api_key: Option<String>,
api_secret: Option<String>,
heartbeat: Option<u64>,
) -> anyhow::Result<Self> {
let url = url.unwrap_or(BINANCE_SPOT_SBE_WS_URL.to_string());
let credential = match (api_key, api_secret) {
(Some(key), Some(secret)) => Some(Arc::new(Ed25519Credential::new(key, &secret)?)),
_ => None,
};
Ok(Self {
url,
credential,
heartbeat,
signal: Arc::new(AtomicBool::new(false)),
slots: Arc::new(Mutex::new(Vec::new())),
out_tx: Arc::new(Mutex::new(None)),
out_rx: Arc::new(Mutex::new(None)),
request_id_counter: Arc::new(AtomicU64::new(1)),
instruments_cache: Arc::new(AtomicMap::new()),
})
}
#[must_use]
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub fn is_active(&self) -> bool {
let slots = self.slots.lock().expect("slots lock poisoned");
slots
.iter()
.any(|s| s.connection_mode.load(Ordering::Relaxed) == ConnectionMode::Active as u8)
}
#[must_use]
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub fn is_closed(&self) -> bool {
let slots = self.slots.lock().expect("slots lock poisoned");
slots.is_empty()
|| slots
.iter()
.all(|s| s.connection_mode.load(Ordering::Relaxed) == ConnectionMode::Closed as u8)
}
#[must_use]
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub fn subscription_count(&self) -> usize {
let slots = self.slots.lock().expect("slots lock poisoned");
slots.iter().map(|s| s.subscriptions_state.len()).sum()
}
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub async fn connect(&mut self) -> BinanceWsResult<()> {
self.signal.store(false, Ordering::Relaxed);
let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
*self.out_tx.lock().expect("out_tx lock poisoned") = Some(out_tx);
*self.out_rx.lock().expect("out_rx lock poisoned") = Some(out_rx);
let slot = self.create_connection().await?;
self.slots.lock().expect("slots lock poisoned").push(slot);
log::info!(
"Connected to Binance Spot SBE stream pool: url={}",
self.url
);
Ok(())
}
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub async fn close(&mut self) -> BinanceWsResult<()> {
self.signal.store(true, Ordering::Relaxed);
let slots: Vec<ConnectionSlot> = {
let mut guard = self.slots.lock().expect("slots lock poisoned");
guard.drain(..).collect()
};
for slot in slots {
slot.cancellation_token.cancel();
let _ = slot.cmd_tx.send(BinanceSpotWsStreamsCommand::Disconnect);
let _ = slot.task_handle.await;
}
*self.out_tx.lock().expect("out_tx lock poisoned") = None;
*self.out_rx.lock().expect("out_rx lock poisoned") = None;
log::info!("Disconnected from Binance Spot SBE stream pool");
Ok(())
}
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub async fn subscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
let new_streams: Vec<String> = {
let slots = self.slots.lock().expect("slots lock poisoned");
streams
.into_iter()
.filter(|s| !slots.iter().any(|slot| slot.streams.contains(s)))
.collect()
};
if new_streams.is_empty() {
return Ok(());
}
loop {
let (remaining_capacity, slot_count) = {
let slots = self.slots.lock().expect("slots lock poisoned");
let cap: usize = slots
.iter()
.map(|s| MAX_STREAMS_PER_CONNECTION - s.streams.len())
.sum();
(cap, slots.len())
};
if remaining_capacity >= new_streams.len() || slot_count >= MAX_CONNECTIONS {
break;
}
let new_slot = self.create_connection().await?;
let slot_count = {
let mut slots = self.slots.lock().expect("slots lock poisoned");
slots.push(new_slot);
slots.len()
};
log::info!("Pool slot {} connected: url={}", slot_count - 1, self.url);
}
let mut slots = self.slots.lock().expect("slots lock poisoned");
let mut slot_batches: Vec<(usize, Vec<String>)> = Vec::new();
let mut slot_counts: Vec<usize> = slots.iter().map(|s| s.streams.len()).collect();
for stream in &new_streams {
let slot_idx = slot_counts
.iter()
.position(|&count| count < MAX_STREAMS_PER_CONNECTION)
.ok_or_else(|| {
let max_total = MAX_CONNECTIONS * MAX_STREAMS_PER_CONNECTION;
BinanceWsError::ClientError(format!(
"Pool exhausted: {max_total} total subscriptions \
({MAX_CONNECTIONS} connections x {MAX_STREAMS_PER_CONNECTION} streams)"
))
})?;
slot_counts[slot_idx] += 1;
if let Some(batch) = slot_batches.iter_mut().find(|(i, _)| *i == slot_idx) {
batch.1.push(stream.clone());
} else {
slot_batches.push((slot_idx, vec![stream.clone()]));
}
}
for (slot_idx, batch) in &slot_batches {
slots[*slot_idx]
.cmd_tx
.send(BinanceSpotWsStreamsCommand::Subscribe {
streams: batch.clone(),
})
.map_err(|e| {
BinanceWsError::ClientError(format!(
"Handler not available for pool slot {slot_idx}: {e}"
))
})?;
slots[*slot_idx].streams.extend(batch.iter().cloned());
}
Ok(())
}
#[allow(clippy::missing_panics_doc, reason = "mutex poisoning is not expected")]
pub async fn unsubscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
let mut slots = self.slots.lock().expect("slots lock poisoned");
let mut slot_batches: Vec<(usize, Vec<String>)> = Vec::new();
for stream in &streams {
if let Some(slot_idx) = slots.iter().position(|s| s.streams.contains(stream)) {
if let Some(batch) = slot_batches.iter_mut().find(|(i, _)| *i == slot_idx) {
batch.1.push(stream.clone());
} else {
slot_batches.push((slot_idx, vec![stream.clone()]));
}
}
}
for (slot_idx, batch) in &slot_batches {
slots[*slot_idx]
.cmd_tx
.send(BinanceSpotWsStreamsCommand::Unsubscribe {
streams: batch.clone(),
})
.map_err(|e| {
BinanceWsError::ClientError(format!(
"Handler not available for pool slot {slot_idx}: {e}"
))
})?;
for stream in batch {
slots[*slot_idx].streams.retain(|s| s != stream);
}
}
Ok(())
}
pub fn stream(&self) -> impl Stream<Item = BinanceSpotWsMessage> + 'static {
let out_rx = self.out_rx.lock().expect("out_rx lock poisoned").take();
async_stream::stream! {
if let Some(mut rx) = out_rx {
while let Some(msg) = rx.recv().await {
yield msg;
}
}
}
}
pub fn cache_instruments(&self, instruments: &[InstrumentAny]) {
self.instruments_cache.rcu(|m| {
for inst in instruments {
m.insert(inst.symbol().inner(), inst.clone());
}
});
}
pub fn cache_instrument(&self, instrument: InstrumentAny) {
self.instruments_cache
.insert(instrument.symbol().inner(), instrument);
}
#[must_use]
pub fn instruments_cache(&self) -> Arc<AtomicMap<Ustr, InstrumentAny>> {
self.instruments_cache.clone()
}
#[must_use]
pub fn get_instrument(&self, symbol: &str) -> Option<InstrumentAny> {
self.instruments_cache.get_cloned(&Ustr::from(symbol))
}
async fn create_connection(&self) -> BinanceWsResult<ConnectionSlot> {
let out_tx = self
.out_tx
.lock()
.expect("out_tx lock poisoned")
.clone()
.ok_or_else(|| {
BinanceWsError::ClientError("Output channel not initialized".to_string())
})?;
let (raw_handler, raw_rx) = channel_message_handler();
let ping_handler: PingHandler = Arc::new(move |_| {});
let headers = if let Some(ref cred) = self.credential {
vec![(
BINANCE_API_KEY_HEADER.to_string(),
cred.api_key().to_string(),
)]
} else {
vec![]
};
let config = WebSocketConfig {
url: self.url.clone(),
headers,
heartbeat: self.heartbeat,
heartbeat_msg: None,
reconnect_timeout_ms: Some(5_000),
reconnect_delay_initial_ms: Some(500),
reconnect_delay_max_ms: Some(5_000),
reconnect_backoff_factor: Some(2.0),
reconnect_jitter_ms: Some(250),
reconnect_max_attempts: None,
idle_timeout_ms: None,
};
let keyed_quotas = vec![(
BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION[0].as_str().to_string(),
*BINANCE_WS_SUBSCRIPTION_QUOTA,
)];
let client = WebSocketClient::connect(
config,
Some(raw_handler),
Some(ping_handler),
None,
keyed_quotas,
Some(*BINANCE_WS_CONNECTION_QUOTA),
)
.await
.map_err(|e| {
log::error!("WebSocket connection failed: {e}");
BinanceWsError::NetworkError(e.to_string())
})?;
let connection_mode = client.connection_mode_atomic();
let subscriptions_state = SubscriptionState::new('@');
let cancellation_token = CancellationToken::new();
let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
let mut handler = BinanceSpotWsFeedHandler::new(
self.signal.clone(),
cmd_rx,
raw_rx,
out_tx.clone(),
subscriptions_state.clone(),
self.request_id_counter.clone(),
);
cmd_tx
.send(BinanceSpotWsStreamsCommand::SetClient(client))
.map_err(|e| BinanceWsError::ClientError(format!("Failed to set client: {e}")))?;
let signal = self.signal.clone();
let token = cancellation_token.clone();
let subs = subscriptions_state.clone();
let resubscribe_tx = cmd_tx.clone();
let task_handle = get_runtime().spawn(async move {
loop {
tokio::select! {
() = token.cancelled() => {
log::debug!("Handler task cancelled");
break;
}
result = handler.next() => {
match result {
Some(BinanceSpotWsMessage::Reconnected) => {
log::info!("WebSocket reconnected, restoring subscriptions");
let all_topics = subs.all_topics();
for topic in &all_topics {
subs.mark_failure(topic);
}
let streams = subs.all_topics();
if !streams.is_empty()
&& let Err(e) = resubscribe_tx.send(BinanceSpotWsStreamsCommand::Subscribe { streams }) {
log::error!("Failed to resubscribe after reconnect: {e}");
}
if out_tx.send(BinanceSpotWsMessage::Reconnected).is_err() {
log::debug!("Output channel closed");
break;
}
}
Some(msg) => {
if out_tx.send(msg).is_err() {
log::debug!("Output channel closed");
break;
}
}
None => {
if signal.load(Ordering::Relaxed) {
log::debug!("Handler received shutdown signal");
} else {
log::warn!("Handler loop ended unexpectedly");
}
break;
}
}
}
}
}
});
Ok(ConnectionSlot {
cmd_tx,
streams: Vec::new(),
subscriptions_state,
task_handle,
cancellation_token,
connection_mode,
})
}
}