predict-sdk 0.1.0

Rust SDK for Predict.fun prediction market - order building, EIP-712 signing, and real-time WebSocket data
Documentation
//! WebSocket message types for Predict.fun API
//!
//! Based on: https://dev.predict.fun/

use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};

/// WebSocket request sent by the client
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WsRequest {
    /// Method: "subscribe", "unsubscribe", or "heartbeat"
    pub method: String,
    /// Unique request ID (required for subscribe/unsubscribe)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub request_id: Option<u64>,
    /// Topic parameters (e.g., ["predictOrderbook/123"])
    #[serde(skip_serializing_if = "Option::is_none")]
    pub params: Option<Vec<String>>,
    /// Data payload (used for heartbeat timestamp echo)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub data: Option<serde_json::Value>,
}

impl WsRequest {
    /// Create a subscribe request
    pub fn subscribe(request_id: u64, topics: Vec<String>) -> Self {
        Self {
            method: "subscribe".to_string(),
            request_id: Some(request_id),
            params: Some(topics),
            data: None,
        }
    }

    /// Create an unsubscribe request
    pub fn unsubscribe(request_id: u64, topics: Vec<String>) -> Self {
        Self {
            method: "unsubscribe".to_string(),
            request_id: Some(request_id),
            params: Some(topics),
            data: None,
        }
    }

    /// Create a heartbeat response (echo server timestamp)
    pub fn heartbeat(timestamp: u64) -> Self {
        Self {
            method: "heartbeat".to_string(),
            request_id: None,
            params: None,
            data: Some(serde_json::Value::Number(timestamp.into())),
        }
    }
}

/// Raw WebSocket message from server (before parsing)
#[derive(Debug, Clone, Deserialize)]
pub struct RawWsMessage {
    /// Message type: "R" for request response, "M" for push message
    #[serde(rename = "type")]
    pub msg_type: String,
    /// Request ID (only for type "R")
    #[serde(rename = "requestId")]
    pub request_id: Option<u64>,
    /// Success flag (only for type "R")
    pub success: Option<bool>,
    /// Topic string (only for type "M")
    pub topic: Option<String>,
    /// Data payload
    pub data: Option<serde_json::Value>,
    /// Error details (only for type "R" when success=false)
    pub error: Option<WsError>,
}

/// Parsed WebSocket message
#[derive(Debug, Clone)]
pub enum WsMessage {
    /// Response to a client request (subscribe, unsubscribe)
    RequestResponse(RequestResponse),
    /// Server-initiated push message (orderbook update, heartbeat)
    PushMessage(PushMessage),
}

impl TryFrom<RawWsMessage> for WsMessage {
    type Error = String;

    fn try_from(raw: RawWsMessage) -> Result<Self, Self::Error> {
        match raw.msg_type.as_str() {
            "R" => Ok(WsMessage::RequestResponse(RequestResponse {
                request_id: raw.request_id.ok_or("Missing request_id for type R")?,
                success: raw.success.unwrap_or(false),
                data: raw.data,
                error: raw.error,
            })),
            "M" => Ok(WsMessage::PushMessage(PushMessage {
                topic: raw.topic.ok_or("Missing topic for type M")?,
                data: raw.data.unwrap_or(serde_json::Value::Null),
            })),
            other => Err(format!("Unknown message type: {}", other)),
        }
    }
}

/// Response to a client request
#[derive(Debug, Clone)]
pub struct RequestResponse {
    /// The request ID this response is for
    pub request_id: u64,
    /// Whether the request succeeded
    pub success: bool,
    /// Response data (usually null for subscriptions)
    pub data: Option<serde_json::Value>,
    /// Error details if success=false
    pub error: Option<WsError>,
}

/// Server-initiated push message
#[derive(Debug, Clone)]
pub struct PushMessage {
    /// Topic string (e.g., "predictOrderbook/123", "heartbeat")
    pub topic: String,
    /// Message data
    pub data: serde_json::Value,
}

impl PushMessage {
    /// Check if this is a heartbeat message
    pub fn is_heartbeat(&self) -> bool {
        self.topic == "heartbeat"
    }

    /// Get the heartbeat timestamp if this is a heartbeat message
    pub fn heartbeat_timestamp(&self) -> Option<u64> {
        if self.is_heartbeat() {
            self.data.as_u64()
        } else {
            None
        }
    }

    /// Check if this is an orderbook update
    pub fn is_orderbook(&self) -> bool {
        self.topic.starts_with("predictOrderbook/")
    }

    /// Extract market ID from orderbook topic
    pub fn orderbook_market_id(&self) -> Option<u64> {
        if self.is_orderbook() {
            self.topic
                .strip_prefix("predictOrderbook/")
                .and_then(|s| s.parse().ok())
        } else {
            None
        }
    }

    /// Check if this is an asset price update
    pub fn is_asset_price(&self) -> bool {
        self.topic.starts_with("assetPriceUpdate/")
    }

    /// Extract price feed ID from asset price update topic
    pub fn asset_price_feed_id(&self) -> Option<&str> {
        if self.is_asset_price() {
            self.topic.strip_prefix("assetPriceUpdate/")
        } else {
            None
        }
    }

    /// Check if this is a Polymarket chance update
    pub fn is_polymarket_chance(&self) -> bool {
        self.topic.starts_with("polymarketChance/")
    }

    /// Extract market ID from Polymarket chance topic
    pub fn polymarket_chance_market_id(&self) -> Option<u64> {
        if self.is_polymarket_chance() {
            self.topic
                .strip_prefix("polymarketChance/")
                .and_then(|s| s.parse().ok())
        } else {
            None
        }
    }

    /// Check if this is a Kalshi chance update
    pub fn is_kalshi_chance(&self) -> bool {
        self.topic.starts_with("kalshiChance/")
    }

    /// Extract market ID from Kalshi chance topic
    pub fn kalshi_chance_market_id(&self) -> Option<u64> {
        if self.is_kalshi_chance() {
            self.topic
                .strip_prefix("kalshiChance/")
                .and_then(|s| s.parse().ok())
        } else {
            None
        }
    }

    /// Check if this is a wallet event
    pub fn is_wallet_event(&self) -> bool {
        self.topic.starts_with("predictWalletEvents/")
    }
}

/// WebSocket error from server
#[derive(Debug, Clone, Deserialize)]
pub struct WsError {
    /// Error code (e.g., "invalid_json", "invalid_topic")
    pub code: String,
    /// Human-readable error message
    pub message: String,
}

/// Orderbook update data from predictOrderbook topic
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OrderbookData {
    /// Market ID
    pub market_id: u64,
    /// Bid orders (price, size)
    pub bids: Vec<PriceLevel>,
    /// Ask orders (price, size)
    pub asks: Vec<PriceLevel>,
    /// Update timestamp (milliseconds)
    #[serde(default)]
    pub timestamp: Option<u64>,
}

/// Asset price update data from assetPriceUpdate topic
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AssetPriceData {
    /// Current price
    pub price: f64,
    /// Pyth publish time (seconds since epoch)
    pub publish_time: u64,
    /// Server timestamp (milliseconds)
    pub timestamp: u64,
}

/// A price level in the orderbook
#[derive(Debug, Clone, Deserialize)]
pub struct PriceLevel {
    /// Price (as string from API, parsed to Decimal)
    #[serde(deserialize_with = "deserialize_decimal")]
    pub price: Decimal,
    /// Size/quantity (as string from API, parsed to Decimal)
    #[serde(deserialize_with = "deserialize_decimal")]
    pub size: Decimal,
}

/// Deserialize a string or number to Decimal
fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
where
    D: serde::Deserializer<'de>,
{
    use serde::de::Error;

    #[derive(Deserialize)]
    #[serde(untagged)]
    enum StringOrNumber {
        String(String),
        Number(f64),
    }

    match StringOrNumber::deserialize(deserializer)? {
        StringOrNumber::String(s) => s.parse().map_err(D::Error::custom),
        StringOrNumber::Number(n) => Decimal::try_from(n).map_err(D::Error::custom),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_subscribe_request() {
        let req = WsRequest::subscribe(1, vec!["predictOrderbook/123".to_string()]);
        let json = serde_json::to_string(&req).unwrap();
        assert!(json.contains("\"method\":\"subscribe\""));
        assert!(json.contains("\"requestId\":1"));
        assert!(json.contains("predictOrderbook/123"));
    }

    #[test]
    fn test_heartbeat_request() {
        let req = WsRequest::heartbeat(1736696400000);
        let json = serde_json::to_string(&req).unwrap();
        assert!(json.contains("\"method\":\"heartbeat\""));
        assert!(json.contains("1736696400000"));
    }

    #[test]
    fn test_parse_request_response() {
        let json = r#"{"type":"R","requestId":1,"success":true,"data":null}"#;
        let raw: RawWsMessage = serde_json::from_str(json).unwrap();
        let msg = WsMessage::try_from(raw).unwrap();
        match msg {
            WsMessage::RequestResponse(resp) => {
                assert_eq!(resp.request_id, 1);
                assert!(resp.success);
            }
            _ => panic!("Expected RequestResponse"),
        }
    }

    #[test]
    fn test_parse_heartbeat_message() {
        let json = r#"{"type":"M","topic":"heartbeat","data":1736696400000}"#;
        let raw: RawWsMessage = serde_json::from_str(json).unwrap();
        let msg = WsMessage::try_from(raw).unwrap();
        match msg {
            WsMessage::PushMessage(push) => {
                assert!(push.is_heartbeat());
                assert_eq!(push.heartbeat_timestamp(), Some(1736696400000));
            }
            _ => panic!("Expected PushMessage"),
        }
    }

    #[test]
    fn test_parse_orderbook_topic() {
        let push = PushMessage {
            topic: "predictOrderbook/5614".to_string(),
            data: serde_json::Value::Null,
        };
        assert!(push.is_orderbook());
        assert_eq!(push.orderbook_market_id(), Some(5614));
    }
}