Skip to main content

polymarket_us/
stream.rs

1use crate::auth::UsAuth;
2use crate::error::PolymarketUsError;
3use futures_util::{SinkExt, StreamExt};
4use http::HeaderValue;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::future::Future;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Notify};
12use tokio_tungstenite::{
13    connect_async,
14    tungstenite::{client::IntoClientRequest, Message},
15};
16
17static TRACKING_COUNTER: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Clone)]
20pub struct PolymarketUsStreamClient {
21    base_url: String,
22    auth: Option<UsAuth>,
23}
24
25impl PolymarketUsStreamClient {
26    pub fn new(base_url: impl Into<String>, auth: Option<UsAuth>) -> Self {
27        Self {
28            base_url: normalize_stream_url(base_url.into()),
29            auth,
30        }
31    }
32
33    pub fn from_gateway_base_url(
34        gateway_base_url: impl Into<String>,
35        auth: Option<UsAuth>,
36    ) -> Self {
37        let gateway_base_url = gateway_base_url.into();
38        Self::new(derive_stream_url(&gateway_base_url), auth)
39    }
40
41    pub fn base_url(&self) -> &str {
42        &self.base_url
43    }
44
45    pub async fn connect(
46        &self,
47        subscriptions: Vec<StreamSubscription>,
48    ) -> Result<ManagedStream, PolymarketUsError> {
49        self.connect_with_config(subscriptions, StreamConnectConfig::default())
50            .await
51    }
52
53    pub async fn connect_with_config(
54        &self,
55        subscriptions: Vec<StreamSubscription>,
56        config: StreamConnectConfig,
57    ) -> Result<ManagedStream, PolymarketUsError> {
58        if subscriptions.is_empty() {
59            return Err(PolymarketUsError::InvalidStreamConfig(
60                "at least one subscription is required".to_string(),
61            ));
62        }
63
64        let (tx, rx) = mpsc::channel(256);
65        let shutdown = Arc::new(StreamShutdown::new());
66        let base_url = self.base_url.clone();
67        let auth = self.auth.clone();
68        let shutdown_task = shutdown.clone();
69
70        tokio::spawn(async move {
71            let runner = StreamRunner {
72                base_url,
73                auth,
74                subscriptions,
75                config,
76                tx,
77                shutdown: shutdown_task,
78            };
79            runner.run().await;
80        });
81
82        Ok(ManagedStream {
83            receiver: rx,
84            shutdown,
85        })
86    }
87
88    pub async fn run<F, Fut>(
89        &self,
90        subscriptions: Vec<StreamSubscription>,
91        config: StreamConnectConfig,
92        mut on_message: F,
93    ) -> Result<(), PolymarketUsError>
94    where
95        F: FnMut(StreamMessage) -> Fut,
96        Fut: Future<Output = ()>,
97    {
98        let mut stream = self.connect_with_config(subscriptions, config).await?;
99        while let Some(message) = stream.next().await {
100            on_message(message).await;
101        }
102        Ok(())
103    }
104}
105
106pub struct ManagedStream {
107    receiver: mpsc::Receiver<StreamMessage>,
108    shutdown: Arc<StreamShutdown>,
109}
110
111impl ManagedStream {
112    pub async fn next(&mut self) -> Option<StreamMessage> {
113        self.receiver.recv().await
114    }
115
116    pub fn shutdown(&self) {
117        self.shutdown.shutdown();
118    }
119
120    pub fn is_shutdown(&self) -> bool {
121        self.shutdown.is_shutdown()
122    }
123}
124
125#[derive(Debug, Clone)]
126pub struct StreamConnectConfig {
127    pub tracking_id: String,
128    pub responses_debounced: bool,
129    pub reconnect: ReconnectConfig,
130}
131
132impl Default for StreamConnectConfig {
133    fn default() -> Self {
134        Self {
135            tracking_id: next_tracking_id("session"),
136            responses_debounced: false,
137            reconnect: ReconnectConfig::default(),
138        }
139    }
140}
141
142impl StreamConnectConfig {
143    pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
144        self.tracking_id = tracking_id.into();
145        self
146    }
147
148    pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
149        self.responses_debounced = responses_debounced;
150        self
151    }
152
153    pub fn with_reconnect(mut self, reconnect: ReconnectConfig) -> Self {
154        self.reconnect = reconnect;
155        self
156    }
157}
158
159#[derive(Debug, Clone)]
160pub struct ReconnectConfig {
161    pub enabled: bool,
162    pub max_attempts: Option<usize>,
163    pub initial_delay: Duration,
164    pub max_delay: Duration,
165    pub multiplier: f64,
166}
167
168impl Default for ReconnectConfig {
169    fn default() -> Self {
170        Self {
171            enabled: true,
172            max_attempts: None,
173            initial_delay: Duration::from_millis(250),
174            max_delay: Duration::from_secs(10),
175            multiplier: 2.0,
176        }
177    }
178}
179
180impl ReconnectConfig {
181    pub fn disabled() -> Self {
182        Self {
183            enabled: false,
184            ..Self::default()
185        }
186    }
187
188    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
189        if attempt == 0 {
190            return self.initial_delay.min(self.max_delay);
191        }
192
193        let scaled = self
194            .initial_delay
195            .mul_f64(self.multiplier.powi(attempt.saturating_sub(1) as i32));
196        scaled.min(self.max_delay)
197    }
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(rename_all = "camelCase")]
202pub struct StreamSubscription {
203    pub channel: String,
204    pub tracking_id: String,
205    #[serde(default, skip_serializing_if = "Option::is_none")]
206    pub responses_debounced: Option<bool>,
207    #[serde(default, skip_serializing_if = "Option::is_none")]
208    pub symbol: Option<String>,
209    #[serde(default, skip_serializing_if = "Option::is_none")]
210    pub market_id: Option<String>,
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub outcome: Option<String>,
213    #[serde(default, flatten)]
214    pub extra: Map<String, Value>,
215}
216
217impl StreamSubscription {
218    pub fn new(channel: impl Into<String>) -> Self {
219        Self {
220            channel: channel.into(),
221            tracking_id: next_tracking_id("sub"),
222            responses_debounced: None,
223            symbol: None,
224            market_id: None,
225            outcome: None,
226            extra: Map::new(),
227        }
228    }
229
230    pub fn order_snapshot(symbol: impl Into<String>) -> Self {
231        let mut subscription = Self::new("order_snapshot");
232        subscription.symbol = Some(symbol.into());
233        subscription
234    }
235
236    pub fn market_data_lite(symbol: impl Into<String>) -> Self {
237        let mut subscription = Self::new("market_data_lite");
238        subscription.symbol = Some(symbol.into());
239        subscription
240    }
241
242    pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
243        self.tracking_id = tracking_id.into();
244        self
245    }
246
247    pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
248        self.responses_debounced = Some(responses_debounced);
249        self
250    }
251
252    pub fn with_symbol(mut self, symbol: impl Into<String>) -> Self {
253        self.symbol = Some(symbol.into());
254        self
255    }
256
257    pub fn with_market_id(mut self, market_id: impl Into<String>) -> Self {
258        self.market_id = Some(market_id.into());
259        self
260    }
261
262    pub fn with_outcome(mut self, outcome: impl Into<String>) -> Self {
263        self.outcome = Some(outcome.into());
264        self
265    }
266
267    pub fn insert_extra(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
268        self.extra.insert(key.into(), value.into());
269        self
270    }
271}
272
273#[derive(Debug, Clone)]
274pub struct StreamMessage {
275    pub tracking_id: Option<String>,
276    pub kind: StreamMessageKind,
277}
278
279#[derive(Debug, Clone)]
280pub enum StreamMessageKind {
281    Data(StreamDataEvent),
282    Control(StreamControlEvent),
283}
284
285#[derive(Debug, Clone)]
286pub enum StreamDataEvent {
287    OrderSnapshot(Value),
288    MarketDataLite(Value),
289    OrderBookDelta(Value),
290    OrderUpdate(Value),
291    Other { event_type: String, payload: Value },
292}
293
294#[derive(Debug, Clone)]
295pub enum StreamControlEvent {
296    Connected { session_tracking_id: String },
297    SubscriptionAck { event_type: String, payload: Value },
298    Reconnecting { attempt: usize, delay_ms: u64 },
299    Closed,
300    Error(String),
301}
302
303impl StreamMessage {
304    pub fn control(tracking_id: Option<String>, event: StreamControlEvent) -> Self {
305        Self {
306            tracking_id,
307            kind: StreamMessageKind::Control(event),
308        }
309    }
310
311    pub fn data(tracking_id: Option<String>, event: StreamDataEvent) -> Self {
312        Self {
313            tracking_id,
314            kind: StreamMessageKind::Data(event),
315        }
316    }
317}
318
319struct StreamRunner {
320    base_url: String,
321    auth: Option<UsAuth>,
322    subscriptions: Vec<StreamSubscription>,
323    config: StreamConnectConfig,
324    tx: mpsc::Sender<StreamMessage>,
325    shutdown: Arc<StreamShutdown>,
326}
327
328impl StreamRunner {
329    async fn run(self) {
330        let mut attempt = 0usize;
331
332        loop {
333            if self.shutdown.is_shutdown() || self.tx.is_closed() {
334                break;
335            }
336
337            match self.connect_and_consume().await {
338                Ok(()) => {
339                    if !self.config.reconnect.enabled {
340                        break;
341                    }
342                }
343                Err(err) => {
344                    if !self
345                        .emit(StreamMessage::control(
346                            Some(self.config.tracking_id.clone()),
347                            StreamControlEvent::Error(err.to_string()),
348                        ))
349                        .await
350                    {
351                        break;
352                    }
353                }
354            }
355
356            if !self.config.reconnect.enabled {
357                break;
358            }
359
360            attempt += 1;
361            if let Some(max_attempts) = self.config.reconnect.max_attempts {
362                if attempt > max_attempts {
363                    break;
364                }
365            }
366
367            let delay = self.config.reconnect.delay_for_attempt(attempt);
368            if !self
369                .emit(StreamMessage::control(
370                    Some(self.config.tracking_id.clone()),
371                    StreamControlEvent::Reconnecting {
372                        attempt,
373                        delay_ms: delay.as_millis() as u64,
374                    },
375                ))
376                .await
377            {
378                break;
379            }
380
381            tokio::select! {
382                _ = self.shutdown.notified() => break,
383                _ = tokio::time::sleep(delay) => {}
384            }
385        }
386
387        let _ = self
388            .emit(StreamMessage::control(
389                Some(self.config.tracking_id.clone()),
390                StreamControlEvent::Closed,
391            ))
392            .await;
393    }
394
395    async fn connect_and_consume(&self) -> Result<(), PolymarketUsError> {
396        let mut request = self
397            .base_url
398            .as_str()
399            .into_client_request()
400            .map_err(|err| {
401                PolymarketUsError::InvalidStreamConfig(format!(
402                    "invalid websocket URL {}: {err}",
403                    self.base_url
404                ))
405            })?;
406
407        if let Some(auth) = &self.auth {
408            let path = request
409                .uri()
410                .path_and_query()
411                .map(|path| path.as_str())
412                .unwrap_or("/");
413            for (name, value) in auth.signed_headers("GET", path) {
414                let header_value = HeaderValue::from_str(&value).map_err(|err| {
415                    PolymarketUsError::InvalidStreamConfig(format!(
416                        "invalid websocket auth header value for {name}: {err}"
417                    ))
418                })?;
419                request.headers_mut().insert(name, header_value);
420            }
421        }
422
423        let (mut websocket, _) = connect_async(request).await?;
424        let _ = self
425            .emit(StreamMessage::control(
426                Some(self.config.tracking_id.clone()),
427                StreamControlEvent::Connected {
428                    session_tracking_id: self.config.tracking_id.clone(),
429                },
430            ))
431            .await;
432
433        self.send_subscriptions(&mut websocket).await?;
434
435        let shutdown_wait = self.shutdown.notified();
436        tokio::pin!(shutdown_wait);
437
438        loop {
439            tokio::select! {
440                _ = &mut shutdown_wait => {
441                    let _ = websocket.close(None).await;
442                    break;
443                }
444                message = websocket.next() => {
445                    let Some(message) = message else {
446                        break;
447                    };
448
449                    match message {
450                        Ok(Message::Text(text)) => {
451                            self.handle_text(&text).await?;
452                        }
453                        Ok(Message::Binary(bytes)) => {
454                            let text = String::from_utf8(bytes.to_vec()).map_err(|err| {
455                                PolymarketUsError::InvalidStreamConfig(format!(
456                                    "received non-UTF8 websocket payload: {err}"
457                                ))
458                            })?;
459                            self.handle_text(&text).await?;
460                        }
461                        Ok(Message::Close(_)) => break,
462                        Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
463                        Ok(_) => {}
464                        Err(err) => return Err(err.into()),
465                    }
466                }
467            }
468        }
469
470        Ok(())
471    }
472
473    async fn send_subscriptions(
474        &self,
475        websocket: &mut tokio_tungstenite::WebSocketStream<
476            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
477        >,
478    ) -> Result<(), PolymarketUsError> {
479        for subscription in &self.subscriptions {
480            let mut prepared = subscription.clone();
481            if prepared.responses_debounced.is_none() {
482                prepared.responses_debounced = Some(self.config.responses_debounced);
483            }
484
485            let payload = serde_json::to_string(&prepared)?;
486            websocket.send(Message::Text(payload.into())).await?;
487        }
488
489        Ok(())
490    }
491
492    async fn handle_text(&self, text: &str) -> Result<(), PolymarketUsError> {
493        let json: Value = serde_json::from_str(text)?;
494        if let Some(message) = parse_stream_message(json) {
495            if !self.emit(message).await {
496                return Ok(());
497            }
498        }
499        Ok(())
500    }
501
502    async fn emit(&self, message: StreamMessage) -> bool {
503        self.tx.send(message).await.is_ok()
504    }
505}
506
507struct StreamShutdown {
508    requested: AtomicBool,
509    notify: Notify,
510}
511
512impl StreamShutdown {
513    fn new() -> Self {
514        Self {
515            requested: AtomicBool::new(false),
516            notify: Notify::new(),
517        }
518    }
519
520    fn shutdown(&self) {
521        if !self.requested.swap(true, Ordering::SeqCst) {
522            self.notify.notify_waiters();
523        }
524    }
525
526    fn is_shutdown(&self) -> bool {
527        self.requested.load(Ordering::SeqCst)
528    }
529
530    fn notified(&self) -> impl Future<Output = ()> + '_ {
531        self.notify.notified()
532    }
533}
534
535fn parse_stream_message(json: Value) -> Option<StreamMessage> {
536    match json {
537        Value::Object(map) => {
538            let tracking_id = extract_tracking_id(&map);
539            let event_type = extract_event_type(&map);
540            let payload = extract_payload(&map);
541
542            let kind = match event_type.as_str() {
543                "order_snapshot" => {
544                    StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload))
545                }
546                "market_data_lite" => {
547                    StreamMessageKind::Data(StreamDataEvent::MarketDataLite(payload))
548                }
549                "order_book_delta" | "orderbook_delta" | "book_delta" => {
550                    StreamMessageKind::Data(StreamDataEvent::OrderBookDelta(payload))
551                }
552                "order_update" | "order_updates" | "user_order" | "fill" => {
553                    StreamMessageKind::Data(StreamDataEvent::OrderUpdate(payload))
554                }
555                "subscription" | "subscribe" | "subscribed" | "ack" => {
556                    StreamMessageKind::Control(StreamControlEvent::SubscriptionAck {
557                        event_type: event_type.clone(),
558                        payload,
559                    })
560                }
561                "error" => {
562                    StreamMessageKind::Control(StreamControlEvent::Error(payload.to_string()))
563                }
564                _ => StreamMessageKind::Data(StreamDataEvent::Other {
565                    event_type: event_type.clone(),
566                    payload,
567                }),
568            };
569
570            Some(StreamMessage { tracking_id, kind })
571        }
572        other => Some(StreamMessage::data(
573            None,
574            StreamDataEvent::Other {
575                event_type: "unknown".to_string(),
576                payload: other,
577            },
578        )),
579    }
580}
581
582fn extract_tracking_id(map: &Map<String, Value>) -> Option<String> {
583    ["trackingId", "tracking_id", "trackingID", "id"]
584        .iter()
585        .find_map(|key| map.get(*key).and_then(Value::as_str).map(ToOwned::to_owned))
586}
587
588fn extract_event_type(map: &Map<String, Value>) -> String {
589    for key in ["event", "type", "channel", "name", "topic"] {
590        if let Some(value) = map.get(key).and_then(Value::as_str) {
591            return value.to_string();
592        }
593    }
594
595    if map.len() == 1 {
596        return map
597            .keys()
598            .next()
599            .cloned()
600            .unwrap_or_else(|| "unknown".to_string());
601    }
602
603    "unknown".to_string()
604}
605
606fn extract_payload(map: &Map<String, Value>) -> Value {
607    for key in ["data", "payload", "body", "message", "result"] {
608        if let Some(value) = map.get(key) {
609            return value.clone();
610        }
611    }
612
613    if map.len() == 1 {
614        return map.values().next().cloned().unwrap_or(Value::Null);
615    }
616
617    Value::Object(map.clone())
618}
619
620fn next_tracking_id(prefix: &str) -> String {
621    let ordinal = TRACKING_COUNTER.fetch_add(1, Ordering::Relaxed);
622    format!(
623        "{prefix}-{}-{ordinal}",
624        chrono::Utc::now().timestamp_millis()
625    )
626}
627
628fn normalize_stream_url(url: String) -> String {
629    let trimmed = url.trim_end_matches('/');
630    if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
631        trimmed.to_string()
632    } else if let Some(rest) = trimmed.strip_prefix("https://") {
633        format!("wss://{rest}/ws")
634    } else if let Some(rest) = trimmed.strip_prefix("http://") {
635        format!("ws://{rest}/ws")
636    } else {
637        format!("wss://{trimmed}/ws")
638    }
639}
640
641fn derive_stream_url(gateway_base_url: &str) -> String {
642    let trimmed = gateway_base_url.trim_end_matches('/');
643    if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
644        trimmed.to_string()
645    } else if let Some(rest) = trimmed.strip_prefix("https://") {
646        format!("wss://{rest}/ws")
647    } else if let Some(rest) = trimmed.strip_prefix("http://") {
648        format!("ws://{rest}/ws")
649    } else {
650        format!("wss://{trimmed}/ws")
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use serde_json::json;
658
659    #[test]
660    fn reconnect_delay_caps_at_max() {
661        let policy = ReconnectConfig {
662            enabled: true,
663            max_attempts: None,
664            initial_delay: Duration::from_millis(250),
665            max_delay: Duration::from_secs(1),
666            multiplier: 3.0,
667        };
668
669        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(250));
670        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(250));
671        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(750));
672        assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(1));
673        assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(1));
674    }
675
676    #[test]
677    fn subscription_serializes_debounced_flag_and_tracking_id() {
678        let subscription = StreamSubscription::order_snapshot("ABC")
679            .with_tracking_id("tracking-1")
680            .with_responses_debounced(true)
681            .insert_extra("bookLevel", json!(2));
682
683        let json = serde_json::to_value(subscription).unwrap();
684        assert_eq!(json["channel"], "order_snapshot");
685        assert_eq!(json["trackingId"], "tracking-1");
686        assert_eq!(json["responsesDebounced"], true);
687        assert_eq!(json["symbol"], "ABC");
688        assert_eq!(json["bookLevel"], 2);
689    }
690
691    #[test]
692    fn parses_order_snapshot_event() {
693        let message = parse_stream_message(json!({
694            "event": "order_snapshot",
695            "trackingId": "abc-123",
696            "data": { "bids": [1, 2], "asks": [3, 4] }
697        }))
698        .expect("message");
699
700        assert_eq!(message.tracking_id.as_deref(), Some("abc-123"));
701        match message.kind {
702            StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload)) => {
703                assert_eq!(payload["bids"][0], 1);
704                assert_eq!(payload["asks"][1], 4);
705            }
706            other => panic!("unexpected event: {other:?}"),
707        }
708    }
709
710    #[test]
711    fn derives_stream_url_from_gateway_base_url() {
712        assert_eq!(
713            derive_stream_url("https://gateway.polymarket.us"),
714            "wss://gateway.polymarket.us/ws"
715        );
716        assert_eq!(
717            normalize_stream_url("wss://custom.example/ws".to_string()),
718            "wss://custom.example/ws"
719        );
720    }
721}
722
723