1use crate::errors::{PolyfillError, Result};
7use crate::types::*;
8use futures::{Stream, SinkExt, StreamExt};
9use serde_json::Value;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use tokio::sync::mpsc;
13use tracing::{debug, error, info, warn};
14use chrono::Utc;
15
16pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
18 fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
20
21 fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
23
24 fn is_connected(&self) -> bool;
26
27 fn get_stats(&self) -> StreamStats;
29}
30
31#[derive(Debug)]
33pub struct WebSocketStream {
34 connection: Option<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>,
36 url: String,
38 auth: Option<WssAuth>,
40 subscriptions: Vec<WssSubscription>,
42 tx: mpsc::UnboundedSender<StreamMessage>,
44 rx: mpsc::UnboundedReceiver<StreamMessage>,
46 stats: StreamStats,
48 reconnect_config: ReconnectConfig,
50}
51
52#[derive(Debug, Clone)]
54pub struct StreamStats {
55 pub messages_received: u64,
56 pub messages_sent: u64,
57 pub errors: u64,
58 pub last_message_time: Option<chrono::DateTime<Utc>>,
59 pub connection_uptime: std::time::Duration,
60 pub reconnect_count: u32,
61}
62
63#[derive(Debug, Clone)]
65pub struct ReconnectConfig {
66 pub max_retries: u32,
67 pub base_delay: std::time::Duration,
68 pub max_delay: std::time::Duration,
69 pub backoff_multiplier: f64,
70}
71
72impl Default for ReconnectConfig {
73 fn default() -> Self {
74 Self {
75 max_retries: 5,
76 base_delay: std::time::Duration::from_secs(1),
77 max_delay: std::time::Duration::from_secs(60),
78 backoff_multiplier: 2.0,
79 }
80 }
81}
82
83impl WebSocketStream {
84 pub fn new(url: &str) -> Self {
86 let (tx, rx) = mpsc::unbounded_channel();
87
88 Self {
89 connection: None,
90 url: url.to_string(),
91 auth: None,
92 subscriptions: Vec::new(),
93 tx,
94 rx,
95 stats: StreamStats {
96 messages_received: 0,
97 messages_sent: 0,
98 errors: 0,
99 last_message_time: None,
100 connection_uptime: std::time::Duration::ZERO,
101 reconnect_count: 0,
102 },
103 reconnect_config: ReconnectConfig::default(),
104 }
105 }
106
107 pub fn with_auth(mut self, auth: WssAuth) -> Self {
109 self.auth = Some(auth);
110 self
111 }
112
113 async fn connect(&mut self) -> Result<()> {
115 let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url).await
116 .map_err(|e| PolyfillError::stream(format!("WebSocket connection failed: {}", e), crate::errors::StreamErrorKind::ConnectionFailed))?;
117
118 self.connection = Some(ws_stream);
119 info!("Connected to WebSocket stream at {}", self.url);
120 Ok(())
121 }
122
123 async fn send_message(&mut self, message: Value) -> Result<()> {
125 if let Some(connection) = &mut self.connection {
126 let text = serde_json::to_string(&message)
127 .map_err(|e| PolyfillError::parse(format!("Failed to serialize message: {}", e), None))?;
128
129 let ws_message = tokio_tungstenite::tungstenite::Message::Text(text);
130 connection.send(ws_message).await
131 .map_err(|e| PolyfillError::stream(format!("Failed to send message: {}", e), crate::errors::StreamErrorKind::MessageCorrupted))?;
132
133 self.stats.messages_sent += 1;
134 }
135
136 Ok(())
137 }
138
139 pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
141 if self.connection.is_none() {
143 self.connect().await?;
144 }
145
146 let message = serde_json::json!({
148 "auth": subscription.auth,
149 "markets": subscription.markets,
150 "asset_ids": subscription.asset_ids,
151 "type": subscription.channel_type,
152 });
153
154 self.send_message(message).await?;
155 self.subscriptions.push(subscription.clone());
156
157 info!("Subscribed to {} channel", subscription.channel_type);
158 Ok(())
159 }
160
161 pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
163 let auth = self.auth.as_ref()
164 .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
165 .clone();
166
167 let subscription = WssSubscription {
168 auth,
169 markets: Some(markets),
170 asset_ids: None,
171 channel_type: "USER".to_string(),
172 };
173
174 self.subscribe_async(subscription).await
175 }
176
177 pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
179 let auth = self.auth.as_ref()
180 .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
181 .clone();
182
183 let subscription = WssSubscription {
184 auth,
185 markets: None,
186 asset_ids: Some(asset_ids),
187 channel_type: "MARKET".to_string(),
188 };
189
190 self.subscribe_async(subscription).await
191 }
192
193 pub async fn unsubscribe_async(&mut self, token_ids: &[String]) -> Result<()> {
195 self.subscriptions.retain(|sub| {
198 match sub.channel_type.as_str() {
199 "USER" => {
200 if let Some(markets) = &sub.markets {
201 !token_ids.iter().any(|id| markets.contains(id))
202 } else {
203 true
204 }
205 }
206 "MARKET" => {
207 if let Some(asset_ids) = &sub.asset_ids {
208 !token_ids.iter().any(|id| asset_ids.contains(id))
209 } else {
210 true
211 }
212 }
213 _ => true
214 }
215 });
216
217 info!("Unsubscribed from {} tokens", token_ids.len());
218 Ok(())
219 }
220
221 async fn handle_message(&mut self, message: tokio_tungstenite::tungstenite::Message) -> Result<()> {
223 match message {
224 tokio_tungstenite::tungstenite::Message::Text(text) => {
225 debug!("Received WebSocket message: {}", text);
226
227 let stream_message = self.parse_polymarket_message(&text)?;
229
230 if let Err(e) = self.tx.send(stream_message) {
232 error!("Failed to send message to internal channel: {}", e);
233 }
234
235 self.stats.messages_received += 1;
236 self.stats.last_message_time = Some(Utc::now());
237 }
238 tokio_tungstenite::tungstenite::Message::Close(_) => {
239 info!("WebSocket connection closed by server");
240 self.connection = None;
241 }
242 tokio_tungstenite::tungstenite::Message::Ping(data) => {
243 if let Some(connection) = &mut self.connection {
245 let pong = tokio_tungstenite::tungstenite::Message::Pong(data);
246 if let Err(e) = connection.send(pong).await {
247 error!("Failed to send pong: {}", e);
248 }
249 }
250 }
251 tokio_tungstenite::tungstenite::Message::Pong(_) => {
252 debug!("Received pong");
254 }
255 tokio_tungstenite::tungstenite::Message::Binary(_) => {
256 warn!("Received binary message (not supported)");
257 }
258 tokio_tungstenite::tungstenite::Message::Frame(_) => {
259 warn!("Received raw frame (not supported)");
260 }
261 }
262
263 Ok(())
264 }
265
266 fn parse_polymarket_message(&self, text: &str) -> Result<StreamMessage> {
268 let value: Value = serde_json::from_str(text)
269 .map_err(|e| PolyfillError::parse(format!("Failed to parse WebSocket message: {}", e), Some(Box::new(e))))?;
270
271 let message_type = value.get("type")
273 .and_then(|v| v.as_str())
274 .ok_or_else(|| PolyfillError::parse("Missing 'type' field in WebSocket message", None))?;
275
276 match message_type {
277 "book_update" => {
278 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
279 .map_err(|e| PolyfillError::parse(format!("Failed to parse book update: {}", e), Some(Box::new(e))))?;
280 Ok(StreamMessage::BookUpdate { data })
281 }
282 "trade" => {
283 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
284 .map_err(|e| PolyfillError::parse(format!("Failed to parse trade: {}", e), Some(Box::new(e))))?;
285 Ok(StreamMessage::Trade { data })
286 }
287 "order_update" => {
288 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
289 .map_err(|e| PolyfillError::parse(format!("Failed to parse order update: {}", e), Some(Box::new(e))))?;
290 Ok(StreamMessage::OrderUpdate { data })
291 }
292 "user_order_update" => {
293 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
294 .map_err(|e| PolyfillError::parse(format!("Failed to parse user order update: {}", e), Some(Box::new(e))))?;
295 Ok(StreamMessage::UserOrderUpdate { data })
296 }
297 "user_trade" => {
298 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
299 .map_err(|e| PolyfillError::parse(format!("Failed to parse user trade: {}", e), Some(Box::new(e))))?;
300 Ok(StreamMessage::UserTrade { data })
301 }
302 "market_book_update" => {
303 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
304 .map_err(|e| PolyfillError::parse(format!("Failed to parse market book update: {}", e), Some(Box::new(e))))?;
305 Ok(StreamMessage::MarketBookUpdate { data })
306 }
307 "market_trade" => {
308 let data = serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
309 .map_err(|e| PolyfillError::parse(format!("Failed to parse market trade: {}", e), Some(Box::new(e))))?;
310 Ok(StreamMessage::MarketTrade { data })
311 }
312 "heartbeat" => {
313 let timestamp = value.get("timestamp")
314 .and_then(|v| v.as_u64())
315 .map(|ts| chrono::DateTime::from_timestamp(ts as i64, 0).unwrap_or_default())
316 .unwrap_or_else(Utc::now);
317 Ok(StreamMessage::Heartbeat { timestamp })
318 }
319 _ => {
320 warn!("Unknown message type: {}", message_type);
321 Ok(StreamMessage::Heartbeat { timestamp: Utc::now() })
323 }
324 }
325 }
326
327 async fn reconnect(&mut self) -> Result<()> {
329 let mut delay = self.reconnect_config.base_delay;
330 let mut retries = 0;
331
332 while retries < self.reconnect_config.max_retries {
333 warn!("Attempting to reconnect (attempt {})", retries + 1);
334
335 match self.connect().await {
336 Ok(()) => {
337 info!("Successfully reconnected");
338 self.stats.reconnect_count += 1;
339
340 let subscriptions = self.subscriptions.clone();
342 for subscription in subscriptions {
343 self.send_message(serde_json::to_value(subscription)?).await?;
344 }
345
346 return Ok(());
347 }
348 Err(e) => {
349 error!("Reconnection attempt {} failed: {}", retries + 1, e);
350 retries += 1;
351
352 if retries < self.reconnect_config.max_retries {
353 tokio::time::sleep(delay).await;
354 delay = std::cmp::min(
355 delay.mul_f64(self.reconnect_config.backoff_multiplier),
356 self.reconnect_config.max_delay
357 );
358 }
359 }
360 }
361 }
362
363 Err(PolyfillError::stream(
364 format!("Failed to reconnect after {} attempts", self.reconnect_config.max_retries),
365 crate::errors::StreamErrorKind::ConnectionFailed
366 ))
367 }
368}
369
370impl Stream for WebSocketStream {
371 type Item = Result<StreamMessage>;
372
373 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
374 if let Poll::Ready(Some(message)) = self.rx.poll_recv(cx) {
376 return Poll::Ready(Some(Ok(message)));
377 }
378
379 if let Some(connection) = &mut self.connection {
381 match connection.poll_next_unpin(cx) {
382 Poll::Ready(Some(Ok(_message))) => {
383 Poll::Ready(Some(Ok(StreamMessage::Heartbeat { timestamp: Utc::now() })))
385 }
386 Poll::Ready(Some(Err(e))) => {
387 error!("WebSocket error: {}", e);
388 self.stats.errors += 1;
389 Poll::Ready(Some(Err(e.into())))
390 }
391 Poll::Ready(None) => {
392 info!("WebSocket stream ended");
393 Poll::Ready(None)
394 }
395 Poll::Pending => Poll::Pending,
396 }
397 } else {
398 Poll::Ready(None)
399 }
400 }
401}
402
403impl MarketStream for WebSocketStream {
404 fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
405 Ok(())
407 }
408
409 fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
410 Ok(())
412 }
413
414 fn is_connected(&self) -> bool {
415 self.connection.is_some()
416 }
417
418 fn get_stats(&self) -> StreamStats {
419 self.stats.clone()
420 }
421}
422
423#[derive(Debug)]
425pub struct MockStream {
426 messages: Vec<Result<StreamMessage>>,
427 index: usize,
428 connected: bool,
429}
430
431impl MockStream {
432 pub fn new() -> Self {
433 Self {
434 messages: Vec::new(),
435 index: 0,
436 connected: true,
437 }
438 }
439
440 pub fn add_message(&mut self, message: StreamMessage) {
441 self.messages.push(Ok(message));
442 }
443
444 pub fn add_error(&mut self, error: PolyfillError) {
445 self.messages.push(Err(error));
446 }
447
448 pub fn set_connected(&mut self, connected: bool) {
449 self.connected = connected;
450 }
451}
452
453impl Stream for MockStream {
454 type Item = Result<StreamMessage>;
455
456 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
457 if self.index >= self.messages.len() {
458 Poll::Ready(None)
459 } else {
460 let message = self.messages[self.index].clone();
461 self.index += 1;
462 Poll::Ready(Some(message))
463 }
464 }
465}
466
467impl MarketStream for MockStream {
468 fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
469 Ok(())
470 }
471
472 fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
473 Ok(())
474 }
475
476 fn is_connected(&self) -> bool {
477 self.connected
478 }
479
480 fn get_stats(&self) -> StreamStats {
481 StreamStats {
482 messages_received: self.messages.len() as u64,
483 messages_sent: 0,
484 errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
485 last_message_time: None,
486 connection_uptime: std::time::Duration::ZERO,
487 reconnect_count: 0,
488 }
489 }
490}
491
492pub struct StreamManager {
494 streams: Vec<Box<dyn MarketStream>>,
495 message_tx: mpsc::UnboundedSender<StreamMessage>,
496 message_rx: mpsc::UnboundedReceiver<StreamMessage>,
497}
498
499impl StreamManager {
500 pub fn new() -> Self {
501 let (message_tx, message_rx) = mpsc::unbounded_channel();
502
503 Self {
504 streams: Vec::new(),
505 message_tx,
506 message_rx,
507 }
508 }
509
510 pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
511 self.streams.push(stream);
512 }
513
514 pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
515 let (_, rx) = mpsc::unbounded_channel();
519 rx
520 }
521
522 pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
523 self.message_tx.send(message)
524 .map_err(|e| PolyfillError::internal("Failed to broadcast message", e))
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_mock_stream() {
534 let mut stream = MockStream::new();
535
536 stream.add_message(StreamMessage::Heartbeat { timestamp: Utc::now() });
538 stream.add_message(StreamMessage::BookUpdate {
539 data: OrderDelta {
540 token_id: "test".to_string(),
541 timestamp: Utc::now(),
542 side: Side::BUY,
543 price: rust_decimal_macros::dec!(0.5),
544 size: rust_decimal_macros::dec!(100),
545 sequence: 1,
546 }
547 });
548
549 assert!(stream.is_connected());
550 assert_eq!(stream.get_stats().messages_received, 2);
551 }
552
553 #[test]
554 fn test_stream_manager() {
555 let mut manager = StreamManager::new();
556 let mock_stream = Box::new(MockStream::new());
557 manager.add_stream(mock_stream);
558
559 let message = StreamMessage::Heartbeat { timestamp: Utc::now() };
561 assert!(manager.broadcast_message(message).is_ok());
562 }
563}