polymarket_api/websocket/
mod.rs

1//! Polymarket WebSocket client and types
2//!
3//! This module provides a WebSocket client for connecting to Polymarket's
4//! real-time market data stream, along with all the data types for messages
5//! and updates received over the WebSocket connection.
6
7pub mod messages;
8pub mod types;
9
10use {
11    anyhow::{Context, Result},
12    futures_util::{SinkExt, StreamExt},
13    serde::{Deserialize, Serialize},
14    std::collections::HashMap,
15    tokio_tungstenite::{connect_async, tungstenite::Message},
16};
17
18#[cfg(feature = "tracing")]
19use tracing::{error, warn};
20
21pub use {
22    messages::{Auth, SubscribedMessage, SubscriptionMessage, UpdateSubscriptionMessage},
23    types::{ErrorMessage, OrderUpdate, OrderbookUpdate, PriceLevel, PriceUpdate, TradeUpdate},
24};
25
26const WS_URL: &str = "wss://ws-subscriptions-clob.polymarket.com/ws/market";
27
28/// Main WebSocket message enum that can represent any message type received from the API
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "type")]
31pub enum WebSocketMessage {
32    #[serde(rename = "orderbook")]
33    Orderbook(OrderbookUpdate),
34    #[serde(rename = "trade")]
35    Trade(TradeUpdate),
36    #[serde(rename = "order")]
37    Order(OrderUpdate),
38    #[serde(rename = "price")]
39    Price(PriceUpdate),
40    #[serde(rename = "error")]
41    Error(ErrorMessage),
42    #[serde(rename = "subscribed")]
43    Subscribed(SubscribedMessage),
44    #[serde(other)]
45    Unknown,
46}
47
48/// WebSocket client for connecting to Polymarket's market data stream
49pub struct PolymarketWebSocket {
50    pub(crate) asset_ids: Vec<String>,
51    market_info_cache: HashMap<String, crate::gamma::MarketInfo>,
52}
53
54impl PolymarketWebSocket {
55    /// Create a new WebSocket client for the given asset IDs
56    pub fn new(asset_ids: Vec<String>) -> Self {
57        Self {
58            asset_ids,
59            market_info_cache: HashMap::new(),
60        }
61    }
62
63    /// Connect to the WebSocket and listen for updates
64    ///
65    /// The callback function will be called for each message received.
66    pub async fn connect_and_listen<F>(&mut self, mut on_update: F) -> Result<()>
67    where
68        F: FnMut(WebSocketMessage) + Send,
69    {
70        let (ws_stream, _) = connect_async(WS_URL)
71            .await
72            .context("Failed to connect to WebSocket")?;
73
74        let (mut write, mut read) = ws_stream.split();
75
76        // Subscribe to market channel
77        let subscribe_msg = SubscriptionMessage {
78            auth: None, // No auth needed for public market data
79            markets: None,
80            assets_ids: Some(self.asset_ids.clone()),
81            channel_type: "market".to_string(), // Use lowercase as per Polymarket docs
82            custom_feature_enabled: None,
83        };
84
85        let subscribe_json = serde_json::to_string(&subscribe_msg)?;
86        write
87            .send(Message::Text(subscribe_json))
88            .await
89            .context("Failed to send subscription message")?;
90
91        // Listen for messages
92        while let Some(msg) = read.next().await {
93            match msg {
94                Ok(Message::Text(text)) => {
95                    // Try to parse as WebSocketMessage first
96                    if let Ok(ws_msg) = serde_json::from_str::<WebSocketMessage>(&text) {
97                        on_update(ws_msg);
98                    } else if let Ok(subscribed) = serde_json::from_str::<SubscribedMessage>(&text)
99                    {
100                        on_update(WebSocketMessage::Subscribed(subscribed));
101                    } else if let Ok(err) = serde_json::from_str::<ErrorMessage>(&text) {
102                        on_update(WebSocketMessage::Error(err));
103                    } else {
104                        // Try to parse by checking for type field
105                        if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text)
106                            && let Some(msg_type) = json.get("type").and_then(|v| v.as_str())
107                        {
108                            match msg_type {
109                                "orderbook" => {
110                                    if let Ok(update) =
111                                        serde_json::from_value::<OrderbookUpdate>(json)
112                                    {
113                                        on_update(WebSocketMessage::Orderbook(update));
114                                    }
115                                },
116                                "trade" => {
117                                    if let Ok(update) = serde_json::from_value::<TradeUpdate>(json)
118                                    {
119                                        on_update(WebSocketMessage::Trade(update));
120                                    }
121                                },
122                                "order" => {
123                                    if let Ok(update) = serde_json::from_value::<OrderUpdate>(json)
124                                    {
125                                        on_update(WebSocketMessage::Order(update));
126                                    }
127                                },
128                                "price" => {
129                                    if let Ok(update) = serde_json::from_value::<PriceUpdate>(json)
130                                    {
131                                        on_update(WebSocketMessage::Price(update));
132                                    }
133                                },
134                                _ => {
135                                    // Unknown message type, log for debugging
136                                    #[cfg(feature = "tracing")]
137                                    warn!("Unknown message type: {}", text);
138                                    #[cfg(not(feature = "tracing"))]
139                                    eprintln!("Unknown message type: {}", text);
140                                },
141                            }
142                        }
143                    }
144                },
145                Ok(Message::Ping(data)) => {
146                    // Respond to ping with pong
147                    if let Err(e) = write.send(Message::Pong(data)).await {
148                        #[cfg(feature = "tracing")]
149                        error!("Failed to send pong: {}", e);
150                        #[cfg(not(feature = "tracing"))]
151                        eprintln!("Failed to send pong: {}", e);
152                        break;
153                    }
154                },
155                Ok(Message::Close(_)) => {
156                    break;
157                },
158                Err(e) => {
159                    #[cfg(feature = "tracing")]
160                    error!("WebSocket error: {}", e);
161                    #[cfg(not(feature = "tracing"))]
162                    eprintln!("WebSocket error: {}", e);
163                    break;
164                },
165                _ => {},
166            }
167        }
168
169        Ok(())
170    }
171
172    /// Update cached market info for an asset
173    pub fn update_market_info(&mut self, asset_id: String, info: crate::gamma::MarketInfo) {
174        self.market_info_cache.insert(asset_id, info);
175    }
176
177    /// Get cached market info for an asset
178    pub fn get_market_info(&self, asset_id: &str) -> Option<&crate::gamma::MarketInfo> {
179        self.market_info_cache.get(asset_id)
180    }
181}