Skip to main content

polymarket_us/
stream.rs

1use crate::auth::UsAuth;
2use crate::error::PolymarketUsError;
3use futures_util::{SinkExt, StreamExt};
4use http::HeaderValue;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::future::Future;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Notify};
12use tokio_tungstenite::{
13    connect_async,
14    tungstenite::{client::IntoClientRequest, Message},
15};
16
17static TRACKING_COUNTER: AtomicU64 = AtomicU64::new(1);
18
19// ---------------------------------------------------------------------------
20// Subscription channel enum
21// ---------------------------------------------------------------------------
22
23/// All known WebSocket subscription channels.
24///
25/// Pass a variant to the typed constructors on [`StreamSubscription`], or use
26/// [`SubscriptionChannel::as_str`] to get the wire-format channel name.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29#[non_exhaustive]
30pub enum SubscriptionChannel {
31    /// Initial snapshot of all open orders (private).
32    OrderSnapshot,
33    /// Real-time order lifecycle changes (private).
34    OrderUpdate,
35    /// Full order-book depth (public).
36    MarketData,
37    /// Best-bid/offer only (public).
38    MarketDataLite,
39    /// Initial snapshot of portfolio positions (private).
40    PositionSnapshot,
41    /// Real-time position changes (private).
42    PositionUpdate,
43    /// Initial snapshot of account balances (private).
44    BalanceSnapshot,
45    /// Real-time balance changes (private).
46    BalanceUpdate,
47    /// Trade execution feed (public).
48    Trade,
49    /// Server heartbeat — useful as an aliveness check.
50    Heartbeat,
51}
52
53impl SubscriptionChannel {
54    /// Returns the snake_case wire-format channel name.
55    pub fn as_str(self) -> &'static str {
56        match self {
57            Self::OrderSnapshot => "order_snapshot",
58            Self::OrderUpdate => "order_update",
59            Self::MarketData => "market_data",
60            Self::MarketDataLite => "market_data_lite",
61            Self::PositionSnapshot => "position_snapshot",
62            Self::PositionUpdate => "position_update",
63            Self::BalanceSnapshot => "balance_snapshot",
64            Self::BalanceUpdate => "balance_update",
65            Self::Trade => "trade",
66            Self::Heartbeat => "heartbeat",
67        }
68    }
69}
70
71impl std::fmt::Display for SubscriptionChannel {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.write_str(self.as_str())
74    }
75}
76
77// ---------------------------------------------------------------------------
78// Internal command sent from ManagedStream → StreamRunner
79// ---------------------------------------------------------------------------
80
81enum StreamCommand {
82    Subscribe(StreamSubscription),
83    Unsubscribe(String), // tracking_id
84}
85
86#[derive(Clone)]
87pub struct PolymarketUsStreamClient {
88    base_url: String,
89    auth: Option<UsAuth>,
90}
91
92impl PolymarketUsStreamClient {
93    pub fn new(base_url: impl Into<String>, auth: Option<UsAuth>) -> Self {
94        Self {
95            base_url: normalize_stream_url(base_url.into()),
96            auth,
97        }
98    }
99
100    pub fn from_gateway_base_url(
101        gateway_base_url: impl Into<String>,
102        auth: Option<UsAuth>,
103    ) -> Self {
104        let gateway_base_url = gateway_base_url.into();
105        Self::new(derive_stream_url(&gateway_base_url), auth)
106    }
107
108    pub fn base_url(&self) -> &str {
109        &self.base_url
110    }
111
112    pub async fn connect(
113        &self,
114        subscriptions: Vec<StreamSubscription>,
115    ) -> Result<ManagedStream, PolymarketUsError> {
116        self.connect_with_config(subscriptions, StreamConnectConfig::default())
117            .await
118    }
119
120    pub async fn connect_with_config(
121        &self,
122        subscriptions: Vec<StreamSubscription>,
123        config: StreamConnectConfig,
124    ) -> Result<ManagedStream, PolymarketUsError> {
125        if subscriptions.is_empty() {
126            return Err(PolymarketUsError::InvalidStreamConfig(
127                "at least one subscription is required".to_string(),
128            ));
129        }
130
131        let (tx, rx) = mpsc::channel(256);
132        let (cmd_tx, cmd_rx) = mpsc::channel(64);
133        let shutdown = Arc::new(StreamShutdown::new());
134        let base_url = self.base_url.clone();
135        let auth = self.auth.clone();
136        let shutdown_task = shutdown.clone();
137
138        tokio::spawn(async move {
139            let runner = StreamRunner {
140                base_url,
141                auth,
142                subscriptions,
143                config,
144                tx,
145                shutdown: shutdown_task,
146                cmd_rx,
147            };
148            runner.run().await;
149        });
150
151        Ok(ManagedStream {
152            receiver: rx,
153            shutdown,
154            cmd_tx,
155        })
156    }
157
158    pub async fn run<F, Fut>(
159        &self,
160        subscriptions: Vec<StreamSubscription>,
161        config: StreamConnectConfig,
162        mut on_message: F,
163    ) -> Result<(), PolymarketUsError>
164    where
165        F: FnMut(StreamMessage) -> Fut,
166        Fut: Future<Output = ()>,
167    {
168        let mut stream = self.connect_with_config(subscriptions, config).await?;
169        while let Some(message) = stream.next().await {
170            on_message(message).await;
171        }
172        Ok(())
173    }
174}
175
176pub struct ManagedStream {
177    receiver: mpsc::Receiver<StreamMessage>,
178    shutdown: Arc<StreamShutdown>,
179    cmd_tx: mpsc::Sender<StreamCommand>,
180}
181
182impl ManagedStream {
183    pub async fn next(&mut self) -> Option<StreamMessage> {
184        self.receiver.recv().await
185    }
186
187    pub fn shutdown(&self) {
188        self.shutdown.shutdown();
189    }
190
191    pub fn is_shutdown(&self) -> bool {
192        self.shutdown.is_shutdown()
193    }
194
195    /// Dynamically add a subscription to the live connection.
196    ///
197    /// The subscription frame is sent immediately over the existing WebSocket
198    /// and re-sent automatically after every reconnect.
199    pub async fn subscribe(&self, sub: StreamSubscription) -> Result<(), PolymarketUsError> {
200        self.cmd_tx
201            .send(StreamCommand::Subscribe(sub))
202            .await
203            .map_err(|_| PolymarketUsError::InvalidStreamConfig("stream is closed".to_string()))
204    }
205
206    /// Remove a subscription by its `tracking_id`.
207    ///
208    /// The subscription is removed from the reconnect list immediately.
209    /// An unsubscribe frame is also sent to the server over the live connection.
210    pub async fn unsubscribe(&self, tracking_id: &str) -> Result<(), PolymarketUsError> {
211        self.cmd_tx
212            .send(StreamCommand::Unsubscribe(tracking_id.to_string()))
213            .await
214            .map_err(|_| PolymarketUsError::InvalidStreamConfig("stream is closed".to_string()))
215    }
216}
217
218#[derive(Debug, Clone)]
219pub struct StreamConnectConfig {
220    pub tracking_id: String,
221    pub responses_debounced: bool,
222    pub reconnect: ReconnectConfig,
223}
224
225impl Default for StreamConnectConfig {
226    fn default() -> Self {
227        Self {
228            tracking_id: next_tracking_id("session"),
229            responses_debounced: false,
230            reconnect: ReconnectConfig::default(),
231        }
232    }
233}
234
235impl StreamConnectConfig {
236    pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
237        self.tracking_id = tracking_id.into();
238        self
239    }
240
241    pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
242        self.responses_debounced = responses_debounced;
243        self
244    }
245
246    pub fn with_reconnect(mut self, reconnect: ReconnectConfig) -> Self {
247        self.reconnect = reconnect;
248        self
249    }
250}
251
252#[derive(Debug, Clone)]
253pub struct ReconnectConfig {
254    pub enabled: bool,
255    pub max_attempts: Option<usize>,
256    pub initial_delay: Duration,
257    pub max_delay: Duration,
258    pub multiplier: f64,
259}
260
261impl Default for ReconnectConfig {
262    fn default() -> Self {
263        Self {
264            enabled: true,
265            max_attempts: None,
266            initial_delay: Duration::from_millis(250),
267            max_delay: Duration::from_secs(10),
268            multiplier: 2.0,
269        }
270    }
271}
272
273impl ReconnectConfig {
274    pub fn disabled() -> Self {
275        Self {
276            enabled: false,
277            ..Self::default()
278        }
279    }
280
281    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
282        if attempt == 0 {
283            return self.initial_delay.min(self.max_delay);
284        }
285
286        let scaled = self
287            .initial_delay
288            .mul_f64(self.multiplier.powi(attempt.saturating_sub(1) as i32));
289        scaled.min(self.max_delay)
290    }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294#[serde(rename_all = "camelCase")]
295pub struct StreamSubscription {
296    pub channel: String,
297    pub tracking_id: String,
298    #[serde(default, skip_serializing_if = "Option::is_none")]
299    pub responses_debounced: Option<bool>,
300    #[serde(default, skip_serializing_if = "Option::is_none")]
301    pub symbol: Option<String>,
302    #[serde(default, skip_serializing_if = "Option::is_none")]
303    pub market_id: Option<String>,
304    #[serde(default, skip_serializing_if = "Option::is_none")]
305    pub outcome: Option<String>,
306    #[serde(default, flatten)]
307    pub extra: Map<String, Value>,
308}
309
310impl StreamSubscription {
311    pub fn new(channel: impl Into<String>) -> Self {
312        Self {
313            channel: channel.into(),
314            tracking_id: next_tracking_id("sub"),
315            responses_debounced: None,
316            symbol: None,
317            market_id: None,
318            outcome: None,
319            extra: Map::new(),
320        }
321    }
322
323    /// Create a subscription for the given typed channel.
324    pub fn for_channel(channel: SubscriptionChannel) -> Self {
325        Self::new(channel.as_str())
326    }
327
328    // --- Market (public) ---
329
330    /// Full order-book depth updates for a market symbol.
331    pub fn market_data(symbol: impl Into<String>) -> Self {
332        let mut s = Self::new(SubscriptionChannel::MarketData.as_str());
333        s.symbol = Some(symbol.into());
334        s
335    }
336
337    /// Best-bid/offer updates for a market symbol (lightweight).
338    pub fn market_data_lite(symbol: impl Into<String>) -> Self {
339        let mut s = Self::new(SubscriptionChannel::MarketDataLite.as_str());
340        s.symbol = Some(symbol.into());
341        s
342    }
343
344    /// Trade executions for a market symbol.
345    pub fn trades(symbol: impl Into<String>) -> Self {
346        let mut s = Self::new(SubscriptionChannel::Trade.as_str());
347        s.symbol = Some(symbol.into());
348        s
349    }
350
351    /// Server heartbeat channel — useful for keepalive monitoring.
352    pub fn heartbeat() -> Self {
353        Self::new(SubscriptionChannel::Heartbeat.as_str())
354    }
355
356    // --- Private (authenticated) ---
357
358    /// Initial snapshot of all open orders for a symbol.
359    pub fn order_snapshot(symbol: impl Into<String>) -> Self {
360        let mut s = Self::new(SubscriptionChannel::OrderSnapshot.as_str());
361        s.symbol = Some(symbol.into());
362        s
363    }
364
365    /// Real-time order lifecycle events.
366    pub fn order_update() -> Self {
367        Self::new(SubscriptionChannel::OrderUpdate.as_str())
368    }
369
370    /// Initial snapshot of all portfolio positions.
371    pub fn position_snapshot() -> Self {
372        Self::new(SubscriptionChannel::PositionSnapshot.as_str())
373    }
374
375    /// Real-time position changes.
376    pub fn position_update() -> Self {
377        Self::new(SubscriptionChannel::PositionUpdate.as_str())
378    }
379
380    /// Initial snapshot of account balances.
381    pub fn balance_snapshot() -> Self {
382        Self::new(SubscriptionChannel::BalanceSnapshot.as_str())
383    }
384
385    /// Real-time balance changes.
386    pub fn balance_update() -> Self {
387        Self::new(SubscriptionChannel::BalanceUpdate.as_str())
388    }
389
390    // --- Builder methods ---
391
392    pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
393        self.tracking_id = tracking_id.into();
394        self
395    }
396
397    pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
398        self.responses_debounced = Some(responses_debounced);
399        self
400    }
401
402    pub fn with_symbol(mut self, symbol: impl Into<String>) -> Self {
403        self.symbol = Some(symbol.into());
404        self
405    }
406
407    pub fn with_market_id(mut self, market_id: impl Into<String>) -> Self {
408        self.market_id = Some(market_id.into());
409        self
410    }
411
412    pub fn with_outcome(mut self, outcome: impl Into<String>) -> Self {
413        self.outcome = Some(outcome.into());
414        self
415    }
416
417    pub fn insert_extra(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
418        self.extra.insert(key.into(), value.into());
419        self
420    }
421}
422
423#[derive(Debug, Clone)]
424pub struct StreamMessage {
425    pub tracking_id: Option<String>,
426    pub kind: StreamMessageKind,
427}
428
429#[derive(Debug, Clone)]
430#[non_exhaustive]
431pub enum StreamMessageKind {
432    Data(StreamDataEvent),
433    Control(StreamControlEvent),
434}
435
436#[derive(Debug, Clone)]
437#[non_exhaustive]
438pub enum StreamDataEvent {
439    /// Initial snapshot of all open orders (private channel).
440    OrderSnapshot(Value),
441    /// Real-time order lifecycle update (private channel).
442    OrderUpdate(Value),
443    /// Full order-book depth update.
444    MarketData(Value),
445    /// Best-bid/offer update (lightweight).
446    MarketDataLite(Value),
447    /// Order-book delta / incremental update.
448    OrderBookDelta(Value),
449    /// Initial snapshot of all portfolio positions (private channel).
450    PositionSnapshot(Value),
451    /// Real-time position change (private channel).
452    PositionUpdate(Value),
453    /// Initial snapshot of account balances (private channel).
454    BalanceSnapshot(Value),
455    /// Real-time balance change (private channel).
456    BalanceUpdate(Value),
457    /// Trade execution event.
458    Trade(Value),
459    /// Server heartbeat — no payload.
460    Heartbeat,
461    /// Any server event not yet modelled by this SDK.
462    Other { event_type: String, payload: Value },
463}
464
465#[derive(Debug, Clone)]
466#[non_exhaustive]
467pub enum StreamControlEvent {
468    Connected { session_tracking_id: String },
469    SubscriptionAck { event_type: String, payload: Value },
470    Reconnecting { attempt: usize, delay_ms: u64 },
471    Closed,
472    Error(String),
473}
474
475impl StreamMessage {
476    pub fn control(tracking_id: Option<String>, event: StreamControlEvent) -> Self {
477        Self {
478            tracking_id,
479            kind: StreamMessageKind::Control(event),
480        }
481    }
482
483    pub fn data(tracking_id: Option<String>, event: StreamDataEvent) -> Self {
484        Self {
485            tracking_id,
486            kind: StreamMessageKind::Data(event),
487        }
488    }
489}
490
491struct StreamRunner {
492    base_url: String,
493    auth: Option<UsAuth>,
494    subscriptions: Vec<StreamSubscription>,
495    config: StreamConnectConfig,
496    tx: mpsc::Sender<StreamMessage>,
497    shutdown: Arc<StreamShutdown>,
498    cmd_rx: mpsc::Receiver<StreamCommand>,
499}
500
501impl StreamRunner {
502    async fn run(mut self) {
503        let mut attempt = 0usize;
504
505        loop {
506            if self.shutdown.is_shutdown() || self.tx.is_closed() {
507                break;
508            }
509
510            match self.connect_and_consume().await {
511                Ok(()) => {
512                    if !self.config.reconnect.enabled {
513                        break;
514                    }
515                }
516                Err(err) => {
517                    if !self
518                        .emit(StreamMessage::control(
519                            Some(self.config.tracking_id.clone()),
520                            StreamControlEvent::Error(err.to_string()),
521                        ))
522                        .await
523                    {
524                        break;
525                    }
526                }
527            }
528
529            if !self.config.reconnect.enabled {
530                break;
531            }
532
533            attempt += 1;
534            if let Some(max_attempts) = self.config.reconnect.max_attempts {
535                if attempt > max_attempts {
536                    break;
537                }
538            }
539
540            let delay = self.config.reconnect.delay_for_attempt(attempt);
541            if !self
542                .emit(StreamMessage::control(
543                    Some(self.config.tracking_id.clone()),
544                    StreamControlEvent::Reconnecting {
545                        attempt,
546                        delay_ms: delay.as_millis() as u64,
547                    },
548                ))
549                .await
550            {
551                break;
552            }
553
554            let shutdown = Arc::clone(&self.shutdown);
555            tokio::select! {
556                _ = shutdown.notified() => break,
557                _ = tokio::time::sleep(delay) => {}
558            }
559        }
560
561        let _ = self
562            .emit(StreamMessage::control(
563                Some(self.config.tracking_id.clone()),
564                StreamControlEvent::Closed,
565            ))
566            .await;
567    }
568
569    async fn connect_and_consume(&mut self) -> Result<(), PolymarketUsError> {
570        let mut request = self
571            .base_url
572            .as_str()
573            .into_client_request()
574            .map_err(|err| {
575                PolymarketUsError::InvalidStreamConfig(format!(
576                    "invalid websocket URL {}: {err}",
577                    self.base_url
578                ))
579            })?;
580
581        if let Some(auth) = &self.auth {
582            let path = request
583                .uri()
584                .path_and_query()
585                .map(|path| path.as_str())
586                .unwrap_or("/");
587            for (name, value) in auth.signed_headers("GET", path) {
588                let header_value = HeaderValue::from_str(&value).map_err(|err| {
589                    PolymarketUsError::InvalidStreamConfig(format!(
590                        "invalid websocket auth header value for {name}: {err}"
591                    ))
592                })?;
593                request.headers_mut().insert(name, header_value);
594            }
595        }
596
597        let (mut websocket, _) = connect_async(request).await?;
598        let _ = self
599            .emit(StreamMessage::control(
600                Some(self.config.tracking_id.clone()),
601                StreamControlEvent::Connected {
602                    session_tracking_id: self.config.tracking_id.clone(),
603                },
604            ))
605            .await;
606
607        self.send_all_subscriptions(&mut websocket).await?;
608
609        // Clone the Arc so the future borrows it, not &mut self, allowing
610        // cmd_rx to be used in the same select! block.
611        let shutdown = Arc::clone(&self.shutdown);
612        let shutdown_wait = shutdown.notified();
613        tokio::pin!(shutdown_wait);
614
615        loop {
616            tokio::select! {
617                _ = &mut shutdown_wait => {
618                    let _ = websocket.close(None).await;
619                    break;
620                }
621                message = websocket.next() => {
622                    let Some(message) = message else {
623                        break;
624                    };
625
626                    match message {
627                        Ok(Message::Text(text)) => {
628                            self.handle_text(&text).await?;
629                        }
630                        Ok(Message::Binary(bytes)) => {
631                            let text = String::from_utf8(bytes.to_vec()).map_err(|err| {
632                                PolymarketUsError::InvalidStreamConfig(format!(
633                                    "received non-UTF8 websocket payload: {err}"
634                                ))
635                            })?;
636                            self.handle_text(&text).await?;
637                        }
638                        Ok(Message::Close(_)) => break,
639                        Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
640                        Ok(_) => {}
641                        Err(err) => return Err(err.into()),
642                    }
643                }
644                cmd = self.cmd_rx.recv() => {
645                    match cmd {
646                        Some(StreamCommand::Subscribe(sub)) => {
647                            self.send_subscription(&mut websocket, &sub).await?;
648                            self.subscriptions.push(sub);
649                        }
650                        Some(StreamCommand::Unsubscribe(tracking_id)) => {
651                            self.subscriptions.retain(|s| s.tracking_id != tracking_id);
652                            // Best-effort unsubscribe frame; server may not support it.
653                            let frame = serde_json::json!({
654                                "type": "unsubscribe",
655                                "trackingId": tracking_id,
656                            });
657                            let _ = websocket
658                                .send(Message::Text(frame.to_string()))
659                                .await;
660                        }
661                        None => break,
662                    }
663                }
664            }
665        }
666
667        Ok(())
668    }
669
670    async fn send_all_subscriptions(
671        &self,
672        websocket: &mut tokio_tungstenite::WebSocketStream<
673            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
674        >,
675    ) -> Result<(), PolymarketUsError> {
676        for subscription in &self.subscriptions {
677            self.send_subscription(websocket, subscription).await?;
678        }
679        Ok(())
680    }
681
682    async fn send_subscription(
683        &self,
684        websocket: &mut tokio_tungstenite::WebSocketStream<
685            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
686        >,
687        subscription: &StreamSubscription,
688    ) -> Result<(), PolymarketUsError> {
689        let mut prepared = subscription.clone();
690        if prepared.responses_debounced.is_none() {
691            prepared.responses_debounced = Some(self.config.responses_debounced);
692        }
693        let payload = serde_json::to_string(&prepared)?;
694        websocket.send(Message::Text(payload)).await?;
695        Ok(())
696    }
697
698    async fn handle_text(&self, text: &str) -> Result<(), PolymarketUsError> {
699        let json: Value = serde_json::from_str(text)?;
700        if let Some(message) = parse_stream_message(json) {
701            if !self.emit(message).await {
702                return Ok(());
703            }
704        }
705        Ok(())
706    }
707
708    async fn emit(&self, message: StreamMessage) -> bool {
709        self.tx.send(message).await.is_ok()
710    }
711}
712
713struct StreamShutdown {
714    requested: AtomicBool,
715    notify: Notify,
716}
717
718impl StreamShutdown {
719    fn new() -> Self {
720        Self {
721            requested: AtomicBool::new(false),
722            notify: Notify::new(),
723        }
724    }
725
726    fn shutdown(&self) {
727        if !self.requested.swap(true, Ordering::SeqCst) {
728            self.notify.notify_waiters();
729        }
730    }
731
732    fn is_shutdown(&self) -> bool {
733        self.requested.load(Ordering::SeqCst)
734    }
735
736    fn notified(&self) -> impl Future<Output = ()> + '_ {
737        self.notify.notified()
738    }
739}
740
741fn parse_stream_message(json: Value) -> Option<StreamMessage> {
742    match json {
743        Value::Object(map) => {
744            let tracking_id = extract_tracking_id(&map);
745            let event_type = extract_event_type(&map);
746            let payload = extract_payload(&map);
747
748            let kind = match event_type.as_str() {
749                // --- Order channels ---
750                "order_snapshot" | "orderSnapshot" => {
751                    StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload))
752                }
753                "order_update" | "order_updates" | "orderUpdate" | "user_order" | "fill" => {
754                    StreamMessageKind::Data(StreamDataEvent::OrderUpdate(payload))
755                }
756                // --- Market channels ---
757                "market_data" | "marketData" => {
758                    StreamMessageKind::Data(StreamDataEvent::MarketData(payload))
759                }
760                "market_data_lite" | "marketDataLite" => {
761                    StreamMessageKind::Data(StreamDataEvent::MarketDataLite(payload))
762                }
763                "order_book_delta" | "orderbook_delta" | "book_delta" | "bookDelta" => {
764                    StreamMessageKind::Data(StreamDataEvent::OrderBookDelta(payload))
765                }
766                "trade" | "trades" => StreamMessageKind::Data(StreamDataEvent::Trade(payload)),
767                // --- Position channels ---
768                "position_snapshot" | "positionSnapshot" => {
769                    StreamMessageKind::Data(StreamDataEvent::PositionSnapshot(payload))
770                }
771                "position_update" | "positionUpdate" => {
772                    StreamMessageKind::Data(StreamDataEvent::PositionUpdate(payload))
773                }
774                // --- Balance channels ---
775                "balance_snapshot" | "balanceSnapshot" => {
776                    StreamMessageKind::Data(StreamDataEvent::BalanceSnapshot(payload))
777                }
778                "balance_update" | "balanceUpdate" => {
779                    StreamMessageKind::Data(StreamDataEvent::BalanceUpdate(payload))
780                }
781                // --- Keepalive ---
782                "heartbeat" | "ping" | "pong" => {
783                    StreamMessageKind::Data(StreamDataEvent::Heartbeat)
784                }
785                // --- Control ---
786                "subscription" | "subscribe" | "subscribed" | "ack" => {
787                    StreamMessageKind::Control(StreamControlEvent::SubscriptionAck {
788                        event_type: event_type.clone(),
789                        payload,
790                    })
791                }
792                "error" => {
793                    StreamMessageKind::Control(StreamControlEvent::Error(payload.to_string()))
794                }
795                _ => StreamMessageKind::Data(StreamDataEvent::Other {
796                    event_type: event_type.clone(),
797                    payload,
798                }),
799            };
800
801            Some(StreamMessage { tracking_id, kind })
802        }
803        other => Some(StreamMessage::data(
804            None,
805            StreamDataEvent::Other {
806                event_type: "unknown".to_string(),
807                payload: other,
808            },
809        )),
810    }
811}
812
813fn extract_tracking_id(map: &Map<String, Value>) -> Option<String> {
814    ["trackingId", "tracking_id", "trackingID", "id"]
815        .iter()
816        .find_map(|key| map.get(*key).and_then(Value::as_str).map(ToOwned::to_owned))
817}
818
819fn extract_event_type(map: &Map<String, Value>) -> String {
820    for key in ["event", "type", "channel", "name", "topic"] {
821        if let Some(value) = map.get(key).and_then(Value::as_str) {
822            return value.to_string();
823        }
824    }
825
826    if map.len() == 1 {
827        return map
828            .keys()
829            .next()
830            .cloned()
831            .unwrap_or_else(|| "unknown".to_string());
832    }
833
834    "unknown".to_string()
835}
836
837fn extract_payload(map: &Map<String, Value>) -> Value {
838    for key in ["data", "payload", "body", "message", "result"] {
839        if let Some(value) = map.get(key) {
840            return value.clone();
841        }
842    }
843
844    if map.len() == 1 {
845        return map.values().next().cloned().unwrap_or(Value::Null);
846    }
847
848    Value::Object(map.clone())
849}
850
851fn next_tracking_id(prefix: &str) -> String {
852    let ordinal = TRACKING_COUNTER.fetch_add(1, Ordering::Relaxed);
853    format!(
854        "{prefix}-{}-{ordinal}",
855        chrono::Utc::now().timestamp_millis()
856    )
857}
858
859fn normalize_stream_url(url: String) -> String {
860    let trimmed = url.trim_end_matches('/');
861    if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
862        trimmed.to_string()
863    } else if let Some(rest) = trimmed.strip_prefix("https://") {
864        format!("wss://{rest}/ws")
865    } else if let Some(rest) = trimmed.strip_prefix("http://") {
866        format!("ws://{rest}/ws")
867    } else {
868        format!("wss://{trimmed}/ws")
869    }
870}
871
872fn derive_stream_url(gateway_base_url: &str) -> String {
873    let trimmed = gateway_base_url.trim_end_matches('/');
874    if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
875        trimmed.to_string()
876    } else if let Some(rest) = trimmed.strip_prefix("https://") {
877        format!("wss://{rest}/ws")
878    } else if let Some(rest) = trimmed.strip_prefix("http://") {
879        format!("ws://{rest}/ws")
880    } else {
881        format!("wss://{trimmed}/ws")
882    }
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888    use serde_json::json;
889
890    #[test]
891    fn reconnect_delay_caps_at_max() {
892        let policy = ReconnectConfig {
893            enabled: true,
894            max_attempts: None,
895            initial_delay: Duration::from_millis(250),
896            max_delay: Duration::from_secs(1),
897            multiplier: 3.0,
898        };
899
900        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(250));
901        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(250));
902        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(750));
903        assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(1));
904        assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(1));
905    }
906
907    #[test]
908    fn subscription_serializes_debounced_flag_and_tracking_id() {
909        let subscription = StreamSubscription::order_snapshot("ABC")
910            .with_tracking_id("tracking-1")
911            .with_responses_debounced(true)
912            .insert_extra("bookLevel", json!(2));
913
914        let json = serde_json::to_value(subscription).unwrap();
915        assert_eq!(json["channel"], "order_snapshot");
916        assert_eq!(json["trackingId"], "tracking-1");
917        assert_eq!(json["responsesDebounced"], true);
918        assert_eq!(json["symbol"], "ABC");
919        assert_eq!(json["bookLevel"], 2);
920    }
921
922    #[test]
923    fn parses_order_snapshot_event() {
924        let message = parse_stream_message(json!({
925            "event": "order_snapshot",
926            "trackingId": "abc-123",
927            "data": { "bids": [1, 2], "asks": [3, 4] }
928        }))
929        .expect("message");
930
931        assert_eq!(message.tracking_id.as_deref(), Some("abc-123"));
932        match message.kind {
933            StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload)) => {
934                assert_eq!(payload["bids"][0], 1);
935                assert_eq!(payload["asks"][1], 4);
936            }
937            other => panic!("unexpected event: {other:?}"),
938        }
939    }
940
941    #[test]
942    fn parses_position_snapshot_event() {
943        let message = parse_stream_message(json!({
944            "event": "position_snapshot",
945            "data": { "positions": [] }
946        }))
947        .expect("message");
948        assert!(
949            matches!(
950                message.kind,
951                StreamMessageKind::Data(StreamDataEvent::PositionSnapshot(_))
952            ),
953            "expected PositionSnapshot"
954        );
955    }
956
957    #[test]
958    fn parses_balance_update_event() {
959        let message = parse_stream_message(json!({
960            "event": "balance_update",
961            "data": { "currency": "USD", "balance": "1000.00" }
962        }))
963        .expect("message");
964        assert!(
965            matches!(
966                message.kind,
967                StreamMessageKind::Data(StreamDataEvent::BalanceUpdate(_))
968            ),
969            "expected BalanceUpdate"
970        );
971    }
972
973    #[test]
974    fn parses_trade_event() {
975        let message = parse_stream_message(json!({
976            "event": "trade",
977            "data": { "price": "0.55", "size": "100" }
978        }))
979        .expect("message");
980        assert!(
981            matches!(
982                message.kind,
983                StreamMessageKind::Data(StreamDataEvent::Trade(_))
984            ),
985            "expected Trade"
986        );
987    }
988
989    #[test]
990    fn parses_heartbeat_event() {
991        let message = parse_stream_message(json!({ "event": "heartbeat" })).expect("message");
992        assert!(
993            matches!(
994                message.kind,
995                StreamMessageKind::Data(StreamDataEvent::Heartbeat)
996            ),
997            "expected Heartbeat"
998        );
999    }
1000
1001    #[test]
1002    fn parses_market_data_lite_event() {
1003        let message = parse_stream_message(json!({
1004            "event": "market_data_lite",
1005            "data": { "bid": "0.50", "ask": "0.55" }
1006        }))
1007        .expect("message");
1008        assert!(
1009            matches!(
1010                message.kind,
1011                StreamMessageKind::Data(StreamDataEvent::MarketDataLite(_))
1012            ),
1013            "expected MarketDataLite"
1014        );
1015    }
1016
1017    #[test]
1018    fn subscription_channel_as_str() {
1019        assert_eq!(
1020            SubscriptionChannel::OrderSnapshot.as_str(),
1021            "order_snapshot"
1022        );
1023        assert_eq!(
1024            SubscriptionChannel::MarketDataLite.as_str(),
1025            "market_data_lite"
1026        );
1027        assert_eq!(
1028            SubscriptionChannel::PositionUpdate.as_str(),
1029            "position_update"
1030        );
1031        assert_eq!(
1032            SubscriptionChannel::BalanceSnapshot.as_str(),
1033            "balance_snapshot"
1034        );
1035        assert_eq!(SubscriptionChannel::Trade.as_str(), "trade");
1036        assert_eq!(SubscriptionChannel::Heartbeat.as_str(), "heartbeat");
1037    }
1038
1039    #[test]
1040    fn subscription_constructors_set_channel() {
1041        assert_eq!(StreamSubscription::market_data("X").channel, "market_data");
1042        assert_eq!(
1043            StreamSubscription::market_data_lite("X").channel,
1044            "market_data_lite"
1045        );
1046        assert_eq!(StreamSubscription::trades("X").channel, "trade");
1047        assert_eq!(StreamSubscription::heartbeat().channel, "heartbeat");
1048        assert_eq!(StreamSubscription::order_update().channel, "order_update");
1049        assert_eq!(
1050            StreamSubscription::position_snapshot().channel,
1051            "position_snapshot"
1052        );
1053        assert_eq!(
1054            StreamSubscription::position_update().channel,
1055            "position_update"
1056        );
1057        assert_eq!(
1058            StreamSubscription::balance_snapshot().channel,
1059            "balance_snapshot"
1060        );
1061        assert_eq!(
1062            StreamSubscription::balance_update().channel,
1063            "balance_update"
1064        );
1065    }
1066
1067    #[test]
1068    fn for_channel_constructor() {
1069        let sub = StreamSubscription::for_channel(SubscriptionChannel::BalanceUpdate);
1070        assert_eq!(sub.channel, "balance_update");
1071    }
1072
1073    #[test]
1074    fn derives_stream_url_from_gateway_base_url() {
1075        assert_eq!(
1076            derive_stream_url("https://gateway.polymarket.us"),
1077            "wss://gateway.polymarket.us/ws"
1078        );
1079        assert_eq!(
1080            normalize_stream_url("wss://custom.example/ws".to_string()),
1081            "wss://custom.example/ws"
1082        );
1083    }
1084}