1use crate::auth::UsAuth;
2use crate::error::PolymarketUsError;
3use futures_util::{SinkExt, StreamExt};
4use http::HeaderValue;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::future::Future;
8use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Notify};
12use tokio_tungstenite::{
13 connect_async,
14 tungstenite::{client::IntoClientRequest, Message},
15};
16
17static TRACKING_COUNTER: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Clone)]
20pub struct PolymarketUsStreamClient {
21 base_url: String,
22 auth: Option<UsAuth>,
23}
24
25impl PolymarketUsStreamClient {
26 pub fn new(base_url: impl Into<String>, auth: Option<UsAuth>) -> Self {
27 Self {
28 base_url: normalize_stream_url(base_url.into()),
29 auth,
30 }
31 }
32
33 pub fn from_gateway_base_url(
34 gateway_base_url: impl Into<String>,
35 auth: Option<UsAuth>,
36 ) -> Self {
37 let gateway_base_url = gateway_base_url.into();
38 Self::new(derive_stream_url(&gateway_base_url), auth)
39 }
40
41 pub fn base_url(&self) -> &str {
42 &self.base_url
43 }
44
45 pub async fn connect(
46 &self,
47 subscriptions: Vec<StreamSubscription>,
48 ) -> Result<ManagedStream, PolymarketUsError> {
49 self.connect_with_config(subscriptions, StreamConnectConfig::default())
50 .await
51 }
52
53 pub async fn connect_with_config(
54 &self,
55 subscriptions: Vec<StreamSubscription>,
56 config: StreamConnectConfig,
57 ) -> Result<ManagedStream, PolymarketUsError> {
58 if subscriptions.is_empty() {
59 return Err(PolymarketUsError::InvalidStreamConfig(
60 "at least one subscription is required".to_string(),
61 ));
62 }
63
64 let (tx, rx) = mpsc::channel(256);
65 let shutdown = Arc::new(StreamShutdown::new());
66 let base_url = self.base_url.clone();
67 let auth = self.auth.clone();
68 let shutdown_task = shutdown.clone();
69
70 tokio::spawn(async move {
71 let runner = StreamRunner {
72 base_url,
73 auth,
74 subscriptions,
75 config,
76 tx,
77 shutdown: shutdown_task,
78 };
79 runner.run().await;
80 });
81
82 Ok(ManagedStream {
83 receiver: rx,
84 shutdown,
85 })
86 }
87
88 pub async fn run<F, Fut>(
89 &self,
90 subscriptions: Vec<StreamSubscription>,
91 config: StreamConnectConfig,
92 mut on_message: F,
93 ) -> Result<(), PolymarketUsError>
94 where
95 F: FnMut(StreamMessage) -> Fut,
96 Fut: Future<Output = ()>,
97 {
98 let mut stream = self.connect_with_config(subscriptions, config).await?;
99 while let Some(message) = stream.next().await {
100 on_message(message).await;
101 }
102 Ok(())
103 }
104}
105
106pub struct ManagedStream {
107 receiver: mpsc::Receiver<StreamMessage>,
108 shutdown: Arc<StreamShutdown>,
109}
110
111impl ManagedStream {
112 pub async fn next(&mut self) -> Option<StreamMessage> {
113 self.receiver.recv().await
114 }
115
116 pub fn shutdown(&self) {
117 self.shutdown.shutdown();
118 }
119
120 pub fn is_shutdown(&self) -> bool {
121 self.shutdown.is_shutdown()
122 }
123}
124
125#[derive(Debug, Clone)]
126pub struct StreamConnectConfig {
127 pub tracking_id: String,
128 pub responses_debounced: bool,
129 pub reconnect: ReconnectConfig,
130}
131
132impl Default for StreamConnectConfig {
133 fn default() -> Self {
134 Self {
135 tracking_id: next_tracking_id("session"),
136 responses_debounced: false,
137 reconnect: ReconnectConfig::default(),
138 }
139 }
140}
141
142impl StreamConnectConfig {
143 pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
144 self.tracking_id = tracking_id.into();
145 self
146 }
147
148 pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
149 self.responses_debounced = responses_debounced;
150 self
151 }
152
153 pub fn with_reconnect(mut self, reconnect: ReconnectConfig) -> Self {
154 self.reconnect = reconnect;
155 self
156 }
157}
158
159#[derive(Debug, Clone)]
160pub struct ReconnectConfig {
161 pub enabled: bool,
162 pub max_attempts: Option<usize>,
163 pub initial_delay: Duration,
164 pub max_delay: Duration,
165 pub multiplier: f64,
166}
167
168impl Default for ReconnectConfig {
169 fn default() -> Self {
170 Self {
171 enabled: true,
172 max_attempts: None,
173 initial_delay: Duration::from_millis(250),
174 max_delay: Duration::from_secs(10),
175 multiplier: 2.0,
176 }
177 }
178}
179
180impl ReconnectConfig {
181 pub fn disabled() -> Self {
182 Self {
183 enabled: false,
184 ..Self::default()
185 }
186 }
187
188 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
189 if attempt == 0 {
190 return self.initial_delay.min(self.max_delay);
191 }
192
193 let scaled = self
194 .initial_delay
195 .mul_f64(self.multiplier.powi(attempt.saturating_sub(1) as i32));
196 scaled.min(self.max_delay)
197 }
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(rename_all = "camelCase")]
202pub struct StreamSubscription {
203 pub channel: String,
204 pub tracking_id: String,
205 #[serde(default, skip_serializing_if = "Option::is_none")]
206 pub responses_debounced: Option<bool>,
207 #[serde(default, skip_serializing_if = "Option::is_none")]
208 pub symbol: Option<String>,
209 #[serde(default, skip_serializing_if = "Option::is_none")]
210 pub market_id: Option<String>,
211 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub outcome: Option<String>,
213 #[serde(default, flatten)]
214 pub extra: Map<String, Value>,
215}
216
217impl StreamSubscription {
218 pub fn new(channel: impl Into<String>) -> Self {
219 Self {
220 channel: channel.into(),
221 tracking_id: next_tracking_id("sub"),
222 responses_debounced: None,
223 symbol: None,
224 market_id: None,
225 outcome: None,
226 extra: Map::new(),
227 }
228 }
229
230 pub fn order_snapshot(symbol: impl Into<String>) -> Self {
231 let mut subscription = Self::new("order_snapshot");
232 subscription.symbol = Some(symbol.into());
233 subscription
234 }
235
236 pub fn market_data_lite(symbol: impl Into<String>) -> Self {
237 let mut subscription = Self::new("market_data_lite");
238 subscription.symbol = Some(symbol.into());
239 subscription
240 }
241
242 pub fn with_tracking_id(mut self, tracking_id: impl Into<String>) -> Self {
243 self.tracking_id = tracking_id.into();
244 self
245 }
246
247 pub fn with_responses_debounced(mut self, responses_debounced: bool) -> Self {
248 self.responses_debounced = Some(responses_debounced);
249 self
250 }
251
252 pub fn with_symbol(mut self, symbol: impl Into<String>) -> Self {
253 self.symbol = Some(symbol.into());
254 self
255 }
256
257 pub fn with_market_id(mut self, market_id: impl Into<String>) -> Self {
258 self.market_id = Some(market_id.into());
259 self
260 }
261
262 pub fn with_outcome(mut self, outcome: impl Into<String>) -> Self {
263 self.outcome = Some(outcome.into());
264 self
265 }
266
267 pub fn insert_extra(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
268 self.extra.insert(key.into(), value.into());
269 self
270 }
271}
272
273#[derive(Debug, Clone)]
274pub struct StreamMessage {
275 pub tracking_id: Option<String>,
276 pub kind: StreamMessageKind,
277}
278
279#[derive(Debug, Clone)]
280pub enum StreamMessageKind {
281 Data(StreamDataEvent),
282 Control(StreamControlEvent),
283}
284
285#[derive(Debug, Clone)]
286pub enum StreamDataEvent {
287 OrderSnapshot(Value),
288 MarketDataLite(Value),
289 OrderBookDelta(Value),
290 OrderUpdate(Value),
291 Other { event_type: String, payload: Value },
292}
293
294#[derive(Debug, Clone)]
295pub enum StreamControlEvent {
296 Connected { session_tracking_id: String },
297 SubscriptionAck { event_type: String, payload: Value },
298 Reconnecting { attempt: usize, delay_ms: u64 },
299 Closed,
300 Error(String),
301}
302
303impl StreamMessage {
304 pub fn control(tracking_id: Option<String>, event: StreamControlEvent) -> Self {
305 Self {
306 tracking_id,
307 kind: StreamMessageKind::Control(event),
308 }
309 }
310
311 pub fn data(tracking_id: Option<String>, event: StreamDataEvent) -> Self {
312 Self {
313 tracking_id,
314 kind: StreamMessageKind::Data(event),
315 }
316 }
317}
318
319struct StreamRunner {
320 base_url: String,
321 auth: Option<UsAuth>,
322 subscriptions: Vec<StreamSubscription>,
323 config: StreamConnectConfig,
324 tx: mpsc::Sender<StreamMessage>,
325 shutdown: Arc<StreamShutdown>,
326}
327
328impl StreamRunner {
329 async fn run(self) {
330 let mut attempt = 0usize;
331
332 loop {
333 if self.shutdown.is_shutdown() || self.tx.is_closed() {
334 break;
335 }
336
337 match self.connect_and_consume().await {
338 Ok(()) => {
339 if !self.config.reconnect.enabled {
340 break;
341 }
342 }
343 Err(err) => {
344 if !self
345 .emit(StreamMessage::control(
346 Some(self.config.tracking_id.clone()),
347 StreamControlEvent::Error(err.to_string()),
348 ))
349 .await
350 {
351 break;
352 }
353 }
354 }
355
356 if !self.config.reconnect.enabled {
357 break;
358 }
359
360 attempt += 1;
361 if let Some(max_attempts) = self.config.reconnect.max_attempts {
362 if attempt > max_attempts {
363 break;
364 }
365 }
366
367 let delay = self.config.reconnect.delay_for_attempt(attempt);
368 if !self
369 .emit(StreamMessage::control(
370 Some(self.config.tracking_id.clone()),
371 StreamControlEvent::Reconnecting {
372 attempt,
373 delay_ms: delay.as_millis() as u64,
374 },
375 ))
376 .await
377 {
378 break;
379 }
380
381 tokio::select! {
382 _ = self.shutdown.notified() => break,
383 _ = tokio::time::sleep(delay) => {}
384 }
385 }
386
387 let _ = self
388 .emit(StreamMessage::control(
389 Some(self.config.tracking_id.clone()),
390 StreamControlEvent::Closed,
391 ))
392 .await;
393 }
394
395 async fn connect_and_consume(&self) -> Result<(), PolymarketUsError> {
396 let mut request = self
397 .base_url
398 .as_str()
399 .into_client_request()
400 .map_err(|err| {
401 PolymarketUsError::InvalidStreamConfig(format!(
402 "invalid websocket URL {}: {err}",
403 self.base_url
404 ))
405 })?;
406
407 if let Some(auth) = &self.auth {
408 let path = request
409 .uri()
410 .path_and_query()
411 .map(|path| path.as_str())
412 .unwrap_or("/");
413 for (name, value) in auth.signed_headers("GET", path) {
414 let header_value = HeaderValue::from_str(&value).map_err(|err| {
415 PolymarketUsError::InvalidStreamConfig(format!(
416 "invalid websocket auth header value for {name}: {err}"
417 ))
418 })?;
419 request.headers_mut().insert(name, header_value);
420 }
421 }
422
423 let (mut websocket, _) = connect_async(request).await?;
424 let _ = self
425 .emit(StreamMessage::control(
426 Some(self.config.tracking_id.clone()),
427 StreamControlEvent::Connected {
428 session_tracking_id: self.config.tracking_id.clone(),
429 },
430 ))
431 .await;
432
433 self.send_subscriptions(&mut websocket).await?;
434
435 let shutdown_wait = self.shutdown.notified();
436 tokio::pin!(shutdown_wait);
437
438 loop {
439 tokio::select! {
440 _ = &mut shutdown_wait => {
441 let _ = websocket.close(None).await;
442 break;
443 }
444 message = websocket.next() => {
445 let Some(message) = message else {
446 break;
447 };
448
449 match message {
450 Ok(Message::Text(text)) => {
451 self.handle_text(&text).await?;
452 }
453 Ok(Message::Binary(bytes)) => {
454 let text = String::from_utf8(bytes.to_vec()).map_err(|err| {
455 PolymarketUsError::InvalidStreamConfig(format!(
456 "received non-UTF8 websocket payload: {err}"
457 ))
458 })?;
459 self.handle_text(&text).await?;
460 }
461 Ok(Message::Close(_)) => break,
462 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {}
463 Ok(_) => {}
464 Err(err) => return Err(err.into()),
465 }
466 }
467 }
468 }
469
470 Ok(())
471 }
472
473 async fn send_subscriptions(
474 &self,
475 websocket: &mut tokio_tungstenite::WebSocketStream<
476 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
477 >,
478 ) -> Result<(), PolymarketUsError> {
479 for subscription in &self.subscriptions {
480 let mut prepared = subscription.clone();
481 if prepared.responses_debounced.is_none() {
482 prepared.responses_debounced = Some(self.config.responses_debounced);
483 }
484
485 let payload = serde_json::to_string(&prepared)?;
486 websocket.send(Message::Text(payload.into())).await?;
487 }
488
489 Ok(())
490 }
491
492 async fn handle_text(&self, text: &str) -> Result<(), PolymarketUsError> {
493 let json: Value = serde_json::from_str(text)?;
494 if let Some(message) = parse_stream_message(json) {
495 if !self.emit(message).await {
496 return Ok(());
497 }
498 }
499 Ok(())
500 }
501
502 async fn emit(&self, message: StreamMessage) -> bool {
503 self.tx.send(message).await.is_ok()
504 }
505}
506
507struct StreamShutdown {
508 requested: AtomicBool,
509 notify: Notify,
510}
511
512impl StreamShutdown {
513 fn new() -> Self {
514 Self {
515 requested: AtomicBool::new(false),
516 notify: Notify::new(),
517 }
518 }
519
520 fn shutdown(&self) {
521 if !self.requested.swap(true, Ordering::SeqCst) {
522 self.notify.notify_waiters();
523 }
524 }
525
526 fn is_shutdown(&self) -> bool {
527 self.requested.load(Ordering::SeqCst)
528 }
529
530 fn notified(&self) -> impl Future<Output = ()> + '_ {
531 self.notify.notified()
532 }
533}
534
535fn parse_stream_message(json: Value) -> Option<StreamMessage> {
536 match json {
537 Value::Object(map) => {
538 let tracking_id = extract_tracking_id(&map);
539 let event_type = extract_event_type(&map);
540 let payload = extract_payload(&map);
541
542 let kind = match event_type.as_str() {
543 "order_snapshot" => {
544 StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload))
545 }
546 "market_data_lite" => {
547 StreamMessageKind::Data(StreamDataEvent::MarketDataLite(payload))
548 }
549 "order_book_delta" | "orderbook_delta" | "book_delta" => {
550 StreamMessageKind::Data(StreamDataEvent::OrderBookDelta(payload))
551 }
552 "order_update" | "order_updates" | "user_order" | "fill" => {
553 StreamMessageKind::Data(StreamDataEvent::OrderUpdate(payload))
554 }
555 "subscription" | "subscribe" | "subscribed" | "ack" => {
556 StreamMessageKind::Control(StreamControlEvent::SubscriptionAck {
557 event_type: event_type.clone(),
558 payload,
559 })
560 }
561 "error" => {
562 StreamMessageKind::Control(StreamControlEvent::Error(payload.to_string()))
563 }
564 _ => StreamMessageKind::Data(StreamDataEvent::Other {
565 event_type: event_type.clone(),
566 payload,
567 }),
568 };
569
570 Some(StreamMessage { tracking_id, kind })
571 }
572 other => Some(StreamMessage::data(
573 None,
574 StreamDataEvent::Other {
575 event_type: "unknown".to_string(),
576 payload: other,
577 },
578 )),
579 }
580}
581
582fn extract_tracking_id(map: &Map<String, Value>) -> Option<String> {
583 ["trackingId", "tracking_id", "trackingID", "id"]
584 .iter()
585 .find_map(|key| map.get(*key).and_then(Value::as_str).map(ToOwned::to_owned))
586}
587
588fn extract_event_type(map: &Map<String, Value>) -> String {
589 for key in ["event", "type", "channel", "name", "topic"] {
590 if let Some(value) = map.get(key).and_then(Value::as_str) {
591 return value.to_string();
592 }
593 }
594
595 if map.len() == 1 {
596 return map
597 .keys()
598 .next()
599 .cloned()
600 .unwrap_or_else(|| "unknown".to_string());
601 }
602
603 "unknown".to_string()
604}
605
606fn extract_payload(map: &Map<String, Value>) -> Value {
607 for key in ["data", "payload", "body", "message", "result"] {
608 if let Some(value) = map.get(key) {
609 return value.clone();
610 }
611 }
612
613 if map.len() == 1 {
614 return map.values().next().cloned().unwrap_or(Value::Null);
615 }
616
617 Value::Object(map.clone())
618}
619
620fn next_tracking_id(prefix: &str) -> String {
621 let ordinal = TRACKING_COUNTER.fetch_add(1, Ordering::Relaxed);
622 format!(
623 "{prefix}-{}-{ordinal}",
624 chrono::Utc::now().timestamp_millis()
625 )
626}
627
628fn normalize_stream_url(url: String) -> String {
629 let trimmed = url.trim_end_matches('/');
630 if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
631 trimmed.to_string()
632 } else if let Some(rest) = trimmed.strip_prefix("https://") {
633 format!("wss://{rest}/ws")
634 } else if let Some(rest) = trimmed.strip_prefix("http://") {
635 format!("ws://{rest}/ws")
636 } else {
637 format!("wss://{trimmed}/ws")
638 }
639}
640
641fn derive_stream_url(gateway_base_url: &str) -> String {
642 let trimmed = gateway_base_url.trim_end_matches('/');
643 if trimmed.starts_with("ws://") || trimmed.starts_with("wss://") {
644 trimmed.to_string()
645 } else if let Some(rest) = trimmed.strip_prefix("https://") {
646 format!("wss://{rest}/ws")
647 } else if let Some(rest) = trimmed.strip_prefix("http://") {
648 format!("ws://{rest}/ws")
649 } else {
650 format!("wss://{trimmed}/ws")
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use serde_json::json;
658
659 #[test]
660 fn reconnect_delay_caps_at_max() {
661 let policy = ReconnectConfig {
662 enabled: true,
663 max_attempts: None,
664 initial_delay: Duration::from_millis(250),
665 max_delay: Duration::from_secs(1),
666 multiplier: 3.0,
667 };
668
669 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(250));
670 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(250));
671 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(750));
672 assert_eq!(policy.delay_for_attempt(3), Duration::from_secs(1));
673 assert_eq!(policy.delay_for_attempt(10), Duration::from_secs(1));
674 }
675
676 #[test]
677 fn subscription_serializes_debounced_flag_and_tracking_id() {
678 let subscription = StreamSubscription::order_snapshot("ABC")
679 .with_tracking_id("tracking-1")
680 .with_responses_debounced(true)
681 .insert_extra("bookLevel", json!(2));
682
683 let json = serde_json::to_value(subscription).unwrap();
684 assert_eq!(json["channel"], "order_snapshot");
685 assert_eq!(json["trackingId"], "tracking-1");
686 assert_eq!(json["responsesDebounced"], true);
687 assert_eq!(json["symbol"], "ABC");
688 assert_eq!(json["bookLevel"], 2);
689 }
690
691 #[test]
692 fn parses_order_snapshot_event() {
693 let message = parse_stream_message(json!({
694 "event": "order_snapshot",
695 "trackingId": "abc-123",
696 "data": { "bids": [1, 2], "asks": [3, 4] }
697 }))
698 .expect("message");
699
700 assert_eq!(message.tracking_id.as_deref(), Some("abc-123"));
701 match message.kind {
702 StreamMessageKind::Data(StreamDataEvent::OrderSnapshot(payload)) => {
703 assert_eq!(payload["bids"][0], 1);
704 assert_eq!(payload["asks"][1], 4);
705 }
706 other => panic!("unexpected event: {other:?}"),
707 }
708 }
709
710 #[test]
711 fn derives_stream_url_from_gateway_base_url() {
712 assert_eq!(
713 derive_stream_url("https://gateway.polymarket.us"),
714 "wss://gateway.polymarket.us/ws"
715 );
716 assert_eq!(
717 normalize_stream_url("wss://custom.example/ws".to_string()),
718 "wss://custom.example/ws"
719 );
720 }
721}
722
723