1use crate::auth::UsAuth;
2use crate::error::PolymarketUsError;
3use futures_util::{SinkExt, StreamExt};
4use http::HeaderValue;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::future::Future;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Notify};
12use tokio_tungstenite::{
13 connect_async,
14 tungstenite::{client::IntoClientRequest, Message},
15};
16
17static TRACKING_COUNTER: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29#[non_exhaustive]
30pub enum SubscriptionChannel {
31 OrderSnapshot,
33 OrderUpdate,
35 MarketData,
37 MarketDataLite,
39 PositionSnapshot,
41 PositionUpdate,
43 BalanceSnapshot,
45 BalanceUpdate,
47 Trade,
49 Heartbeat,
51}
52
53impl SubscriptionChannel {
54 pub fn as_str(self) -> &'static str {
56 match self {
57 Self::OrderSnapshot => "order_snapshot",
58 Self::OrderUpdate => "order_update",
59 Self::MarketData => "market_data",
60 Self::MarketDataLite => "market_data_lite",
61 Self::PositionSnapshot => "position_snapshot",
62 Self::PositionUpdate => "position_update",
63 Self::BalanceSnapshot => "balance_snapshot",
64 Self::BalanceUpdate => "balance_update",
65 Self::Trade => "trade",
66 Self::Heartbeat => "heartbeat",
67 }
68 }
69}
70
71impl std::fmt::Display for SubscriptionChannel {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.write_str(self.as_str())
74 }
75}
76
77enum StreamCommand {
82 Subscribe(StreamSubscription),
83 Unsubscribe(String), }
85
86#[derive(Clone)]
87pub struct PolymarketUsStreamClient {
88 base_url: String,
89 auth: Option<UsAuth>,
90}
91
92impl PolymarketUsStreamClient {
93 pub fn new(base_url: impl Into<String>, auth: Option<UsAuth>) -> Self {
94 Self {
95 base_url: normalize_stream_url(base_url.into()),
96 auth,
97 }
98 }
99
100 pub fn from_gateway_base_url(
101 gateway_base_url: impl Into<String>,
102 auth: Option<UsAuth>,
103 ) -> Self {
104 let gateway_base_url = gateway_base_url.into();
105 Self::new(derive_stream_url(&gateway_base_url), auth)
106 }
107
108 pub fn base_url(&self) -> &str {
109 &self.base_url
110 }
111
112 pub async fn connect(
113 &self,
114 subscriptions: Vec<StreamSubscription>,
115 ) -> Result<ManagedStream, PolymarketUsError> {
116 self.connect_with_config(subscriptions, StreamConnectConfig::default())
117 .await
118 }
119
120 pub async fn connect_with_config(
121 &self,
122 subscriptions: Vec<StreamSubscription>,
123 config: StreamConnectConfig,
124 ) -> Result<ManagedStream, PolymarketUsError> {
125 if subscriptions.is_empty() {
126 return Err(PolymarketUsError::InvalidStreamConfig(
127 "at least one subscription is required".to_string(),
128 ));
129 }
130
131 let (tx, rx) = mpsc::channel(256);
132 let (cmd_tx, cmd_rx) = mpsc::channel(64);
133 let shutdown = Arc::new(StreamShutdown::new());
134 let base_url = self.base_url.clone();
135 let auth = self.auth.clone();
136 let shutdown_task = shutdown.clone();
137
138 tokio::spawn(async move {
139 let runner = StreamRunner {
140 base_url,
141 auth,
142 subscriptions,
143 config,
144 tx,
145 shutdown: shutdown_task,
146 cmd_rx,
147 };
148 runner.run().await;
149 });
150
151 Ok(ManagedStream {
152 receiver: rx,
153 shutdown,
154 cmd_tx,
155 })
156 }
157
158 pub async fn run<F, Fut>(
159 &self,
160 subscriptions: Vec<StreamSubscription>,
161 config: StreamConnectConfig,
162 mut on_message: F,
163 ) -> Result<(), PolymarketUsError>
164 where
165 F: FnMut(StreamMessage) -> Fut,
166 Fut: Future<Output = ()>,
167 {
168 let mut stream = self.connect_with_config(subscriptions, config).await?;
169 while let Some(message) = stream.next().await {
170 on_message(message).await;
171 }
172 Ok(())
173 }
174}
175
176pub struct ManagedStream {
177 receiver: mpsc::Receiver<StreamMessage>,
178 shutdown: Arc<StreamShutdown>,
179 cmd_tx: mpsc::Sender<StreamCommand>,
180}
181
182impl ManagedStream {
183 pub async fn next(&mut self) -> Option<StreamMessage> {
184 self.receiver.recv().await
185 }
186
187 pub fn shutdown(&self) {
188 self.shutdown.shutdown();
189 }
190
191 pub fn is_shutdown(&self) -> bool {
192 self.shutdown.is_shutdown()
193 }
194
195 pub async fn subscribe(&self, sub: StreamSubscription) -> Result<(), PolymarketUsError> {
200 self.cmd_tx
201 .send(StreamCommand::Subscribe(sub))
202 .await
203 .map_err(|_| PolymarketUsError::InvalidStreamConfig("stream is closed".to_string()))
204 }
205
206 pub async fn unsubscribe(&self, tracking_id: &str) -> Result<(), PolymarketUsError> {
211 self.cmd_tx
212 .send(StreamCommand::Unsubscribe(tracking_id.to_string()))
213 .await
214 .map_err(|_| PolymarketUsError::InvalidStreamConfig("stream is closed".to_string()))
215 }
216}
217
218#[derive(Debug, Clone)]
219pub struct StreamConnectConfig {
220 pub tracking_id: String,
221 pub responses_debounced: bool,
222 pub reconnect: ReconnectConfig,
223}
224
225impl Default for StreamConnectConfig {
226 fn default() -> Self {
227 Self {
228 tracking_id: next_tracking_id("session"),
229 responses_debounced: false,
230 reconnect: ReconnectConfig::default(),
231 }
232 }
233}
234
235impl StreamConnectConfig {
236 pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
237 self.tracking_id = tracking_id.into();
238 self
239 }
240
241 pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
242 self.responses_debounced = responses_debounced;
243 self
244 }
245
246 pub fn with_reconnect(mut self, reconnect: ReconnectConfig) -> Self {
247 self.reconnect = reconnect;
248 self
249 }
250}
251
252#[derive(Debug, Clone)]
253pub struct ReconnectConfig {
254 pub enabled: bool,
255 pub max_attempts: Option<usize>,
256 pub initial_delay: Duration,
257 pub max_delay: Duration,
258 pub multiplier: f64,
259}
260
261impl Default for ReconnectConfig {
262 fn default() -> Self {
263 Self {
264 enabled: true,
265 max_attempts: None,
266 initial_delay: Duration::from_millis(250),
267 max_delay: Duration::from_secs(10),
268 multiplier: 2.0,
269 }
270 }
271}
272
273impl ReconnectConfig {
274 pub fn disabled() -> Self {
275 Self {
276 enabled: false,
277 ..Self::default()
278 }
279 }
280
281 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
282 if attempt == 0 {
283 return self.initial_delay.min(self.max_delay);
284 }
285
286 let scaled = self
287 .initial_delay
288 .mul_f64(self.multiplier.powi(attempt.saturating_sub(1) as i32));
289 scaled.min(self.max_delay)
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294#[serde(rename_all = "camelCase")]
295pub struct StreamSubscription {
296 pub channel: String,
297 pub tracking_id: String,
298 #[serde(default, skip_serializing_if = "Option::is_none")]
299 pub responses_debounced: Option<bool>,
300 #[serde(default, skip_serializing_if = "Option::is_none")]
301 pub symbol: Option<String>,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub market_id: Option<String>,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
305 pub outcome: Option<String>,
306 #[serde(default, flatten)]
307 pub extra: Map<String, Value>,
308}
309
310impl StreamSubscription {
311 pub fn new(channel: impl Into<String>) -> Self {
312 Self {
313 channel: channel.into(),
314 tracking_id: next_tracking_id("sub"),
315 responses_debounced: None,
316 symbol: None,
317 market_id: None,
318 outcome: None,
319 extra: Map::new(),
320 }
321 }
322
323 pub fn for_channel(channel: SubscriptionChannel) -> Self {
325 Self::new(channel.as_str())
326 }
327
328 pub fn market_data(symbol: impl Into<String>) -> Self {
332 let mut s = Self::new(SubscriptionChannel::MarketData.as_str());
333 s.symbol = Some(symbol.into());
334 s
335 }
336
337 pub fn market_data_lite(symbol: impl Into<String>) -> Self {
339 let mut s = Self::new(SubscriptionChannel::MarketDataLite.as_str());
340 s.symbol = Some(symbol.into());
341 s
342 }
343
344 pub fn trades(symbol: impl Into<String>) -> Self {
346 let mut s = Self::new(SubscriptionChannel::Trade.as_str());
347 s.symbol = Some(symbol.into());
348 s
349 }
350
351 pub fn heartbeat() -> Self {
353 Self::new(SubscriptionChannel::Heartbeat.as_str())
354 }
355
356 pub fn order_snapshot(symbol: impl Into<String>) -> Self {
360 let mut s = Self::new(SubscriptionChannel::OrderSnapshot.as_str());
361 s.symbol = Some(symbol.into());
362 s
363 }
364
365 pub fn order_update() -> Self {
367 Self::new(SubscriptionChannel::OrderUpdate.as_str())
368 }
369
370 pub fn position_snapshot() -> Self {
372 Self::new(SubscriptionChannel::PositionSnapshot.as_str())
373 }
374
375 pub fn position_update() -> Self {
377 Self::new(SubscriptionChannel::PositionUpdate.as_str())
378 }
379
380 pub fn balance_snapshot() -> Self {
382 Self::new(SubscriptionChannel::BalanceSnapshot.as_str())
383 }
384
385 pub fn balance_update() -> Self {
387 Self::new(SubscriptionChannel::BalanceUpdate.as_str())
388 }
389
390 pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
393 self.tracking_id = tracking_id.into();
394 self
395 }
396
397 pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
398 self.responses_debounced = Some(responses_debounced);
399 self
400 }
401
402 pub fn with_symbol(mut self, symbol: impl Into<String>) -> Self {
403 self.symbol = Some(symbol.into());
404 self
405 }
406
407 pub fn with_market_id(mut self, market_id: impl Into<String>) -> Self {
408 self.market_id = Some(market_id.into());
409 self
410 }
411
412 pub fn with_outcome(mut self, outcome: impl Into<String>) -> Self {
413 self.outcome = Some(outcome.into());
414 self
415 }
416
417 pub fn insert_extra(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
418 self.extra.insert(key.into(), value.into());
419 self
420 }
421}
422
423#[derive(Debug, Clone)]
424pub struct StreamMessage {
425 pub tracking_id: Option<String>,
426 pub kind: StreamMessageKind,
427}
428
429#[derive(Debug, Clone)]
430#[non_exhaustive]
431pub enum StreamMessageKind {
432 Data(StreamDataEvent),
433 Control(StreamControlEvent),
434}
435
436#[derive(Debug, Clone)]
437#[non_exhaustive]
438pub enum StreamDataEvent {
439 OrderSnapshot(Value),
441 OrderUpdate(Value),
443 MarketData(Value),
445 MarketDataLite(Value),
447 OrderBookDelta(Value),
449 PositionSnapshot(Value),
451 PositionUpdate(Value),
453 BalanceSnapshot(Value),
455 BalanceUpdate(Value),
457 Trade(Value),
459 Heartbeat,
461 Other { event_type: String, payload: Value },
463}
464
465#[derive(Debug, Clone)]
466#[non_exhaustive]
467pub enum StreamControlEvent {
468 Connected { session_tracking_id: String },
469 SubscriptionAck { event_type: String, payload: Value },
470 Reconnecting { attempt: usize, delay_ms: u64 },
471 Closed,
472 Error(String),
473}
474
475impl StreamMessage {
476 pub fn control(tracking_id: Option<String>, event: StreamControlEvent) -> Self {
477 Self {
478 tracking_id,
479 kind: StreamMessageKind::Control(event),
480 }
481 }
482
483 pub fn data(tracking_id: Option<String>, event: StreamDataEvent) -> Self {
484 Self {
485 tracking_id,
486 kind: StreamMessageKind::Data(event),
487 }
488 }
489}
490
491struct StreamRunner {
492 base_url: String,
493 auth: Option<UsAuth>,
494 subscriptions: Vec<StreamSubscription>,
495 config: StreamConnectConfig,
496 tx: mpsc::Sender<StreamMessage>,
497 shutdown: Arc<StreamShutdown>,
498 cmd_rx: mpsc::Receiver<StreamCommand>,
499}
500
501impl StreamRunner {
502 async fn run(mut self) {
503 let mut attempt = 0usize;
504
505 loop {
506 if self.shutdown.is_shutdown() || self.tx.is_closed() {
507 break;
508 }
509
510 match self.connect_and_consume().await {
511 Ok(()) => {
512 if !self.config.reconnect.enabled {
513 break;
514 }
515 }
516 Err(err) => {
517 if !self
518 .emit(StreamMessage::control(
519 Some(self.config.tracking_id.clone()),
520 StreamControlEvent::Error(err.to_string()),
521 ))
522 .await
523 {
524 break;
525 }
526 }
527 }
528
529 if !self.config.reconnect.enabled {
530 break;
531 }
532
533 attempt += 1;
534 if let Some(max_attempts) = self.config.reconnect.max_attempts {
535 if attempt > max_attempts {
536 break;
537 }
538 }
539
540 let delay = self.config.reconnect.delay_for_attempt(attempt);
541 if !self
542 .emit(StreamMessage::control(
543 Some(self.config.tracking_id.clone()),
544 StreamControlEvent::Reconnecting {
545 attempt,
546 delay_ms: delay.as_millis() as u64,
547 },
548 ))
549 .await
550 {
551 break;
552 }
553
554 let shutdown = Arc::clone(&self.shutdown);
555 tokio::select! {
556 _ = shutdown.notified() => break,
557 _ = tokio::time::sleep(delay) => {}
558 }
559 }
560
561 let _ = self
562 .emit(StreamMessage::control(
563 Some(self.config.tracking_id.clone()),
564 StreamControlEvent::Closed,
565 ))
566 .await;
567 }
568
569 async fn connect_and_consume(&mut self) -> Result<(), PolymarketUsError> {
570 let mut request = self
571 .base_url
572 .as_str()
573 .into_client_request()
574 .map_err(|err| {
575 PolymarketUsError::InvalidStreamConfig(format!(
576 "invalid websocket URL {}: {err}",
577 self.base_url
578 ))
579 })?;
580
581 if let Some(auth) = &self.auth {
582 let path = request
583 .uri()
584 .path_and_query()
585 .map(|path| path.as_str())
586 .unwrap_or("/");
587 for (name, value) in auth.signed_headers("GET", path) {
588 let header_value = HeaderValue::from_str(&value).map_err(|err| {
589 PolymarketUsError::InvalidStreamConfig(format!(
590 "invalid websocket auth header value for {name}: {err}"
591 ))
592 })?;
593 request.headers_mut().insert(name, header_value);
594 }
595 }
596
597 let (mut websocket, _) = connect_async(request).await?;
598 let _ = self
599 .emit(StreamMessage::control(
600 Some(self.config.tracking_id.clone()),
601 StreamControlEvent::Connected {
602 session_tracking_id: self.config.tracking_id.clone(),
603 },
604 ))
605 .await;
606
607 self.send_all_subscriptions(&mut websocket).await?;
608
609 let shutdown = Arc::clone(&self.shutdown);
612 let shutdown_wait = shutdown.notified();
613 tokio::pin!(shutdown_wait);
614
615 loop {
616 tokio::select! {
617 _ = &mut shutdown_wait => {
618 let _ = websocket.close(None).await;
619 break;
620 }
621 message = websocket.next() => {
622 let Some(message) = message else {
623 break;
624 };
625
626 match message {
627 Ok(Message::Text(text)) => {
628 self.handle_text(&text).await?;
629 }
630 Ok(Message::Binary(bytes)) => {
631 let text = String::from_utf8(bytes.to_vec()).map_err(|err| {
632 PolymarketUsError::InvalidStreamConfig(format!(
633 "received non-UTF8 websocket payload: {err}"
634 ))
635 })?;
636 self.handle_text(&text).await?;
637 }
638 Ok(Message::Close(_)) => break,
639 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
640 Ok(_) => {}
641 Err(err) => return Err(err.into()),
642 }
643 }
644 cmd = self.cmd_rx.recv() => {
645 match cmd {
646 Some(StreamCommand::Subscribe(sub)) => {
647 self.send_subscription(&mut websocket, &sub).await?;
648 self.subscriptions.push(sub);
649 }
650 Some(StreamCommand::Unsubscribe(tracking_id)) => {
651 self.subscriptions.retain(|s| s.tracking_id != tracking_id);
652 let frame = serde_json::json!({
654 "type": "unsubscribe",
655 "trackingId": tracking_id,
656 });
657 let _ = websocket
658 .send(Message::Text(frame.to_string()))
659 .await;
660 }
661 None => break,
662 }
663 }
664 }
665 }
666
667 Ok(())
668 }
669
670 async fn send_all_subscriptions(
671 &self,
672 websocket: &mut tokio_tungstenite::WebSocketStream<
673 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
674 >,
675 ) -> Result<(), PolymarketUsError> {
676 for subscription in &self.subscriptions {
677 self.send_subscription(websocket, subscription).await?;
678 }
679 Ok(())
680 }
681
682 async fn send_subscription(
683 &self,
684 websocket: &mut tokio_tungstenite::WebSocketStream<
685 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
686 >,
687 subscription: &StreamSubscription,
688 ) -> Result<(), PolymarketUsError> {
689 let mut prepared = subscription.clone();
690 if prepared.responses_debounced.is_none() {
691 prepared.responses_debounced = Some(self.config.responses_debounced);
692 }
693 let payload = serde_json::to_string(&prepared)?;
694 websocket.send(Message::Text(payload)).await?;
695 Ok(())
696 }
697
698 async fn handle_text(&self, text: &str) -> Result<(), PolymarketUsError> {
699 let json: Value = serde_json::from_str(text)?;
700 if let Some(message) = parse_stream_message(json) {
701 if !self.emit(message).await {
702 return Ok(());
703 }
704 }
705 Ok(())
706 }
707
708 async fn emit(&self, message: StreamMessage) -> bool {
709 self.tx.send(message).await.is_ok()
710 }
711}
712
713struct StreamShutdown {
714 requested: AtomicBool,
715 notify: Notify,
716}
717
718impl StreamShutdown {
719 fn new() -> Self {
720 Self {
721 requested: AtomicBool::new(false),
722 notify: Notify::new(),
723 }
724 }
725
726 fn shutdown(&self) {
727 if !self.requested.swap(true, Ordering::SeqCst) {
728 self.notify.notify_waiters();
729 }
730 }
731
732 fn is_shutdown(&self) -> bool {
733 self.requested.load(Ordering::SeqCst)
734 }
735
736 fn notified(&self) -> impl Future<Output = ()> + '_ {
737 self.notify.notified()
738 }
739}
740
741fn parse_stream_message(json: Value) -> Option<StreamMessage> {
742 match json {
743 Value::Object(map) => {
744 let tracking_id = extract_tracking_id(&map);
745 let event_type = extract_event_type(&map);
746 let payload = extract_payload(&map);
747
748 let kind = match event_type.as_str() {
749 "order_snapshot" | "orderSnapshot" => {
751 StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload))
752 }
753 "order_update" | "order_updates" | "orderUpdate" | "user_order" | "fill" => {
754 StreamMessageKind::Data(StreamDataEvent::OrderUpdate(payload))
755 }
756 "market_data" | "marketData" => {
758 StreamMessageKind::Data(StreamDataEvent::MarketData(payload))
759 }
760 "market_data_lite" | "marketDataLite" => {
761 StreamMessageKind::Data(StreamDataEvent::MarketDataLite(payload))
762 }
763 "order_book_delta" | "orderbook_delta" | "book_delta" | "bookDelta" => {
764 StreamMessageKind::Data(StreamDataEvent::OrderBookDelta(payload))
765 }
766 "trade" | "trades" => StreamMessageKind::Data(StreamDataEvent::Trade(payload)),
767 "position_snapshot" | "positionSnapshot" => {
769 StreamMessageKind::Data(StreamDataEvent::PositionSnapshot(payload))
770 }
771 "position_update" | "positionUpdate" => {
772 StreamMessageKind::Data(StreamDataEvent::PositionUpdate(payload))
773 }
774 "balance_snapshot" | "balanceSnapshot" => {
776 StreamMessageKind::Data(StreamDataEvent::BalanceSnapshot(payload))
777 }
778 "balance_update" | "balanceUpdate" => {
779 StreamMessageKind::Data(StreamDataEvent::BalanceUpdate(payload))
780 }
781 "heartbeat" | "ping" | "pong" => {
783 StreamMessageKind::Data(StreamDataEvent::Heartbeat)
784 }
785 "subscription" | "subscribe" | "subscribed" | "ack" => {
787 StreamMessageKind::Control(StreamControlEvent::SubscriptionAck {
788 event_type: event_type.clone(),
789 payload,
790 })
791 }
792 "error" => {
793 StreamMessageKind::Control(StreamControlEvent::Error(payload.to_string()))
794 }
795 _ => StreamMessageKind::Data(StreamDataEvent::Other {
796 event_type: event_type.clone(),
797 payload,
798 }),
799 };
800
801 Some(StreamMessage { tracking_id, kind })
802 }
803 other => Some(StreamMessage::data(
804 None,
805 StreamDataEvent::Other {
806 event_type: "unknown".to_string(),
807 payload: other,
808 },
809 )),
810 }
811}
812
813fn extract_tracking_id(map: &Map<String, Value>) -> Option<String> {
814 ["trackingId", "tracking_id", "trackingID", "id"]
815 .iter()
816 .find_map(|key| map.get(*key).and_then(Value::as_str).map(ToOwned::to_owned))
817}
818
819fn extract_event_type(map: &Map<String, Value>) -> String {
820 for key in ["event", "type", "channel", "name", "topic"] {
821 if let Some(value) = map.get(key).and_then(Value::as_str) {
822 return value.to_string();
823 }
824 }
825
826 if map.len() == 1 {
827 return map
828 .keys()
829 .next()
830 .cloned()
831 .unwrap_or_else(|| "unknown".to_string());
832 }
833
834 "unknown".to_string()
835}
836
837fn extract_payload(map: &Map<String, Value>) -> Value {
838 for key in ["data", "payload", "body", "message", "result"] {
839 if let Some(value) = map.get(key) {
840 return value.clone();
841 }
842 }
843
844 if map.len() == 1 {
845 return map.values().next().cloned().unwrap_or(Value::Null);
846 }
847
848 Value::Object(map.clone())
849}
850
851fn next_tracking_id(prefix: &str) -> String {
852 let ordinal = TRACKING_COUNTER.fetch_add(1, Ordering::Relaxed);
853 format!(
854 "{prefix}-{}-{ordinal}",
855 chrono::Utc::now().timestamp_millis()
856 )
857}
858
859fn normalize_stream_url(url: String) -> String {
860 let trimmed = url.trim_end_matches('/');
861 if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
862 trimmed.to_string()
863 } else if let Some(rest) = trimmed.strip_prefix("https://") {
864 format!("wss://{rest}/ws")
865 } else if let Some(rest) = trimmed.strip_prefix("http://") {
866 format!("ws://{rest}/ws")
867 } else {
868 format!("wss://{trimmed}/ws")
869 }
870}
871
872fn derive_stream_url(gateway_base_url: &str) -> String {
873 let trimmed = gateway_base_url.trim_end_matches('/');
874 if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
875 trimmed.to_string()
876 } else if let Some(rest) = trimmed.strip_prefix("https://") {
877 format!("wss://{rest}/ws")
878 } else if let Some(rest) = trimmed.strip_prefix("http://") {
879 format!("ws://{rest}/ws")
880 } else {
881 format!("wss://{trimmed}/ws")
882 }
883}
884
885#[cfg(test)]
886mod tests {
887 use super::*;
888 use serde_json::json;
889
890 #[test]
891 fn reconnect_delay_caps_at_max() {
892 let policy = ReconnectConfig {
893 enabled: true,
894 max_attempts: None,
895 initial_delay: Duration::from_millis(250),
896 max_delay: Duration::from_secs(1),
897 multiplier: 3.0,
898 };
899
900 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(250));
901 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(250));
902 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(750));
903 assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(1));
904 assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(1));
905 }
906
907 #[test]
908 fn subscription_serializes_debounced_flag_and_tracking_id() {
909 let subscription = StreamSubscription::order_snapshot("ABC")
910 .with_tracking_id("tracking-1")
911 .with_responses_debounced(true)
912 .insert_extra("bookLevel", json!(2));
913
914 let json = serde_json::to_value(subscription).unwrap();
915 assert_eq!(json["channel"], "order_snapshot");
916 assert_eq!(json["trackingId"], "tracking-1");
917 assert_eq!(json["responsesDebounced"], true);
918 assert_eq!(json["symbol"], "ABC");
919 assert_eq!(json["bookLevel"], 2);
920 }
921
922 #[test]
923 fn parses_order_snapshot_event() {
924 let message = parse_stream_message(json!({
925 "event": "order_snapshot",
926 "trackingId": "abc-123",
927 "data": { "bids": [1, 2], "asks": [3, 4] }
928 }))
929 .expect("message");
930
931 assert_eq!(message.tracking_id.as_deref(), Some("abc-123"));
932 match message.kind {
933 StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload)) => {
934 assert_eq!(payload["bids"][0], 1);
935 assert_eq!(payload["asks"][1], 4);
936 }
937 other => panic!("unexpected event: {other:?}"),
938 }
939 }
940
941 #[test]
942 fn parses_position_snapshot_event() {
943 let message = parse_stream_message(json!({
944 "event": "position_snapshot",
945 "data": { "positions": [] }
946 }))
947 .expect("message");
948 assert!(
949 matches!(
950 message.kind,
951 StreamMessageKind::Data(StreamDataEvent::PositionSnapshot(_))
952 ),
953 "expected PositionSnapshot"
954 );
955 }
956
957 #[test]
958 fn parses_balance_update_event() {
959 let message = parse_stream_message(json!({
960 "event": "balance_update",
961 "data": { "currency": "USD", "balance": "1000.00" }
962 }))
963 .expect("message");
964 assert!(
965 matches!(
966 message.kind,
967 StreamMessageKind::Data(StreamDataEvent::BalanceUpdate(_))
968 ),
969 "expected BalanceUpdate"
970 );
971 }
972
973 #[test]
974 fn parses_trade_event() {
975 let message = parse_stream_message(json!({
976 "event": "trade",
977 "data": { "price": "0.55", "size": "100" }
978 }))
979 .expect("message");
980 assert!(
981 matches!(
982 message.kind,
983 StreamMessageKind::Data(StreamDataEvent::Trade(_))
984 ),
985 "expected Trade"
986 );
987 }
988
989 #[test]
990 fn parses_heartbeat_event() {
991 let message = parse_stream_message(json!({ "event": "heartbeat" })).expect("message");
992 assert!(
993 matches!(
994 message.kind,
995 StreamMessageKind::Data(StreamDataEvent::Heartbeat)
996 ),
997 "expected Heartbeat"
998 );
999 }
1000
1001 #[test]
1002 fn parses_market_data_lite_event() {
1003 let message = parse_stream_message(json!({
1004 "event": "market_data_lite",
1005 "data": { "bid": "0.50", "ask": "0.55" }
1006 }))
1007 .expect("message");
1008 assert!(
1009 matches!(
1010 message.kind,
1011 StreamMessageKind::Data(StreamDataEvent::MarketDataLite(_))
1012 ),
1013 "expected MarketDataLite"
1014 );
1015 }
1016
1017 #[test]
1018 fn subscription_channel_as_str() {
1019 assert_eq!(
1020 SubscriptionChannel::OrderSnapshot.as_str(),
1021 "order_snapshot"
1022 );
1023 assert_eq!(
1024 SubscriptionChannel::MarketDataLite.as_str(),
1025 "market_data_lite"
1026 );
1027 assert_eq!(
1028 SubscriptionChannel::PositionUpdate.as_str(),
1029 "position_update"
1030 );
1031 assert_eq!(
1032 SubscriptionChannel::BalanceSnapshot.as_str(),
1033 "balance_snapshot"
1034 );
1035 assert_eq!(SubscriptionChannel::Trade.as_str(), "trade");
1036 assert_eq!(SubscriptionChannel::Heartbeat.as_str(), "heartbeat");
1037 }
1038
1039 #[test]
1040 fn subscription_constructors_set_channel() {
1041 assert_eq!(StreamSubscription::market_data("X").channel, "market_data");
1042 assert_eq!(
1043 StreamSubscription::market_data_lite("X").channel,
1044 "market_data_lite"
1045 );
1046 assert_eq!(StreamSubscription::trades("X").channel, "trade");
1047 assert_eq!(StreamSubscription::heartbeat().channel, "heartbeat");
1048 assert_eq!(StreamSubscription::order_update().channel, "order_update");
1049 assert_eq!(
1050 StreamSubscription::position_snapshot().channel,
1051 "position_snapshot"
1052 );
1053 assert_eq!(
1054 StreamSubscription::position_update().channel,
1055 "position_update"
1056 );
1057 assert_eq!(
1058 StreamSubscription::balance_snapshot().channel,
1059 "balance_snapshot"
1060 );
1061 assert_eq!(
1062 StreamSubscription::balance_update().channel,
1063 "balance_update"
1064 );
1065 }
1066
1067 #[test]
1068 fn for_channel_constructor() {
1069 let sub = StreamSubscription::for_channel(SubscriptionChannel::BalanceUpdate);
1070 assert_eq!(sub.channel, "balance_update");
1071 }
1072
1073 #[test]
1074 fn derives_stream_url_from_gateway_base_url() {
1075 assert_eq!(
1076 derive_stream_url("https://gateway.polymarket.us"),
1077 "wss://gateway.polymarket.us/ws"
1078 );
1079 assert_eq!(
1080 normalize_stream_url("wss://custom.example/ws".to_string()),
1081 "wss://custom.example/ws"
1082 );
1083 }
1084}