Skip to main content

polyoxide_clob/ws/
client.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4    time::Duration,
5};
6
7use futures_util::{SinkExt, Stream, StreamExt};
8use tokio::{net::TcpStream, time::interval};
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11use super::{
12    auth::ApiCredentials,
13    error::WebSocketError,
14    market::MarketMessage,
15    subscription::{ChannelType, MarketSubscription, UserSubscription, WS_MARKET_URL, WS_USER_URL},
16    user::UserMessage,
17    Channel,
18};
19
20/// Maximum number of subscriptions per WebSocket connection.
21const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 500;
22
23/// WebSocket client for Polymarket real-time updates.
24///
25/// Provides streaming access to market data (order book, prices) and user-specific
26/// updates (orders, trades).
27///
28/// # Example
29///
30/// ```no_run
31/// use polyoxide_clob::ws::WebSocket;
32/// use futures_util::StreamExt;
33///
34/// #[tokio::main]
35/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
36///     let mut ws = WebSocket::connect_market(vec!["asset_id".to_string()]).await?;
37///
38///     while let Some(msg) = ws.next().await {
39///         println!("Received: {:?}", msg?);
40///     }
41///
42///     Ok(())
43/// }
44/// ```
45pub struct WebSocket {
46    inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
47    channel_type: ChannelType,
48}
49
50impl WebSocket {
51    /// Connect to the market channel for public order book and price updates.
52    ///
53    /// # Arguments
54    ///
55    /// * `asset_ids` - Token IDs to subscribe to
56    ///
57    /// # Example
58    ///
59    /// ```no_run
60    /// use polyoxide_clob::ws::WebSocket;
61    ///
62    /// #[tokio::main]
63    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
64    ///     let ws = WebSocket::connect_market(vec![
65    ///         "token_id_1".to_string(),
66    ///         "token_id_2".to_string(),
67    ///     ]).await?;
68    ///     Ok(())
69    /// }
70    /// ```
71    pub async fn connect_market(asset_ids: Vec<String>) -> Result<Self, WebSocketError> {
72        if asset_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
73            return Err(WebSocketError::InvalidMessage(format!(
74                "Too many subscriptions ({}), max {}",
75                asset_ids.len(),
76                MAX_SUBSCRIPTIONS_PER_CONNECTION
77            )));
78        }
79        let (mut ws, _) = connect_async(WS_MARKET_URL).await?;
80
81        let subscription = MarketSubscription::new(asset_ids);
82        let msg = serde_json::to_string(&subscription)?;
83        ws.send(Message::Text(msg.into())).await?;
84
85        Ok(Self {
86            inner: ws,
87            channel_type: ChannelType::Market,
88        })
89    }
90
91    /// Connect to the user channel for authenticated order and trade updates.
92    ///
93    /// # Arguments
94    ///
95    /// * `market_ids` - Condition IDs to subscribe to
96    /// * `credentials` - API credentials for authentication
97    ///
98    /// # Example
99    ///
100    /// ```no_run
101    /// use polyoxide_clob::ws::{ApiCredentials, WebSocket};
102    ///
103    /// #[tokio::main]
104    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
105    ///     let credentials = ApiCredentials::from_env()?;
106    ///     let ws = WebSocket::connect_user(
107    ///         vec!["condition_id".to_string()],
108    ///         credentials,
109    ///     ).await?;
110    ///     Ok(())
111    /// }
112    /// ```
113    pub async fn connect_user(
114        market_ids: Vec<String>,
115        credentials: ApiCredentials,
116    ) -> Result<Self, WebSocketError> {
117        if market_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
118            return Err(WebSocketError::InvalidMessage(format!(
119                "Too many subscriptions ({}), max {}",
120                market_ids.len(),
121                MAX_SUBSCRIPTIONS_PER_CONNECTION
122            )));
123        }
124        let (mut ws, _) = connect_async(WS_USER_URL).await?;
125
126        let subscription = UserSubscription::new(market_ids, credentials);
127        let msg = serde_json::to_string(&subscription)?;
128        ws.send(Message::Text(msg.into())).await?;
129
130        Ok(Self {
131            inner: ws,
132            channel_type: ChannelType::User,
133        })
134    }
135
136    /// Send a ping message to keep the connection alive.
137    ///
138    /// The Polymarket WebSocket expects "PING" text messages every ~10 seconds.
139    pub async fn ping(&mut self) -> Result<(), WebSocketError> {
140        self.inner.send(Message::Text("PING".into())).await?;
141        Ok(())
142    }
143
144    /// Close the WebSocket connection.
145    pub async fn close(&mut self) -> Result<(), WebSocketError> {
146        self.inner.close(None).await?;
147        Ok(())
148    }
149
150    /// Get the channel type this WebSocket is connected to.
151    pub fn channel_type(&self) -> ChannelType {
152        self.channel_type
153    }
154
155    /// Parse a text message based on the channel type.
156    fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
157        // Skip PONG responses and empty messages
158        if text == "PONG" || text == "{}" || text.is_empty() {
159            return Ok(None);
160        }
161
162        // Skip messages without event_type (heartbeats, acks, etc.)
163        if !text.contains("event_type") {
164            tracing::trace!("Skipping non-event message: {}", text);
165            return Ok(None);
166        }
167
168        match self.channel_type {
169            ChannelType::Market => {
170                let msg = MarketMessage::from_json(text)?;
171                Ok(Some(Channel::Market(msg)))
172            }
173            ChannelType::User => {
174                let msg = UserMessage::from_json(text)?;
175                Ok(Some(Channel::User(msg)))
176            }
177        }
178    }
179}
180
181impl Stream for WebSocket {
182    type Item = Result<Channel, WebSocketError>;
183
184    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185        loop {
186            match Pin::new(&mut self.inner).poll_next(cx) {
187                Poll::Ready(Some(Ok(msg))) => match msg {
188                    Message::Text(text) => match self.parse_message(&text) {
189                        Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
190                        Ok(None) => continue, // Skip PONG, poll again
191                        Err(e) => return Poll::Ready(Some(Err(e))),
192                    },
193                    Message::Binary(data) => {
194                        // Try to parse as text
195                        if let Ok(text) = String::from_utf8(data.to_vec()) {
196                            match self.parse_message(&text) {
197                                Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
198                                Ok(None) => continue,
199                                Err(e) => return Poll::Ready(Some(Err(e))),
200                            }
201                        }
202                        continue;
203                    }
204                    Message::Ping(_) | Message::Pong(_) => continue,
205                    Message::Close(_) => return Poll::Ready(None),
206                    Message::Frame(_) => continue,
207                },
208                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
209                Poll::Ready(None) => return Poll::Ready(None),
210                Poll::Pending => return Poll::Pending,
211            }
212        }
213    }
214}
215
216/// Builder for WebSocket connections with additional configuration.
217pub struct WebSocketBuilder {
218    market_url: String,
219    user_url: String,
220    ping_interval: Option<Duration>,
221}
222
223impl Default for WebSocketBuilder {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl WebSocketBuilder {
230    /// Create a new WebSocket builder.
231    pub fn new() -> Self {
232        Self {
233            market_url: WS_MARKET_URL.to_string(),
234            user_url: WS_USER_URL.to_string(),
235            ping_interval: None,
236        }
237    }
238
239    /// Set a custom WebSocket URL for market channel.
240    pub fn market_url(mut self, url: impl Into<String>) -> Self {
241        self.market_url = url.into();
242        self
243    }
244
245    /// Set a custom WebSocket URL for user channel.
246    pub fn user_url(mut self, url: impl Into<String>) -> Self {
247        self.user_url = url.into();
248        self
249    }
250
251    /// Set the ping interval for keep-alive messages.
252    ///
253    /// If set, the returned `WebSocketWithPing` will automatically send
254    /// ping messages at this interval.
255    pub fn ping_interval(mut self, interval: Duration) -> Self {
256        self.ping_interval = Some(interval);
257        self
258    }
259
260    /// Connect to the market channel.
261    pub async fn connect_market(
262        self,
263        asset_ids: Vec<String>,
264    ) -> Result<WebSocketWithPing, WebSocketError> {
265        if asset_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
266            return Err(WebSocketError::InvalidMessage(format!(
267                "Too many subscriptions ({}), max {}",
268                asset_ids.len(),
269                MAX_SUBSCRIPTIONS_PER_CONNECTION
270            )));
271        }
272        let (mut ws, _) = connect_async(&self.market_url).await?;
273
274        let subscription = MarketSubscription::new(asset_ids);
275        let msg = serde_json::to_string(&subscription)?;
276        ws.send(Message::Text(msg.into())).await?;
277
278        Ok(WebSocketWithPing {
279            inner: ws,
280            channel_type: ChannelType::Market,
281            ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
282        })
283    }
284
285    /// Connect to the user channel.
286    pub async fn connect_user(
287        self,
288        market_ids: Vec<String>,
289        credentials: ApiCredentials,
290    ) -> Result<WebSocketWithPing, WebSocketError> {
291        if market_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
292            return Err(WebSocketError::InvalidMessage(format!(
293                "Too many subscriptions ({}), max {}",
294                market_ids.len(),
295                MAX_SUBSCRIPTIONS_PER_CONNECTION
296            )));
297        }
298        let (mut ws, _) = connect_async(&self.user_url).await?;
299
300        let subscription = UserSubscription::new(market_ids, credentials);
301        let msg = serde_json::to_string(&subscription)?;
302        ws.send(Message::Text(msg.into())).await?;
303
304        Ok(WebSocketWithPing {
305            inner: ws,
306            channel_type: ChannelType::User,
307            ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
308        })
309    }
310}
311
312/// WebSocket client with automatic ping handling.
313///
314/// Use this when you need automatic keep-alive pings. Call `run` to process
315/// messages with automatic ping handling.
316pub struct WebSocketWithPing {
317    inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
318    channel_type: ChannelType,
319    ping_interval: Duration,
320}
321
322impl WebSocketWithPing {
323    /// Run the WebSocket message loop with automatic ping handling.
324    ///
325    /// This method will:
326    /// - Send ping messages at the configured interval
327    /// - Call the provided handler for each received message
328    /// - Return when the connection is closed or an error occurs
329    ///
330    /// # Arguments
331    ///
332    /// * `handler` - Async function called for each received channel message
333    ///
334    /// # Example
335    ///
336    /// ```no_run
337    /// use polyoxide_clob::ws::{WebSocketBuilder, Channel};
338    /// use std::time::Duration;
339    ///
340    /// #[tokio::main]
341    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
342    ///     let ws = WebSocketBuilder::new()
343    ///         .ping_interval(Duration::from_secs(10))
344    ///         .connect_market(vec!["asset_id".to_string()])
345    ///         .await?;
346    ///
347    ///     ws.run(|msg| async move {
348    ///         println!("Received: {:?}", msg);
349    ///         Ok(())
350    ///     }).await?;
351    ///
352    ///     Ok(())
353    /// }
354    /// ```
355    pub async fn run<F, Fut>(mut self, mut handler: F) -> Result<(), WebSocketError>
356    where
357        F: FnMut(Channel) -> Fut,
358        Fut: std::future::Future<Output = Result<(), WebSocketError>>,
359    {
360        let mut ping_interval = interval(self.ping_interval);
361
362        loop {
363            tokio::select! {
364                _ = ping_interval.tick() => {
365                    self.inner.send(Message::Text("PING".into())).await?;
366                }
367                msg = self.inner.next() => {
368                    match msg {
369                        Some(Ok(Message::Text(text))) => {
370                            if text.as_str() == "PONG" {
371                                continue;
372                            }
373                            let channel = self.parse_message(&text)?;
374                            if let Some(channel) = channel {
375                                handler(channel).await?;
376                            }
377                        }
378                        Some(Ok(Message::Binary(data))) => {
379                            if let Ok(text) = String::from_utf8(data.to_vec()) {
380                                if text == "PONG" {
381                                    continue;
382                                }
383                                let channel = self.parse_message(&text)?;
384                                if let Some(channel) = channel {
385                                    handler(channel).await?;
386                                }
387                            }
388                        }
389                        Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) | Some(Ok(Message::Frame(_))) => continue,
390                        Some(Ok(Message::Close(_))) => return Ok(()),
391                        Some(Err(e)) => return Err(e.into()),
392                        None => return Ok(()),
393                    }
394                }
395            }
396        }
397    }
398
399    /// Get the channel type this WebSocket is connected to.
400    pub fn channel_type(&self) -> ChannelType {
401        self.channel_type
402    }
403
404    /// Parse a text message based on the channel type.
405    fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
406        // Skip PONG responses and empty messages
407        if text == "PONG" || text == "{}" || text.is_empty() {
408            return Ok(None);
409        }
410
411        // Skip messages without event_type (heartbeats, acks, etc.)
412        if !text.contains("event_type") {
413            tracing::trace!("Skipping non-event message: {}", text);
414            return Ok(None);
415        }
416
417        match self.channel_type {
418            ChannelType::Market => {
419                let msg = MarketMessage::from_json(text)?;
420                Ok(Some(Channel::Market(msg)))
421            }
422            ChannelType::User => {
423                let msg = UserMessage::from_json(text)?;
424                Ok(Some(Channel::User(msg)))
425            }
426        }
427    }
428}