polyte_clob/ws/
client.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4    time::Duration,
5};
6
7use futures_util::{SinkExt, Stream, StreamExt};
8use tokio::{net::TcpStream, time::interval};
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11use super::{
12    auth::ApiCredentials,
13    error::WebSocketError,
14    market::MarketMessage,
15    subscription::{ChannelType, MarketSubscription, UserSubscription, WS_MARKET_URL, WS_USER_URL},
16    user::UserMessage,
17    Channel,
18};
19
20/// WebSocket client for Polymarket real-time updates.
21///
22/// Provides streaming access to market data (order book, prices) and user-specific
23/// updates (orders, trades).
24///
25/// # Example
26///
27/// ```no_run
28/// use polyte_clob::ws::WebSocket;
29/// use futures_util::StreamExt;
30///
31/// #[tokio::main]
32/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
33///     let mut ws = WebSocket::connect_market(vec!["asset_id".to_string()]).await?;
34///
35///     while let Some(msg) = ws.next().await {
36///         println!("Received: {:?}", msg?);
37///     }
38///
39///     Ok(())
40/// }
41/// ```
42pub struct WebSocket {
43    inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
44    channel_type: ChannelType,
45}
46
47impl WebSocket {
48    /// Connect to the market channel for public order book and price updates.
49    ///
50    /// # Arguments
51    ///
52    /// * `asset_ids` - Token IDs to subscribe to
53    ///
54    /// # Example
55    ///
56    /// ```no_run
57    /// use polyte_clob::ws::WebSocket;
58    ///
59    /// #[tokio::main]
60    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
61    ///     let ws = WebSocket::connect_market(vec![
62    ///         "token_id_1".to_string(),
63    ///         "token_id_2".to_string(),
64    ///     ]).await?;
65    ///     Ok(())
66    /// }
67    /// ```
68    pub async fn connect_market(asset_ids: Vec<String>) -> Result<Self, WebSocketError> {
69        let (mut ws, _) = connect_async(WS_MARKET_URL).await?;
70
71        let subscription = MarketSubscription::new(asset_ids);
72        let msg = serde_json::to_string(&subscription)?;
73        ws.send(Message::Text(msg.into())).await?;
74
75        Ok(Self {
76            inner: ws,
77            channel_type: ChannelType::Market,
78        })
79    }
80
81    /// Connect to the user channel for authenticated order and trade updates.
82    ///
83    /// # Arguments
84    ///
85    /// * `market_ids` - Condition IDs to subscribe to
86    /// * `credentials` - API credentials for authentication
87    ///
88    /// # Example
89    ///
90    /// ```no_run
91    /// use polyte_clob::ws::{ApiCredentials, WebSocket};
92    ///
93    /// #[tokio::main]
94    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
95    ///     let credentials = ApiCredentials::from_env()?;
96    ///     let ws = WebSocket::connect_user(
97    ///         vec!["condition_id".to_string()],
98    ///         credentials,
99    ///     ).await?;
100    ///     Ok(())
101    /// }
102    /// ```
103    pub async fn connect_user(
104        market_ids: Vec<String>,
105        credentials: ApiCredentials,
106    ) -> Result<Self, WebSocketError> {
107        let (mut ws, _) = connect_async(WS_USER_URL).await?;
108
109        let subscription = UserSubscription::new(market_ids, credentials);
110        let msg = serde_json::to_string(&subscription)?;
111        ws.send(Message::Text(msg.into())).await?;
112
113        Ok(Self {
114            inner: ws,
115            channel_type: ChannelType::User,
116        })
117    }
118
119    /// Send a ping message to keep the connection alive.
120    ///
121    /// The Polymarket WebSocket expects "PING" text messages every ~10 seconds.
122    pub async fn ping(&mut self) -> Result<(), WebSocketError> {
123        self.inner.send(Message::Text("PING".into())).await?;
124        Ok(())
125    }
126
127    /// Close the WebSocket connection.
128    pub async fn close(&mut self) -> Result<(), WebSocketError> {
129        self.inner.close(None).await?;
130        Ok(())
131    }
132
133    /// Get the channel type this WebSocket is connected to.
134    pub fn channel_type(&self) -> ChannelType {
135        self.channel_type
136    }
137
138    /// Parse a text message based on the channel type.
139    fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
140        // Skip PONG responses and empty messages
141        if text == "PONG" || text == "{}" || text.is_empty() {
142            return Ok(None);
143        }
144
145        // Skip messages without event_type (heartbeats, acks, etc.)
146        if !text.contains("event_type") {
147            tracing::debug!("Skipping non-event message: {}", text);
148            return Ok(None);
149        }
150
151        match self.channel_type {
152            ChannelType::Market => {
153                let msg = MarketMessage::from_json(text)?;
154                Ok(Some(Channel::Market(msg)))
155            }
156            ChannelType::User => {
157                let msg = UserMessage::from_json(text)?;
158                Ok(Some(Channel::User(msg)))
159            }
160        }
161    }
162}
163
164impl Stream for WebSocket {
165    type Item = Result<Channel, WebSocketError>;
166
167    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
168        loop {
169            match Pin::new(&mut self.inner).poll_next(cx) {
170                Poll::Ready(Some(Ok(msg))) => match msg {
171                    Message::Text(text) => match self.parse_message(&text) {
172                        Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
173                        Ok(None) => continue, // Skip PONG, poll again
174                        Err(e) => return Poll::Ready(Some(Err(e))),
175                    },
176                    Message::Binary(data) => {
177                        // Try to parse as text
178                        if let Ok(text) = String::from_utf8(data.to_vec()) {
179                            match self.parse_message(&text) {
180                                Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
181                                Ok(None) => continue,
182                                Err(e) => return Poll::Ready(Some(Err(e))),
183                            }
184                        }
185                        continue;
186                    }
187                    Message::Ping(_) | Message::Pong(_) => continue,
188                    Message::Close(_) => return Poll::Ready(None),
189                    Message::Frame(_) => continue,
190                },
191                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
192                Poll::Ready(None) => return Poll::Ready(None),
193                Poll::Pending => return Poll::Pending,
194            }
195        }
196    }
197}
198
199/// Builder for WebSocket connections with additional configuration.
200pub struct WebSocketBuilder {
201    market_url: String,
202    user_url: String,
203    ping_interval: Option<Duration>,
204}
205
206impl Default for WebSocketBuilder {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl WebSocketBuilder {
213    /// Create a new WebSocket builder.
214    pub fn new() -> Self {
215        Self {
216            market_url: WS_MARKET_URL.to_string(),
217            user_url: WS_USER_URL.to_string(),
218            ping_interval: None,
219        }
220    }
221
222    /// Set a custom WebSocket URL for market channel.
223    pub fn market_url(mut self, url: impl Into<String>) -> Self {
224        self.market_url = url.into();
225        self
226    }
227
228    /// Set a custom WebSocket URL for user channel.
229    pub fn user_url(mut self, url: impl Into<String>) -> Self {
230        self.user_url = url.into();
231        self
232    }
233
234    /// Set the ping interval for keep-alive messages.
235    ///
236    /// If set, the returned `WebSocketWithPing` will automatically send
237    /// ping messages at this interval.
238    pub fn ping_interval(mut self, interval: Duration) -> Self {
239        self.ping_interval = Some(interval);
240        self
241    }
242
243    /// Connect to the market channel.
244    pub async fn connect_market(
245        self,
246        asset_ids: Vec<String>,
247    ) -> Result<WebSocketWithPing, WebSocketError> {
248        let (mut ws, _) = connect_async(&self.market_url).await?;
249
250        let subscription = MarketSubscription::new(asset_ids);
251        let msg = serde_json::to_string(&subscription)?;
252        ws.send(Message::Text(msg.into())).await?;
253
254        Ok(WebSocketWithPing {
255            inner: ws,
256            channel_type: ChannelType::Market,
257            ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
258        })
259    }
260
261    /// Connect to the user channel.
262    pub async fn connect_user(
263        self,
264        market_ids: Vec<String>,
265        credentials: ApiCredentials,
266    ) -> Result<WebSocketWithPing, WebSocketError> {
267        let (mut ws, _) = connect_async(&self.user_url).await?;
268
269        let subscription = UserSubscription::new(market_ids, credentials);
270        let msg = serde_json::to_string(&subscription)?;
271        ws.send(Message::Text(msg.into())).await?;
272
273        Ok(WebSocketWithPing {
274            inner: ws,
275            channel_type: ChannelType::User,
276            ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
277        })
278    }
279}
280
281/// WebSocket client with automatic ping handling.
282///
283/// Use this when you need automatic keep-alive pings. Call `run` to process
284/// messages with automatic ping handling.
285pub struct WebSocketWithPing {
286    inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
287    channel_type: ChannelType,
288    ping_interval: Duration,
289}
290
291impl WebSocketWithPing {
292    /// Run the WebSocket message loop with automatic ping handling.
293    ///
294    /// This method will:
295    /// - Send ping messages at the configured interval
296    /// - Call the provided handler for each received message
297    /// - Return when the connection is closed or an error occurs
298    ///
299    /// # Arguments
300    ///
301    /// * `handler` - Async function called for each received channel message
302    ///
303    /// # Example
304    ///
305    /// ```no_run
306    /// use polyte_clob::ws::{WebSocketBuilder, Channel};
307    /// use std::time::Duration;
308    ///
309    /// #[tokio::main]
310    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
311    ///     let ws = WebSocketBuilder::new()
312    ///         .ping_interval(Duration::from_secs(10))
313    ///         .connect_market(vec!["asset_id".to_string()])
314    ///         .await?;
315    ///
316    ///     ws.run(|msg| async move {
317    ///         println!("Received: {:?}", msg);
318    ///         Ok(())
319    ///     }).await?;
320    ///
321    ///     Ok(())
322    /// }
323    /// ```
324    pub async fn run<F, Fut>(mut self, mut handler: F) -> Result<(), WebSocketError>
325    where
326        F: FnMut(Channel) -> Fut,
327        Fut: std::future::Future<Output = Result<(), WebSocketError>>,
328    {
329        let mut ping_interval = interval(self.ping_interval);
330
331        loop {
332            tokio::select! {
333                _ = ping_interval.tick() => {
334                    self.inner.send(Message::Text("PING".into())).await?;
335                }
336                msg = self.inner.next() => {
337                    match msg {
338                        Some(Ok(Message::Text(text))) => {
339                            if text.as_str() == "PONG" {
340                                continue;
341                            }
342                            let channel = self.parse_message(&text)?;
343                            if let Some(channel) = channel {
344                                handler(channel).await?;
345                            }
346                        }
347                        Some(Ok(Message::Binary(data))) => {
348                            if let Ok(text) = String::from_utf8(data.to_vec()) {
349                                if text == "PONG" {
350                                    continue;
351                                }
352                                let channel = self.parse_message(&text)?;
353                                if let Some(channel) = channel {
354                                    handler(channel).await?;
355                                }
356                            }
357                        }
358                        Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) | Some(Ok(Message::Frame(_))) => continue,
359                        Some(Ok(Message::Close(_))) => return Ok(()),
360                        Some(Err(e)) => return Err(e.into()),
361                        None => return Ok(()),
362                    }
363                }
364            }
365        }
366    }
367
368    /// Get the channel type this WebSocket is connected to.
369    pub fn channel_type(&self) -> ChannelType {
370        self.channel_type
371    }
372
373    /// Parse a text message based on the channel type.
374    fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
375        // Skip PONG responses and empty messages
376        if text == "PONG" || text == "{}" || text.is_empty() {
377            return Ok(None);
378        }
379
380        // Skip messages without event_type (heartbeats, acks, etc.)
381        if !text.contains("event_type") {
382            tracing::debug!("Skipping non-event message: {}", text);
383            return Ok(None);
384        }
385
386        match self.channel_type {
387            ChannelType::Market => {
388                let msg = MarketMessage::from_json(text)?;
389                Ok(Some(Channel::Market(msg)))
390            }
391            ChannelType::User => {
392                let msg = UserMessage::from_json(text)?;
393                Ok(Some(Channel::User(msg)))
394            }
395        }
396    }
397}