polynode 0.7.1

Rust SDK for the PolyNode API — real-time Polymarket data
Documentation
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;

use crate::error::{Error, Result};
use crate::types::orderbook::{ObMessage, OrderbookUpdate, RawObMessage};
use crate::ws::codec::decode_frame;

/// Options for the orderbook stream.
#[derive(Debug, Clone)]
pub struct ObStreamOptions {
    pub compress: bool,
    pub auto_reconnect: bool,
    pub max_reconnect_attempts: Option<u32>,
    pub initial_backoff: Duration,
    pub max_backoff: Duration,
}

impl Default for ObStreamOptions {
    fn default() -> Self {
        Self {
            compress: true,
            auto_reconnect: true,
            max_reconnect_attempts: None,
            initial_backoff: Duration::from_secs(1),
            max_backoff: Duration::from_secs(30),
        }
    }
}

enum Command {
    Subscribe(Vec<String>),
    Unsubscribe,
    Close,
}

/// A real-time orderbook stream from ob.polynode.dev.
pub struct ObStream {
    rx: mpsc::Receiver<Result<ObMessage>>,
    cmd_tx: mpsc::Sender<Command>,
    _handle: tokio::task::JoinHandle<()>,
}

impl ObStream {
    pub(crate) async fn connect(
        api_key: &str,
        ob_url: &str,
        options: ObStreamOptions,
    ) -> Result<Self> {
        let mut url = format!("{}?key={}", ob_url, api_key);
        if options.compress {
            url.push_str("&compress=zlib");
        }

        let (msg_tx, msg_rx) = mpsc::channel(4096);
        let (cmd_tx, cmd_rx) = mpsc::channel(64);

        let handle = tokio::spawn(ob_task(url, options, msg_tx, cmd_rx));

        Ok(Self {
            rx: msg_rx,
            cmd_tx,
            _handle: handle,
        })
    }

    /// Receive the next message. Returns None when the stream is closed.
    pub async fn next(&mut self) -> Option<Result<ObMessage>> {
        self.rx.recv().await
    }

    /// Subscribe to orderbook updates for the given token IDs.
    pub async fn subscribe(&self, token_ids: Vec<String>) -> Result<()> {
        self.cmd_tx.send(Command::Subscribe(token_ids)).await
            .map_err(|_| Error::Disconnected)
    }

    /// Unsubscribe from all markets.
    pub async fn unsubscribe(&self) -> Result<()> {
        self.cmd_tx.send(Command::Unsubscribe).await
            .map_err(|_| Error::Disconnected)
    }

    /// Close the connection.
    pub async fn close(self) -> Result<()> {
        let _ = self.cmd_tx.send(Command::Close).await;
        Ok(())
    }
}

async fn ob_task(
    url: String,
    options: ObStreamOptions,
    msg_tx: mpsc::Sender<Result<ObMessage>>,
    mut cmd_rx: mpsc::Receiver<Command>,
) {
    let mut last_token_ids: Vec<String> = Vec::new();
    let mut reconnect_attempts: u32 = 0;

    'outer: loop {
        let ws_stream = match tokio_tungstenite::connect_async(&url).await {
            Ok((stream, _)) => {
                reconnect_attempts = 0;
                stream
            }
            Err(e) => {
                let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
                if !should_reconnect(&options, reconnect_attempts) {
                    break;
                }
                let delay = backoff_delay(&options, reconnect_attempts);
                reconnect_attempts += 1;
                tokio::time::sleep(delay).await;
                continue;
            }
        };

        let (mut write, mut read) = ws_stream.split();

        // Re-subscribe after reconnect
        if !last_token_ids.is_empty() {
            let msg = serde_json::json!({
                "action": "subscribe",
                "markets": last_token_ids
            });
            let msg_text = serde_json::to_string(&msg).unwrap();
            if write.send(Message::Text(msg_text.into())).await.is_err() {
                continue 'outer;
            }
        }

        loop {
            tokio::select! {
                frame = read.next() => {
                    match frame {
                        Some(Ok(msg)) => {
                            match decode_frame(msg) {
                                Ok(Some(text)) => {
                                    let messages = parse_ob_message(&text);
                                    for m in messages {
                                        if msg_tx.send(Ok(m)).await.is_err() {
                                            break 'outer;
                                        }
                                    }
                                }
                                Ok(None) => {}
                                Err(Error::ConnectionClosed) => break,
                                Err(e) => {
                                    let _ = msg_tx.send(Err(e)).await;
                                }
                            }
                        }
                        Some(Err(e)) => {
                            let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
                            break;
                        }
                        None => break,
                    }
                }
                cmd = cmd_rx.recv() => {
                    match cmd {
                        Some(Command::Subscribe(ids)) => {
                            last_token_ids = ids.clone();
                            let msg = serde_json::json!({
                                "action": "subscribe",
                                "markets": ids
                            });
                            let msg_text = serde_json::to_string(&msg).unwrap();
                            if write.send(Message::Text(msg_text.into())).await.is_err() {
                                break;
                            }
                        }
                        Some(Command::Unsubscribe) => {
                            last_token_ids.clear();
                            let msg = serde_json::json!({"action": "unsubscribe"});
                            let msg_text = serde_json::to_string(&msg).unwrap();
                            if write.send(Message::Text(msg_text.into())).await.is_err() {
                                break;
                            }
                        }
                        Some(Command::Close) | None => {
                            let _ = write.send(Message::Close(None)).await;
                            break 'outer;
                        }
                    }
                }
            }
        }

        if !should_reconnect(&options, reconnect_attempts) {
            break;
        }
        let delay = backoff_delay(&options, reconnect_attempts);
        reconnect_attempts += 1;
        tokio::time::sleep(delay).await;
    }
}

fn should_reconnect(options: &ObStreamOptions, attempts: u32) -> bool {
    if !options.auto_reconnect {
        return false;
    }
    match options.max_reconnect_attempts {
        Some(max) => attempts < max,
        None => true,
    }
}

fn backoff_delay(options: &ObStreamOptions, attempts: u32) -> Duration {
    let base = options.initial_backoff.as_millis() as u64;
    let max = options.max_backoff.as_millis() as u64;
    let delay = std::cmp::min(base * 2u64.pow(attempts), max);
    let jitter = delay / 2 + (rand_simple() % (delay / 2 + 1));
    Duration::from_millis(jitter)
}

fn rand_simple() -> u64 {
    use std::time::SystemTime;
    SystemTime::now()
        .duration_since(SystemTime::UNIX_EPOCH)
        .unwrap_or_default()
        .subsec_nanos() as u64
}

/// Parse a raw JSON message into zero or more ObMessages.
/// Batches and snapshot_batches are flattened into individual updates.
fn parse_ob_message(text: &str) -> Vec<ObMessage> {
    let raw: RawObMessage = match serde_json::from_str(text) {
        Ok(r) => r,
        Err(_) => return vec![],
    };

    // Error messages have "error" field instead of "type"
    if let Some(error) = raw.error {
        return vec![ObMessage::Error {
            error,
            message: raw.message.unwrap_or_default(),
        }];
    }

    let msg_type = match raw.msg_type {
        Some(ref t) => t.as_str(),
        None => return vec![],
    };

    match msg_type {
        "subscribed" => vec![ObMessage::Subscribed {
            markets: raw.markets.unwrap_or(0),
        }],
        "unsubscribed" => vec![ObMessage::Unsubscribed],
        "snapshots_done" => vec![ObMessage::SnapshotsDone {
            total: raw.total.unwrap_or(0),
        }],
        "snapshot_batch" => {
            let mut out = Vec::new();
            if let Some(snapshots) = raw.snapshots {
                for val in snapshots {
                    if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
                        out.push(ObMessage::Update(update));
                    }
                }
            }
            out
        }
        "batch" => {
            let mut out = Vec::new();
            if let Some(updates) = raw.updates {
                for val in updates {
                    if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
                        out.push(ObMessage::Update(update));
                    }
                }
            }
            out
        }
        "pong" => vec![],
        _ => vec![],
    }
}