1use crate::errors::{PolyfillError, Result};
7use crate::types::*;
8use chrono::Utc;
9use futures::{SinkExt, Stream, StreamExt};
10use serde_json::Value;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
18 fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
20
21 fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
23
24 fn is_connected(&self) -> bool;
26
27 fn get_stats(&self) -> StreamStats;
29}
30
31#[derive(Debug)]
33#[allow(dead_code)]
34pub struct WebSocketStream {
35 connection: Option<
37 tokio_tungstenite::WebSocketStream<
38 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
39 >,
40 >,
41 url: String,
43 auth: Option<WssAuth>,
45 subscriptions: Vec<WssSubscription>,
47 tx: mpsc::UnboundedSender<StreamMessage>,
49 rx: mpsc::UnboundedReceiver<StreamMessage>,
51 stats: StreamStats,
53 reconnect_config: ReconnectConfig,
55}
56
57#[derive(Debug, Clone)]
59pub struct StreamStats {
60 pub messages_received: u64,
61 pub messages_sent: u64,
62 pub errors: u64,
63 pub last_message_time: Option<chrono::DateTime<Utc>>,
64 pub connection_uptime: std::time::Duration,
65 pub reconnect_count: u32,
66}
67
68#[derive(Debug, Clone)]
70pub struct ReconnectConfig {
71 pub max_retries: u32,
72 pub base_delay: std::time::Duration,
73 pub max_delay: std::time::Duration,
74 pub backoff_multiplier: f64,
75}
76
77impl Default for ReconnectConfig {
78 fn default() -> Self {
79 Self {
80 max_retries: 5,
81 base_delay: std::time::Duration::from_secs(1),
82 max_delay: std::time::Duration::from_secs(60),
83 backoff_multiplier: 2.0,
84 }
85 }
86}
87
88impl WebSocketStream {
89 pub fn new(url: &str) -> Self {
91 let (tx, rx) = mpsc::unbounded_channel();
92
93 Self {
94 connection: None,
95 url: url.to_string(),
96 auth: None,
97 subscriptions: Vec::new(),
98 tx,
99 rx,
100 stats: StreamStats {
101 messages_received: 0,
102 messages_sent: 0,
103 errors: 0,
104 last_message_time: None,
105 connection_uptime: std::time::Duration::ZERO,
106 reconnect_count: 0,
107 },
108 reconnect_config: ReconnectConfig::default(),
109 }
110 }
111
112 pub fn with_auth(mut self, auth: WssAuth) -> Self {
114 self.auth = Some(auth);
115 self
116 }
117
118 async fn connect(&mut self) -> Result<()> {
120 let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url)
121 .await
122 .map_err(|e| {
123 PolyfillError::stream(
124 format!("WebSocket connection failed: {}", e),
125 crate::errors::StreamErrorKind::ConnectionFailed,
126 )
127 })?;
128
129 self.connection = Some(ws_stream);
130 info!("Connected to WebSocket stream at {}", self.url);
131 Ok(())
132 }
133
134 async fn send_message(&mut self, message: Value) -> Result<()> {
136 if let Some(connection) = &mut self.connection {
137 let text = serde_json::to_string(&message).map_err(|e| {
138 PolyfillError::parse(format!("Failed to serialize message: {}", e), None)
139 })?;
140
141 let ws_message = tokio_tungstenite::tungstenite::Message::Text(text);
142 connection.send(ws_message).await.map_err(|e| {
143 PolyfillError::stream(
144 format!("Failed to send message: {}", e),
145 crate::errors::StreamErrorKind::MessageCorrupted,
146 )
147 })?;
148
149 self.stats.messages_sent += 1;
150 }
151
152 Ok(())
153 }
154
155 pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
157 if self.connection.is_none() {
159 self.connect().await?;
160 }
161
162 let message = serde_json::to_value(&subscription).map_err(|e| {
165 PolyfillError::parse(format!("Failed to serialize subscription: {}", e), None)
166 })?;
167
168 self.send_message(message).await?;
169 self.subscriptions.push(subscription.clone());
170
171 info!("Subscribed to {} channel", subscription.channel_type);
172 Ok(())
173 }
174
175 pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
177 let auth = self
178 .auth
179 .as_ref()
180 .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
181 .clone();
182
183 let subscription = WssSubscription {
184 channel_type: "user".to_string(),
185 operation: Some("subscribe".to_string()),
186 markets,
187 asset_ids: Vec::new(),
188 initial_dump: Some(true),
189 custom_feature_enabled: None,
190 auth: Some(auth),
191 };
192
193 self.subscribe_async(subscription).await
194 }
195
196 pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
199 let subscription = WssSubscription {
200 channel_type: "market".to_string(),
201 operation: Some("subscribe".to_string()),
202 markets: Vec::new(),
203 asset_ids,
204 initial_dump: Some(true),
205 custom_feature_enabled: None,
206 auth: None,
207 };
208
209 self.subscribe_async(subscription).await
210 }
211
212 pub async fn subscribe_market_channel_with_features(&mut self, asset_ids: Vec<String>) -> Result<()> {
215 let subscription = WssSubscription {
216 channel_type: "market".to_string(),
217 operation: Some("subscribe".to_string()),
218 markets: Vec::new(),
219 asset_ids,
220 initial_dump: Some(true),
221 custom_feature_enabled: Some(true),
222 auth: None,
223 };
224
225 self.subscribe_async(subscription).await
226 }
227
228 pub async fn unsubscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
230 let subscription = WssSubscription {
231 channel_type: "market".to_string(),
232 operation: Some("unsubscribe".to_string()),
233 markets: Vec::new(),
234 asset_ids,
235 initial_dump: None,
236 custom_feature_enabled: None,
237 auth: None,
238 };
239
240 self.subscribe_async(subscription).await
241 }
242
243 pub async fn unsubscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
245 let auth = self
246 .auth
247 .as_ref()
248 .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
249 .clone();
250
251 let subscription = WssSubscription {
252 channel_type: "user".to_string(),
253 operation: Some("unsubscribe".to_string()),
254 markets,
255 asset_ids: Vec::new(),
256 initial_dump: None,
257 custom_feature_enabled: None,
258 auth: Some(auth),
259 };
260
261 self.subscribe_async(subscription).await
262 }
263
264 #[allow(dead_code)]
266 async fn handle_message(
267 &mut self,
268 message: tokio_tungstenite::tungstenite::Message,
269 ) -> Result<()> {
270 match message {
271 tokio_tungstenite::tungstenite::Message::Text(text) => {
272 debug!("Received WebSocket message: {}", text);
273
274 let stream_message = self.parse_polymarket_message(&text)?;
276
277 if let Err(e) = self.tx.send(stream_message) {
279 error!("Failed to send message to internal channel: {}", e);
280 }
281
282 self.stats.messages_received += 1;
283 self.stats.last_message_time = Some(Utc::now());
284 },
285 tokio_tungstenite::tungstenite::Message::Close(_) => {
286 info!("WebSocket connection closed by server");
287 self.connection = None;
288 },
289 tokio_tungstenite::tungstenite::Message::Ping(data) => {
290 if let Some(connection) = &mut self.connection {
292 let pong = tokio_tungstenite::tungstenite::Message::Pong(data);
293 if let Err(e) = connection.send(pong).await {
294 error!("Failed to send pong: {}", e);
295 }
296 }
297 },
298 tokio_tungstenite::tungstenite::Message::Pong(_) => {
299 debug!("Received pong");
301 },
302 tokio_tungstenite::tungstenite::Message::Binary(_) => {
303 warn!("Received binary message (not supported)");
304 },
305 tokio_tungstenite::tungstenite::Message::Frame(_) => {
306 warn!("Received raw frame (not supported)");
307 },
308 }
309
310 Ok(())
311 }
312
313 #[allow(dead_code)]
315 fn parse_polymarket_message(&self, text: &str) -> Result<StreamMessage> {
316 let value: Value = serde_json::from_str(text).map_err(|e| {
317 PolyfillError::parse(
318 format!("Failed to parse WebSocket message: {}", e),
319 Some(Box::new(e)),
320 )
321 })?;
322
323 let message_type = value.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
325 PolyfillError::parse("Missing 'type' field in WebSocket message", None)
326 })?;
327
328 match message_type {
329 "book_update" => {
330 let data =
331 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
332 .map_err(|e| {
333 PolyfillError::parse(
334 format!("Failed to parse book update: {}", e),
335 Some(Box::new(e)),
336 )
337 })?;
338 Ok(StreamMessage::BookUpdate { data })
339 },
340 "trade" => {
341 let data =
342 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
343 .map_err(|e| {
344 PolyfillError::parse(
345 format!("Failed to parse trade: {}", e),
346 Some(Box::new(e)),
347 )
348 })?;
349 Ok(StreamMessage::Trade { data })
350 },
351 "order_update" => {
352 let data =
353 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
354 .map_err(|e| {
355 PolyfillError::parse(
356 format!("Failed to parse order update: {}", e),
357 Some(Box::new(e)),
358 )
359 })?;
360 Ok(StreamMessage::OrderUpdate { data })
361 },
362 "user_order_update" => {
363 let data =
364 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
365 .map_err(|e| {
366 PolyfillError::parse(
367 format!("Failed to parse user order update: {}", e),
368 Some(Box::new(e)),
369 )
370 })?;
371 Ok(StreamMessage::UserOrderUpdate { data })
372 },
373 "user_trade" => {
374 let data =
375 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
376 .map_err(|e| {
377 PolyfillError::parse(
378 format!("Failed to parse user trade: {}", e),
379 Some(Box::new(e)),
380 )
381 })?;
382 Ok(StreamMessage::UserTrade { data })
383 },
384 "market_book_update" => {
385 let data =
386 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
387 .map_err(|e| {
388 PolyfillError::parse(
389 format!("Failed to parse market book update: {}", e),
390 Some(Box::new(e)),
391 )
392 })?;
393 Ok(StreamMessage::MarketBookUpdate { data })
394 },
395 "market_trade" => {
396 let data =
397 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
398 .map_err(|e| {
399 PolyfillError::parse(
400 format!("Failed to parse market trade: {}", e),
401 Some(Box::new(e)),
402 )
403 })?;
404 Ok(StreamMessage::MarketTrade { data })
405 },
406 "heartbeat" => {
407 let timestamp = value
408 .get("timestamp")
409 .and_then(|v| v.as_u64())
410 .map(|ts| chrono::DateTime::from_timestamp(ts as i64, 0).unwrap_or_default())
411 .unwrap_or_else(Utc::now);
412 Ok(StreamMessage::Heartbeat { timestamp })
413 },
414 _ => {
415 warn!("Unknown message type: {}", message_type);
416 Ok(StreamMessage::Heartbeat {
418 timestamp: Utc::now(),
419 })
420 },
421 }
422 }
423
424 #[allow(dead_code)]
426 async fn reconnect(&mut self) -> Result<()> {
427 let mut delay = self.reconnect_config.base_delay;
428 let mut retries = 0;
429
430 while retries < self.reconnect_config.max_retries {
431 warn!("Attempting to reconnect (attempt {})", retries + 1);
432
433 match self.connect().await {
434 Ok(()) => {
435 info!("Successfully reconnected");
436 self.stats.reconnect_count += 1;
437
438 let subscriptions = self.subscriptions.clone();
440 for subscription in subscriptions {
441 self.send_message(serde_json::to_value(subscription)?)
442 .await?;
443 }
444
445 return Ok(());
446 },
447 Err(e) => {
448 error!("Reconnection attempt {} failed: {}", retries + 1, e);
449 retries += 1;
450
451 if retries < self.reconnect_config.max_retries {
452 tokio::time::sleep(delay).await;
453 delay = std::cmp::min(
454 delay.mul_f64(self.reconnect_config.backoff_multiplier),
455 self.reconnect_config.max_delay,
456 );
457 }
458 },
459 }
460 }
461
462 Err(PolyfillError::stream(
463 format!(
464 "Failed to reconnect after {} attempts",
465 self.reconnect_config.max_retries
466 ),
467 crate::errors::StreamErrorKind::ConnectionFailed,
468 ))
469 }
470}
471
472impl Stream for WebSocketStream {
473 type Item = Result<StreamMessage>;
474
475 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
476 if let Poll::Ready(Some(message)) = self.rx.poll_recv(cx) {
478 return Poll::Ready(Some(Ok(message)));
479 }
480
481 if let Some(connection) = &mut self.connection {
483 match connection.poll_next_unpin(cx) {
484 Poll::Ready(Some(Ok(_message))) => {
485 Poll::Ready(Some(Ok(StreamMessage::Heartbeat {
487 timestamp: Utc::now(),
488 })))
489 },
490 Poll::Ready(Some(Err(e))) => {
491 error!("WebSocket error: {}", e);
492 self.stats.errors += 1;
493 Poll::Ready(Some(Err(e.into())))
494 },
495 Poll::Ready(None) => {
496 info!("WebSocket stream ended");
497 Poll::Ready(None)
498 },
499 Poll::Pending => Poll::Pending,
500 }
501 } else {
502 Poll::Ready(None)
503 }
504 }
505}
506
507impl MarketStream for WebSocketStream {
508 fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
509 Ok(())
511 }
512
513 fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
514 Ok(())
516 }
517
518 fn is_connected(&self) -> bool {
519 self.connection.is_some()
520 }
521
522 fn get_stats(&self) -> StreamStats {
523 self.stats.clone()
524 }
525}
526
527#[derive(Debug)]
529pub struct MockStream {
530 messages: Vec<Result<StreamMessage>>,
531 index: usize,
532 connected: bool,
533}
534
535impl Default for MockStream {
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541impl MockStream {
542 pub fn new() -> Self {
543 Self {
544 messages: Vec::new(),
545 index: 0,
546 connected: true,
547 }
548 }
549
550 pub fn add_message(&mut self, message: StreamMessage) {
551 self.messages.push(Ok(message));
552 }
553
554 pub fn add_error(&mut self, error: PolyfillError) {
555 self.messages.push(Err(error));
556 }
557
558 pub fn set_connected(&mut self, connected: bool) {
559 self.connected = connected;
560 }
561}
562
563impl Stream for MockStream {
564 type Item = Result<StreamMessage>;
565
566 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
567 if self.index >= self.messages.len() {
568 Poll::Ready(None)
569 } else {
570 let message = self.messages[self.index].clone();
571 self.index += 1;
572 Poll::Ready(Some(message))
573 }
574 }
575}
576
577impl MarketStream for MockStream {
578 fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
579 Ok(())
580 }
581
582 fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
583 Ok(())
584 }
585
586 fn is_connected(&self) -> bool {
587 self.connected
588 }
589
590 fn get_stats(&self) -> StreamStats {
591 StreamStats {
592 messages_received: self.messages.len() as u64,
593 messages_sent: 0,
594 errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
595 last_message_time: None,
596 connection_uptime: std::time::Duration::ZERO,
597 reconnect_count: 0,
598 }
599 }
600}
601
602#[allow(dead_code)]
604pub struct StreamManager {
605 streams: Vec<Box<dyn MarketStream>>,
606 message_tx: mpsc::UnboundedSender<StreamMessage>,
607 message_rx: mpsc::UnboundedReceiver<StreamMessage>,
608}
609
610impl Default for StreamManager {
611 fn default() -> Self {
612 Self::new()
613 }
614}
615
616impl StreamManager {
617 pub fn new() -> Self {
618 let (message_tx, message_rx) = mpsc::unbounded_channel();
619
620 Self {
621 streams: Vec::new(),
622 message_tx,
623 message_rx,
624 }
625 }
626
627 pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
628 self.streams.push(stream);
629 }
630
631 pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
632 let (_, rx) = mpsc::unbounded_channel();
636 rx
637 }
638
639 pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
640 self.message_tx
641 .send(message)
642 .map_err(|e| PolyfillError::internal("Failed to broadcast message", e))
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_mock_stream() {
652 let mut stream = MockStream::new();
653
654 stream.add_message(StreamMessage::Heartbeat {
656 timestamp: Utc::now(),
657 });
658 stream.add_message(StreamMessage::BookUpdate {
659 data: OrderDelta {
660 token_id: "test".to_string(),
661 timestamp: Utc::now(),
662 side: Side::BUY,
663 price: rust_decimal_macros::dec!(0.5),
664 size: rust_decimal_macros::dec!(100),
665 sequence: 1,
666 },
667 });
668
669 assert!(stream.is_connected());
670 assert_eq!(stream.get_stats().messages_received, 2);
671 }
672
673 #[test]
674 fn test_stream_manager() {
675 let mut manager = StreamManager::new();
676 let mock_stream = Box::new(MockStream::new());
677 manager.add_stream(mock_stream);
678
679 let message = StreamMessage::Heartbeat {
681 timestamp: Utc::now(),
682 };
683 assert!(manager.broadcast_message(message).is_ok());
684 }
685}