#![allow(dead_code)]
use crate::binance::Binance;
use ccxt_core::error::{Error, Result};
use ccxt_core::types::OrderBook;
use ccxt_core::types::financial::{Amount, Price};
use ccxt_core::types::orderbook::{OrderBookDelta, OrderBookEntry};
use rust_decimal::Decimal;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
struct MessageLoopContext {
ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
is_connected: Arc<std::sync::atomic::AtomicBool>,
reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
request_id: Arc<std::sync::atomic::AtomicU64>,
listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
base_url: String,
current_url: String,
}
pub struct MessageRouter {
ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
router_task: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
is_connected: Arc<std::sync::atomic::AtomicBool>,
connection_lock: Arc<Mutex<()>>,
reconnect_config: Arc<tokio::sync::RwLock<super::subscriptions::ReconnectConfig>>,
listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
ws_url: String,
request_id: Arc<std::sync::atomic::AtomicU64>,
}
impl MessageRouter {
pub fn new(
ws_url: String,
subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
) -> Self {
Self {
ws_client: Arc::new(tokio::sync::RwLock::new(None)),
subscription_manager,
router_task: Arc::new(Mutex::new(None)),
is_connected: Arc::new(std::sync::atomic::AtomicBool::new(false)),
connection_lock: Arc::new(Mutex::new(())),
reconnect_config: Arc::new(tokio::sync::RwLock::new(
super::subscriptions::ReconnectConfig::default(),
)),
listen_key_manager,
ws_url,
request_id: Arc::new(std::sync::atomic::AtomicU64::new(1)),
}
}
pub async fn start(&self, url_override: Option<String>) -> Result<()> {
let _lock = self.connection_lock.lock().await;
if self.is_connected() {
if url_override.is_some() {
self.stop().await?;
} else {
return Ok(());
}
}
let url = url_override.unwrap_or_else(|| self.ws_url.clone());
let config = ccxt_core::ws_client::WsConfig {
url: url.clone(),
..Default::default()
};
let client = ccxt_core::ws_client::WsClient::new(config);
client.connect().await?;
*self.ws_client.write().await = Some(client);
self.is_connected
.store(true, std::sync::atomic::Ordering::SeqCst);
let ws_client = self.ws_client.clone();
let subscription_manager = self.subscription_manager.clone();
let is_connected = self.is_connected.clone();
let reconnect_config = self.reconnect_config.clone();
let request_id = self.request_id.clone();
let listen_key_manager = self.listen_key_manager.clone();
let ws_url = self.ws_url.clone();
let current_url = url;
let ctx = MessageLoopContext {
ws_client,
subscription_manager,
is_connected,
reconnect_config,
request_id,
listen_key_manager,
base_url: ws_url,
current_url,
};
let handle = tokio::spawn(async move {
Self::message_loop(ctx).await;
});
*self.router_task.lock().await = Some(handle);
Ok(())
}
pub async fn stop(&self) -> Result<()> {
self.is_connected
.store(false, std::sync::atomic::Ordering::SeqCst);
let mut task_opt = self.router_task.lock().await;
if let Some(handle) = task_opt.take() {
handle.abort();
}
let mut client_opt = self.ws_client.write().await;
if let Some(client) = client_opt.take() {
let _ = client.disconnect().await;
}
Ok(())
}
pub async fn restart(&self) -> Result<()> {
self.stop().await?;
tokio::time::sleep(Duration::from_millis(100)).await;
self.start(None).await
}
pub fn get_url(&self) -> String {
self.ws_url.clone()
}
pub fn is_connected(&self) -> bool {
self.is_connected.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn latency(&self) -> Option<i64> {
if let Ok(guard) = self.ws_client.try_read() {
if let Some(ref client) = *guard {
return client.latency();
}
}
None
}
pub fn reconnect_count(&self) -> u32 {
if let Ok(guard) = self.ws_client.try_read() {
if let Some(ref client) = *guard {
return client.reconnect_count();
}
}
0
}
pub async fn set_reconnect_config(&self, config: super::subscriptions::ReconnectConfig) {
*self.reconnect_config.write().await = config;
}
pub async fn get_reconnect_config(&self) -> super::subscriptions::ReconnectConfig {
self.reconnect_config.read().await.clone()
}
pub async fn subscribe(&self, streams: Vec<String>) -> Result<()> {
if streams.is_empty() {
return Ok(());
}
let client_opt = self.ws_client.read().await;
let client = client_opt
.as_ref()
.ok_or_else(|| Error::network("WebSocket not connected"))?;
let id = self
.request_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
#[allow(clippy::disallowed_methods)]
let request = serde_json::json!({
"method": "SUBSCRIBE",
"params": streams,
"id": id
});
client
.send(tokio_tungstenite::tungstenite::protocol::Message::Text(
request.to_string().into(),
))
.await?;
Ok(())
}
pub async fn unsubscribe(&self, streams: Vec<String>) -> Result<()> {
if streams.is_empty() {
return Ok(());
}
let client_opt = self.ws_client.read().await;
let client = client_opt
.as_ref()
.ok_or_else(|| Error::network("WebSocket not connected"))?;
let id = self
.request_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
#[allow(clippy::disallowed_methods)]
let request = serde_json::json!({
"method": "UNSUBSCRIBE",
"params": streams,
"id": id
});
client
.send(tokio_tungstenite::tungstenite::protocol::Message::Text(
request.to_string().into(),
))
.await?;
Ok(())
}
async fn message_loop(ctx: MessageLoopContext) {
let mut reconnect_attempt = 0;
Self::resubscribe_all(&ctx.ws_client, &ctx.subscription_manager, &ctx.request_id).await;
loop {
if !ctx.is_connected.load(std::sync::atomic::Ordering::SeqCst) {
break;
}
let has_client = ctx.ws_client.read().await.is_some();
if !has_client {
let config = ctx.reconnect_config.read().await;
if config.should_retry(reconnect_attempt) {
let delay = config.calculate_delay(reconnect_attempt);
drop(config);
tokio::time::sleep(Duration::from_millis(delay)).await;
if let Ok(()) = Self::reconnect(
&ctx.base_url,
&ctx.current_url,
ctx.ws_client.clone(),
ctx.listen_key_manager.clone(),
)
.await
{
Self::resubscribe_all(
&ctx.ws_client,
&ctx.subscription_manager,
&ctx.request_id,
)
.await;
reconnect_attempt = 0;
continue;
}
reconnect_attempt += 1;
continue;
}
ctx.is_connected
.store(false, std::sync::atomic::Ordering::SeqCst);
break;
}
let message_opt = {
let guard = ctx.ws_client.read().await;
if let Some(client) = guard.as_ref() {
client.receive().await
} else {
None
}
};
if let Some(value) = message_opt {
if let Err(_e) = Self::handle_message(
value,
ctx.subscription_manager.clone(),
ctx.listen_key_manager.clone(),
)
.await
{
continue;
}
reconnect_attempt = 0;
} else {
let config = ctx.reconnect_config.read().await;
if config.should_retry(reconnect_attempt) {
let delay = config.calculate_delay(reconnect_attempt);
drop(config);
tokio::time::sleep(Duration::from_millis(delay)).await;
if let Ok(()) = Self::reconnect(
&ctx.base_url,
&ctx.current_url,
ctx.ws_client.clone(),
ctx.listen_key_manager.clone(),
)
.await
{
Self::resubscribe_all(
&ctx.ws_client,
&ctx.subscription_manager,
&ctx.request_id,
)
.await;
reconnect_attempt = 0;
continue;
}
reconnect_attempt += 1;
continue;
}
ctx.is_connected
.store(false, std::sync::atomic::Ordering::SeqCst);
break;
}
}
}
async fn resubscribe_all(
ws_client: &Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
subscription_manager: &Arc<super::subscriptions::SubscriptionManager>,
request_id: &Arc<std::sync::atomic::AtomicU64>,
) {
let streams = subscription_manager.get_active_streams().await;
if streams.is_empty() {
return;
}
let client_opt = ws_client.read().await;
if let Some(client) = client_opt.as_ref() {
for chunk in streams.chunks(10) {
let id = request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
#[allow(clippy::disallowed_methods)]
let request = serde_json::json!({
"method": "SUBSCRIBE",
"params": chunk,
"id": id
});
if let Err(e) = client
.send(tokio_tungstenite::tungstenite::protocol::Message::Text(
request.to_string().into(),
))
.await
{
tracing::error!("Failed to resubscribe: {}", e);
}
}
}
}
async fn handle_message(
message: Value,
subscription_manager: Arc<super::subscriptions::SubscriptionManager>,
listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
) -> Result<()> {
let stream_name = Self::extract_stream_name(&message)?;
let payload = if message.get("stream").is_some() && message.get("data").is_some() {
message.get("data").cloned().unwrap_or(message.clone())
} else {
message.clone()
};
if stream_name == "!userData" {
if let Some(event) = payload.get("e").and_then(|e| e.as_str()) {
if event == "listenKeyExpired" {
if let Some(manager) = listen_key_manager {
tracing::warn!(
"Listen key expired, regenerating and triggering reconnect..."
);
let _ = manager.regenerate().await;
return Err(Error::network("Listen key expired, reconnecting"));
}
}
}
}
let sent = subscription_manager
.send_to_stream(&stream_name, payload.clone())
.await;
if sent {
return Ok(());
}
let symbol_opt = payload.get("s").and_then(|s| s.as_str());
if let Some(symbol) = symbol_opt {
let normalized_symbol = symbol.to_lowercase();
let active_streams = subscription_manager
.get_subscriptions_by_symbol(&normalized_symbol)
.await;
tracing::debug!(
"Routing message for symbol {} (normalized: {}): stream_name={}, active_subscriptions={}",
symbol,
normalized_symbol,
stream_name,
active_streams.len()
);
let mut fallback_sent = false;
for sub in active_streams {
tracing::debug!(
"Checking subscription: stream={}, expected_starts_with={}",
sub.stream,
stream_name
);
if sub.stream.starts_with(&stream_name) {
if subscription_manager
.send_to_stream(&sub.stream, payload.clone())
.await
{
fallback_sent = true;
tracing::debug!("Successfully routed to fallback stream: {}", sub.stream);
}
}
}
if fallback_sent {
return Ok(());
}
}
Err(Error::generic("No subscribers for stream"))
}
pub fn extract_stream_name(message: &Value) -> Result<String> {
if let Some(stream) = message.get("stream").and_then(|s| s.as_str()) {
return Ok(stream.to_string());
}
if let Some(arr) = message.as_array() {
if let Some(first) = arr.first() {
if let Some(event_type) = first.get("e").and_then(|e| e.as_str()) {
match event_type {
"24hrTicker" => return Ok("!ticker@arr".to_string()),
"24hrMiniTicker" => return Ok("!miniTicker@arr".to_string()),
_ => {}
}
}
}
}
if let Some(event_type) = message.get("e").and_then(|e| e.as_str()) {
match event_type {
"outboundAccountPosition"
| "balanceUpdate"
| "executionReport"
| "listStatus"
| "ACCOUNT_UPDATE"
| "ORDER_TRADE_UPDATE"
| "listenKeyExpired" => {
return Ok("!userData".to_string());
}
_ => {}
}
if let Some(symbol) = message.get("s").and_then(|s| s.as_str()) {
let stream = match event_type {
"24hrTicker" => format!("{}@ticker", symbol.to_lowercase()),
"24hrMiniTicker" => format!("{}@miniTicker", symbol.to_lowercase()),
"depthUpdate" => format!("{}@depth", symbol.to_lowercase()),
"aggTrade" => format!("{}@aggTrade", symbol.to_lowercase()),
"trade" => format!("{}@trade", symbol.to_lowercase()),
"kline" => {
if let Some(kline) = message.get("k") {
if let Some(interval) = kline.get("i").and_then(|i| i.as_str()) {
format!("{}@kline_{}", symbol.to_lowercase(), interval)
} else {
return Err(Error::generic("Missing kline interval"));
}
} else {
return Err(Error::generic("Missing kline data"));
}
}
"markPriceUpdate" => format!("{}@markPrice", symbol.to_lowercase()),
"bookTicker" => format!("{}@bookTicker", symbol.to_lowercase()),
_ => {
return Err(Error::generic(format!(
"Unknown event type: {}",
event_type
)));
}
};
return Ok(stream);
}
}
if message.get("result").is_some() || message.get("error").is_some() {
return Err(Error::generic("Subscription response, skip routing"));
}
Err(Error::generic("Cannot extract stream name from message"))
}
async fn reconnect(
base_url: &str,
current_url: &str,
ws_client: Arc<tokio::sync::RwLock<Option<ccxt_core::ws_client::WsClient>>>,
listen_key_manager: Option<Arc<super::listen_key::ListenKeyManager>>,
) -> Result<()> {
{
let mut client_opt = ws_client.write().await;
if let Some(client) = client_opt.take() {
let _ = client.disconnect().await;
}
}
let mut final_url = current_url.to_string();
if let Some(manager) = listen_key_manager {
if current_url != base_url {
if let Ok(key) = manager.get_or_create().await {
let base = if let Some(stripped) = base_url.strip_suffix('/') {
stripped
} else {
base_url
};
final_url = format!("{}/{}", base, key);
}
}
}
let config = ccxt_core::ws_client::WsConfig {
url: final_url,
..Default::default()
};
let new_client = ccxt_core::ws_client::WsClient::new(config);
new_client.connect().await?;
*ws_client.write().await = Some(new_client);
Ok(())
}
}
impl Drop for MessageRouter {
fn drop(&mut self) {
}
}
pub async fn handle_orderbook_delta(
symbol: &str,
delta_message: &Value,
is_futures: bool,
orderbooks: &Mutex<HashMap<String, OrderBook>>,
) -> Result<()> {
let first_update_id = delta_message["U"]
.as_i64()
.ok_or_else(|| Error::invalid_request("Missing first update ID in delta message"))?;
let final_update_id = delta_message["u"]
.as_i64()
.ok_or_else(|| Error::invalid_request("Missing final update ID in delta message"))?;
let prev_final_update_id = if is_futures {
delta_message["pu"].as_i64()
} else {
None
};
let timestamp = delta_message["E"]
.as_i64()
.unwrap_or_else(|| chrono::Utc::now().timestamp_millis());
let mut bids = Vec::new();
if let Some(bids_arr) = delta_message["b"].as_array() {
for bid in bids_arr {
if let (Some(price_str), Some(amount_str)) = (bid[0].as_str(), bid[1].as_str()) {
if let (Ok(price), Ok(amount)) =
(price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
{
bids.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
}
}
}
}
let mut asks = Vec::new();
if let Some(asks_arr) = delta_message["a"].as_array() {
for ask in asks_arr {
if let (Some(price_str), Some(amount_str)) = (ask[0].as_str(), ask[1].as_str()) {
if let (Ok(price), Ok(amount)) =
(price_str.parse::<Decimal>(), amount_str.parse::<Decimal>())
{
asks.push(OrderBookEntry::new(Price::new(price), Amount::new(amount)));
}
}
}
}
let delta = OrderBookDelta {
symbol: symbol.to_string(),
first_update_id,
final_update_id,
prev_final_update_id,
timestamp,
bids,
asks,
};
let mut orderbooks_map = orderbooks.lock().await;
let orderbook = orderbooks_map
.entry(symbol.to_string())
.or_insert_with(|| OrderBook::new(symbol.to_string(), timestamp));
if !orderbook.is_synced {
orderbook.buffer_delta(delta);
return Ok(());
}
if let Err(e) = orderbook.apply_delta(&delta, is_futures) {
if orderbook.needs_resync {
tracing::warn!("Orderbook {} needs resync due to: {}", symbol, e);
orderbook.buffer_delta(delta);
return Err(Error::invalid_request(format!("RESYNC_NEEDED: {}", e)));
}
return Err(Error::invalid_request(e));
}
Ok(())
}
pub async fn fetch_orderbook_snapshot(
exchange: &Binance,
symbol: &str,
limit: Option<i64>,
is_futures: bool,
orderbooks: &Mutex<HashMap<String, OrderBook>>,
) -> Result<OrderBook> {
let snapshot = exchange
.fetch_order_book(symbol, limit.map(|l| l as u32))
.await?;
let mut orderbooks_map = orderbooks.lock().await;
let cached_ob = orderbooks_map
.entry(symbol.to_string())
.or_insert_with(|| OrderBook::new(symbol.to_string(), snapshot.timestamp));
cached_ob.reset_from_snapshot(
snapshot.bids,
snapshot.asks,
snapshot.timestamp,
snapshot.nonce,
);
if let Ok(processed) = cached_ob.process_buffered_deltas(is_futures) {
tracing::debug!("Processed {} buffered deltas for {}", processed, symbol);
}
Ok(cached_ob.clone())
}
#[cfg(test)]
mod tests {
#![allow(clippy::disallowed_methods)]
use super::*;
use serde_json::json;
use std::sync::Arc;
#[test]
fn test_extract_stream_name_combined() {
let message = json!({
"stream": "btcusdt@ticker",
"data": {
"e": "24hrTicker",
"s": "BTCUSDT"
}
});
let stream = MessageRouter::extract_stream_name(&message).unwrap();
assert_eq!(stream, "btcusdt@ticker");
}
#[test]
fn test_extract_stream_name_raw() {
let message = json!({
"e": "24hrTicker",
"s": "BTCUSDT"
});
let stream = MessageRouter::extract_stream_name(&message).unwrap();
assert_eq!(stream, "btcusdt@ticker");
}
#[tokio::test]
async fn test_handle_message_unwrapping() {
let manager = Arc::new(crate::binance::ws::subscriptions::SubscriptionManager::new());
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
manager
.add_subscription(
"btcusdt@ticker".to_string(),
"BTCUSDT".to_string(),
crate::binance::ws::subscriptions::SubscriptionType::Ticker,
tx,
)
.await
.unwrap();
let message = json!({
"stream": "btcusdt@ticker",
"data": {
"e": "24hrTicker",
"s": "BTCUSDT",
"c": "50000.00"
}
});
MessageRouter::handle_message(message, manager, None)
.await
.unwrap();
let received = rx.recv().await.unwrap();
assert!(received.get("stream").is_none());
assert_eq!(received["e"], "24hrTicker");
assert_eq!(received["c"], "50000.00");
}
#[tokio::test]
async fn test_handle_message_mark_price_fallback() {
let manager = Arc::new(crate::binance::ws::subscriptions::SubscriptionManager::new());
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
manager
.add_subscription(
"btcusdt@markPrice@1s".to_string(),
"btcusdt".to_string(),
crate::binance::ws::subscriptions::SubscriptionType::MarkPrice,
tx,
)
.await
.unwrap();
let message = json!({
"e": "markPriceUpdate",
"s": "BTCUSDT",
"p": "50000.00",
"E": 123456789
});
MessageRouter::handle_message(message, manager, None)
.await
.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received["e"], "markPriceUpdate");
assert_eq!(received["p"], "50000.00");
}
}