use crate::errors::{PolyfillError, Result};
use crate::types::*;
use crate::ws_hot_path::{WsBookApplyStats, WsBookUpdateProcessor};
use chrono::Utc;
use futures::{SinkExt, Stream, StreamExt};
use serde_json::Value;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
fn is_connected(&self) -> bool;
fn get_stats(&self) -> StreamStats;
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct WebSocketStream {
connection: Option<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
url: String,
auth: Option<WssAuth>,
subscriptions: Vec<WssSubscription>,
pending: VecDeque<StreamMessage>,
pending_capacity: usize,
stats: StreamStats,
reconnect_config: ReconnectConfig,
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub messages_received: u64,
pub messages_sent: u64,
pub errors: u64,
pub dropped_messages: u64,
pub last_message_time: Option<chrono::DateTime<Utc>>,
pub connection_uptime: std::time::Duration,
pub reconnect_count: u32,
}
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub max_retries: u32,
pub base_delay: std::time::Duration,
pub max_delay: std::time::Duration,
pub backoff_multiplier: f64,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
max_retries: 5,
base_delay: std::time::Duration::from_secs(1),
max_delay: std::time::Duration::from_secs(60),
backoff_multiplier: 2.0,
}
}
}
impl WebSocketStream {
pub fn new(url: &str) -> Self {
let pending_capacity = 1024;
Self {
connection: None,
url: url.to_string(),
auth: None,
subscriptions: Vec::new(),
pending: VecDeque::with_capacity(pending_capacity),
pending_capacity,
stats: StreamStats {
messages_received: 0,
messages_sent: 0,
errors: 0,
dropped_messages: 0,
last_message_time: None,
connection_uptime: std::time::Duration::ZERO,
reconnect_count: 0,
},
reconnect_config: ReconnectConfig::default(),
}
}
fn enqueue(&mut self, message: StreamMessage) {
if self.pending.len() >= self.pending_capacity {
let _ = self.pending.pop_front();
self.stats.dropped_messages += 1;
}
self.pending.push_back(message);
}
pub fn with_auth(mut self, auth: WssAuth) -> Self {
self.auth = Some(auth);
self
}
async fn connect(&mut self) -> Result<()> {
let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url)
.await
.map_err(|e| {
PolyfillError::stream(
format!("WebSocket connection failed: {}", e),
crate::errors::StreamErrorKind::ConnectionFailed,
)
})?;
self.connection = Some(ws_stream);
info!("Connected to WebSocket stream at {}", self.url);
Ok(())
}
async fn send_message(&mut self, message: Value) -> Result<()> {
if let Some(connection) = &mut self.connection {
let text = serde_json::to_string(&message).map_err(|e| {
PolyfillError::parse(format!("Failed to serialize message: {}", e), None)
})?;
let ws_message = tokio_tungstenite::tungstenite::Message::Text(text.into());
connection.send(ws_message).await.map_err(|e| {
PolyfillError::stream(
format!("Failed to send message: {}", e),
crate::errors::StreamErrorKind::MessageCorrupted,
)
})?;
self.stats.messages_sent += 1;
}
Ok(())
}
pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
if self.connection.is_none() {
self.connect().await?;
}
let message = serde_json::to_value(&subscription).map_err(|e| {
PolyfillError::parse(format!("Failed to serialize subscription: {}", e), None)
})?;
self.send_message(message).await?;
self.subscriptions.push(subscription.clone());
info!("Subscribed to {} channel", subscription.channel_type);
Ok(())
}
pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
let auth = self
.auth
.as_ref()
.ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
.clone();
let subscription = WssSubscription {
channel_type: "user".to_string(),
operation: Some("subscribe".to_string()),
markets,
asset_ids: Vec::new(),
initial_dump: Some(true),
custom_feature_enabled: None,
auth: Some(auth),
};
self.subscribe_async(subscription).await
}
pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
let subscription = WssSubscription {
channel_type: "market".to_string(),
operation: Some("subscribe".to_string()),
markets: Vec::new(),
asset_ids,
initial_dump: Some(true),
custom_feature_enabled: None,
auth: None,
};
self.subscribe_async(subscription).await
}
pub async fn subscribe_market_channel_with_features(
&mut self,
asset_ids: Vec<String>,
) -> Result<()> {
let subscription = WssSubscription {
channel_type: "market".to_string(),
operation: Some("subscribe".to_string()),
markets: Vec::new(),
asset_ids,
initial_dump: Some(true),
custom_feature_enabled: Some(true),
auth: None,
};
self.subscribe_async(subscription).await
}
pub async fn unsubscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
let subscription = WssSubscription {
channel_type: "market".to_string(),
operation: Some("unsubscribe".to_string()),
markets: Vec::new(),
asset_ids,
initial_dump: None,
custom_feature_enabled: None,
auth: None,
};
self.subscribe_async(subscription).await
}
pub async fn unsubscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
let auth = self
.auth
.as_ref()
.ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
.clone();
let subscription = WssSubscription {
channel_type: "user".to_string(),
operation: Some("unsubscribe".to_string()),
markets,
asset_ids: Vec::new(),
initial_dump: None,
custom_feature_enabled: None,
auth: Some(auth),
};
self.subscribe_async(subscription).await
}
#[allow(dead_code)]
async fn handle_message(
&mut self,
message: tokio_tungstenite::tungstenite::Message,
) -> Result<()> {
match message {
tokio_tungstenite::tungstenite::Message::Text(text) => {
debug!("Received WebSocket message: {}", text);
let stream_messages = crate::decode::parse_stream_messages(&text)?;
for stream_message in stream_messages {
self.enqueue(stream_message);
}
self.stats.messages_received += 1;
self.stats.last_message_time = Some(Utc::now());
},
tokio_tungstenite::tungstenite::Message::Close(_) => {
info!("WebSocket connection closed by server");
self.connection = None;
},
tokio_tungstenite::tungstenite::Message::Ping(data) => {
if let Some(connection) = &mut self.connection {
let pong = tokio_tungstenite::tungstenite::Message::Pong(data);
if let Err(e) = connection.send(pong).await {
error!("Failed to send pong: {}", e);
}
}
},
tokio_tungstenite::tungstenite::Message::Pong(_) => {
debug!("Received pong");
},
tokio_tungstenite::tungstenite::Message::Binary(_) => {
warn!("Received binary message (not supported)");
},
tokio_tungstenite::tungstenite::Message::Frame(_) => {
warn!("Received raw frame (not supported)");
},
}
Ok(())
}
#[allow(dead_code)]
fn parse_polymarket_messages(&self, text: &str) -> Result<Vec<StreamMessage>> {
crate::decode::parse_stream_messages(text)
}
#[allow(dead_code)]
async fn reconnect(&mut self) -> Result<()> {
let mut delay = self.reconnect_config.base_delay;
let mut retries = 0;
while retries < self.reconnect_config.max_retries {
warn!("Attempting to reconnect (attempt {})", retries + 1);
match self.connect().await {
Ok(()) => {
info!("Successfully reconnected");
self.stats.reconnect_count += 1;
let subscriptions = self.subscriptions.clone();
for subscription in subscriptions {
self.send_message(serde_json::to_value(subscription)?)
.await?;
}
return Ok(());
},
Err(e) => {
error!("Reconnection attempt {} failed: {}", retries + 1, e);
retries += 1;
if retries < self.reconnect_config.max_retries {
tokio::time::sleep(delay).await;
delay = std::cmp::min(
delay.mul_f64(self.reconnect_config.backoff_multiplier),
self.reconnect_config.max_delay,
);
}
},
}
}
Err(PolyfillError::stream(
format!(
"Failed to reconnect after {} attempts",
self.reconnect_config.max_retries
),
crate::errors::StreamErrorKind::ConnectionFailed,
))
}
}
pub struct WebSocketBookApplier<'a> {
stream: WebSocketStream,
books: &'a crate::book::OrderBookManager,
processor: WsBookUpdateProcessor,
}
impl WebSocketStream {
pub fn into_book_applier<'a>(
mut self,
books: &'a crate::book::OrderBookManager,
processor: WsBookUpdateProcessor,
) -> WebSocketBookApplier<'a> {
self.pending.clear();
WebSocketBookApplier {
stream: self,
books,
processor,
}
}
}
impl<'a> WebSocketBookApplier<'a> {
pub fn stream_mut(&mut self) -> &mut WebSocketStream {
&mut self.stream
}
pub fn stream_stats(&self) -> StreamStats {
self.stream.stats.clone()
}
pub fn processor_mut(&mut self) -> &mut WsBookUpdateProcessor {
&mut self.processor
}
pub fn apply_text_message(&mut self, text: String) -> Result<WsBookApplyStats> {
let stats = self.processor.process_text(text, self.books)?;
self.stream.stats.messages_received += 1;
self.stream.stats.last_message_time = Some(Utc::now());
Ok(stats)
}
fn apply_text_utf8(
&mut self,
text: tokio_tungstenite::tungstenite::Utf8Bytes,
) -> Result<WsBookApplyStats> {
let bytes: bytes::Bytes = text.into();
let vec: Vec<u8> = bytes.into();
let owned = unsafe { String::from_utf8_unchecked(vec) };
self.apply_text_message(owned)
}
}
impl<'a> Stream for WebSocketBookApplier<'a> {
type Item = Result<WsBookApplyStats>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let Some(connection) = &mut self.stream.connection else {
return Poll::Ready(None);
};
match connection.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(msg))) => match msg {
tokio_tungstenite::tungstenite::Message::Text(text) => {
match self.apply_text_utf8(text) {
Ok(stats) => {
if stats.book_messages == 0 {
continue;
}
return Poll::Ready(Some(Ok(stats)));
},
Err(e) => {
self.stream.stats.errors += 1;
return Poll::Ready(Some(Err(e)));
},
}
},
tokio_tungstenite::tungstenite::Message::Close(_) => {
info!("WebSocket connection closed by server");
self.stream.connection = None;
return Poll::Ready(None);
},
tokio_tungstenite::tungstenite::Message::Ping(_) => {
continue;
},
tokio_tungstenite::tungstenite::Message::Pong(_) => continue,
tokio_tungstenite::tungstenite::Message::Binary(_) => continue,
tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
},
Poll::Ready(Some(Err(e))) => {
error!("WebSocket error: {}", e);
self.stream.stats.errors += 1;
return Poll::Ready(Some(Err(e.into())));
},
Poll::Ready(None) => {
info!("WebSocket stream ended");
return Poll::Ready(None);
},
}
}
}
}
impl Stream for WebSocketStream {
type Item = Result<StreamMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(message) = self.pending.pop_front() {
return Poll::Ready(Some(Ok(message)));
}
let Some(connection) = &mut self.connection else {
return Poll::Ready(None);
};
match connection.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(ws_message))) => match ws_message {
tokio_tungstenite::tungstenite::Message::Text(text) => {
match crate::decode::parse_stream_messages(&text) {
Ok(messages) => {
let mut iter = messages.into_iter();
let Some(first) = iter.next() else {
continue;
};
for msg in iter {
self.enqueue(msg);
}
self.stats.messages_received += 1;
self.stats.last_message_time = Some(Utc::now());
return Poll::Ready(Some(Ok(first)));
},
Err(e) => {
self.stats.errors += 1;
return Poll::Ready(Some(Err(e)));
},
}
},
tokio_tungstenite::tungstenite::Message::Close(_) => {
info!("WebSocket connection closed by server");
self.connection = None;
return Poll::Ready(None);
},
tokio_tungstenite::tungstenite::Message::Ping(_) => {
continue;
},
tokio_tungstenite::tungstenite::Message::Pong(_) => continue,
tokio_tungstenite::tungstenite::Message::Binary(_) => continue,
tokio_tungstenite::tungstenite::Message::Frame(_) => continue,
},
Poll::Ready(Some(Err(e))) => {
error!("WebSocket error: {}", e);
self.stats.errors += 1;
return Poll::Ready(Some(Err(e.into())));
},
Poll::Ready(None) => {
info!("WebSocket stream ended");
return Poll::Ready(None);
},
}
}
}
}
impl MarketStream for WebSocketStream {
fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
Ok(())
}
fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
Ok(())
}
fn is_connected(&self) -> bool {
self.connection.is_some()
}
fn get_stats(&self) -> StreamStats {
self.stats.clone()
}
}
#[derive(Debug)]
pub struct MockStream {
messages: Vec<Result<StreamMessage>>,
index: usize,
connected: bool,
}
impl Default for MockStream {
fn default() -> Self {
Self::new()
}
}
impl MockStream {
pub fn new() -> Self {
Self {
messages: Vec::new(),
index: 0,
connected: true,
}
}
pub fn add_message(&mut self, message: StreamMessage) {
self.messages.push(Ok(message));
}
pub fn add_error(&mut self, error: PolyfillError) {
self.messages.push(Err(error));
}
pub fn set_connected(&mut self, connected: bool) {
self.connected = connected;
}
}
impl Stream for MockStream {
type Item = Result<StreamMessage>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.index >= self.messages.len() {
Poll::Ready(None)
} else {
let message = self.messages[self.index].clone();
self.index += 1;
Poll::Ready(Some(message))
}
}
}
impl MarketStream for MockStream {
fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
Ok(())
}
fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
fn get_stats(&self) -> StreamStats {
StreamStats {
messages_received: self.messages.len() as u64,
messages_sent: 0,
errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
dropped_messages: 0,
last_message_time: None,
connection_uptime: std::time::Duration::ZERO,
reconnect_count: 0,
}
}
}
#[allow(dead_code)]
pub struct StreamManager {
streams: Vec<Box<dyn MarketStream>>,
message_tx: mpsc::UnboundedSender<StreamMessage>,
message_rx: mpsc::UnboundedReceiver<StreamMessage>,
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}
impl StreamManager {
pub fn new() -> Self {
let (message_tx, message_rx) = mpsc::unbounded_channel();
Self {
streams: Vec::new(),
message_tx,
message_rx,
}
}
pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
self.streams.push(stream);
}
pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
let (_, rx) = mpsc::unbounded_channel();
rx
}
pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
self.message_tx
.send(message)
.map_err(|e| PolyfillError::internal("Failed to broadcast message", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal::Decimal;
use std::str::FromStr;
#[test]
fn test_mock_stream() {
let mut stream = MockStream::new();
stream.add_message(StreamMessage::Book(BookUpdate {
asset_id: "1".to_string(),
market: "0xabc".to_string(),
timestamp: 1_234_567_890,
bids: vec![],
asks: vec![],
hash: None,
}));
stream.add_message(StreamMessage::PriceChange(PriceChange {
market: "0xabc".to_string(),
timestamp: 1_234_567_891,
price_changes: vec![],
}));
assert!(stream.is_connected());
assert_eq!(stream.get_stats().messages_received, 2);
}
#[test]
fn test_stream_manager() {
let mut manager = StreamManager::new();
let mock_stream = Box::new(MockStream::new());
manager.add_stream(mock_stream);
let message = StreamMessage::Book(BookUpdate {
asset_id: "1".to_string(),
market: "0xabc".to_string(),
timestamp: 1_234_567_890,
bids: vec![],
asks: vec![],
hash: None,
});
assert!(manager.broadcast_message(message).is_ok());
}
#[test]
fn test_websocket_book_applier_apply_text_message_updates_book() {
let books = crate::book::OrderBookManager::new(64);
let _ = books.get_or_create_book("12345").unwrap();
let processor = WsBookUpdateProcessor::new(1024);
let stream = WebSocketStream::new("wss://example.com/ws");
let mut applier = stream.into_book_applier(&books, processor);
let msg = r#"{"event_type":"book","asset_id":"12345","timestamp":1,"bids":[{"price":"0.75","size":"10"}],"asks":[{"price":"0.76","size":"5"}]}"#.to_string();
let stats = applier.apply_text_message(msg).unwrap();
assert_eq!(stats.book_messages, 1);
assert_eq!(stats.book_levels_applied, 2);
let snapshot = books.get_book("12345").unwrap();
assert_eq!(snapshot.bids.len(), 1);
assert_eq!(snapshot.asks.len(), 1);
assert_eq!(snapshot.bids[0].price, Decimal::from_str("0.75").unwrap());
assert_eq!(snapshot.bids[0].size, Decimal::from_str("10").unwrap());
assert_eq!(snapshot.asks[0].price, Decimal::from_str("0.76").unwrap());
assert_eq!(snapshot.asks[0].size, Decimal::from_str("5").unwrap());
}
}