use std::time::{Duration, Instant};
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::time::sleep;
use tokio_tungstenite::{
connect_async,
tungstenite::Message,
MaybeTlsStream, WebSocketStream,
};
use crate::core::{AccountType, OrderbookCapabilities};
use super::parser::{
WsBookSnapshot, WsLastTradePrice, WsPriceChange, WsTickSizeChange,
WsBestBidAsk,
};
const WS_MARKET_URL: &str = "wss://ws-subscriptions-clob.polymarket.com/ws/market";
const _WS_USER_URL: &str = "wss://ws-subscriptions-clob.polymarket.com/ws/user";
const PING_INTERVAL_SECS: u64 = 10;
const INITIAL_BACKOFF_SECS: u64 = 1;
const MAX_BACKOFF_SECS: u64 = 60;
#[derive(Debug)]
pub enum WsError {
Connection(String),
Send(String),
Receive(String),
Parse(String),
Disconnected,
}
impl std::fmt::Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connection(s) => write!(f, "Connection error: {}", s),
Self::Send(s) => write!(f, "Send error: {}", s),
Self::Receive(s) => write!(f, "Receive error: {}", s),
Self::Parse(s) => write!(f, "Parse error: {}", s),
Self::Disconnected => write!(f, "WebSocket disconnected"),
}
}
}
impl std::error::Error for WsError {}
#[derive(Debug, Clone)]
pub struct WsReconnectInfo {
pub reconnection_count: u64,
pub markets_resubscribed: usize,
}
#[derive(Debug, Clone)]
pub struct WsUnknownEvent {
pub raw: String,
}
#[derive(Debug, Clone)]
pub enum WsEvent {
Book(WsBookSnapshot),
PriceChange(WsPriceChange),
LastTradePrice(WsLastTradePrice),
TickSizeChange(WsTickSizeChange),
BestBidAsk(WsBestBidAsk),
Pong,
Reconnected(WsReconnectInfo),
Unknown(WsUnknownEvent),
}
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
pub struct ClobWebSocket {
token_ids: Vec<String>,
enable_features: bool,
ws: Option<WsStream>,
last_ping: Instant,
backoff_delay: Duration,
reconnection_count: u64,
}
impl ClobWebSocket {
pub fn new(token_ids: Vec<String>, enable_features: bool) -> Self {
Self {
token_ids,
enable_features,
ws: None,
last_ping: Instant::now(),
backoff_delay: Duration::from_secs(INITIAL_BACKOFF_SECS),
reconnection_count: 0,
}
}
pub async fn connect(&mut self) -> Result<(), WsError> {
let (ws_stream, _) = connect_async(WS_MARKET_URL)
.await
.map_err(|e| WsError::Connection(e.to_string()))?;
self.ws = Some(ws_stream);
self.send_subscription().await?;
self.backoff_delay = Duration::from_secs(INITIAL_BACKOFF_SECS);
self.last_ping = Instant::now();
Ok(())
}
pub async fn reconnect(&mut self) -> Result<WsReconnectInfo, WsError> {
sleep(self.backoff_delay).await;
let result = self.connect().await;
if let Err(e) = result {
self.backoff_delay = std::cmp::min(
self.backoff_delay * 2,
Duration::from_secs(MAX_BACKOFF_SECS),
);
return Err(e);
}
self.reconnection_count += 1;
let markets_resubscribed = self.token_ids.len();
tracing::info!(
reconnection_count = self.reconnection_count,
markets_resubscribed,
"Polymarket WebSocket reconnected"
);
Ok(WsReconnectInfo {
reconnection_count: self.reconnection_count,
markets_resubscribed,
})
}
async fn send_subscription(&mut self) -> Result<(), WsError> {
let mut sub_value = serde_json::json!({
"type": "market",
"assets_ids": self.token_ids
});
if self.enable_features {
if let Some(obj) = sub_value.as_object_mut() {
obj.insert(
"custom_feature_enabled".to_string(),
Value::Bool(true),
);
}
}
let msg = serde_json::to_string(&sub_value)
.map_err(|e| WsError::Send(e.to_string()))?;
self.send_text(msg).await
}
pub async fn subscribe(&mut self, token_ids: Vec<String>) -> Result<(), WsError> {
let sub_value = serde_json::json!({
"assets_ids": token_ids,
"operation": "subscribe"
});
let msg = serde_json::to_string(&sub_value)
.map_err(|e| WsError::Send(e.to_string()))?;
self.send_text(msg).await?;
self.token_ids.extend(token_ids);
Ok(())
}
pub async fn unsubscribe(&mut self, token_ids: &[String]) -> Result<(), WsError> {
let sub_value = serde_json::json!({
"assets_ids": token_ids,
"operation": "unsubscribe"
});
let msg = serde_json::to_string(&sub_value)
.map_err(|e| WsError::Send(e.to_string()))?;
self.send_text(msg).await?;
self.token_ids.retain(|id| !token_ids.contains(id));
Ok(())
}
async fn send_text(&mut self, text: String) -> Result<(), WsError> {
let ws = self.ws.as_mut().ok_or(WsError::Disconnected)?;
ws.send(Message::Text(text))
.await
.map_err(|e| WsError::Send(e.to_string()))
}
async fn send_ping(&mut self) -> Result<(), WsError> {
self.send_text("PING".to_string()).await?;
self.last_ping = Instant::now();
Ok(())
}
pub async fn recv(&mut self) -> Result<Option<WsEvent>, WsError> {
loop {
if self.last_ping.elapsed() >= Duration::from_secs(PING_INTERVAL_SECS) {
self.send_ping().await?;
}
let ws = self.ws.as_mut().ok_or(WsError::Disconnected)?;
match ws.next().await {
Some(Ok(Message::Text(text))) => {
if text == "PONG" {
return Ok(Some(WsEvent::Pong));
}
return parse_event(&text).map(Some);
}
Some(Ok(Message::Close(_))) => return Ok(None),
Some(Err(e)) => return Err(WsError::Receive(e.to_string())),
None => return Ok(None),
_ => continue, }
}
}
pub async fn start(
&mut self,
tx: tokio::sync::mpsc::Sender<WsEvent>,
) -> Result<(), WsError> {
loop {
if self.ws.is_none() {
self.connect().await?;
}
match self.recv().await {
Ok(Some(event)) => {
let _ = tx.send(event).await;
}
Ok(None) => {
self.ws = None;
match self.reconnect().await {
Ok(info) => {
let _ = tx.send(WsEvent::Reconnected(info)).await;
}
Err(e) => {
tracing::warn!("Polymarket WS reconnect failed: {}", e);
}
}
}
Err(WsError::Disconnected) => {
self.ws = None;
match self.reconnect().await {
Ok(info) => {
let _ = tx.send(WsEvent::Reconnected(info)).await;
}
Err(e) => {
tracing::warn!("Polymarket WS reconnect failed: {}", e);
}
}
}
Err(e) => return Err(e),
}
}
}
pub async fn close(&mut self) {
if let Some(mut ws) = self.ws.take() {
let _ = ws.close(None).await;
}
}
pub fn is_connected(&self) -> bool {
self.ws.is_some()
}
pub fn reconnection_count(&self) -> u64 {
self.reconnection_count
}
pub fn subscribed_tokens(&self) -> &[String] {
&self.token_ids
}
pub fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
OrderbookCapabilities {
ws_depths: &[],
ws_default_depth: None,
rest_max_depth: None,
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[],
default_speed_ms: None,
ws_channels: &[],
checksum: None,
has_sequence: false,
has_prev_sequence: false,
supports_aggregation: false,
aggregation_levels: &[],
}
}
}
pub fn parse_event(json: &str) -> Result<WsEvent, WsError> {
let raw: Value =
serde_json::from_str(json).map_err(|e| WsError::Parse(e.to_string()))?;
let value = if let Some(arr) = raw.as_array() {
match arr.first() {
Some(v) => v.clone(),
None => {
return Ok(WsEvent::Unknown(WsUnknownEvent { raw: json.to_string() }));
}
}
} else {
raw
};
let event_type = match value.get("event_type").and_then(|v| v.as_str()) {
Some(et) => et,
None => {
tracing::debug!("Polymarket WS: message without event_type: {}", json);
return Ok(WsEvent::Unknown(WsUnknownEvent { raw: json.to_string() }));
}
};
match event_type {
"book" => {
let mut snapshot: WsBookSnapshot = serde_json::from_value(value.clone())
.map_err(|e| WsError::Parse(format!("book parse: {}", e)))?;
for level in &mut snapshot.bids {
normalize_price_in_place(&mut level.price);
}
for level in &mut snapshot.asks {
normalize_price_in_place(&mut level.price);
}
Ok(WsEvent::Book(snapshot))
}
"price_change" => {
let mut change: WsPriceChange = serde_json::from_value(value)
.map_err(|e| WsError::Parse(format!("price_change parse: {}", e)))?;
if change.changes.is_empty() {
if let Some(price) = change.price.take() {
change.changes.push(super::parser::PolyPriceLevel {
price,
size: change.size.take().unwrap_or_default(),
});
}
}
for level in &mut change.changes {
normalize_price_in_place(&mut level.price);
}
Ok(WsEvent::PriceChange(change))
}
"last_trade_price" => {
let trade: WsLastTradePrice = serde_json::from_value(value)
.map_err(|e| WsError::Parse(format!("last_trade_price parse: {}", e)))?;
Ok(WsEvent::LastTradePrice(trade))
}
"tick_size_change" => {
let change: WsTickSizeChange = serde_json::from_value(value)
.map_err(|e| WsError::Parse(format!("tick_size_change parse: {}", e)))?;
Ok(WsEvent::TickSizeChange(change))
}
"best_bid_ask" => {
let bba: WsBestBidAsk = serde_json::from_value(value)
.map_err(|e| WsError::Parse(format!("best_bid_ask parse: {}", e)))?;
Ok(WsEvent::BestBidAsk(bba))
}
_ => {
tracing::debug!(
"Polymarket WS: unhandled event_type '{}': {}",
event_type,
json
);
Ok(WsEvent::Unknown(WsUnknownEvent { raw: json.to_string() }))
}
}
}
fn normalize_price_in_place(price: &mut String) {
if price.starts_with('.') {
*price = format!("0{}", price);
}
}
pub fn normalize_price(price: &str) -> String {
if price.starts_with('.') {
format!("0{}", price)
} else {
price.to_string()
}
}