1use crate::errors::{PolyError, Result};
10use crate::types::{ApiCredentials, OrderSummary, Side};
11use chrono::{DateTime, Utc};
12use futures::{SinkExt, StreamExt};
13use serde::Deserialize;
14use serde_json::{Value, json};
15use std::collections::VecDeque;
16use std::time::Duration;
17use tokio::net::TcpStream;
18use tokio::time::{sleep, timeout};
19use tokio_tungstenite::{
20 MaybeTlsStream, WebSocketStream, connect_async, tungstenite::protocol::Message,
21};
22use tracing::warn;
23
24const DEFAULT_WSS_BASE: &str = "wss://ws-subscriptions-clob.polymarket.com";
25const MARKET_CHANNEL_PATH: &str = "/ws/market";
26const USER_CHANNEL_PATH: &str = "/ws/user";
27const BASE_RECONNECT_DELAY: Duration = Duration::from_millis(250);
28const MAX_RECONNECT_DELAY: Duration = Duration::from_secs(10);
29const MAX_RECONNECT_ATTEMPTS: u32 = 8;
30const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(25);
31
32#[derive(Debug, Clone)]
34pub enum WssMarketEvent {
35 Book(MarketBook),
36 PriceChange(PriceChangeMessage),
37 TickSizeChange(TickSizeChangeMessage),
38 LastTrade(LastTradeMessage),
39}
40
41#[derive(Debug, Clone)]
43pub enum WssUserEvent {
44 Trade(WssUserTradeMessage),
45 Order(WssUserOrderMessage),
46}
47
48#[derive(Debug, Clone, Deserialize)]
50pub struct WssUserTradeMessage {
51 #[serde(rename = "event_type")]
52 pub event_type: String,
53 pub asset_id: String,
54 pub id: String,
55 pub last_update: String,
56 #[serde(default)]
57 pub maker_orders: Vec<MakerOrder>,
58 pub market: String,
59 pub matchtime: String,
60 pub outcome: String,
61 pub owner: String,
62 #[serde(with = "rust_decimal::serde::str")]
63 pub price: rust_decimal::Decimal,
64 pub side: Side,
65 #[serde(with = "rust_decimal::serde::str")]
66 pub size: rust_decimal::Decimal,
67 pub status: String,
68 pub taker_order_id: String,
69 pub timestamp: String,
70 pub trade_owner: String,
71 #[serde(rename = "type")]
72 pub message_type: String,
73}
74
75#[derive(Debug, Clone, Deserialize)]
77pub struct MakerOrder {
78 pub asset_id: String,
79 #[serde(with = "rust_decimal::serde::str")]
80 pub matched_amount: rust_decimal::Decimal,
81 pub order_id: String,
82 pub outcome: String,
83 pub owner: String,
84 #[serde(with = "rust_decimal::serde::str")]
85 pub price: rust_decimal::Decimal,
86}
87
88#[derive(Debug, Clone, Deserialize)]
90pub struct WssUserOrderMessage {
91 #[serde(rename = "event_type")]
92 pub event_type: String,
93 #[serde(default)]
94 pub associate_trades: Option<Vec<String>>,
95 pub asset_id: String,
96 pub id: String,
97 pub market: String,
98 pub order_owner: String,
99 #[serde(with = "rust_decimal::serde::str")]
100 pub original_size: rust_decimal::Decimal,
101 pub outcome: String,
102 pub owner: String,
103 #[serde(with = "rust_decimal::serde::str")]
104 pub price: rust_decimal::Decimal,
105 pub side: Side,
106 #[serde(with = "rust_decimal::serde::str")]
107 pub size_matched: rust_decimal::Decimal,
108 pub timestamp: String,
109 #[serde(rename = "type")]
110 pub message_type: String,
111}
112
113#[derive(Debug, Clone, Deserialize)]
115pub struct MarketBook {
116 #[serde(rename = "event_type")]
117 pub event_type: String,
118 pub asset_id: String,
119 pub market: String,
120 pub timestamp: String,
121 pub hash: String,
122 pub bids: Vec<OrderSummary>,
123 pub asks: Vec<OrderSummary>,
124}
125
126#[derive(Debug, Clone, Deserialize)]
128pub struct PriceChangeMessage {
129 #[serde(rename = "event_type")]
130 pub event_type: String,
131 pub market: String,
132 #[serde(rename = "price_changes")]
133 pub price_changes: Vec<PriceChangeEntry>,
134 pub timestamp: String,
135}
136
137#[derive(Debug, Clone, Deserialize)]
139pub struct PriceChangeEntry {
140 pub asset_id: String,
141 #[serde(with = "rust_decimal::serde::str")]
142 pub price: rust_decimal::Decimal,
143 #[serde(with = "rust_decimal::serde::str")]
144 pub size: rust_decimal::Decimal,
145 pub side: Side,
146 pub hash: String,
147 #[serde(with = "rust_decimal::serde::str")]
148 pub best_bid: rust_decimal::Decimal,
149 #[serde(with = "rust_decimal::serde::str")]
150 pub best_ask: rust_decimal::Decimal,
151}
152
153#[derive(Debug, Clone, Deserialize)]
155pub struct TickSizeChangeMessage {
156 #[serde(rename = "event_type")]
157 pub event_type: String,
158 pub asset_id: String,
159 pub market: String,
160 #[serde(rename = "old_tick_size", with = "rust_decimal::serde::str")]
161 pub old_tick_size: rust_decimal::Decimal,
162 #[serde(rename = "new_tick_size", with = "rust_decimal::serde::str")]
163 pub new_tick_size: rust_decimal::Decimal,
164 pub side: String,
165 pub timestamp: String,
166}
167
168#[derive(Debug, Clone, Deserialize)]
170pub struct LastTradeMessage {
171 #[serde(rename = "event_type")]
172 pub event_type: String,
173 pub asset_id: String,
174 pub fee_rate_bps: String,
175 pub market: String,
176 #[serde(with = "rust_decimal::serde::str")]
177 pub price: rust_decimal::Decimal,
178 #[serde(with = "rust_decimal::serde::str")]
179 pub size: rust_decimal::Decimal,
180 pub side: Side,
181 pub timestamp: String,
182}
183
184#[derive(Debug, Clone)]
186pub struct WssStats {
187 pub messages_received: u64,
188 pub errors: u64,
189 pub reconnect_count: u32,
190 pub last_message_time: Option<DateTime<Utc>>,
191}
192
193impl Default for WssStats {
194 fn default() -> Self {
195 Self {
196 messages_received: 0,
197 errors: 0,
198 reconnect_count: 0,
199 last_message_time: None,
200 }
201 }
202}
203
204pub struct WssMarketClient {
206 connect_url: String,
207 connection: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
208 subscribed_asset_ids: Vec<String>,
209 stats: WssStats,
210 disconnect_history: VecDeque<DateTime<Utc>>,
211 pending_events: VecDeque<WssMarketEvent>,
212}
213
214impl WssMarketClient {
215 pub fn new() -> Self {
217 Self::with_url(DEFAULT_WSS_BASE)
218 }
219
220 pub fn with_url(url: &str) -> Self {
222 let trimmed = url.trim_end_matches('/');
223 let connect_url = format!("{}{}", trimmed, MARKET_CHANNEL_PATH);
224 Self {
225 connection: None,
226 subscribed_asset_ids: Vec::new(),
227 stats: WssStats::default(),
228 disconnect_history: VecDeque::with_capacity(5),
229 connect_url,
230 pending_events: VecDeque::new(),
231 }
232 }
233
234 pub fn stats(&self) -> WssStats {
236 self.stats.clone()
237 }
238
239 fn format_subscription(&self) -> Value {
240 json!({
241 "type": "market",
242 "assets_ids": self.subscribed_asset_ids,
243 })
244 }
245
246 async fn send_subscription(&mut self) -> Result<()> {
247 if self.subscribed_asset_ids.is_empty() {
248 return Ok(());
249 }
250
251 let message = self.format_subscription();
252 self.send_raw_message(message).await
253 }
254
255 async fn send_raw_message(&mut self, message: Value) -> Result<()> {
256 if let Some(connection) = self.connection.as_mut() {
257 let text = serde_json::to_string(&message).map_err(|e| {
258 PolyError::parse(
259 format!("Failed to serialize subscription message: {}", e),
260 None,
261 )
262 })?;
263 connection
264 .send(Message::Text(text.into()))
265 .await
266 .map_err(|e| {
267 PolyError::stream(
268 format!("Failed to send message: {}", e),
269 crate::errors::StreamErrorKind::MessageCorrupted,
270 )
271 })?;
272 return Ok(());
273 }
274 Err(PolyError::stream(
275 "WebSocket connection not established",
276 crate::errors::StreamErrorKind::ConnectionFailed,
277 ))
278 }
279
280 async fn connect(&mut self) -> Result<()> {
281 let mut attempts = 0;
282 loop {
283 match connect_async(&self.connect_url).await {
284 Ok((socket, _)) => {
285 self.connection = Some(socket);
286 if attempts > 0 {
287 self.stats.reconnect_count += 1;
288 }
289 return Ok(());
290 }
291 Err(err) => {
292 attempts += 1;
293 let delay = self.reconnect_delay(attempts);
294 self.stats.errors += 1;
295 if attempts >= MAX_RECONNECT_ATTEMPTS {
296 return Err(PolyError::stream(
297 format!("Failed to connect after {} attempts: {}", attempts, err),
298 crate::errors::StreamErrorKind::ConnectionFailed,
299 ));
300 }
301 sleep(delay).await;
302 }
303 }
304 }
305 }
306
307 fn reconnect_delay(&self, attempts: u32) -> Duration {
308 let millis = BASE_RECONNECT_DELAY.as_millis() as u128 * attempts as u128;
309 let desired =
310 Duration::from_millis(millis.min(MAX_RECONNECT_DELAY.as_millis() as u128) as u64);
311 desired
312 }
313
314 async fn ensure_connection(&mut self) -> Result<()> {
315 if self.connection.is_none() {
316 self.connect().await?;
317 self.send_subscription().await?;
318 }
319 Ok(())
320 }
321
322 pub async fn subscribe(&mut self, asset_ids: Vec<String>) -> Result<()> {
324 self.subscribed_asset_ids = asset_ids;
325 self.ensure_connection().await?;
326 self.send_subscription().await
327 }
328
329 pub async fn next_event(&mut self) -> Result<WssMarketEvent> {
332 loop {
333 if let Some(evt) = self.pending_events.pop_front() {
334 return Ok(evt);
335 }
336 self.ensure_connection().await?;
337
338 match self.connection.as_mut().unwrap().next().await {
339 Some(Ok(Message::Text(text))) => {
340 let trimmed = text.trim();
341 if trimmed.eq_ignore_ascii_case("ping") || trimmed.eq_ignore_ascii_case("pong")
342 {
343 continue;
344 }
345 let first_char = trimmed.chars().next();
346 if first_char != Some('{') && first_char != Some('[') {
347 warn!("ignoring unexpected text frame: {}", trimmed);
348 continue;
349 }
350 let events = parse_market_events(&text)?;
351 self.stats.messages_received += events.len() as u64;
352 self.stats.last_message_time = Some(Utc::now());
353 for evt in events {
354 self.pending_events.push_back(evt);
355 }
356 if let Some(evt) = self.pending_events.pop_front() {
357 return Ok(evt);
358 }
359 continue;
360 }
361 Some(Ok(Message::Ping(payload))) => {
362 if let Some(connection) = self.connection.as_mut() {
363 let _ = connection.send(Message::Pong(payload)).await;
364 }
365 }
366 Some(Ok(Message::Pong(_))) => {}
367 Some(Ok(Message::Close(_))) => {
368 self.disconnect_history.push_back(Utc::now());
369 if self.disconnect_history.len() > 5 {
370 self.disconnect_history.pop_front();
371 }
372 self.connection = None;
373 }
374 Some(Ok(_)) => {}
375 Some(Err(err)) => {
376 warn!("WebSocket error: {}", err);
377 self.connection = None;
378 self.stats.errors += 1;
379 continue;
380 }
381 None => {
382 self.connection = None;
383 }
384 }
385 }
386 }
387}
388
389pub struct WssUserClient {
391 connect_url: String,
392 connection: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
393 subscribed_markets: Vec<String>,
394 stats: WssStats,
395 disconnect_history: VecDeque<DateTime<Utc>>,
396 pending_events: VecDeque<WssUserEvent>,
397 auth: ApiCredentials,
398}
399
400impl WssUserClient {
401 pub fn new(auth: ApiCredentials) -> Self {
403 Self::with_url(DEFAULT_WSS_BASE, auth)
404 }
405
406 pub fn with_url(url: &str, auth: ApiCredentials) -> Self {
408 let trimmed = url.trim_end_matches('/');
409 let connect_url = format!("{}{}", trimmed, USER_CHANNEL_PATH);
410 Self {
411 connection: None,
412 subscribed_markets: Vec::new(),
413 stats: WssStats::default(),
414 disconnect_history: VecDeque::with_capacity(5),
415 connect_url,
416 pending_events: VecDeque::new(),
417 auth,
418 }
419 }
420
421 pub fn stats(&self) -> WssStats {
423 self.stats.clone()
424 }
425
426 fn format_subscription(&self) -> Option<Value> {
427 if self.subscribed_markets.is_empty() {
428 return None;
429 }
430
431 Some(json!({
432 "type": "user",
433 "auth": {
434 "apiKey": self.auth.api_key,
435 "secret": self.auth.secret,
436 "passphrase": self.auth.passphrase,
437 },
438 "markets": self.subscribed_markets,
439 }))
440 }
441
442 async fn send_subscription(&mut self) -> Result<()> {
443 if let Some(message) = self.format_subscription() {
444 self.send_raw_message(message).await
445 } else {
446 Ok(())
447 }
448 }
449
450 async fn send_raw_message(&mut self, message: Value) -> Result<()> {
451 if let Some(connection) = self.connection.as_mut() {
452 let text = serde_json::to_string(&message).map_err(|e| {
453 PolyError::parse(
454 format!("Failed to serialize subscription message: {}", e),
455 None,
456 )
457 })?;
458 connection
459 .send(Message::Text(text.into()))
460 .await
461 .map_err(|e| {
462 PolyError::stream(
463 format!("Failed to send message: {}", e),
464 crate::errors::StreamErrorKind::MessageCorrupted,
465 )
466 })?;
467 return Ok(());
468 }
469 Err(PolyError::stream(
470 "WebSocket connection not established",
471 crate::errors::StreamErrorKind::ConnectionFailed,
472 ))
473 }
474
475 async fn connect(&mut self) -> Result<()> {
476 let mut attempts = 0;
477 loop {
478 match connect_async(&self.connect_url).await {
479 Ok((socket, _)) => {
480 self.connection = Some(socket);
481 if attempts > 0 {
482 self.stats.reconnect_count += 1;
483 }
484 return Ok(());
485 }
486 Err(err) => {
487 attempts += 1;
488 let delay = self.reconnect_delay(attempts);
489 self.stats.errors += 1;
490 if attempts >= MAX_RECONNECT_ATTEMPTS {
491 return Err(PolyError::stream(
492 format!("Failed to connect after {} attempts: {}", attempts, err),
493 crate::errors::StreamErrorKind::ConnectionFailed,
494 ));
495 }
496 sleep(delay).await;
497 }
498 }
499 }
500 }
501
502 fn reconnect_delay(&self, attempts: u32) -> Duration {
503 let millis = BASE_RECONNECT_DELAY.as_millis() as u128 * attempts as u128;
504 let desired =
505 Duration::from_millis(millis.min(MAX_RECONNECT_DELAY.as_millis() as u128) as u64);
506 desired
507 }
508
509 async fn ensure_connection(&mut self) -> Result<()> {
510 if self.connection.is_none() {
511 self.connect().await?;
512 self.send_subscription().await?;
513 }
514 Ok(())
515 }
516
517 pub async fn subscribe(&mut self, market_ids: Vec<String>) -> Result<()> {
519 self.subscribed_markets = market_ids;
520 self.ensure_connection().await?;
521 self.send_subscription().await
522 }
523
524 pub async fn next_event(&mut self) -> Result<WssUserEvent> {
527 loop {
528 if let Some(evt) = self.pending_events.pop_front() {
529 return Ok(evt);
530 }
531 self.ensure_connection().await?;
532
533 match timeout(KEEPALIVE_INTERVAL, self.connection.as_mut().unwrap().next()).await {
534 Ok(Some(Ok(Message::Text(text)))) => {
535 let trimmed = text.trim();
536 if trimmed.eq_ignore_ascii_case("ping") || trimmed.eq_ignore_ascii_case("pong")
537 {
538 continue;
539 }
540 let first_char = trimmed.chars().next();
541 if first_char != Some('{') && first_char != Some('[') {
542 warn!("ignoring unexpected text frame: {}", trimmed);
543 continue;
544 }
545 let events = parse_user_events(&text)?;
546 self.stats.messages_received += events.len() as u64;
547 self.stats.last_message_time = Some(Utc::now());
548 for evt in events {
549 self.pending_events.push_back(evt);
550 }
551 if let Some(evt) = self.pending_events.pop_front() {
552 return Ok(evt);
553 }
554 continue;
555 }
556 Ok(Some(Ok(Message::Ping(payload)))) => {
557 if let Some(connection) = self.connection.as_mut() {
558 let _ = connection.send(Message::Pong(payload)).await;
559 }
560 }
561 Ok(Some(Ok(Message::Pong(_)))) => {}
562 Ok(Some(Ok(Message::Close(_)))) => {
563 self.disconnect_history.push_back(Utc::now());
564 if self.disconnect_history.len() > 5 {
565 self.disconnect_history.pop_front();
566 }
567 self.connection = None;
568 }
569 Ok(Some(Ok(_))) => {}
570 Ok(Some(Err(err))) => {
571 warn!("WebSocket error: {}", err);
572 self.connection = None;
573 self.stats.errors += 1;
574 continue;
575 }
576 Ok(None) => {
577 self.connection = None;
578 }
579 Err(_) => {
580 if let Some(connection) = self.connection.as_mut() {
581 let _ = connection.send(Message::Text("PING".into())).await;
582 }
583 }
584 }
585 }
586 }
587}
588
589fn parse_market_events(text: &str) -> Result<Vec<WssMarketEvent>> {
590 let value: Value = serde_json::from_str(text)
591 .map_err(|err| PolyError::parse(format!("Invalid JSON: {}", err), Some(Box::new(err))))?;
592
593 if let Some(array) = value.as_array() {
594 array
595 .iter()
596 .map(parse_market_event_value)
597 .collect::<Result<Vec<_>>>()
598 } else {
599 Ok(vec![parse_market_event_value(&value)?])
600 }
601}
602
603fn parse_market_event_value(value: &Value) -> Result<WssMarketEvent> {
604 let event_type = value
605 .get("event_type")
606 .and_then(|v| v.as_str())
607 .or_else(|| value.get("type").and_then(|v| v.as_str()))
608 .ok_or_else(|| PolyError::parse("Missing event_type/type in market message", None))?;
609
610 match event_type {
611 "book" => {
612 let parsed: MarketBook = serde_json::from_value(value.clone()).map_err(|err| {
613 PolyError::parse(
614 format!("Failed to parse book message: {}", err),
615 Some(Box::new(err)),
616 )
617 })?;
618 Ok(WssMarketEvent::Book(parsed))
619 }
620 "price_change" => {
621 let parsed =
622 serde_json::from_value::<PriceChangeMessage>(value.clone()).map_err(|err| {
623 PolyError::parse(
624 format!("Failed to parse price_change: {}", err),
625 Some(Box::new(err)),
626 )
627 })?;
628 Ok(WssMarketEvent::PriceChange(parsed))
629 }
630 "tick_size_change" => {
631 let parsed =
632 serde_json::from_value::<TickSizeChangeMessage>(value.clone()).map_err(|err| {
633 PolyError::parse(
634 format!("Failed to parse tick_size_change: {}", err),
635 Some(Box::new(err)),
636 )
637 })?;
638 Ok(WssMarketEvent::TickSizeChange(parsed))
639 }
640 "last_trade_price" => {
641 let parsed =
642 serde_json::from_value::<LastTradeMessage>(value.clone()).map_err(|err| {
643 PolyError::parse(
644 format!("Failed to parse last_trade_price: {}", err),
645 Some(Box::new(err)),
646 )
647 })?;
648 Ok(WssMarketEvent::LastTrade(parsed))
649 }
650 other => Err(PolyError::parse(
651 format!("Unknown market event_type: {}", other),
652 None,
653 )),
654 }
655}
656
657fn parse_user_events(text: &str) -> Result<Vec<WssUserEvent>> {
658 let value: Value = serde_json::from_str(text)
659 .map_err(|err| PolyError::parse(format!("Invalid JSON: {}", err), Some(Box::new(err))))?;
660
661 if let Some(array) = value.as_array() {
662 array
663 .iter()
664 .map(parse_user_event_value)
665 .collect::<Result<Vec<_>>>()
666 } else {
667 Ok(vec![parse_user_event_value(&value)?])
668 }
669}
670
671fn parse_user_event_value(value: &Value) -> Result<WssUserEvent> {
672 let event_type = value
673 .get("event_type")
674 .and_then(|v| v.as_str())
675 .ok_or_else(|| PolyError::parse("Missing event_type in user message", None))?;
676
677 match event_type {
678 "trade" => {
679 let parsed =
680 serde_json::from_value::<WssUserTradeMessage>(value.clone()).map_err(|err| {
681 PolyError::parse(
682 format!("Failed to parse user trade message: {}", err),
683 Some(Box::new(err)),
684 )
685 })?;
686 Ok(WssUserEvent::Trade(parsed))
687 }
688 "order" => {
689 let parsed =
690 serde_json::from_value::<WssUserOrderMessage>(value.clone()).map_err(|err| {
691 PolyError::parse(
692 format!("Failed to parse user order message: {}", err),
693 Some(Box::new(err)),
694 )
695 })?;
696 Ok(WssUserEvent::Order(parsed))
697 }
698 other => Err(PolyError::parse(
699 format!("Unknown user event_type: {}", other),
700 None,
701 )),
702 }
703}