polynode 0.13.10

Rust SDK for the PolyNode API — real-time Polymarket data
Documentation
//! WebSocket stream — connects, reads events, handles reconnection.

use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;

use super::codec::decode_frame;
use super::subscription::Subscription;
use crate::error::{Error, Result};
use crate::types::events::{PolyNodeEvent, PriceFeedEvent};
use crate::types::ws_messages::{RawWsMessage, WsMessage};

/// Options for the WebSocket stream.
#[derive(Debug, Clone)]
pub struct StreamOptions {
    pub compress: bool,
    pub auto_reconnect: bool,
    pub max_reconnect_attempts: Option<u32>,
    pub initial_backoff: Duration,
    pub max_backoff: Duration,
}

impl Default for StreamOptions {
    fn default() -> Self {
        Self {
            compress: true,
            auto_reconnect: true,
            max_reconnect_attempts: None,
            initial_backoff: Duration::from_secs(1),
            max_backoff: Duration::from_secs(30),
        }
    }
}

enum Command {
    Subscribe(Subscription),
    Unsubscribe(Option<String>),
    Close,
}

/// A WebSocket event stream from PolyNode.
pub struct WsStream {
    rx: mpsc::Receiver<Result<WsMessage>>,
    cmd_tx: mpsc::Sender<Command>,
    _handle: tokio::task::JoinHandle<()>,
}

impl WsStream {
    pub(crate) async fn connect(
        api_key: &str,
        ws_url: &str,
        options: StreamOptions,
    ) -> Result<Self> {
        let mut url = format!("{}?key={}", ws_url, api_key);
        if options.compress {
            url.push_str("&compress=zlib");
        }

        let (msg_tx, msg_rx) = mpsc::channel(1024);
        let (cmd_tx, cmd_rx) = mpsc::channel(64);

        let handle = tokio::spawn(ws_task(url, options, msg_tx, cmd_rx));

        Ok(Self {
            rx: msg_rx,
            cmd_tx,
            _handle: handle,
        })
    }

    /// Receive the next message. Returns None when the stream is closed.
    pub async fn next(&mut self) -> Option<Result<WsMessage>> {
        self.rx.recv().await
    }

    /// Subscribe to events.
    pub async fn subscribe(&self, sub: Subscription) -> Result<()> {
        self.cmd_tx
            .send(Command::Subscribe(sub))
            .await
            .map_err(|_| Error::Disconnected)
    }

    /// Unsubscribe from a specific subscription or all.
    pub async fn unsubscribe(&self, subscription_id: Option<String>) -> Result<()> {
        self.cmd_tx
            .send(Command::Unsubscribe(subscription_id))
            .await
            .map_err(|_| Error::Disconnected)
    }

    /// Close the connection.
    pub async fn close(self) -> Result<()> {
        let _ = self.cmd_tx.send(Command::Close).await;
        Ok(())
    }
}

async fn ws_task(
    url: String,
    options: StreamOptions,
    msg_tx: mpsc::Sender<Result<WsMessage>>,
    mut cmd_rx: mpsc::Receiver<Command>,
) {
    let mut active_subs: Vec<Subscription> = Vec::new();
    let mut reconnect_attempts: u32 = 0;

    'outer: loop {
        // Connect
        let ws_stream = match tokio_tungstenite::connect_async(&url).await {
            Ok((stream, _)) => {
                reconnect_attempts = 0;
                stream
            }
            Err(e) => {
                let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
                if !should_reconnect(&options, reconnect_attempts) {
                    break;
                }
                let delay = backoff_delay(&options, reconnect_attempts);
                reconnect_attempts += 1;
                tokio::time::sleep(delay).await;
                continue;
            }
        };

        let (mut write, mut read) = ws_stream.split();

        // Re-subscribe all active subscriptions after reconnect
        for sub in &active_subs {
            let msg_text = serde_json::to_string(&sub.to_message()).unwrap();
            if write.send(Message::Text(msg_text.into())).await.is_err() {
                continue 'outer;
            }
        }

        loop {
            tokio::select! {
                frame = read.next() => {
                    match frame {
                        Some(Ok(msg)) => {
                            match decode_frame(msg) {
                                Ok(Some(text)) => {
                                    if let Some(ws_msg) = parse_message(&text) {
                                        if msg_tx.send(Ok(ws_msg)).await.is_err() {
                                            break 'outer;
                                        }
                                    }
                                }
                                Ok(None) => {} // ping/pong
                                Err(Error::ConnectionClosed) => break,
                                Err(e) => {
                                    let _ = msg_tx.send(Err(e)).await;
                                }
                            }
                        }
                        Some(Err(e)) => {
                            let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
                            break;
                        }
                        None => break, // stream ended
                    }
                }
                cmd = cmd_rx.recv() => {
                    match cmd {
                        Some(Command::Subscribe(sub)) => {
                            let msg_text = serde_json::to_string(&sub.to_message()).unwrap();
                            active_subs.push(sub);
                            if write.send(Message::Text(msg_text.into())).await.is_err() {
                                break;
                            }
                        }
                        Some(Command::Unsubscribe(id)) => {
                            let msg = if let Some(ref sid) = id {
                                serde_json::json!({"action": "unsubscribe", "subscription_id": sid})
                            } else {
                                active_subs.clear();
                                serde_json::json!({"action": "unsubscribe"})
                            };
                            let msg_text = serde_json::to_string(&msg).unwrap();
                            if write.send(Message::Text(msg_text.into())).await.is_err() {
                                break;
                            }
                        }
                        Some(Command::Close) | None => {
                            let _ = write.send(Message::Close(None)).await;
                            break 'outer;
                        }
                    }
                }
            }
        }

        // Reconnect logic
        if !should_reconnect(&options, reconnect_attempts) {
            break;
        }
        let delay = backoff_delay(&options, reconnect_attempts);
        reconnect_attempts += 1;
        tokio::time::sleep(delay).await;
    }
}

fn should_reconnect(options: &StreamOptions, attempts: u32) -> bool {
    if !options.auto_reconnect {
        return false;
    }
    match options.max_reconnect_attempts {
        Some(max) => attempts < max,
        None => true,
    }
}

fn backoff_delay(options: &StreamOptions, attempts: u32) -> Duration {
    let base = options.initial_backoff.as_millis() as u64;
    let max = options.max_backoff.as_millis() as u64;
    let delay = std::cmp::min(base * 2u64.pow(attempts), max);
    // Add jitter: 50-100% of delay
    let jitter = delay / 2 + (rand_simple() % (delay / 2 + 1));
    Duration::from_millis(jitter)
}

fn rand_simple() -> u64 {
    use std::time::SystemTime;
    SystemTime::now()
        .duration_since(SystemTime::UNIX_EPOCH)
        .unwrap_or_default()
        .subsec_nanos() as u64
}

fn parse_message(text: &str) -> Option<WsMessage> {
    let raw: RawWsMessage = serde_json::from_str(text).ok()?;

    match raw.msg_type.as_str() {
        "subscribed" => Some(WsMessage::Subscribed {
            subscription_id: raw.subscription_id.unwrap_or_default(),
            subscription_type: raw.subscription_type.unwrap_or_default(),
            warnings: raw.warnings.unwrap_or_default(),
        }),
        "unsubscribed" => Some(WsMessage::Unsubscribed {
            subscriber_id: raw.subscriber_id.unwrap_or_default(),
        }),
        "heartbeat" => Some(WsMessage::Heartbeat {
            ts: raw.ts.unwrap_or(0),
        }),
        "pong" => None, // silently ignore
        "error" => Some(WsMessage::Error {
            code: raw.code,
            message: raw.message.unwrap_or_default(),
        }),
        "snapshot" => Some(WsMessage::Snapshot(raw.events.unwrap_or_default())),
        "price_feed" => {
            // Price feed has data field with the event
            if let Some(data) = raw.data {
                if let Ok(pf) = serde_json::from_value::<PriceFeedEvent>(data) {
                    return Some(WsMessage::PriceFeed(pf));
                }
            }
            None
        }
        _ => {
            // Event message: { type: "settlement", timestamp: ..., data: {...} }
            if let Some(data) = raw.data {
                if let Ok(event) = serde_json::from_value::<PolyNodeEvent>(data) {
                    return Some(WsMessage::Event(event));
                }
            }
            None
        }
    }
}