use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt, stream::SplitSink, stream::SplitStream};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{mpsc, broadcast, Mutex};
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
AccountType, Symbol,
ExchangeError, ExchangeResult,
ConnectionStatus, StreamEvent, SubscriptionRequest,
};
use crate::core::types::{WebSocketResult, WebSocketError, OrderbookCapabilities, WsBookChannel};
use crate::core::traits::WebSocketConnector;
use super::endpoints::{BitstampUrls, format_symbol};
use super::parser::BitstampParser;
#[derive(Debug, Clone, Serialize)]
struct SubscribeMessage {
event: String,
data: ChannelData,
}
#[derive(Debug, Clone, Serialize)]
struct ChannelData {
channel: String,
}
#[derive(Debug, Clone, Deserialize)]
struct IncomingMessage {
event: String,
channel: Option<String>,
data: Option<Value>,
}
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsWriter = SplitSink<WsStream, Message>;
type WsReader = SplitStream<WsStream>;
pub struct BitstampWebSocket {
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
ticker_channels: Arc<Mutex<HashSet<String>>>,
event_tx: Arc<Mutex<Option<mpsc::UnboundedSender<WebSocketResult<StreamEvent>>>>>,
broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_writer: Arc<Mutex<Option<WsWriter>>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl BitstampWebSocket {
pub async fn new() -> ExchangeResult<Self> {
Ok(Self {
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
ticker_channels: Arc::new(Mutex::new(HashSet::new())),
event_tx: Arc::new(Mutex::new(None)),
broadcast_tx: Arc::new(StdMutex::new(None)),
ws_writer: Arc::new(Mutex::new(None)),
last_ping: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
async fn subscribe_channel(&self, channel: &str) -> ExchangeResult<()> {
let msg = SubscribeMessage {
event: "bts:subscribe".to_string(),
data: ChannelData {
channel: channel.to_string(),
},
};
let json = serde_json::to_string(&msg)
.map_err(|e| ExchangeError::Parse(format!("Failed to serialize: {}", e)))?;
let mut writer_guard = self.ws_writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
writer.send(Message::Text(json))
.await
.map_err(|e| ExchangeError::Network(format!("Failed to send message: {}", e)))?;
} else {
return Err(ExchangeError::Network("Not connected".to_string()));
}
Ok(())
}
pub async fn subscribe_ticker(&self, symbol: Symbol) -> ExchangeResult<()> {
let pair = format_symbol(&symbol, AccountType::Spot);
let channel = format!("live_trades_{}", pair);
self.ticker_channels.lock().await.insert(channel.clone());
self.subscribe_channel(&channel).await
}
pub async fn subscribe_trades(&self, symbol: Symbol) -> ExchangeResult<()> {
let pair = format_symbol(&symbol, AccountType::Spot);
let channel = format!("live_trades_{}", pair);
self.subscribe_channel(&channel).await
}
pub async fn subscribe_orderbook(&self, symbol: Symbol) -> ExchangeResult<()> {
let pair = format_symbol(&symbol, AccountType::Spot);
let channel = format!("order_book_{}", pair);
self.subscribe_channel(&channel).await
}
fn start_message_handler(
reader: WsReader,
ws_writer: Arc<Mutex<Option<WsWriter>>>,
event_tx: mpsc::UnboundedSender<WebSocketResult<StreamEvent>>,
status: Arc<Mutex<ConnectionStatus>>,
ticker_channels: Arc<Mutex<HashSet<String>>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
) {
tokio::spawn(async move {
let mut reader = reader;
loop {
match reader.next().await {
Some(Ok(Message::Text(text))) => {
if let Err(e) = Self::handle_message(&text, &event_tx, &ticker_channels).await {
let _ = event_tx.send(Err(e));
}
}
Some(Ok(Message::Ping(data))) => {
let mut writer_guard = ws_writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
let _ = writer.send(Message::Pong(data)).await;
}
}
Some(Ok(Message::Pong(_))) => {
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
}
Some(Ok(Message::Binary(_))) => {
}
Some(Ok(Message::Close(_))) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Some(Ok(Message::Frame(_))) => {
}
Some(Err(_e)) => {
let _ = event_tx.send(Err(WebSocketError::ConnectionError(
"WebSocket read error".to_string()
)));
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
None => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
}
}
});
}
fn start_ping_task(
ws_writer: Arc<Mutex<Option<WsWriter>>>,
status: Arc<Mutex<ConnectionStatus>>,
last_ping: Arc<Mutex<Instant>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.tick().await;
loop {
interval.tick().await;
let current_status = *status.lock().await;
if current_status != ConnectionStatus::Connected {
break;
}
let mut writer_guard = ws_writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
*last_ping.lock().await = Instant::now();
if writer.send(Message::Ping(vec![])).await.is_err() {
break;
}
} else {
break;
}
}
});
}
async fn handle_message(
text: &str,
event_tx: &mpsc::UnboundedSender<WebSocketResult<StreamEvent>>,
ticker_channels: &Arc<Mutex<HashSet<String>>>,
) -> WebSocketResult<()> {
let msg: IncomingMessage = serde_json::from_str(text)
.map_err(|e| WebSocketError::Parse(format!("Failed to parse message: {}", e)))?;
match msg.event.as_str() {
"pusher:connection_established" => {
return Ok(());
}
"pusher:pong" => {
return Ok(());
}
"pusher:error" => {
return Err(WebSocketError::ProtocolError(
format!("Pusher error: {:?}", msg.data)
));
}
"bts:subscription_succeeded" => {
return Ok(());
}
"bts:error" => {
return Err(WebSocketError::ProtocolError(
format!("Bitstamp error: {:?}", msg.data)
));
}
"bts:request_reconnect" => {
return Err(WebSocketError::ConnectionError(
"Server requested reconnection (bts:request_reconnect)".to_string()
));
}
"trade" | "data" => {
let is_ticker_channel = if let Some(ch) = msg.channel.as_ref() {
ticker_channels.lock().await.contains(ch)
} else {
false
};
if let Some(event) = Self::parse_data_message(&msg, is_ticker_channel)? {
let _ = event_tx.send(Ok(event));
}
}
_ => {
}
}
Ok(())
}
fn parse_data_message(msg: &IncomingMessage, as_ticker: bool) -> WebSocketResult<Option<StreamEvent>> {
let channel = msg.channel.as_ref()
.ok_or_else(|| WebSocketError::Parse("Missing channel".to_string()))?;
let json = serde_json::json!({
"channel": channel,
"event": &msg.event,
"data": msg.data
});
if channel.starts_with("live_trades_") {
if as_ticker {
let trade = BitstampParser::parse_ws_trade(&json)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
let ticker = crate::core::types::Ticker {
symbol: trade.symbol,
last_price: trade.price,
bid_price: None,
ask_price: None,
high_24h: None,
low_24h: None,
volume_24h: None,
quote_volume_24h: None,
price_change_24h: None,
price_change_percent_24h: None,
timestamp: trade.timestamp,
};
Ok(Some(StreamEvent::Ticker(ticker)))
} else {
let trade = BitstampParser::parse_ws_trade(&json)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::Trade(trade)))
}
} else if channel.starts_with("diff_order_book_") {
let orderbook = BitstampParser::parse_ws_orderbook(&json)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::OrderbookSnapshot(orderbook)))
} else if channel.starts_with("order_book_") {
let orderbook = BitstampParser::parse_ws_orderbook(&json)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
Ok(Some(StreamEvent::OrderbookSnapshot(orderbook)))
} else {
Ok(None)
}
}
}
#[async_trait]
impl WebSocketConnector for BitstampWebSocket {
async fn connect(&mut self, _account_type: AccountType) -> WebSocketResult<()> {
*self.status.lock().await = ConnectionStatus::Connecting;
let url = BitstampUrls::ws_url();
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| WebSocketError::ConnectionError(format!("Failed to connect: {}", e)))?;
let (writer, reader) = ws_stream.split();
*self.ws_writer.lock().await = Some(writer);
*self.status.lock().await = ConnectionStatus::Connected;
let (tx, mut rx) = mpsc::unbounded_channel();
*self.event_tx.lock().await = Some(tx.clone());
Self::start_message_handler(
reader,
self.ws_writer.clone(),
tx,
self.status.clone(),
self.ticker_channels.clone(),
self.last_ping.clone(),
self.ws_ping_rtt_ms.clone(),
);
Self::start_ping_task(
self.ws_writer.clone(),
self.status.clone(),
self.last_ping.clone(),
);
let (broadcast_sender, _) = broadcast::channel(1000);
*self.broadcast_tx.lock().unwrap() = Some(broadcast_sender);
let broadcast_tx = self.broadcast_tx.clone();
tokio::spawn(async move {
while let Some(event) = rx.recv().await {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(event);
}
}
let _ = broadcast_tx.lock().unwrap().take();
});
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
*self.status.lock().await = ConnectionStatus::Disconnected;
*self.ws_writer.lock().await = None;
*self.event_tx.lock().await = None;
let _ = self.broadcast_tx.lock().unwrap().take();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
match self.status.try_lock() {
Ok(guard) => *guard,
Err(_) => ConnectionStatus::Disconnected,
}
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let result = match request.stream_type {
crate::core::types::StreamType::Ticker => {
self.subscribe_ticker(request.symbol.clone()).await
.map_err(|e| WebSocketError::Subscription(format!("{:?}", e)))
}
crate::core::types::StreamType::Trade => {
self.subscribe_trades(request.symbol.clone()).await
.map_err(|e| WebSocketError::Subscription(format!("{:?}", e)))
}
crate::core::types::StreamType::Orderbook => {
self.subscribe_orderbook(request.symbol.clone()).await
.map_err(|e| WebSocketError::Subscription(format!("{:?}", e)))
}
crate::core::types::StreamType::OrderbookDelta => {
let pair = format_symbol(&request.symbol, AccountType::Spot);
let channel = format!("diff_order_book_{}", pair);
self.subscribe_channel(&channel).await
.map_err(|e| WebSocketError::Subscription(format!("{:?}", e)))
}
_ => Err(WebSocketError::Subscription("Unsupported subscription type".to_string())),
};
if result.is_ok() {
self.subscriptions.lock().await.insert(request);
}
result
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
self.subscriptions.lock().await.remove(&request);
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let rx = self.broadcast_tx.lock().unwrap().as_ref()
.map(|tx| tx.subscribe())
.unwrap_or_else(|| broadcast::channel(1).1);
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|res| async move {
res.ok()
}))
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.subscriptions.try_lock() {
Ok(guard) => guard.iter().cloned().collect(),
Err(_) => Vec::new(),
}
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
static BITSTAMP_CHANNELS: &[WsBookChannel] = &[
WsBookChannel::snapshot("order_book", 100, 1000),
WsBookChannel::delta("diff_order_book", None, None),
];
OrderbookCapabilities {
ws_depths: &[],
ws_default_depth: None,
rest_max_depth: None,
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[],
default_speed_ms: None,
ws_channels: BITSTAMP_CHANNELS,
checksum: None,
has_sequence: false,
has_prev_sequence: false,
supports_aggregation: true,
aggregation_levels: &["0", "1", "2"],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_websocket_creation() {
let ws = BitstampWebSocket::new().await;
assert!(ws.is_ok());
}
#[tokio::test]
async fn test_subscribe_message() {
let msg = SubscribeMessage {
event: "bts:subscribe".to_string(),
data: ChannelData {
channel: "diff_order_book_btcusd".to_string(),
},
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("bts:subscribe"));
assert!(json.contains("diff_order_book_btcusd"));
}
#[tokio::test]
async fn test_pusher_message_parsing() {
let established = r#"{"event":"pusher:connection_established","data":"{\"socket_id\":\"123\",\"activity_timeout\":120}"}"#;
let parsed: IncomingMessage = serde_json::from_str(established).unwrap();
assert_eq!(parsed.event, "pusher:connection_established");
}
}