Skip to main content

predict_sdk/websocket/
types.rs

1//! WebSocket message types for Predict.fun API
2//!
3//! Based on: https://dev.predict.fun/
4
5use rust_decimal::Decimal;
6use serde::{Deserialize, Serialize};
7
8/// WebSocket request sent by the client
9#[derive(Debug, Clone, Serialize)]
10#[serde(rename_all = "camelCase")]
11pub struct WsRequest {
12    /// Method: "subscribe", "unsubscribe", or "heartbeat"
13    pub method: String,
14    /// Unique request ID (required for subscribe/unsubscribe)
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub request_id: Option<u64>,
17    /// Topic parameters (e.g., ["predictOrderbook/123"])
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub params: Option<Vec<String>>,
20    /// Data payload (used for heartbeat timestamp echo)
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub data: Option<serde_json::Value>,
23}
24
25impl WsRequest {
26    /// Create a subscribe request
27    pub fn subscribe(request_id: u64, topics: Vec<String>) -> Self {
28        Self {
29            method: "subscribe".to_string(),
30            request_id: Some(request_id),
31            params: Some(topics),
32            data: None,
33        }
34    }
35
36    /// Create an unsubscribe request
37    pub fn unsubscribe(request_id: u64, topics: Vec<String>) -> Self {
38        Self {
39            method: "unsubscribe".to_string(),
40            request_id: Some(request_id),
41            params: Some(topics),
42            data: None,
43        }
44    }
45
46    /// Create a heartbeat response (echo server timestamp)
47    pub fn heartbeat(timestamp: u64) -> Self {
48        Self {
49            method: "heartbeat".to_string(),
50            request_id: None,
51            params: None,
52            data: Some(serde_json::Value::Number(timestamp.into())),
53        }
54    }
55}
56
57/// Raw WebSocket message from server (before parsing)
58#[derive(Debug, Clone, Deserialize)]
59pub struct RawWsMessage {
60    /// Message type: "R" for request response, "M" for push message
61    #[serde(rename = "type")]
62    pub msg_type: String,
63    /// Request ID (only for type "R")
64    #[serde(rename = "requestId")]
65    pub request_id: Option<u64>,
66    /// Success flag (only for type "R")
67    pub success: Option<bool>,
68    /// Topic string (only for type "M")
69    pub topic: Option<String>,
70    /// Data payload
71    pub data: Option<serde_json::Value>,
72    /// Error details (only for type "R" when success=false)
73    pub error: Option<WsError>,
74}
75
76/// Parsed WebSocket message
77#[derive(Debug, Clone)]
78pub enum WsMessage {
79    /// Response to a client request (subscribe, unsubscribe)
80    RequestResponse(RequestResponse),
81    /// Server-initiated push message (orderbook update, heartbeat)
82    PushMessage(PushMessage),
83}
84
85impl TryFrom<RawWsMessage> for WsMessage {
86    type Error = String;
87
88    fn try_from(raw: RawWsMessage) -> Result<Self, Self::Error> {
89        match raw.msg_type.as_str() {
90            "R" => Ok(WsMessage::RequestResponse(RequestResponse {
91                request_id: raw.request_id.ok_or("Missing request_id for type R")?,
92                success: raw.success.unwrap_or(false),
93                data: raw.data,
94                error: raw.error,
95            })),
96            "M" => Ok(WsMessage::PushMessage(PushMessage {
97                topic: raw.topic.ok_or("Missing topic for type M")?,
98                data: raw.data.unwrap_or(serde_json::Value::Null),
99            })),
100            other => Err(format!("Unknown message type: {}", other)),
101        }
102    }
103}
104
105/// Response to a client request
106#[derive(Debug, Clone)]
107pub struct RequestResponse {
108    /// The request ID this response is for
109    pub request_id: u64,
110    /// Whether the request succeeded
111    pub success: bool,
112    /// Response data (usually null for subscriptions)
113    pub data: Option<serde_json::Value>,
114    /// Error details if success=false
115    pub error: Option<WsError>,
116}
117
118/// Server-initiated push message
119#[derive(Debug, Clone)]
120pub struct PushMessage {
121    /// Topic string (e.g., "predictOrderbook/123", "heartbeat")
122    pub topic: String,
123    /// Message data
124    pub data: serde_json::Value,
125}
126
127impl PushMessage {
128    /// Check if this is a heartbeat message
129    pub fn is_heartbeat(&self) -> bool {
130        self.topic == "heartbeat"
131    }
132
133    /// Get the heartbeat timestamp if this is a heartbeat message
134    pub fn heartbeat_timestamp(&self) -> Option<u64> {
135        if self.is_heartbeat() {
136            self.data.as_u64()
137        } else {
138            None
139        }
140    }
141
142    /// Check if this is an orderbook update
143    pub fn is_orderbook(&self) -> bool {
144        self.topic.starts_with("predictOrderbook/")
145    }
146
147    /// Extract market ID from orderbook topic
148    pub fn orderbook_market_id(&self) -> Option<u64> {
149        if self.is_orderbook() {
150            self.topic
151                .strip_prefix("predictOrderbook/")
152                .and_then(|s| s.parse().ok())
153        } else {
154            None
155        }
156    }
157
158    /// Check if this is an asset price update
159    pub fn is_asset_price(&self) -> bool {
160        self.topic.starts_with("assetPriceUpdate/")
161    }
162
163    /// Extract price feed ID from asset price update topic
164    pub fn asset_price_feed_id(&self) -> Option<&str> {
165        if self.is_asset_price() {
166            self.topic.strip_prefix("assetPriceUpdate/")
167        } else {
168            None
169        }
170    }
171
172    /// Check if this is a Polymarket chance update
173    pub fn is_polymarket_chance(&self) -> bool {
174        self.topic.starts_with("polymarketChance/")
175    }
176
177    /// Extract market ID from Polymarket chance topic
178    pub fn polymarket_chance_market_id(&self) -> Option<u64> {
179        if self.is_polymarket_chance() {
180            self.topic
181                .strip_prefix("polymarketChance/")
182                .and_then(|s| s.parse().ok())
183        } else {
184            None
185        }
186    }
187
188    /// Check if this is a Kalshi chance update
189    pub fn is_kalshi_chance(&self) -> bool {
190        self.topic.starts_with("kalshiChance/")
191    }
192
193    /// Extract market ID from Kalshi chance topic
194    pub fn kalshi_chance_market_id(&self) -> Option<u64> {
195        if self.is_kalshi_chance() {
196            self.topic
197                .strip_prefix("kalshiChance/")
198                .and_then(|s| s.parse().ok())
199        } else {
200            None
201        }
202    }
203
204    /// Check if this is a wallet event
205    pub fn is_wallet_event(&self) -> bool {
206        self.topic.starts_with("predictWalletEvents/")
207    }
208}
209
210/// WebSocket error from server
211#[derive(Debug, Clone, Deserialize)]
212pub struct WsError {
213    /// Error code (e.g., "invalid_json", "invalid_topic")
214    pub code: String,
215    /// Human-readable error message
216    pub message: String,
217}
218
219/// Orderbook update data from predictOrderbook topic
220#[derive(Debug, Clone, Deserialize)]
221#[serde(rename_all = "camelCase")]
222pub struct OrderbookData {
223    /// Market ID
224    pub market_id: u64,
225    /// Bid orders (price, size)
226    pub bids: Vec<PriceLevel>,
227    /// Ask orders (price, size)
228    pub asks: Vec<PriceLevel>,
229    /// Update timestamp (milliseconds)
230    #[serde(default)]
231    pub timestamp: Option<u64>,
232}
233
234/// Asset price update data from assetPriceUpdate topic
235#[derive(Debug, Clone, Deserialize)]
236#[serde(rename_all = "camelCase")]
237pub struct AssetPriceData {
238    /// Current price
239    pub price: f64,
240    /// Pyth publish time (seconds since epoch)
241    pub publish_time: u64,
242    /// Server timestamp (milliseconds)
243    pub timestamp: u64,
244}
245
246/// A price level in the orderbook
247#[derive(Debug, Clone, Deserialize)]
248pub struct PriceLevel {
249    /// Price (as string from API, parsed to Decimal)
250    #[serde(deserialize_with = "deserialize_decimal")]
251    pub price: Decimal,
252    /// Size/quantity (as string from API, parsed to Decimal)
253    #[serde(deserialize_with = "deserialize_decimal")]
254    pub size: Decimal,
255}
256
257/// Deserialize a string or number to Decimal
258fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
259where
260    D: serde::Deserializer<'de>,
261{
262    use serde::de::Error;
263
264    #[derive(Deserialize)]
265    #[serde(untagged)]
266    enum StringOrNumber {
267        String(String),
268        Number(f64),
269    }
270
271    match StringOrNumber::deserialize(deserializer)? {
272        StringOrNumber::String(s) => s.parse().map_err(D::Error::custom),
273        StringOrNumber::Number(n) => Decimal::try_from(n).map_err(D::Error::custom),
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_subscribe_request() {
283        let req = WsRequest::subscribe(1, vec!["predictOrderbook/123".to_string()]);
284        let json = serde_json::to_string(&req).unwrap();
285        assert!(json.contains("\"method\":\"subscribe\""));
286        assert!(json.contains("\"requestId\":1"));
287        assert!(json.contains("predictOrderbook/123"));
288    }
289
290    #[test]
291    fn test_heartbeat_request() {
292        let req = WsRequest::heartbeat(1736696400000);
293        let json = serde_json::to_string(&req).unwrap();
294        assert!(json.contains("\"method\":\"heartbeat\""));
295        assert!(json.contains("1736696400000"));
296    }
297
298    #[test]
299    fn test_parse_request_response() {
300        let json = r#"{"type":"R","requestId":1,"success":true,"data":null}"#;
301        let raw: RawWsMessage = serde_json::from_str(json).unwrap();
302        let msg = WsMessage::try_from(raw).unwrap();
303        match msg {
304            WsMessage::RequestResponse(resp) => {
305                assert_eq!(resp.request_id, 1);
306                assert!(resp.success);
307            }
308            _ => panic!("Expected RequestResponse"),
309        }
310    }
311
312    #[test]
313    fn test_parse_heartbeat_message() {
314        let json = r#"{"type":"M","topic":"heartbeat","data":1736696400000}"#;
315        let raw: RawWsMessage = serde_json::from_str(json).unwrap();
316        let msg = WsMessage::try_from(raw).unwrap();
317        match msg {
318            WsMessage::PushMessage(push) => {
319                assert!(push.is_heartbeat());
320                assert_eq!(push.heartbeat_timestamp(), Some(1736696400000));
321            }
322            _ => panic!("Expected PushMessage"),
323        }
324    }
325
326    #[test]
327    fn test_parse_orderbook_topic() {
328        let push = PushMessage {
329            topic: "predictOrderbook/5614".to_string(),
330            data: serde_json::Value::Null,
331        };
332        assert!(push.is_orderbook());
333        assert_eq!(push.orderbook_market_id(), Some(5614));
334    }
335}