polyfill_rs/
stream.rs

1//! Async streaming functionality for Polymarket client
2//!
3//! This module provides high-performance streaming capabilities for
4//! real-time market data and order updates.
5
6use crate::errors::{PolyfillError, Result};
7use crate::types::*;
8use chrono::Utc;
9use futures::{SinkExt, Stream, StreamExt};
10use serde_json::Value;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16/// Trait for market data streams
17pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
18    /// Subscribe to market data for specific tokens
19    fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
20
21    /// Unsubscribe from market data
22    fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
23
24    /// Check if the stream is connected
25    fn is_connected(&self) -> bool;
26
27    /// Get connection statistics
28    fn get_stats(&self) -> StreamStats;
29}
30
31/// WebSocket-based market stream implementation
32#[derive(Debug)]
33#[allow(dead_code)]
34pub struct WebSocketStream {
35    /// WebSocket connection
36    connection: Option<
37        tokio_tungstenite::WebSocketStream<
38            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
39        >,
40    >,
41    /// URL for the WebSocket connection
42    url: String,
43    /// Authentication credentials
44    auth: Option<WssAuth>,
45    /// Current subscriptions
46    subscriptions: Vec<WssSubscription>,
47    /// Message sender for internal communication
48    tx: mpsc::UnboundedSender<StreamMessage>,
49    /// Message receiver
50    rx: mpsc::UnboundedReceiver<StreamMessage>,
51    /// Connection statistics
52    stats: StreamStats,
53    /// Reconnection configuration
54    reconnect_config: ReconnectConfig,
55}
56
57/// Stream statistics
58#[derive(Debug, Clone)]
59pub struct StreamStats {
60    pub messages_received: u64,
61    pub messages_sent: u64,
62    pub errors: u64,
63    pub last_message_time: Option<chrono::DateTime<Utc>>,
64    pub connection_uptime: std::time::Duration,
65    pub reconnect_count: u32,
66}
67
68/// Reconnection configuration
69#[derive(Debug, Clone)]
70pub struct ReconnectConfig {
71    pub max_retries: u32,
72    pub base_delay: std::time::Duration,
73    pub max_delay: std::time::Duration,
74    pub backoff_multiplier: f64,
75}
76
77impl Default for ReconnectConfig {
78    fn default() -> Self {
79        Self {
80            max_retries: 5,
81            base_delay: std::time::Duration::from_secs(1),
82            max_delay: std::time::Duration::from_secs(60),
83            backoff_multiplier: 2.0,
84        }
85    }
86}
87
88impl WebSocketStream {
89    /// Create a new WebSocket stream
90    pub fn new(url: &str) -> Self {
91        let (tx, rx) = mpsc::unbounded_channel();
92
93        Self {
94            connection: None,
95            url: url.to_string(),
96            auth: None,
97            subscriptions: Vec::new(),
98            tx,
99            rx,
100            stats: StreamStats {
101                messages_received: 0,
102                messages_sent: 0,
103                errors: 0,
104                last_message_time: None,
105                connection_uptime: std::time::Duration::ZERO,
106                reconnect_count: 0,
107            },
108            reconnect_config: ReconnectConfig::default(),
109        }
110    }
111
112    /// Set authentication credentials
113    pub fn with_auth(mut self, auth: WssAuth) -> Self {
114        self.auth = Some(auth);
115        self
116    }
117
118    /// Connect to the WebSocket
119    async fn connect(&mut self) -> Result<()> {
120        let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url)
121            .await
122            .map_err(|e| {
123                PolyfillError::stream(
124                    format!("WebSocket connection failed: {}", e),
125                    crate::errors::StreamErrorKind::ConnectionFailed,
126                )
127            })?;
128
129        self.connection = Some(ws_stream);
130        info!("Connected to WebSocket stream at {}", self.url);
131        Ok(())
132    }
133
134    /// Send a message to the WebSocket
135    async fn send_message(&mut self, message: Value) -> Result<()> {
136        if let Some(connection) = &mut self.connection {
137            let text = serde_json::to_string(&message).map_err(|e| {
138                PolyfillError::parse(format!("Failed to serialize message: {}", e), None)
139            })?;
140
141            let ws_message = tokio_tungstenite::tungstenite::Message::Text(text);
142            connection.send(ws_message).await.map_err(|e| {
143                PolyfillError::stream(
144                    format!("Failed to send message: {}", e),
145                    crate::errors::StreamErrorKind::MessageCorrupted,
146                )
147            })?;
148
149            self.stats.messages_sent += 1;
150        }
151
152        Ok(())
153    }
154
155    /// Subscribe to market data using official Polymarket WebSocket API
156    pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
157        // Ensure connection
158        if self.connection.is_none() {
159            self.connect().await?;
160        }
161
162        // Send subscription message in the format expected by Polymarket
163        // The subscription struct will serialize correctly with proper field names
164        let message = serde_json::to_value(&subscription).map_err(|e| {
165            PolyfillError::parse(format!("Failed to serialize subscription: {}", e), None)
166        })?;
167
168        self.send_message(message).await?;
169        self.subscriptions.push(subscription.clone());
170
171        info!("Subscribed to {} channel", subscription.channel_type);
172        Ok(())
173    }
174
175    /// Subscribe to user channel (orders and trades)
176    pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
177        let auth = self
178            .auth
179            .as_ref()
180            .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
181            .clone();
182
183        let subscription = WssSubscription {
184            channel_type: "user".to_string(),
185            operation: Some("subscribe".to_string()),
186            markets,
187            asset_ids: Vec::new(),
188            initial_dump: Some(true),
189            custom_feature_enabled: None,
190            auth: Some(auth),
191        };
192
193        self.subscribe_async(subscription).await
194    }
195
196    /// Subscribe to market channel (order book and trades)
197    /// Market subscriptions do not require authentication
198    pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
199        let subscription = WssSubscription {
200            channel_type: "market".to_string(),
201            operation: Some("subscribe".to_string()),
202            markets: Vec::new(),
203            asset_ids,
204            initial_dump: Some(true),
205            custom_feature_enabled: None,
206            auth: None,
207        };
208
209        self.subscribe_async(subscription).await
210    }
211
212    /// Subscribe to market channel with custom features enabled
213    /// Custom features include: best_bid_ask, new_market, market_resolved events
214    pub async fn subscribe_market_channel_with_features(&mut self, asset_ids: Vec<String>) -> Result<()> {
215        let subscription = WssSubscription {
216            channel_type: "market".to_string(),
217            operation: Some("subscribe".to_string()),
218            markets: Vec::new(),
219            asset_ids,
220            initial_dump: Some(true),
221            custom_feature_enabled: Some(true),
222            auth: None,
223        };
224
225        self.subscribe_async(subscription).await
226    }
227
228    /// Unsubscribe from market channel
229    pub async fn unsubscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
230        let subscription = WssSubscription {
231            channel_type: "market".to_string(),
232            operation: Some("unsubscribe".to_string()),
233            markets: Vec::new(),
234            asset_ids,
235            initial_dump: None,
236            custom_feature_enabled: None,
237            auth: None,
238        };
239
240        self.subscribe_async(subscription).await
241    }
242
243    /// Unsubscribe from user channel
244    pub async fn unsubscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
245        let auth = self
246            .auth
247            .as_ref()
248            .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
249            .clone();
250
251        let subscription = WssSubscription {
252            channel_type: "user".to_string(),
253            operation: Some("unsubscribe".to_string()),
254            markets,
255            asset_ids: Vec::new(),
256            initial_dump: None,
257            custom_feature_enabled: None,
258            auth: Some(auth),
259        };
260
261        self.subscribe_async(subscription).await
262    }
263
264    /// Handle incoming WebSocket messages
265    #[allow(dead_code)]
266    async fn handle_message(
267        &mut self,
268        message: tokio_tungstenite::tungstenite::Message,
269    ) -> Result<()> {
270        match message {
271            tokio_tungstenite::tungstenite::Message::Text(text) => {
272                debug!("Received WebSocket message: {}", text);
273
274                // Parse the message according to Polymarket's format
275                let stream_message = self.parse_polymarket_message(&text)?;
276
277                // Send to internal channel
278                if let Err(e) = self.tx.send(stream_message) {
279                    error!("Failed to send message to internal channel: {}", e);
280                }
281
282                self.stats.messages_received += 1;
283                self.stats.last_message_time = Some(Utc::now());
284            },
285            tokio_tungstenite::tungstenite::Message::Close(_) => {
286                info!("WebSocket connection closed by server");
287                self.connection = None;
288            },
289            tokio_tungstenite::tungstenite::Message::Ping(data) => {
290                // Respond with pong
291                if let Some(connection) = &mut self.connection {
292                    let pong = tokio_tungstenite::tungstenite::Message::Pong(data);
293                    if let Err(e) = connection.send(pong).await {
294                        error!("Failed to send pong: {}", e);
295                    }
296                }
297            },
298            tokio_tungstenite::tungstenite::Message::Pong(_) => {
299                // Handle pong if needed
300                debug!("Received pong");
301            },
302            tokio_tungstenite::tungstenite::Message::Binary(_) => {
303                warn!("Received binary message (not supported)");
304            },
305            tokio_tungstenite::tungstenite::Message::Frame(_) => {
306                warn!("Received raw frame (not supported)");
307            },
308        }
309
310        Ok(())
311    }
312
313    /// Parse Polymarket WebSocket message format
314    #[allow(dead_code)]
315    fn parse_polymarket_message(&self, text: &str) -> Result<StreamMessage> {
316        let value: Value = serde_json::from_str(text).map_err(|e| {
317            PolyfillError::parse(
318                format!("Failed to parse WebSocket message: {}", e),
319                Some(Box::new(e)),
320            )
321        })?;
322
323        // Extract message type
324        let message_type = value.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
325            PolyfillError::parse("Missing 'type' field in WebSocket message", None)
326        })?;
327
328        match message_type {
329            "book_update" => {
330                let data =
331                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
332                        .map_err(|e| {
333                            PolyfillError::parse(
334                                format!("Failed to parse book update: {}", e),
335                                Some(Box::new(e)),
336                            )
337                        })?;
338                Ok(StreamMessage::BookUpdate { data })
339            },
340            "trade" => {
341                let data =
342                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
343                        .map_err(|e| {
344                            PolyfillError::parse(
345                                format!("Failed to parse trade: {}", e),
346                                Some(Box::new(e)),
347                            )
348                        })?;
349                Ok(StreamMessage::Trade { data })
350            },
351            "order_update" => {
352                let data =
353                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
354                        .map_err(|e| {
355                            PolyfillError::parse(
356                                format!("Failed to parse order update: {}", e),
357                                Some(Box::new(e)),
358                            )
359                        })?;
360                Ok(StreamMessage::OrderUpdate { data })
361            },
362            "user_order_update" => {
363                let data =
364                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
365                        .map_err(|e| {
366                            PolyfillError::parse(
367                                format!("Failed to parse user order update: {}", e),
368                                Some(Box::new(e)),
369                            )
370                        })?;
371                Ok(StreamMessage::UserOrderUpdate { data })
372            },
373            "user_trade" => {
374                let data =
375                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
376                        .map_err(|e| {
377                            PolyfillError::parse(
378                                format!("Failed to parse user trade: {}", e),
379                                Some(Box::new(e)),
380                            )
381                        })?;
382                Ok(StreamMessage::UserTrade { data })
383            },
384            "market_book_update" => {
385                let data =
386                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
387                        .map_err(|e| {
388                            PolyfillError::parse(
389                                format!("Failed to parse market book update: {}", e),
390                                Some(Box::new(e)),
391                            )
392                        })?;
393                Ok(StreamMessage::MarketBookUpdate { data })
394            },
395            "market_trade" => {
396                let data =
397                    serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
398                        .map_err(|e| {
399                            PolyfillError::parse(
400                                format!("Failed to parse market trade: {}", e),
401                                Some(Box::new(e)),
402                            )
403                        })?;
404                Ok(StreamMessage::MarketTrade { data })
405            },
406            "heartbeat" => {
407                let timestamp = value
408                    .get("timestamp")
409                    .and_then(|v| v.as_u64())
410                    .map(|ts| chrono::DateTime::from_timestamp(ts as i64, 0).unwrap_or_default())
411                    .unwrap_or_else(Utc::now);
412                Ok(StreamMessage::Heartbeat { timestamp })
413            },
414            _ => {
415                warn!("Unknown message type: {}", message_type);
416                // Return heartbeat as fallback
417                Ok(StreamMessage::Heartbeat {
418                    timestamp: Utc::now(),
419                })
420            },
421        }
422    }
423
424    /// Reconnect with exponential backoff
425    #[allow(dead_code)]
426    async fn reconnect(&mut self) -> Result<()> {
427        let mut delay = self.reconnect_config.base_delay;
428        let mut retries = 0;
429
430        while retries < self.reconnect_config.max_retries {
431            warn!("Attempting to reconnect (attempt {})", retries + 1);
432
433            match self.connect().await {
434                Ok(()) => {
435                    info!("Successfully reconnected");
436                    self.stats.reconnect_count += 1;
437
438                    // Resubscribe to all previous subscriptions
439                    let subscriptions = self.subscriptions.clone();
440                    for subscription in subscriptions {
441                        self.send_message(serde_json::to_value(subscription)?)
442                            .await?;
443                    }
444
445                    return Ok(());
446                },
447                Err(e) => {
448                    error!("Reconnection attempt {} failed: {}", retries + 1, e);
449                    retries += 1;
450
451                    if retries < self.reconnect_config.max_retries {
452                        tokio::time::sleep(delay).await;
453                        delay = std::cmp::min(
454                            delay.mul_f64(self.reconnect_config.backoff_multiplier),
455                            self.reconnect_config.max_delay,
456                        );
457                    }
458                },
459            }
460        }
461
462        Err(PolyfillError::stream(
463            format!(
464                "Failed to reconnect after {} attempts",
465                self.reconnect_config.max_retries
466            ),
467            crate::errors::StreamErrorKind::ConnectionFailed,
468        ))
469    }
470}
471
472impl Stream for WebSocketStream {
473    type Item = Result<StreamMessage>;
474
475    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
476        // First check internal channel
477        if let Poll::Ready(Some(message)) = self.rx.poll_recv(cx) {
478            return Poll::Ready(Some(Ok(message)));
479        }
480
481        // Then check WebSocket connection
482        if let Some(connection) = &mut self.connection {
483            match connection.poll_next_unpin(cx) {
484                Poll::Ready(Some(Ok(_message))) => {
485                    // Simplified message handling
486                    Poll::Ready(Some(Ok(StreamMessage::Heartbeat {
487                        timestamp: Utc::now(),
488                    })))
489                },
490                Poll::Ready(Some(Err(e))) => {
491                    error!("WebSocket error: {}", e);
492                    self.stats.errors += 1;
493                    Poll::Ready(Some(Err(e.into())))
494                },
495                Poll::Ready(None) => {
496                    info!("WebSocket stream ended");
497                    Poll::Ready(None)
498                },
499                Poll::Pending => Poll::Pending,
500            }
501        } else {
502            Poll::Ready(None)
503        }
504    }
505}
506
507impl MarketStream for WebSocketStream {
508    fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
509        // This is for backward compatibility - use subscribe_async for new code
510        Ok(())
511    }
512
513    fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
514        // This is for backward compatibility - use unsubscribe_async for new code
515        Ok(())
516    }
517
518    fn is_connected(&self) -> bool {
519        self.connection.is_some()
520    }
521
522    fn get_stats(&self) -> StreamStats {
523        self.stats.clone()
524    }
525}
526
527/// Mock stream for testing
528#[derive(Debug)]
529pub struct MockStream {
530    messages: Vec<Result<StreamMessage>>,
531    index: usize,
532    connected: bool,
533}
534
535impl Default for MockStream {
536    fn default() -> Self {
537        Self::new()
538    }
539}
540
541impl MockStream {
542    pub fn new() -> Self {
543        Self {
544            messages: Vec::new(),
545            index: 0,
546            connected: true,
547        }
548    }
549
550    pub fn add_message(&mut self, message: StreamMessage) {
551        self.messages.push(Ok(message));
552    }
553
554    pub fn add_error(&mut self, error: PolyfillError) {
555        self.messages.push(Err(error));
556    }
557
558    pub fn set_connected(&mut self, connected: bool) {
559        self.connected = connected;
560    }
561}
562
563impl Stream for MockStream {
564    type Item = Result<StreamMessage>;
565
566    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
567        if self.index >= self.messages.len() {
568            Poll::Ready(None)
569        } else {
570            let message = self.messages[self.index].clone();
571            self.index += 1;
572            Poll::Ready(Some(message))
573        }
574    }
575}
576
577impl MarketStream for MockStream {
578    fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
579        Ok(())
580    }
581
582    fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
583        Ok(())
584    }
585
586    fn is_connected(&self) -> bool {
587        self.connected
588    }
589
590    fn get_stats(&self) -> StreamStats {
591        StreamStats {
592            messages_received: self.messages.len() as u64,
593            messages_sent: 0,
594            errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
595            last_message_time: None,
596            connection_uptime: std::time::Duration::ZERO,
597            reconnect_count: 0,
598        }
599    }
600}
601
602/// Stream manager for handling multiple streams
603#[allow(dead_code)]
604pub struct StreamManager {
605    streams: Vec<Box<dyn MarketStream>>,
606    message_tx: mpsc::UnboundedSender<StreamMessage>,
607    message_rx: mpsc::UnboundedReceiver<StreamMessage>,
608}
609
610impl Default for StreamManager {
611    fn default() -> Self {
612        Self::new()
613    }
614}
615
616impl StreamManager {
617    pub fn new() -> Self {
618        let (message_tx, message_rx) = mpsc::unbounded_channel();
619
620        Self {
621            streams: Vec::new(),
622            message_tx,
623            message_rx,
624        }
625    }
626
627    pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
628        self.streams.push(stream);
629    }
630
631    pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
632        // Note: UnboundedReceiver doesn't implement Clone
633        // In a real implementation, you'd want to use a different approach
634        // For now, we'll return a dummy receiver
635        let (_, rx) = mpsc::unbounded_channel();
636        rx
637    }
638
639    pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
640        self.message_tx
641            .send(message)
642            .map_err(|e| PolyfillError::internal("Failed to broadcast message", e))
643    }
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649
650    #[test]
651    fn test_mock_stream() {
652        let mut stream = MockStream::new();
653
654        // Add some test messages
655        stream.add_message(StreamMessage::Heartbeat {
656            timestamp: Utc::now(),
657        });
658        stream.add_message(StreamMessage::BookUpdate {
659            data: OrderDelta {
660                token_id: "test".to_string(),
661                timestamp: Utc::now(),
662                side: Side::BUY,
663                price: rust_decimal_macros::dec!(0.5),
664                size: rust_decimal_macros::dec!(100),
665                sequence: 1,
666            },
667        });
668
669        assert!(stream.is_connected());
670        assert_eq!(stream.get_stats().messages_received, 2);
671    }
672
673    #[test]
674    fn test_stream_manager() {
675        let mut manager = StreamManager::new();
676        let mock_stream = Box::new(MockStream::new());
677        manager.add_stream(mock_stream);
678
679        // Test message broadcasting
680        let message = StreamMessage::Heartbeat {
681            timestamp: Utc::now(),
682        };
683        assert!(manager.broadcast_message(message).is_ok());
684    }
685}