use crate::errors::{self, PolyError, Result};
use crate::types::*;
use chrono::Utc;
use futures::{SinkExt, Stream, StreamExt};
use serde_json::Value;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::sync::mpsc;
use tracing::{error, info, warn};
pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
fn is_connected(&self) -> bool;
fn get_stats(&self) -> StreamStats;
}
#[derive(Debug)]
pub struct WebSocketStream {
connection: Option<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
url: String,
auth: Option<WssAuth>,
subscriptions: Vec<WssSubscription>,
stats: StreamStats,
connected_since: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub messages_received: u64,
pub messages_sent: u64,
pub errors: u64,
pub last_message_time: Option<chrono::DateTime<Utc>>,
pub connection_uptime: std::time::Duration,
pub reconnect_count: u32,
}
impl WebSocketStream {
pub fn new(url: &str) -> Self {
Self {
connection: None,
url: url.to_string(),
auth: None,
subscriptions: Vec::new(),
stats: StreamStats {
messages_received: 0,
messages_sent: 0,
errors: 0,
last_message_time: None,
connection_uptime: std::time::Duration::ZERO,
reconnect_count: 0,
},
connected_since: None,
}
}
pub fn with_auth(mut self, auth: WssAuth) -> Self {
self.auth = Some(auth);
self
}
async fn connect(&mut self) -> Result<()> {
let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url)
.await
.map_err(|e| {
PolyError::stream(
format!("WebSocket connection failed: {}", e),
crate::errors::StreamErrorKind::ConnectionFailed,
)
})?;
self.connection = Some(ws_stream);
self.connected_since = Some(Instant::now());
self.stats.connection_uptime = std::time::Duration::ZERO;
info!("Connected to WebSocket stream at {}", self.url);
Ok(())
}
async fn send_message(&mut self, message: Value) -> Result<()> {
if let Some(connection) = &mut self.connection {
let text = serde_json::to_string(&message).map_err(|e| {
PolyError::parse(format!("Failed to serialize message: {}", e), None)
})?;
let ws_message = tokio_tungstenite::tungstenite::Message::Text(text.into());
connection.send(ws_message).await.map_err(|e| {
PolyError::stream(
format!("Failed to send message: {}", e),
crate::errors::StreamErrorKind::MessageCorrupted,
)
})?;
self.stats.messages_sent += 1;
}
Ok(())
}
pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
if self.connection.is_none() {
self.connect().await?;
}
let message = serde_json::json!({
"auth": subscription.auth,
"markets": subscription.markets,
"asset_ids": subscription.asset_ids,
"type": subscription.channel_type,
});
self.send_message(message).await?;
self.subscriptions.push(subscription.clone());
info!("Subscribed to {} channel", subscription.channel_type);
Ok(())
}
pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
let auth = self
.auth
.as_ref()
.ok_or_else(|| PolyError::auth("No authentication provided for WebSocket"))?
.clone();
let subscription = WssSubscription {
auth,
markets: Some(markets),
asset_ids: None,
channel_type: "USER".to_string(),
};
self.subscribe_async(subscription).await
}
pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
let auth = self
.auth
.as_ref()
.ok_or_else(|| PolyError::auth("No authentication provided for WebSocket"))?
.clone();
let subscription = WssSubscription {
auth,
markets: None,
asset_ids: Some(asset_ids),
channel_type: "MARKET".to_string(),
};
self.subscribe_async(subscription).await
}
pub async fn unsubscribe_async(&mut self, token_ids: &[String]) -> Result<()> {
self.subscriptions
.retain(|sub| match sub.channel_type.as_str() {
"USER" => {
if let Some(markets) = &sub.markets {
!token_ids.iter().any(|id| markets.contains(id))
} else {
true
}
}
"MARKET" => {
if let Some(asset_ids) = &sub.asset_ids {
!token_ids.iter().any(|id| asset_ids.contains(id))
} else {
true
}
}
_ => true,
});
info!("Unsubscribed from {} tokens", token_ids.len());
Ok(())
}
fn update_connection_uptime(&mut self) {
if let Some(started) = self.connected_since {
self.stats.connection_uptime = started.elapsed();
}
}
fn parse_polymarket_message(&self, text: &str) -> Result<StreamMessage> {
let value: Value = serde_json::from_str(text).map_err(|e| {
PolyError::parse(
format!("Failed to parse WebSocket message: {}", e),
Some(Box::new(e)),
)
})?;
let message_type = value
.get("type")
.and_then(|v| v.as_str())
.ok_or_else(|| PolyError::parse("Missing 'type' field in WebSocket message", None))?;
match message_type {
"book_update" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse book update: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::BookUpdate { data })
}
"trade" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse trade: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::Trade { data })
}
"order_update" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse order update: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::OrderUpdate { data })
}
"user_order_update" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse user order update: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::UserOrderUpdate { data })
}
"user_trade" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse user trade: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::UserTrade { data })
}
"market_book_update" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse market book update: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::MarketBookUpdate { data })
}
"market_trade" => {
let data =
serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
.map_err(|e| {
PolyError::parse(
format!("Failed to parse market trade: {}", e),
Some(Box::new(e)),
)
})?;
Ok(StreamMessage::MarketTrade { data })
}
"heartbeat" => {
let timestamp = value
.get("timestamp")
.and_then(|v| v.as_u64())
.map(|ts| chrono::DateTime::from_timestamp(ts as i64, 0).unwrap_or_default())
.unwrap_or_else(Utc::now);
Ok(StreamMessage::Heartbeat { timestamp })
}
_ => {
warn!("Unknown message type: {}", message_type);
Ok(StreamMessage::Heartbeat {
timestamp: Utc::now(),
})
}
}
}
}
impl Stream for WebSocketStream {
type Item = Result<StreamMessage>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let connection = match self.connection.as_mut() {
Some(connection) => connection,
None => return Poll::Ready(None),
};
match connection.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(message))) => match message {
tokio_tungstenite::tungstenite::Message::Text(text) => {
match self.parse_polymarket_message(&text) {
Ok(stream_msg) => {
self.stats.messages_received += 1;
self.stats.last_message_time = Some(Utc::now());
self.update_connection_uptime();
Poll::Ready(Some(Ok(stream_msg)))
}
Err(err) => {
self.stats.errors += 1;
Poll::Ready(Some(Err(err)))
}
}
}
tokio_tungstenite::tungstenite::Message::Close(_) => {
info!("WebSocket connection closed by server");
self.connection = None;
self.connected_since = None;
Poll::Ready(None)
}
tokio_tungstenite::tungstenite::Message::Ping(_) => Poll::Pending,
tokio_tungstenite::tungstenite::Message::Pong(_) => Poll::Pending,
tokio_tungstenite::tungstenite::Message::Binary(_) => {
warn!("Received binary message (not supported)");
Poll::Pending
}
tokio_tungstenite::tungstenite::Message::Frame(_) => {
warn!("Received raw frame (not supported)");
Poll::Pending
}
},
Poll::Ready(Some(Err(e))) => {
error!("WebSocket error: {}", e);
self.stats.errors += 1;
self.connected_since = None;
Poll::Ready(Some(Err(errors::PolyError::Stream {
message: e.to_string(),
kind: errors::StreamErrorKind::Unknown,
})))
}
Poll::Ready(None) => {
self.connected_since = None;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
impl MarketStream for WebSocketStream {
fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
Ok(())
}
fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
Ok(())
}
fn is_connected(&self) -> bool {
self.connection.is_some()
}
fn get_stats(&self) -> StreamStats {
let mut stats = self.stats.clone();
if let Some(started) = self.connected_since {
stats.connection_uptime = started.elapsed();
}
stats
}
}
#[derive(Debug)]
pub struct MockStream {
messages: Vec<Result<StreamMessage>>,
index: usize,
connected: bool,
}
impl MockStream {
pub fn new() -> Self {
Self {
messages: Vec::new(),
index: 0,
connected: true,
}
}
pub fn add_message(&mut self, message: StreamMessage) {
self.messages.push(Ok(message));
}
pub fn add_error(&mut self, error: PolyError) {
self.messages.push(Err(error));
}
pub fn set_connected(&mut self, connected: bool) {
self.connected = connected;
}
}
impl Default for MockStream {
fn default() -> Self {
Self::new()
}
}
impl Stream for MockStream {
type Item = Result<StreamMessage>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.index >= self.messages.len() {
Poll::Ready(None)
} else {
let message = self.messages[self.index].clone();
self.index += 1;
Poll::Ready(Some(message))
}
}
}
impl MarketStream for MockStream {
fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
Ok(())
}
fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
fn get_stats(&self) -> StreamStats {
StreamStats {
messages_received: self.messages.len() as u64,
messages_sent: 0,
errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
last_message_time: None,
connection_uptime: std::time::Duration::ZERO,
reconnect_count: 0,
}
}
}
pub struct StreamManager {
streams: Vec<Box<dyn MarketStream>>,
message_tx: mpsc::UnboundedSender<StreamMessage>,
}
impl StreamManager {
pub fn new() -> Self {
let (message_tx, _message_rx) = mpsc::unbounded_channel();
Self {
streams: Vec::new(),
message_tx,
}
}
pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
self.streams.push(stream);
}
pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
let (_, rx) = mpsc::unbounded_channel();
rx
}
pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
self.message_tx
.send(message)
.map_err(|e| PolyError::internal("Failed to broadcast message", e))
}
}
impl Default for StreamManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_stream() {
let mut stream = MockStream::new();
stream.add_message(StreamMessage::Heartbeat {
timestamp: Utc::now(),
});
stream.add_message(StreamMessage::BookUpdate {
data: OrderDelta {
token_id: "test".to_string(),
timestamp: Utc::now(),
side: Side::BUY,
price: rust_decimal_macros::dec!(0.5),
size: rust_decimal_macros::dec!(100),
sequence: 1,
},
});
assert!(stream.is_connected());
assert_eq!(stream.get_stats().messages_received, 2);
}
}