use std::{
collections::VecDeque,
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
};
use ahash::AHashMap;
use nautilus_network::{
RECONNECTED,
websocket::{SubscriptionState, WebSocketClient},
};
use tokio_tungstenite::tungstenite::Message;
use ustr::Ustr;
pub use super::parse::{MarketDataMessage, decode_market_data};
use super::{
messages::{
BinanceSpotServerShutdownMsg, BinanceSpotWsMessage, BinanceSpotWsStreamsCommand,
BinanceWsErrorMsg, BinanceWsErrorResponse, BinanceWsResponse, BinanceWsSubscription,
},
parse::decode_market_data as decode_sbe,
};
use crate::common::consts::BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION;
pub(super) struct BinanceSpotWsFeedHandler {
#[allow(dead_code)]
signal: Arc<AtomicBool>,
inner: Option<WebSocketClient>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<BinanceSpotWsStreamsCommand>,
raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
#[allow(dead_code)]
out_tx: tokio::sync::mpsc::UnboundedSender<BinanceSpotWsMessage>,
subscriptions: SubscriptionState,
request_id_counter: Arc<AtomicU64>,
pending_messages: VecDeque<BinanceSpotWsMessage>,
pending_requests: AHashMap<u64, Vec<String>>,
}
impl BinanceSpotWsFeedHandler {
pub(super) fn new(
signal: Arc<AtomicBool>,
cmd_rx: tokio::sync::mpsc::UnboundedReceiver<BinanceSpotWsStreamsCommand>,
raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
out_tx: tokio::sync::mpsc::UnboundedSender<BinanceSpotWsMessage>,
subscriptions: SubscriptionState,
request_id_counter: Arc<AtomicU64>,
) -> Self {
Self {
signal,
inner: None,
cmd_rx,
raw_rx,
out_tx,
subscriptions,
request_id_counter,
pending_messages: VecDeque::new(),
pending_requests: AHashMap::new(),
}
}
pub(super) async fn next(&mut self) -> Option<BinanceSpotWsMessage> {
if let Some(message) = self.pending_messages.pop_front() {
return Some(message);
}
loop {
tokio::select! {
Some(cmd) = self.cmd_rx.recv() => {
match cmd {
BinanceSpotWsStreamsCommand::SetClient(client) => {
log::debug!("Handler received WebSocket client");
self.inner = Some(client);
}
BinanceSpotWsStreamsCommand::Disconnect => {
log::debug!("Handler disconnecting WebSocket client");
self.inner = None;
return None;
}
BinanceSpotWsStreamsCommand::Subscribe { streams } => {
if let Err(e) = self.handle_subscribe(streams).await {
log::error!("Failed to handle subscribe command: {e}");
}
}
BinanceSpotWsStreamsCommand::Unsubscribe { streams } => {
if let Err(e) = self.handle_unsubscribe(streams).await {
log::error!("Failed to handle unsubscribe command: {e}");
}
}
}
}
Some(msg) = self.raw_rx.recv() => {
if let Message::Text(ref text) = msg
&& text.as_str() == RECONNECTED
{
log::info!("Handler received reconnection signal");
return Some(BinanceSpotWsMessage::Reconnected);
}
let messages = self.handle_message(msg);
if !messages.is_empty() {
let mut iter = messages.into_iter();
let first = iter.next();
self.pending_messages.extend(iter);
if let Some(msg) = first {
return Some(msg);
}
}
}
else => {
return None;
}
}
}
}
fn handle_message(&mut self, msg: Message) -> Vec<BinanceSpotWsMessage> {
match msg {
Message::Binary(data) => self.handle_binary_frame(&data),
Message::Text(text) => self.handle_text_frame(&text),
Message::Close(_) => {
log::debug!("Received close frame");
vec![]
}
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => vec![],
}
}
fn handle_binary_frame(&self, data: &[u8]) -> Vec<BinanceSpotWsMessage> {
match decode_sbe(data) {
Ok(MarketDataMessage::Trades(event)) => {
vec![BinanceSpotWsMessage::Trades(event)]
}
Ok(MarketDataMessage::BestBidAsk(event)) => {
vec![BinanceSpotWsMessage::BestBidAsk(event)]
}
Ok(MarketDataMessage::DepthSnapshot(event)) => {
vec![BinanceSpotWsMessage::DepthSnapshot(event)]
}
Ok(MarketDataMessage::DepthDiff(event)) => {
vec![BinanceSpotWsMessage::DepthDiff(event)]
}
Err(e) => {
log::error!("SBE decode error: {e}");
vec![BinanceSpotWsMessage::RawBinary(data.to_vec())]
}
}
}
fn handle_text_frame(&mut self, text: &str) -> Vec<BinanceSpotWsMessage> {
if let Ok(error) = serde_json::from_str::<BinanceWsErrorResponse>(text) {
if let Some(id) = error.id
&& let Some(streams) = self.pending_requests.remove(&id)
{
for stream in &streams {
self.subscriptions.mark_failure(stream);
}
log::warn!(
"Subscription request failed: id={id}, streams={streams:?}, code={}, msg={}",
error.code,
error.msg
);
}
return vec![BinanceSpotWsMessage::Error(BinanceWsErrorMsg {
code: error.code,
msg: error.msg,
})];
}
if let Ok(response) = serde_json::from_str::<BinanceWsResponse>(text) {
self.handle_subscription_response(&response);
return vec![];
}
classify_unsolicited_json(text)
}
fn handle_subscription_response(&mut self, response: &BinanceWsResponse) {
if let Some(streams) = self.pending_requests.remove(&response.id) {
if response.result.is_none() {
for stream in &streams {
self.subscriptions.confirm_subscribe(stream);
}
log::debug!("Subscription confirmed: streams={streams:?}");
} else {
for stream in &streams {
self.subscriptions.mark_failure(stream);
}
log::warn!(
"Subscription failed: streams={streams:?}, result={:?}",
response.result
);
}
} else {
log::debug!("Received response for unknown request: id={}", response.id);
}
}
async fn handle_subscribe(&mut self, streams: Vec<String>) -> anyhow::Result<()> {
let request_id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
let request = BinanceWsSubscription::subscribe(streams.clone(), request_id);
let payload = serde_json::to_string(&request)?;
self.pending_requests.insert(request_id, streams.clone());
for stream in &streams {
self.subscriptions.mark_subscribe(stream);
}
self.send_text(
payload,
Some(BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION.as_slice()),
)
.await?;
Ok(())
}
async fn handle_unsubscribe(&self, streams: Vec<String>) -> anyhow::Result<()> {
let request_id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
let request = BinanceWsSubscription::unsubscribe(streams.clone(), request_id);
let payload = serde_json::to_string(&request)?;
self.send_text(
payload,
Some(BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION.as_slice()),
)
.await?;
for stream in &streams {
self.subscriptions.mark_unsubscribe(stream);
self.subscriptions.confirm_unsubscribe(stream);
}
Ok(())
}
async fn send_text(
&self,
payload: String,
rate_limit_keys: Option<&[Ustr]>,
) -> anyhow::Result<()> {
let Some(client) = &self.inner else {
anyhow::bail!("No active WebSocket client");
};
client
.send_text(payload, rate_limit_keys)
.await
.map_err(|e| anyhow::anyhow!("Failed to send message: {e}"))?;
Ok(())
}
}
fn classify_unsolicited_json(text: &str) -> Vec<BinanceSpotWsMessage> {
let Ok(value) = serde_json::from_str::<serde_json::Value>(text) else {
log::warn!("Failed to parse JSON message: {text}");
return vec![];
};
if value.get("e").and_then(|v| v.as_str()) == Some("serverShutdown")
&& let Ok(msg) = serde_json::from_value::<BinanceSpotServerShutdownMsg>(value.clone())
{
log::warn!(
"Binance server shutdown notice received (event_time={}); disconnect expected ~10 minutes from event",
msg.event_time,
);
return vec![BinanceSpotWsMessage::ServerShutdown(msg)];
}
vec![BinanceSpotWsMessage::RawJson(value)]
}
#[cfg(test)]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicBool, AtomicU64},
};
use nautilus_network::websocket::SubscriptionState;
use rstest::rstest;
use super::*;
#[rstest]
fn test_classify_unsolicited_json_server_shutdown_emits_variant() {
let text = r#"{"e":"serverShutdown","E":1700000000000}"#;
let out = classify_unsolicited_json(text);
assert_eq!(out.len(), 1);
match &out[0] {
BinanceSpotWsMessage::ServerShutdown(msg) => {
assert_eq!(msg.event_type, "serverShutdown");
assert_eq!(msg.event_time, 1_700_000_000_000);
}
other => panic!("expected ServerShutdown variant, was {other:?}"),
}
}
#[rstest]
fn test_classify_unsolicited_json_unrelated_emits_raw_json() {
let text = r#"{"e":"trade","p":"50000"}"#;
let out = classify_unsolicited_json(text);
assert_eq!(out.len(), 1);
assert!(matches!(out[0], BinanceSpotWsMessage::RawJson(_)));
}
#[rstest]
fn test_classify_unsolicited_json_invalid_returns_empty() {
let out = classify_unsolicited_json("not json");
assert!(out.is_empty());
}
#[rstest]
fn test_handle_text_frame_error_with_id_emits_error_and_clears_pending_request() {
let signal = Arc::new(AtomicBool::new(false));
let request_id_counter = Arc::new(AtomicU64::new(2));
let (_cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
let (_raw_tx, raw_rx) = tokio::sync::mpsc::unbounded_channel();
let (out_tx, _out_rx) = tokio::sync::mpsc::unbounded_channel();
let subscriptions = SubscriptionState::new('@');
let mut handler = BinanceSpotWsFeedHandler::new(
signal,
cmd_rx,
raw_rx,
out_tx,
subscriptions,
request_id_counter,
);
handler
.pending_requests
.insert(1, vec!["btcusdt@trade".to_string()]);
let out = handler.handle_text_frame(r#"{"code":2,"msg":"Invalid request","id":1}"#);
assert_eq!(out.len(), 1);
match &out[0] {
BinanceSpotWsMessage::Error(err) => {
assert_eq!(err.code, 2);
assert_eq!(err.msg, "Invalid request");
}
other => panic!("expected Error variant, was {other:?}"),
}
assert!(handler.pending_requests.is_empty());
}
}