polyfill_rs/
stream.rs

1//! Async streaming functionality for Polymarket client
2//!
3//! This module provides high-performance streaming capabilities for
4//! real-time market data and order updates.
5
6use crate::errors::{PolyfillError, Result};
7use crate::types::*;
8use futures::{Stream, SinkExt, StreamExt};
9use serde_json::Value;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio::sync::mpsc;
13use tracing::{debug, error, info, warn};
14use chrono::Utc;
15
16/// Trait for market data streams
17pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
18    /// Subscribe to market data for specific tokens
19    fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
20    
21    /// Unsubscribe from market data
22    fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
23    
24    /// Check if the stream is connected
25    fn is_connected(&self) -> bool;
26    
27    /// Get connection statistics
28    fn get_stats(&self) -> StreamStats;
29}
30
31/// WebSocket-based market stream implementation
32#[derive(Debug)]
33pub struct WebSocketStream {
34    /// WebSocket connection
35    connection: Option<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
36    /// URL for the WebSocket connection
37    url: String,
38    /// Authentication credentials
39    auth: Option<WssAuth>,
40    /// Current subscriptions
41    subscriptions: Vec<WssSubscription>,
42    /// Message sender for internal communication
43    tx: mpsc::UnboundedSender<StreamMessage>,
44    /// Message receiver
45    rx: mpsc::UnboundedReceiver<StreamMessage>,
46    /// Connection statistics
47    stats: StreamStats,
48    /// Reconnection configuration
49    reconnect_config: ReconnectConfig,
50}
51
52/// Stream statistics
53#[derive(Debug, Clone)]
54pub struct StreamStats {
55    pub messages_received: u64,
56    pub messages_sent: u64,
57    pub errors: u64,
58    pub last_message_time: Option<chrono::DateTime<Utc>>,
59    pub connection_uptime: std::time::Duration,
60    pub reconnect_count: u32,
61}
62
63/// Reconnection configuration
64#[derive(Debug, Clone)]
65pub struct ReconnectConfig {
66    pub max_retries: u32,
67    pub base_delay: std::time::Duration,
68    pub max_delay: std::time::Duration,
69    pub backoff_multiplier: f64,
70}
71
72impl Default for ReconnectConfig {
73    fn default() -> Self {
74        Self {
75            max_retries: 5,
76            base_delay: std::time::Duration::from_secs(1),
77            max_delay: std::time::Duration::from_secs(60),
78            backoff_multiplier: 2.0,
79        }
80    }
81}
82
83impl WebSocketStream {
84    /// Create a new WebSocket stream
85    pub fn new(url: &str) -> Self {
86        let (tx, rx) = mpsc::unbounded_channel();
87        
88        Self {
89            connection: None,
90            url: url.to_string(),
91            auth: None,
92            subscriptions: Vec::new(),
93            tx,
94            rx,
95            stats: StreamStats {
96                messages_received: 0,
97                messages_sent: 0,
98                errors: 0,
99                last_message_time: None,
100                connection_uptime: std::time::Duration::ZERO,
101                reconnect_count: 0,
102            },
103            reconnect_config: ReconnectConfig::default(),
104        }
105    }
106
107    /// Set authentication credentials
108    pub fn with_auth(mut self, auth: WssAuth) -> Self {
109        self.auth = Some(auth);
110        self
111    }
112
113    /// Connect to the WebSocket
114    async fn connect(&mut self) -> Result<()> {
115        let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url).await
116            .map_err(|e| PolyfillError::stream(format!("WebSocket connection failed: {}", e), crate::errors::StreamErrorKind::ConnectionFailed))?;
117
118        self.connection = Some(ws_stream);
119        info!("Connected to WebSocket stream at {}", self.url);
120        Ok(())
121    }
122
123    /// Send a message to the WebSocket
124    async fn send_message(&mut self, message: Value) -> Result<()> {
125        if let Some(connection) = &mut self.connection {
126            let text = serde_json::to_string(&message)
127                .map_err(|e| PolyfillError::parse(format!("Failed to serialize message: {}", e), None))?;
128            
129            let ws_message = tokio_tungstenite::tungstenite::Message::Text(text);
130            connection.send(ws_message).await
131                .map_err(|e| PolyfillError::stream(format!("Failed to send message: {}", e), crate::errors::StreamErrorKind::MessageCorrupted))?;
132            
133            self.stats.messages_sent += 1;
134        }
135        
136        Ok(())
137    }
138
139    /// Subscribe to market data using official Polymarket WebSocket API
140    pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
141        // Ensure connection
142        if self.connection.is_none() {
143            self.connect().await?;
144        }
145
146        // Send subscription message in the format expected by Polymarket
147        let message = serde_json::json!({
148            "auth": subscription.auth,
149            "markets": subscription.markets,
150            "asset_ids": subscription.asset_ids,
151            "type": subscription.channel_type,
152        });
153
154        self.send_message(message).await?;
155        self.subscriptions.push(subscription.clone());
156        
157        info!("Subscribed to {} channel", subscription.channel_type);
158        Ok(())
159    }
160
161    /// Subscribe to user channel (orders and trades)
162    pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
163        let auth = self.auth.as_ref()
164            .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
165            .clone();
166
167        let subscription = WssSubscription {
168            auth,
169            markets: Some(markets),
170            asset_ids: None,
171            channel_type: "USER".to_string(),
172        };
173
174        self.subscribe_async(subscription).await
175    }
176
177    /// Subscribe to market channel (order book and trades)
178    pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
179        let auth = self.auth.as_ref()
180            .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
181            .clone();
182
183        let subscription = WssSubscription {
184            auth,
185            markets: None,
186            asset_ids: Some(asset_ids),
187            channel_type: "MARKET".to_string(),
188        };
189
190        self.subscribe_async(subscription).await
191    }
192
193    /// Unsubscribe from market data
194    pub async fn unsubscribe_async(&mut self, token_ids: &[String]) -> Result<()> {
195        // Note: Polymarket WebSocket API doesn't seem to have explicit unsubscribe
196        // We'll just remove from our local subscriptions
197        self.subscriptions.retain(|sub| {
198            match sub.channel_type.as_str() {
199                "USER" => {
200                    if let Some(markets) = &sub.markets {
201                        !token_ids.iter().any(|id| markets.contains(id))
202                    } else {
203                        true
204                    }
205                }
206                "MARKET" => {
207                    if let Some(asset_ids) = &sub.asset_ids {
208                        !token_ids.iter().any(|id| asset_ids.contains(id))
209                    } else {
210                        true
211                    }
212                }
213                _ => true
214            }
215        });
216        
217        info!("Unsubscribed from {} tokens", token_ids.len());
218        Ok(())
219    }
220
221    /// Handle incoming WebSocket messages
222    async fn handle_message(&mut self, message: tokio_tungstenite::tungstenite::Message) -> Result<()> {
223        match message {
224            tokio_tungstenite::tungstenite::Message::Text(text) => {
225                debug!("Received WebSocket message: {}", text);
226                
227                // Parse the message according to Polymarket's format
228                let stream_message = self.parse_polymarket_message(&text)?;
229                
230                // Send to internal channel
231                if let Err(e) = self.tx.send(stream_message) {
232                    error!("Failed to send message to internal channel: {}", e);
233                }
234                
235                self.stats.messages_received += 1;
236                self.stats.last_message_time = Some(Utc::now());
237            }
238            tokio_tungstenite::tungstenite::Message::Close(_) => {
239                info!("WebSocket connection closed by server");
240                self.connection = None;
241            }
242            tokio_tungstenite::tungstenite::Message::Ping(data) => {
243                // Respond with pong
244                if let Some(connection) = &mut self.connection {
245                    let pong = tokio_tungstenite::tungstenite::Message::Pong(data);
246                    if let Err(e) = connection.send(pong).await {
247                        error!("Failed to send pong: {}", e);
248                    }
249                }
250            }
251            tokio_tungstenite::tungstenite::Message::Pong(_) => {
252                // Handle pong if needed
253                debug!("Received pong");
254            }
255            tokio_tungstenite::tungstenite::Message::Binary(_) => {
256                warn!("Received binary message (not supported)");
257            }
258            tokio_tungstenite::tungstenite::Message::Frame(_) => {
259                warn!("Received raw frame (not supported)");
260            }
261        }
262        
263        Ok(())
264    }
265
266    /// Parse Polymarket WebSocket message format
267    fn parse_polymarket_message(&self, text: &str) -> Result<StreamMessage> {
268        let value: Value = serde_json::from_str(text)
269            .map_err(|e| PolyfillError::parse(format!("Failed to parse WebSocket message: {}", e), Some(Box::new(e))))?;
270
271        // Extract message type
272        let message_type = value.get("type")
273            .and_then(|v| v.as_str())
274            .ok_or_else(|| PolyfillError::parse("Missing 'type' field in WebSocket message", None))?;
275
276        match message_type {
277            "book_update" => {
278                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
279                    .map_err(|e| PolyfillError::parse(format!("Failed to parse book update: {}", e), Some(Box::new(e))))?;
280                Ok(StreamMessage::BookUpdate { data })
281            }
282            "trade" => {
283                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
284                    .map_err(|e| PolyfillError::parse(format!("Failed to parse trade: {}", e), Some(Box::new(e))))?;
285                Ok(StreamMessage::Trade { data })
286            }
287            "order_update" => {
288                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
289                    .map_err(|e| PolyfillError::parse(format!("Failed to parse order update: {}", e), Some(Box::new(e))))?;
290                Ok(StreamMessage::OrderUpdate { data })
291            }
292            "user_order_update" => {
293                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
294                    .map_err(|e| PolyfillError::parse(format!("Failed to parse user order update: {}", e), Some(Box::new(e))))?;
295                Ok(StreamMessage::UserOrderUpdate { data })
296            }
297            "user_trade" => {
298                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
299                    .map_err(|e| PolyfillError::parse(format!("Failed to parse user trade: {}", e), Some(Box::new(e))))?;
300                Ok(StreamMessage::UserTrade { data })
301            }
302            "market_book_update" => {
303                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
304                    .map_err(|e| PolyfillError::parse(format!("Failed to parse market book update: {}", e), Some(Box::new(e))))?;
305                Ok(StreamMessage::MarketBookUpdate { data })
306            }
307            "market_trade" => {
308                let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
309                    .map_err(|e| PolyfillError::parse(format!("Failed to parse market trade: {}", e), Some(Box::new(e))))?;
310                Ok(StreamMessage::MarketTrade { data })
311            }
312            "heartbeat" => {
313                let timestamp = value.get("timestamp")
314                    .and_then(|v| v.as_u64())
315                    .map(|ts| chrono::DateTime::from_timestamp(ts as i64, 0).unwrap_or_default())
316                    .unwrap_or_else(Utc::now);
317                Ok(StreamMessage::Heartbeat { timestamp })
318            }
319            _ => {
320                warn!("Unknown message type: {}", message_type);
321                // Return heartbeat as fallback
322                Ok(StreamMessage::Heartbeat { timestamp: Utc::now() })
323            }
324        }
325    }
326
327    /// Reconnect with exponential backoff
328    async fn reconnect(&mut self) -> Result<()> {
329        let mut delay = self.reconnect_config.base_delay;
330        let mut retries = 0;
331
332        while retries < self.reconnect_config.max_retries {
333            warn!("Attempting to reconnect (attempt {})", retries + 1);
334            
335            match self.connect().await {
336                Ok(()) => {
337                    info!("Successfully reconnected");
338                    self.stats.reconnect_count += 1;
339                    
340                    // Resubscribe to all previous subscriptions
341                    let subscriptions = self.subscriptions.clone();
342                    for subscription in subscriptions {
343                        self.send_message(serde_json::to_value(subscription)?).await?;
344                    }
345                    
346                    return Ok(());
347                }
348                Err(e) => {
349                    error!("Reconnection attempt {} failed: {}", retries + 1, e);
350                    retries += 1;
351                    
352                    if retries < self.reconnect_config.max_retries {
353                        tokio::time::sleep(delay).await;
354                        delay = std::cmp::min(
355                            delay.mul_f64(self.reconnect_config.backoff_multiplier),
356                            self.reconnect_config.max_delay
357                        );
358                    }
359                }
360            }
361        }
362
363        Err(PolyfillError::stream(
364            format!("Failed to reconnect after {} attempts", self.reconnect_config.max_retries),
365            crate::errors::StreamErrorKind::ConnectionFailed
366        ))
367    }
368}
369
370impl Stream for WebSocketStream {
371    type Item = Result<StreamMessage>;
372
373    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
374        // First check internal channel
375        if let Poll::Ready(Some(message)) = self.rx.poll_recv(cx) {
376            return Poll::Ready(Some(Ok(message)));
377        }
378
379        // Then check WebSocket connection
380        if let Some(connection) = &mut self.connection {
381            match connection.poll_next_unpin(cx) {
382                Poll::Ready(Some(Ok(_message))) => {
383                    // Simplified message handling
384                    Poll::Ready(Some(Ok(StreamMessage::Heartbeat { timestamp: Utc::now() })))
385                }
386                Poll::Ready(Some(Err(e))) => {
387                    error!("WebSocket error: {}", e);
388                    self.stats.errors += 1;
389                    Poll::Ready(Some(Err(e.into())))
390                }
391                Poll::Ready(None) => {
392                    info!("WebSocket stream ended");
393                    Poll::Ready(None)
394                }
395                Poll::Pending => Poll::Pending,
396            }
397        } else {
398            Poll::Ready(None)
399        }
400    }
401}
402
403impl MarketStream for WebSocketStream {
404    fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
405        // This is for backward compatibility - use subscribe_async for new code
406        Ok(())
407    }
408
409    fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
410        // This is for backward compatibility - use unsubscribe_async for new code
411        Ok(())
412    }
413
414    fn is_connected(&self) -> bool {
415        self.connection.is_some()
416    }
417
418    fn get_stats(&self) -> StreamStats {
419        self.stats.clone()
420    }
421}
422
423/// Mock stream for testing
424#[derive(Debug)]
425pub struct MockStream {
426    messages: Vec<Result<StreamMessage>>,
427    index: usize,
428    connected: bool,
429}
430
431impl MockStream {
432    pub fn new() -> Self {
433        Self {
434            messages: Vec::new(),
435            index: 0,
436            connected: true,
437        }
438    }
439
440    pub fn add_message(&mut self, message: StreamMessage) {
441        self.messages.push(Ok(message));
442    }
443
444    pub fn add_error(&mut self, error: PolyfillError) {
445        self.messages.push(Err(error));
446    }
447
448    pub fn set_connected(&mut self, connected: bool) {
449        self.connected = connected;
450    }
451}
452
453impl Stream for MockStream {
454    type Item = Result<StreamMessage>;
455
456    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
457        if self.index >= self.messages.len() {
458            Poll::Ready(None)
459        } else {
460            let message = self.messages[self.index].clone();
461            self.index += 1;
462            Poll::Ready(Some(message))
463        }
464    }
465}
466
467impl MarketStream for MockStream {
468    fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
469        Ok(())
470    }
471
472    fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
473        Ok(())
474    }
475
476    fn is_connected(&self) -> bool {
477        self.connected
478    }
479
480    fn get_stats(&self) -> StreamStats {
481        StreamStats {
482            messages_received: self.messages.len() as u64,
483            messages_sent: 0,
484            errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
485            last_message_time: None,
486            connection_uptime: std::time::Duration::ZERO,
487            reconnect_count: 0,
488        }
489    }
490}
491
492/// Stream manager for handling multiple streams
493pub struct StreamManager {
494    streams: Vec<Box<dyn MarketStream>>,
495    message_tx: mpsc::UnboundedSender<StreamMessage>,
496    message_rx: mpsc::UnboundedReceiver<StreamMessage>,
497}
498
499impl StreamManager {
500    pub fn new() -> Self {
501        let (message_tx, message_rx) = mpsc::unbounded_channel();
502        
503        Self {
504            streams: Vec::new(),
505            message_tx,
506            message_rx,
507        }
508    }
509
510    pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
511        self.streams.push(stream);
512    }
513
514    pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
515        // Note: UnboundedReceiver doesn't implement Clone
516        // In a real implementation, you'd want to use a different approach
517        // For now, we'll return a dummy receiver
518        let (_, rx) = mpsc::unbounded_channel();
519        rx
520    }
521
522    pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
523        self.message_tx.send(message)
524            .map_err(|e| PolyfillError::internal("Failed to broadcast message", e))
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_mock_stream() {
534        let mut stream = MockStream::new();
535        
536        // Add some test messages
537        stream.add_message(StreamMessage::Heartbeat { timestamp: Utc::now() });
538        stream.add_message(StreamMessage::BookUpdate {
539            data: OrderDelta {
540                token_id: "test".to_string(),
541                timestamp: Utc::now(),
542                side: Side::BUY,
543                price: rust_decimal_macros::dec!(0.5),
544                size: rust_decimal_macros::dec!(100),
545                sequence: 1,
546            }
547        });
548        
549        assert!(stream.is_connected());
550        assert_eq!(stream.get_stats().messages_received, 2);
551    }
552
553    #[test]
554    fn test_stream_manager() {
555        let mut manager = StreamManager::new();
556        let mock_stream = Box::new(MockStream::new());
557        manager.add_stream(mock_stream);
558        
559        // Test message broadcasting
560        let message = StreamMessage::Heartbeat { timestamp: Utc::now() };
561        assert!(manager.broadcast_message(message).is_ok());
562    }
563}