1use crate::errors::{PolyfillError, Result};
7use crate::types::*;
8use chrono::Utc;
9use futures::{SinkExt, Stream, StreamExt};
10use serde_json::Value;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tokio::sync::mpsc;
14use tracing::{debug, error, info, warn};
15
16pub trait MarketStream: Stream<Item = Result<StreamMessage>> + Send + Sync {
18 fn subscribe(&mut self, subscription: Subscription) -> Result<()>;
20
21 fn unsubscribe(&mut self, token_ids: &[String]) -> Result<()>;
23
24 fn is_connected(&self) -> bool;
26
27 fn get_stats(&self) -> StreamStats;
29}
30
31#[derive(Debug)]
33#[allow(dead_code)]
34pub struct WebSocketStream {
35 connection: Option<
37 tokio_tungstenite::WebSocketStream<
38 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
39 >,
40 >,
41 url: String,
43 auth: Option<WssAuth>,
45 subscriptions: Vec<WssSubscription>,
47 tx: mpsc::UnboundedSender<StreamMessage>,
49 rx: mpsc::UnboundedReceiver<StreamMessage>,
51 stats: StreamStats,
53 reconnect_config: ReconnectConfig,
55}
56
57#[derive(Debug, Clone)]
59pub struct StreamStats {
60 pub messages_received: u64,
61 pub messages_sent: u64,
62 pub errors: u64,
63 pub last_message_time: Option<chrono::DateTime<Utc>>,
64 pub connection_uptime: std::time::Duration,
65 pub reconnect_count: u32,
66}
67
68#[derive(Debug, Clone)]
70pub struct ReconnectConfig {
71 pub max_retries: u32,
72 pub base_delay: std::time::Duration,
73 pub max_delay: std::time::Duration,
74 pub backoff_multiplier: f64,
75}
76
77impl Default for ReconnectConfig {
78 fn default() -> Self {
79 Self {
80 max_retries: 5,
81 base_delay: std::time::Duration::from_secs(1),
82 max_delay: std::time::Duration::from_secs(60),
83 backoff_multiplier: 2.0,
84 }
85 }
86}
87
88impl WebSocketStream {
89 pub fn new(url: &str) -> Self {
91 let (tx, rx) = mpsc::unbounded_channel();
92
93 Self {
94 connection: None,
95 url: url.to_string(),
96 auth: None,
97 subscriptions: Vec::new(),
98 tx,
99 rx,
100 stats: StreamStats {
101 messages_received: 0,
102 messages_sent: 0,
103 errors: 0,
104 last_message_time: None,
105 connection_uptime: std::time::Duration::ZERO,
106 reconnect_count: 0,
107 },
108 reconnect_config: ReconnectConfig::default(),
109 }
110 }
111
112 pub fn with_auth(mut self, auth: WssAuth) -> Self {
114 self.auth = Some(auth);
115 self
116 }
117
118 async fn connect(&mut self) -> Result<()> {
120 let (ws_stream, _) = tokio_tungstenite::connect_async(&self.url)
121 .await
122 .map_err(|e| {
123 PolyfillError::stream(
124 format!("WebSocket connection failed: {}", e),
125 crate::errors::StreamErrorKind::ConnectionFailed,
126 )
127 })?;
128
129 self.connection = Some(ws_stream);
130 info!("Connected to WebSocket stream at {}", self.url);
131 Ok(())
132 }
133
134 async fn send_message(&mut self, message: Value) -> Result<()> {
136 if let Some(connection) = &mut self.connection {
137 let text = serde_json::to_string(&message).map_err(|e| {
138 PolyfillError::parse(format!("Failed to serialize message: {}", e), None)
139 })?;
140
141 let ws_message = tokio_tungstenite::tungstenite::Message::Text(text);
142 connection.send(ws_message).await.map_err(|e| {
143 PolyfillError::stream(
144 format!("Failed to send message: {}", e),
145 crate::errors::StreamErrorKind::MessageCorrupted,
146 )
147 })?;
148
149 self.stats.messages_sent += 1;
150 }
151
152 Ok(())
153 }
154
155 pub async fn subscribe_async(&mut self, subscription: WssSubscription) -> Result<()> {
157 if self.connection.is_none() {
159 self.connect().await?;
160 }
161
162 let message = serde_json::json!({
164 "auth": subscription.auth,
165 "markets": subscription.markets,
166 "asset_ids": subscription.asset_ids,
167 "type": subscription.channel_type,
168 });
169
170 self.send_message(message).await?;
171 self.subscriptions.push(subscription.clone());
172
173 info!("Subscribed to {} channel", subscription.channel_type);
174 Ok(())
175 }
176
177 pub async fn subscribe_user_channel(&mut self, markets: Vec<String>) -> Result<()> {
179 let auth = self
180 .auth
181 .as_ref()
182 .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
183 .clone();
184
185 let subscription = WssSubscription {
186 auth,
187 markets: Some(markets),
188 asset_ids: None,
189 channel_type: "USER".to_string(),
190 };
191
192 self.subscribe_async(subscription).await
193 }
194
195 pub async fn subscribe_market_channel(&mut self, asset_ids: Vec<String>) -> Result<()> {
197 let auth = self
198 .auth
199 .as_ref()
200 .ok_or_else(|| PolyfillError::auth("No authentication provided for WebSocket"))?
201 .clone();
202
203 let subscription = WssSubscription {
204 auth,
205 markets: None,
206 asset_ids: Some(asset_ids),
207 channel_type: "MARKET".to_string(),
208 };
209
210 self.subscribe_async(subscription).await
211 }
212
213 pub async fn unsubscribe_async(&mut self, token_ids: &[String]) -> Result<()> {
215 self.subscriptions
218 .retain(|sub| match sub.channel_type.as_str() {
219 "USER" => {
220 if let Some(markets) = &sub.markets {
221 !token_ids.iter().any(|id| markets.contains(id))
222 } else {
223 true
224 }
225 },
226 "MARKET" => {
227 if let Some(asset_ids) = &sub.asset_ids {
228 !token_ids.iter().any(|id| asset_ids.contains(id))
229 } else {
230 true
231 }
232 },
233 _ => true,
234 });
235
236 info!("Unsubscribed from {} tokens", token_ids.len());
237 Ok(())
238 }
239
240 #[allow(dead_code)]
242 async fn handle_message(
243 &mut self,
244 message: tokio_tungstenite::tungstenite::Message,
245 ) -> Result<()> {
246 match message {
247 tokio_tungstenite::tungstenite::Message::Text(text) => {
248 debug!("Received WebSocket message: {}", text);
249
250 let stream_message = self.parse_polymarket_message(&text)?;
252
253 if let Err(e) = self.tx.send(stream_message) {
255 error!("Failed to send message to internal channel: {}", e);
256 }
257
258 self.stats.messages_received += 1;
259 self.stats.last_message_time = Some(Utc::now());
260 },
261 tokio_tungstenite::tungstenite::Message::Close(_) => {
262 info!("WebSocket connection closed by server");
263 self.connection = None;
264 },
265 tokio_tungstenite::tungstenite::Message::Ping(data) => {
266 if let Some(connection) = &mut self.connection {
268 let pong = tokio_tungstenite::tungstenite::Message::Pong(data);
269 if let Err(e) = connection.send(pong).await {
270 error!("Failed to send pong: {}", e);
271 }
272 }
273 },
274 tokio_tungstenite::tungstenite::Message::Pong(_) => {
275 debug!("Received pong");
277 },
278 tokio_tungstenite::tungstenite::Message::Binary(_) => {
279 warn!("Received binary message (not supported)");
280 },
281 tokio_tungstenite::tungstenite::Message::Frame(_) => {
282 warn!("Received raw frame (not supported)");
283 },
284 }
285
286 Ok(())
287 }
288
289 #[allow(dead_code)]
291 fn parse_polymarket_message(&self, text: &str) -> Result<StreamMessage> {
292 let value: Value = serde_json::from_str(text).map_err(|e| {
293 PolyfillError::parse(
294 format!("Failed to parse WebSocket message: {}", e),
295 Some(Box::new(e)),
296 )
297 })?;
298
299 let message_type = value.get("type").and_then(|v| v.as_str()).ok_or_else(|| {
301 PolyfillError::parse("Missing 'type' field in WebSocket message", None)
302 })?;
303
304 match message_type {
305 "book_update" => {
306 let data =
307 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
308 .map_err(|e| {
309 PolyfillError::parse(
310 format!("Failed to parse book update: {}", e),
311 Some(Box::new(e)),
312 )
313 })?;
314 Ok(StreamMessage::BookUpdate { data })
315 },
316 "trade" => {
317 let data =
318 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
319 .map_err(|e| {
320 PolyfillError::parse(
321 format!("Failed to parse trade: {}", e),
322 Some(Box::new(e)),
323 )
324 })?;
325 Ok(StreamMessage::Trade { data })
326 },
327 "order_update" => {
328 let data =
329 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
330 .map_err(|e| {
331 PolyfillError::parse(
332 format!("Failed to parse order update: {}", e),
333 Some(Box::new(e)),
334 )
335 })?;
336 Ok(StreamMessage::OrderUpdate { data })
337 },
338 "user_order_update" => {
339 let data =
340 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
341 .map_err(|e| {
342 PolyfillError::parse(
343 format!("Failed to parse user order update: {}", e),
344 Some(Box::new(e)),
345 )
346 })?;
347 Ok(StreamMessage::UserOrderUpdate { data })
348 },
349 "user_trade" => {
350 let data =
351 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
352 .map_err(|e| {
353 PolyfillError::parse(
354 format!("Failed to parse user trade: {}", e),
355 Some(Box::new(e)),
356 )
357 })?;
358 Ok(StreamMessage::UserTrade { data })
359 },
360 "market_book_update" => {
361 let data =
362 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
363 .map_err(|e| {
364 PolyfillError::parse(
365 format!("Failed to parse market book update: {}", e),
366 Some(Box::new(e)),
367 )
368 })?;
369 Ok(StreamMessage::MarketBookUpdate { data })
370 },
371 "market_trade" => {
372 let data =
373 serde_json::from_value(value.get("data").unwrap_or(&Value::Null).clone())
374 .map_err(|e| {
375 PolyfillError::parse(
376 format!("Failed to parse market trade: {}", e),
377 Some(Box::new(e)),
378 )
379 })?;
380 Ok(StreamMessage::MarketTrade { data })
381 },
382 "heartbeat" => {
383 let timestamp = value
384 .get("timestamp")
385 .and_then(|v| v.as_u64())
386 .map(|ts| chrono::DateTime::from_timestamp(ts as i64, 0).unwrap_or_default())
387 .unwrap_or_else(Utc::now);
388 Ok(StreamMessage::Heartbeat { timestamp })
389 },
390 _ => {
391 warn!("Unknown message type: {}", message_type);
392 Ok(StreamMessage::Heartbeat {
394 timestamp: Utc::now(),
395 })
396 },
397 }
398 }
399
400 #[allow(dead_code)]
402 async fn reconnect(&mut self) -> Result<()> {
403 let mut delay = self.reconnect_config.base_delay;
404 let mut retries = 0;
405
406 while retries < self.reconnect_config.max_retries {
407 warn!("Attempting to reconnect (attempt {})", retries + 1);
408
409 match self.connect().await {
410 Ok(()) => {
411 info!("Successfully reconnected");
412 self.stats.reconnect_count += 1;
413
414 let subscriptions = self.subscriptions.clone();
416 for subscription in subscriptions {
417 self.send_message(serde_json::to_value(subscription)?)
418 .await?;
419 }
420
421 return Ok(());
422 },
423 Err(e) => {
424 error!("Reconnection attempt {} failed: {}", retries + 1, e);
425 retries += 1;
426
427 if retries < self.reconnect_config.max_retries {
428 tokio::time::sleep(delay).await;
429 delay = std::cmp::min(
430 delay.mul_f64(self.reconnect_config.backoff_multiplier),
431 self.reconnect_config.max_delay,
432 );
433 }
434 },
435 }
436 }
437
438 Err(PolyfillError::stream(
439 format!(
440 "Failed to reconnect after {} attempts",
441 self.reconnect_config.max_retries
442 ),
443 crate::errors::StreamErrorKind::ConnectionFailed,
444 ))
445 }
446}
447
448impl Stream for WebSocketStream {
449 type Item = Result<StreamMessage>;
450
451 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
452 if let Poll::Ready(Some(message)) = self.rx.poll_recv(cx) {
454 return Poll::Ready(Some(Ok(message)));
455 }
456
457 if let Some(connection) = &mut self.connection {
459 match connection.poll_next_unpin(cx) {
460 Poll::Ready(Some(Ok(_message))) => {
461 Poll::Ready(Some(Ok(StreamMessage::Heartbeat {
463 timestamp: Utc::now(),
464 })))
465 },
466 Poll::Ready(Some(Err(e))) => {
467 error!("WebSocket error: {}", e);
468 self.stats.errors += 1;
469 Poll::Ready(Some(Err(e.into())))
470 },
471 Poll::Ready(None) => {
472 info!("WebSocket stream ended");
473 Poll::Ready(None)
474 },
475 Poll::Pending => Poll::Pending,
476 }
477 } else {
478 Poll::Ready(None)
479 }
480 }
481}
482
483impl MarketStream for WebSocketStream {
484 fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
485 Ok(())
487 }
488
489 fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
490 Ok(())
492 }
493
494 fn is_connected(&self) -> bool {
495 self.connection.is_some()
496 }
497
498 fn get_stats(&self) -> StreamStats {
499 self.stats.clone()
500 }
501}
502
503#[derive(Debug)]
505pub struct MockStream {
506 messages: Vec<Result<StreamMessage>>,
507 index: usize,
508 connected: bool,
509}
510
511impl Default for MockStream {
512 fn default() -> Self {
513 Self::new()
514 }
515}
516
517impl MockStream {
518 pub fn new() -> Self {
519 Self {
520 messages: Vec::new(),
521 index: 0,
522 connected: true,
523 }
524 }
525
526 pub fn add_message(&mut self, message: StreamMessage) {
527 self.messages.push(Ok(message));
528 }
529
530 pub fn add_error(&mut self, error: PolyfillError) {
531 self.messages.push(Err(error));
532 }
533
534 pub fn set_connected(&mut self, connected: bool) {
535 self.connected = connected;
536 }
537}
538
539impl Stream for MockStream {
540 type Item = Result<StreamMessage>;
541
542 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
543 if self.index >= self.messages.len() {
544 Poll::Ready(None)
545 } else {
546 let message = self.messages[self.index].clone();
547 self.index += 1;
548 Poll::Ready(Some(message))
549 }
550 }
551}
552
553impl MarketStream for MockStream {
554 fn subscribe(&mut self, _subscription: Subscription) -> Result<()> {
555 Ok(())
556 }
557
558 fn unsubscribe(&mut self, _token_ids: &[String]) -> Result<()> {
559 Ok(())
560 }
561
562 fn is_connected(&self) -> bool {
563 self.connected
564 }
565
566 fn get_stats(&self) -> StreamStats {
567 StreamStats {
568 messages_received: self.messages.len() as u64,
569 messages_sent: 0,
570 errors: self.messages.iter().filter(|m| m.is_err()).count() as u64,
571 last_message_time: None,
572 connection_uptime: std::time::Duration::ZERO,
573 reconnect_count: 0,
574 }
575 }
576}
577
578#[allow(dead_code)]
580pub struct StreamManager {
581 streams: Vec<Box<dyn MarketStream>>,
582 message_tx: mpsc::UnboundedSender<StreamMessage>,
583 message_rx: mpsc::UnboundedReceiver<StreamMessage>,
584}
585
586impl Default for StreamManager {
587 fn default() -> Self {
588 Self::new()
589 }
590}
591
592impl StreamManager {
593 pub fn new() -> Self {
594 let (message_tx, message_rx) = mpsc::unbounded_channel();
595
596 Self {
597 streams: Vec::new(),
598 message_tx,
599 message_rx,
600 }
601 }
602
603 pub fn add_stream(&mut self, stream: Box<dyn MarketStream>) {
604 self.streams.push(stream);
605 }
606
607 pub fn get_message_receiver(&mut self) -> mpsc::UnboundedReceiver<StreamMessage> {
608 let (_, rx) = mpsc::unbounded_channel();
612 rx
613 }
614
615 pub fn broadcast_message(&self, message: StreamMessage) -> Result<()> {
616 self.message_tx
617 .send(message)
618 .map_err(|e| PolyfillError::internal("Failed to broadcast message", e))
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 #[test]
627 fn test_mock_stream() {
628 let mut stream = MockStream::new();
629
630 stream.add_message(StreamMessage::Heartbeat {
632 timestamp: Utc::now(),
633 });
634 stream.add_message(StreamMessage::BookUpdate {
635 data: OrderDelta {
636 token_id: "test".to_string(),
637 timestamp: Utc::now(),
638 side: Side::BUY,
639 price: rust_decimal_macros::dec!(0.5),
640 size: rust_decimal_macros::dec!(100),
641 sequence: 1,
642 },
643 });
644
645 assert!(stream.is_connected());
646 assert_eq!(stream.get_stats().messages_received, 2);
647 }
648
649 #[test]
650 fn test_stream_manager() {
651 let mut manager = StreamManager::new();
652 let mock_stream = Box::new(MockStream::new());
653 manager.add_stream(mock_stream);
654
655 let message = StreamMessage::Heartbeat {
657 timestamp: Utc::now(),
658 };
659 assert!(manager.broadcast_message(message).is_ok());
660 }
661}