polysqueeze/
wss.rs

1//! Lightweight WSS client for Polymarket market channel updates.
2//!
3//! This module focuses on the public market channel exposed at
4//! `wss://ws-subscriptions-clob.polymarket.com/ws/`. It maintains a single
5//! reconnecting connection, replays the most recent market/asset subscriptions,
6//! and exposes typed events for books, price changes, tick size changes, and
7//! last trade notifications.
8
9use crate::errors::{PolyError, Result};
10use crate::types::{ApiCredentials, OrderSummary, Side};
11use chrono::{DateTime, Utc};
12use futures::{SinkExt, StreamExt};
13use serde::Deserialize;
14use serde_json::{Value, json};
15use std::collections::VecDeque;
16use std::time::Duration;
17use tokio::net::TcpStream;
18use tokio::time::{sleep, timeout};
19use tokio_tungstenite::{
20    MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
21};
22use tracing::warn;
23
24const DEFAULT_WSS_BASE: &str = "wss://ws-subscriptions-clob.polymarket.com";
25const MARKET_CHANNEL_PATH: &str = "/ws/market";
26const USER_CHANNEL_PATH: &str = "/ws/user";
27const BASE_RECONNECT_DELAY: Duration = Duration::from_millis(250);
28const MAX_RECONNECT_DELAY: Duration = Duration::from_secs(10);
29const MAX_RECONNECT_ATTEMPTS: u32 = 8;
30const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(25);
31
32/// Represents a parsed market broadcast from the public market channel.
33#[derive(Debug, Clone)]
34pub enum WssMarketEvent {
35    Book(MarketBook),
36    PriceChange(PriceChangeMessage),
37    TickSizeChange(TickSizeChangeMessage),
38    LastTrade(LastTradeMessage),
39}
40
41/// Events emitted by the authenticated user channel.
42#[derive(Debug, Clone)]
43pub enum WssUserEvent {
44    Trade(WssUserTradeMessage),
45    Order(WssUserOrderMessage),
46}
47
48/// Trade notifications scoped to the authenticated user.
49#[derive(Debug, Clone, Deserialize)]
50pub struct WssUserTradeMessage {
51    #[serde(rename = "event_type")]
52    pub event_type: String,
53    pub asset_id: String,
54    pub id: String,
55    pub last_update: String,
56    #[serde(default)]
57    pub maker_orders: Vec<MakerOrder>,
58    pub market: String,
59    pub matchtime: String,
60    pub outcome: String,
61    pub owner: String,
62    #[serde(with = "rust_decimal::serde::str")]
63    pub price: rust_decimal::Decimal,
64    pub side: Side,
65    #[serde(with = "rust_decimal::serde::str")]
66    pub size: rust_decimal::Decimal,
67    pub status: String,
68    pub taker_order_id: String,
69    pub timestamp: String,
70    pub trade_owner: String,
71    #[serde(rename = "type")]
72    pub message_type: String,
73}
74
75/// Maker order details included in user trade events.
76#[derive(Debug, Clone, Deserialize)]
77pub struct MakerOrder {
78    pub asset_id: String,
79    #[serde(with = "rust_decimal::serde::str")]
80    pub matched_amount: rust_decimal::Decimal,
81    pub order_id: String,
82    pub outcome: String,
83    pub owner: String,
84    #[serde(with = "rust_decimal::serde::str")]
85    pub price: rust_decimal::Decimal,
86}
87
88/// Order notifications scoped to the authenticated user.
89#[derive(Debug, Clone, Deserialize)]
90pub struct WssUserOrderMessage {
91    #[serde(rename = "event_type")]
92    pub event_type: String,
93    #[serde(default)]
94    pub associate_trades: Option<Vec<String>>,
95    pub asset_id: String,
96    pub id: String,
97    pub market: String,
98    pub order_owner: String,
99    #[serde(with = "rust_decimal::serde::str")]
100    pub original_size: rust_decimal::Decimal,
101    pub outcome: String,
102    pub owner: String,
103    #[serde(with = "rust_decimal::serde::str")]
104    pub price: rust_decimal::Decimal,
105    pub side: Side,
106    #[serde(with = "rust_decimal::serde::str")]
107    pub size_matched: rust_decimal::Decimal,
108    pub timestamp: String,
109    #[serde(rename = "type")]
110    pub message_type: String,
111}
112
113/// Book summary message
114#[derive(Debug, Clone, Deserialize)]
115pub struct MarketBook {
116    #[serde(rename = "event_type")]
117    pub event_type: String,
118    pub asset_id: String,
119    pub market: String,
120    pub timestamp: String,
121    pub hash: String,
122    pub bids: Vec<OrderSummary>,
123    pub asks: Vec<OrderSummary>,
124}
125
126/// Payload for price change notifications.
127#[derive(Debug, Clone, Deserialize)]
128pub struct PriceChangeMessage {
129    #[serde(rename = "event_type")]
130    pub event_type: String,
131    pub market: String,
132    #[serde(rename = "price_changes")]
133    pub price_changes: Vec<PriceChangeEntry>,
134    pub timestamp: String,
135}
136
137/// Individual price change entry.
138#[derive(Debug, Clone, Deserialize)]
139pub struct PriceChangeEntry {
140    pub asset_id: String,
141    #[serde(with = "rust_decimal::serde::str")]
142    pub price: rust_decimal::Decimal,
143    #[serde(with = "rust_decimal::serde::str")]
144    pub size: rust_decimal::Decimal,
145    pub side: Side,
146    pub hash: String,
147    #[serde(with = "rust_decimal::serde::str")]
148    pub best_bid: rust_decimal::Decimal,
149    #[serde(with = "rust_decimal::serde::str")]
150    pub best_ask: rust_decimal::Decimal,
151}
152
153/// Tick size change events.
154#[derive(Debug, Clone, Deserialize)]
155pub struct TickSizeChangeMessage {
156    #[serde(rename = "event_type")]
157    pub event_type: String,
158    pub asset_id: String,
159    pub market: String,
160    #[serde(rename = "old_tick_size", with = "rust_decimal::serde::str")]
161    pub old_tick_size: rust_decimal::Decimal,
162    #[serde(rename = "new_tick_size", with = "rust_decimal::serde::str")]
163    pub new_tick_size: rust_decimal::Decimal,
164    pub side: String,
165    pub timestamp: String,
166}
167
168/// Trade events emitted when a trade settles.
169#[derive(Debug, Clone, Deserialize)]
170pub struct LastTradeMessage {
171    #[serde(rename = "event_type")]
172    pub event_type: String,
173    pub asset_id: String,
174    pub fee_rate_bps: String,
175    pub market: String,
176    #[serde(with = "rust_decimal::serde::str")]
177    pub price: rust_decimal::Decimal,
178    #[serde(with = "rust_decimal::serde::str")]
179    pub size: rust_decimal::Decimal,
180    pub side: Side,
181    pub timestamp: String,
182}
183
184/// Simple stats for monitoring connection health.
185#[derive(Debug, Clone)]
186pub struct WssStats {
187    pub messages_received: u64,
188    pub errors: u64,
189    pub reconnect_count: u32,
190    pub last_message_time: Option<DateTime<Utc>>,
191}
192
193impl Default for WssStats {
194    fn default() -> Self {
195        Self {
196            messages_received: 0,
197            errors: 0,
198            reconnect_count: 0,
199            last_message_time: None,
200        }
201    }
202}
203
204/// Reconnecting client for the market channel.
205pub struct WssMarketClient {
206    connect_url: String,
207    connection: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
208    subscribed_asset_ids: Vec<String>,
209    stats: WssStats,
210    disconnect_history: VecDeque<DateTime<Utc>>,
211    pending_events: VecDeque<WssMarketEvent>,
212}
213
214impl WssMarketClient {
215    /// Create a new instance using the default Polymarket WSS base.
216    pub fn new() -> Self {
217        Self::with_url(DEFAULT_WSS_BASE)
218    }
219
220    /// Create a new client against a custom endpoint (useful for tests).
221    pub fn with_url(url: &str) -> Self {
222        let trimmed = url.trim_end_matches('/');
223        let connect_url = format!("{}{}", trimmed, MARKET_CHANNEL_PATH);
224        Self {
225            connection: None,
226            subscribed_asset_ids: Vec::new(),
227            stats: WssStats::default(),
228            disconnect_history: VecDeque::with_capacity(5),
229            connect_url,
230            pending_events: VecDeque::new(),
231        }
232    }
233
234    /// Access connection stats for observability.
235    pub fn stats(&self) -> WssStats {
236        self.stats.clone()
237    }
238
239    fn format_subscription(&self) -> Value {
240        json!({
241            "type": "market",
242            "assets_ids": self.subscribed_asset_ids,
243        })
244    }
245
246    async fn send_subscription(&mut self) -> Result<()> {
247        if self.subscribed_asset_ids.is_empty() {
248            return Ok(());
249        }
250
251        let message = self.format_subscription();
252        self.send_raw_message(message).await
253    }
254
255    async fn send_raw_message(&mut self, message: Value) -> Result<()> {
256        if let Some(connection) = self.connection.as_mut() {
257            let text = serde_json::to_string(&message).map_err(|e| {
258                PolyError::parse(
259                    format!("Failed to serialize subscription message: {}", e),
260                    None,
261                )
262            })?;
263            connection
264                .send(Message::Text(text.into()))
265                .await
266                .map_err(|e| {
267                    PolyError::stream(
268                        format!("Failed to send message: {}", e),
269                        crate::errors::StreamErrorKind::MessageCorrupted,
270                    )
271                })?;
272            return Ok(());
273        }
274        Err(PolyError::stream(
275            "WebSocket connection not established",
276            crate::errors::StreamErrorKind::ConnectionFailed,
277        ))
278    }
279
280    async fn connect(&mut self) -> Result<()> {
281        let mut attempts = 0;
282        loop {
283            match connect_async(&self.connect_url).await {
284                Ok((socket, _)) => {
285                    self.connection = Some(socket);
286                    if attempts > 0 {
287                        self.stats.reconnect_count += 1;
288                    }
289                    return Ok(());
290                }
291                Err(err) => {
292                    attempts += 1;
293                    let delay = self.reconnect_delay(attempts);
294                    self.stats.errors += 1;
295                    if attempts >= MAX_RECONNECT_ATTEMPTS {
296                        return Err(PolyError::stream(
297                            format!("Failed to connect after {} attempts: {}", attempts, err),
298                            crate::errors::StreamErrorKind::ConnectionFailed,
299                        ));
300                    }
301                    sleep(delay).await;
302                }
303            }
304        }
305    }
306
307    fn reconnect_delay(&self, attempts: u32) -> Duration {
308        let millis = BASE_RECONNECT_DELAY.as_millis() as u128 * attempts as u128;
309        let desired =
310            Duration::from_millis(millis.min(MAX_RECONNECT_DELAY.as_millis() as u128) as u64);
311        desired
312    }
313
314    async fn ensure_connection(&mut self) -> Result<()> {
315        if self.connection.is_none() {
316            self.connect().await?;
317            self.send_subscription().await?;
318        }
319        Ok(())
320    }
321
322    /// Subscribe to the market channel for the provided token/market IDs.
323    pub async fn subscribe(&mut self, asset_ids: Vec<String>) -> Result<()> {
324        self.subscribed_asset_ids = asset_ids;
325        self.ensure_connection().await?;
326        self.send_subscription().await
327    }
328
329    /// Read the next market channel event, reconnecting transparently when
330    /// the socket drops.
331    pub async fn next_event(&mut self) -> Result<WssMarketEvent> {
332        loop {
333            if let Some(evt) = self.pending_events.pop_front() {
334                return Ok(evt);
335            }
336            self.ensure_connection().await?;
337
338            match self.connection.as_mut().unwrap().next().await {
339                Some(Ok(Message::Text(text))) => {
340                    let trimmed = text.trim();
341                    if trimmed.eq_ignore_ascii_case("ping") || trimmed.eq_ignore_ascii_case("pong")
342                    {
343                        continue;
344                    }
345                    let first_char = trimmed.chars().next();
346                    if first_char != Some('{') && first_char != Some('[') {
347                        warn!("ignoring unexpected text frame: {}", trimmed);
348                        continue;
349                    }
350                    let events = parse_market_events(&text)?;
351                    self.stats.messages_received += events.len() as u64;
352                    self.stats.last_message_time = Some(Utc::now());
353                    for evt in events {
354                        self.pending_events.push_back(evt);
355                    }
356                    if let Some(evt) = self.pending_events.pop_front() {
357                        return Ok(evt);
358                    }
359                    continue;
360                }
361                Some(Ok(Message::Ping(payload))) => {
362                    if let Some(connection) = self.connection.as_mut() {
363                        let _ = connection.send(Message::Pong(payload)).await;
364                    }
365                }
366                Some(Ok(Message::Pong(_))) => {}
367                Some(Ok(Message::Close(_))) => {
368                    self.disconnect_history.push_back(Utc::now());
369                    if self.disconnect_history.len() > 5 {
370                        self.disconnect_history.pop_front();
371                    }
372                    self.connection = None;
373                }
374                Some(Ok(_)) => {}
375                Some(Err(err)) => {
376                    warn!("WebSocket error: {}", err);
377                    self.connection = None;
378                    self.stats.errors += 1;
379                    continue;
380                }
381                None => {
382                    self.connection = None;
383                }
384            }
385        }
386    }
387}
388
389/// Reconnecting client for the authenticated user channel.
390pub struct WssUserClient {
391    connect_url: String,
392    connection: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
393    subscribed_markets: Vec<String>,
394    stats: WssStats,
395    disconnect_history: VecDeque<DateTime<Utc>>,
396    pending_events: VecDeque<WssUserEvent>,
397    auth: ApiCredentials,
398}
399
400impl WssUserClient {
401    /// Create a new instance using the default Polymarket WSS base.
402    pub fn new(auth: ApiCredentials) -> Self {
403        Self::with_url(DEFAULT_WSS_BASE, auth)
404    }
405
406    /// Create a new client against a custom endpoint (useful for tests).
407    pub fn with_url(url: &str, auth: ApiCredentials) -> Self {
408        let trimmed = url.trim_end_matches('/');
409        let connect_url = format!("{}{}", trimmed, USER_CHANNEL_PATH);
410        Self {
411            connection: None,
412            subscribed_markets: Vec::new(),
413            stats: WssStats::default(),
414            disconnect_history: VecDeque::with_capacity(5),
415            connect_url,
416            pending_events: VecDeque::new(),
417            auth,
418        }
419    }
420
421    /// Access connection stats for observability.
422    pub fn stats(&self) -> WssStats {
423        self.stats.clone()
424    }
425
426    fn format_subscription(&self) -> Option<Value> {
427        if self.subscribed_markets.is_empty() {
428            return None;
429        }
430
431        Some(json!({
432            "type": "user",
433            "auth": {
434                "apiKey": self.auth.api_key,
435                "secret": self.auth.secret,
436                "passphrase": self.auth.passphrase,
437            },
438            "markets": self.subscribed_markets,
439        }))
440    }
441
442    async fn send_subscription(&mut self) -> Result<()> {
443        if let Some(message) = self.format_subscription() {
444            self.send_raw_message(message).await
445        } else {
446            Ok(())
447        }
448    }
449
450    async fn send_raw_message(&mut self, message: Value) -> Result<()> {
451        if let Some(connection) = self.connection.as_mut() {
452            let text = serde_json::to_string(&message).map_err(|e| {
453                PolyError::parse(
454                    format!("Failed to serialize subscription message: {}", e),
455                    None,
456                )
457            })?;
458            connection
459                .send(Message::Text(text.into()))
460                .await
461                .map_err(|e| {
462                    PolyError::stream(
463                        format!("Failed to send message: {}", e),
464                        crate::errors::StreamErrorKind::MessageCorrupted,
465                    )
466                })?;
467            return Ok(());
468        }
469        Err(PolyError::stream(
470            "WebSocket connection not established",
471            crate::errors::StreamErrorKind::ConnectionFailed,
472        ))
473    }
474
475    async fn connect(&mut self) -> Result<()> {
476        let mut attempts = 0;
477        loop {
478            match connect_async(&self.connect_url).await {
479                Ok((socket, _)) => {
480                    self.connection = Some(socket);
481                    if attempts > 0 {
482                        self.stats.reconnect_count += 1;
483                    }
484                    return Ok(());
485                }
486                Err(err) => {
487                    attempts += 1;
488                    let delay = self.reconnect_delay(attempts);
489                    self.stats.errors += 1;
490                    if attempts >= MAX_RECONNECT_ATTEMPTS {
491                        return Err(PolyError::stream(
492                            format!("Failed to connect after {} attempts: {}", attempts, err),
493                            crate::errors::StreamErrorKind::ConnectionFailed,
494                        ));
495                    }
496                    sleep(delay).await;
497                }
498            }
499        }
500    }
501
502    fn reconnect_delay(&self, attempts: u32) -> Duration {
503        let millis = BASE_RECONNECT_DELAY.as_millis() as u128 * attempts as u128;
504        let desired =
505            Duration::from_millis(millis.min(MAX_RECONNECT_DELAY.as_millis() as u128) as u64);
506        desired
507    }
508
509    async fn ensure_connection(&mut self) -> Result<()> {
510        if self.connection.is_none() {
511            self.connect().await?;
512            self.send_subscription().await?;
513        }
514        Ok(())
515    }
516
517    /// Subscribe to the user channel for the provided market IDs.
518    pub async fn subscribe(&mut self, market_ids: Vec<String>) -> Result<()> {
519        self.subscribed_markets = market_ids;
520        self.ensure_connection().await?;
521        self.send_subscription().await
522    }
523
524    /// Read the next user channel event, reconnecting transparently when the
525    /// socket drops.
526    pub async fn next_event(&mut self) -> Result<WssUserEvent> {
527        loop {
528            if let Some(evt) = self.pending_events.pop_front() {
529                return Ok(evt);
530            }
531            self.ensure_connection().await?;
532
533            match timeout(KEEPALIVE_INTERVAL, self.connection.as_mut().unwrap().next()).await {
534                Ok(Some(Ok(Message::Text(text)))) => {
535                    let trimmed = text.trim();
536                    if trimmed.eq_ignore_ascii_case("ping") || trimmed.eq_ignore_ascii_case("pong")
537                    {
538                        continue;
539                    }
540                    let first_char = trimmed.chars().next();
541                    if first_char != Some('{') && first_char != Some('[') {
542                        warn!("ignoring unexpected text frame: {}", trimmed);
543                        continue;
544                    }
545                    let events = parse_user_events(&text)?;
546                    self.stats.messages_received += events.len() as u64;
547                    self.stats.last_message_time = Some(Utc::now());
548                    for evt in events {
549                        self.pending_events.push_back(evt);
550                    }
551                    if let Some(evt) = self.pending_events.pop_front() {
552                        return Ok(evt);
553                    }
554                    continue;
555                }
556                Ok(Some(Ok(Message::Ping(payload)))) => {
557                    if let Some(connection) = self.connection.as_mut() {
558                        let _ = connection.send(Message::Pong(payload)).await;
559                    }
560                }
561                Ok(Some(Ok(Message::Pong(_)))) => {}
562                Ok(Some(Ok(Message::Close(_)))) => {
563                    self.disconnect_history.push_back(Utc::now());
564                    if self.disconnect_history.len() > 5 {
565                        self.disconnect_history.pop_front();
566                    }
567                    self.connection = None;
568                }
569                Ok(Some(Ok(_))) => {}
570                Ok(Some(Err(err))) => {
571                    warn!("WebSocket error: {}", err);
572                    self.connection = None;
573                    self.stats.errors += 1;
574                    continue;
575                }
576                Ok(None) => {
577                    self.connection = None;
578                }
579                Err(_) => {
580                    if let Some(connection) = self.connection.as_mut() {
581                        let _ = connection.send(Message::Text("PING".into())).await;
582                    }
583                }
584            }
585        }
586    }
587}
588
589fn parse_market_events(text: &str) -> Result<Vec<WssMarketEvent>> {
590    let value: Value = serde_json::from_str(text)
591        .map_err(|err| PolyError::parse(format!("Invalid JSON: {}", err), Some(Box::new(err))))?;
592
593    if let Some(array) = value.as_array() {
594        array
595            .iter()
596            .map(parse_market_event_value)
597            .collect::<Result<Vec<_>>>()
598    } else {
599        Ok(vec![parse_market_event_value(&value)?])
600    }
601}
602
603fn parse_market_event_value(value: &Value) -> Result<WssMarketEvent> {
604    let event_type = value
605        .get("event_type")
606        .and_then(|v| v.as_str())
607        .or_else(|| value.get("type").and_then(|v| v.as_str()))
608        .ok_or_else(|| PolyError::parse("Missing event_type/type in market message", None))?;
609
610    match event_type {
611        "book" => {
612            let parsed: MarketBook = serde_json::from_value(value.clone()).map_err(|err| {
613                PolyError::parse(
614                    format!("Failed to parse book message: {}", err),
615                    Some(Box::new(err)),
616                )
617            })?;
618            Ok(WssMarketEvent::Book(parsed))
619        }
620        "price_change" => {
621            let parsed =
622                serde_json::from_value::<PriceChangeMessage>(value.clone()).map_err(|err| {
623                    PolyError::parse(
624                        format!("Failed to parse price_change: {}", err),
625                        Some(Box::new(err)),
626                    )
627                })?;
628            Ok(WssMarketEvent::PriceChange(parsed))
629        }
630        "tick_size_change" => {
631            let parsed =
632                serde_json::from_value::<TickSizeChangeMessage>(value.clone()).map_err(|err| {
633                    PolyError::parse(
634                        format!("Failed to parse tick_size_change: {}", err),
635                        Some(Box::new(err)),
636                    )
637                })?;
638            Ok(WssMarketEvent::TickSizeChange(parsed))
639        }
640        "last_trade_price" => {
641            let parsed =
642                serde_json::from_value::<LastTradeMessage>(value.clone()).map_err(|err| {
643                    PolyError::parse(
644                        format!("Failed to parse last_trade_price: {}", err),
645                        Some(Box::new(err)),
646                    )
647                })?;
648            Ok(WssMarketEvent::LastTrade(parsed))
649        }
650        other => Err(PolyError::parse(
651            format!("Unknown market event_type: {}", other),
652            None,
653        )),
654    }
655}
656
657fn parse_user_events(text: &str) -> Result<Vec<WssUserEvent>> {
658    let value: Value = serde_json::from_str(text)
659        .map_err(|err| PolyError::parse(format!("Invalid JSON: {}", err), Some(Box::new(err))))?;
660
661    if let Some(array) = value.as_array() {
662        array
663            .iter()
664            .map(parse_user_event_value)
665            .collect::<Result<Vec<_>>>()
666    } else {
667        Ok(vec![parse_user_event_value(&value)?])
668    }
669}
670
671fn parse_user_event_value(value: &Value) -> Result<WssUserEvent> {
672    let event_type = value
673        .get("event_type")
674        .and_then(|v| v.as_str())
675        .ok_or_else(|| PolyError::parse("Missing event_type in user message", None))?;
676
677    match event_type {
678        "trade" => {
679            let parsed =
680                serde_json::from_value::<WssUserTradeMessage>(value.clone()).map_err(|err| {
681                    PolyError::parse(
682                        format!("Failed to parse user trade message: {}", err),
683                        Some(Box::new(err)),
684                    )
685                })?;
686            Ok(WssUserEvent::Trade(parsed))
687        }
688        "order" => {
689            let parsed =
690                serde_json::from_value::<WssUserOrderMessage>(value.clone()).map_err(|err| {
691                    PolyError::parse(
692                        format!("Failed to parse user order message: {}", err),
693                        Some(Box::new(err)),
694                    )
695                })?;
696            Ok(WssUserEvent::Order(parsed))
697        }
698        other => Err(PolyError::parse(
699            format!("Unknown user event_type: {}", other),
700            None,
701        )),
702    }
703}