Skip to main content

predict_fun_sdk/
ws.rs

1//! Predict.fun WebSocket client for real-time market data (BNB Chain).
2//!
3//! Connects to `wss://ws.predict.fun/ws` and provides:
4//! - Orderbook snapshots (`predictOrderbook/{marketId}`)
5//! - Asset price updates (`assetPriceUpdate/{feedId}`)
6//! - Cross-venue chance data (`polymarketChance/{marketId}`, `kalshiChance/{marketId}`)
7//! - Wallet event notifications (`predictWalletEvents/{jwt}`)
8//!
9//! # Protocol
10//!
11//! Custom JSON RPC over WebSocket (not graphql-ws):
12//! ```text
13//! → {"requestId": 0, "method": "subscribe", "params": ["predictOrderbook/123"]}
14//! ← {"type": "R", "requestId": 0, "success": true}
15//! ← {"type": "M", "topic": "predictOrderbook/123", "data": {...}}
16//! ```
17//!
18//! # Example
19//!
20//! ```rust,no_run
21//! use predict_fun_sdk::ws::{PredictWsClient, PredictWsMessage, Topic};
22//!
23//! # async fn example() -> anyhow::Result<()> {
24//! let (client, mut rx) = PredictWsClient::connect_mainnet().await?;
25//!
26//! client.subscribe(Topic::Orderbook { market_id: 45532 }).await?;
27//! client.subscribe(Topic::AssetPrice { feed_id: 1 }).await?;
28//!
29//! while let Some(msg) = rx.recv().await {
30//!     match msg {
31//!         PredictWsMessage::Orderbook(ob) => {
32//!             println!("OB market={}: {} bids, {} asks",
33//!                 ob.market_id, ob.bids.len(), ob.asks.len());
34//!         }
35//!         PredictWsMessage::AssetPrice(p) => {
36//!             println!("Price feed {}: ${:.2}", p.feed_id, p.price);
37//!         }
38//!         _ => {}
39//!     }
40//! }
41//! # Ok(())
42//! # }
43//! ```
44
45use std::collections::{HashMap, HashSet};
46use std::sync::atomic::{AtomicU64, Ordering};
47use std::sync::Arc;
48
49use anyhow::{anyhow, Context, Result};
50use futures_util::{SinkExt, StreamExt};
51use serde::{Deserialize, Serialize};
52use serde_json::Value;
53use tokio::sync::{mpsc, oneshot, Mutex};
54use tokio::time::{self, Duration};
55use tokio_tungstenite::tungstenite::Message as WsMessage;
56use tracing::{debug, error, info, warn};
57
58/// WebSocket endpoints.
59pub const PREDICT_WS_MAINNET: &str = "wss://ws.predict.fun/ws";
60pub const PREDICT_WS_TESTNET: &str = "wss://ws.bnb.predict.fail/ws";
61
62/// GraphQL endpoints (for reference / future use).
63pub const PREDICT_GQL_MAINNET: &str = "https://graphql.predict.fun/graphql";
64pub const PREDICT_GQL_TESTNET: &str = "https://graphql.bnb.predict.fail/graphql";
65
66// ── Topic ──
67
68/// Subscription topic for the predict.fun WebSocket feed.
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub enum Topic {
71    /// Full orderbook snapshots on every change (sub-second).
72    Orderbook { market_id: i64 },
73    /// Oracle price updates (~3-5/sec per asset).
74    AssetPrice { feed_id: i64 },
75    /// Cross-venue Polymarket probability.
76    PolymarketChance { market_id: i64 },
77    /// Cross-venue Kalshi probability.
78    KalshiChance { market_id: i64 },
79    /// Wallet events (fills, settlements). Requires JWT.
80    WalletEvents { jwt: String },
81    /// Raw topic string for undocumented topics.
82    Raw(String),
83}
84
85impl Topic {
86    pub fn to_topic_string(&self) -> String {
87        match self {
88            Topic::Orderbook { market_id } => format!("predictOrderbook/{}", market_id),
89            Topic::AssetPrice { feed_id } => format!("assetPriceUpdate/{}", feed_id),
90            Topic::PolymarketChance { market_id } => format!("polymarketChance/{}", market_id),
91            Topic::KalshiChance { market_id } => format!("kalshiChance/{}", market_id),
92            Topic::WalletEvents { jwt } => format!("predictWalletEvents/{}", jwt),
93            Topic::Raw(s) => s.clone(),
94        }
95    }
96
97    pub fn from_topic_string(s: &str) -> Self {
98        if let Some(rest) = s.strip_prefix("predictOrderbook/") {
99            if let Ok(id) = rest.parse::<i64>() {
100                return Topic::Orderbook { market_id: id };
101            }
102        }
103        if let Some(rest) = s.strip_prefix("assetPriceUpdate/") {
104            if let Ok(id) = rest.parse::<i64>() {
105                return Topic::AssetPrice { feed_id: id };
106            }
107        }
108        if let Some(rest) = s.strip_prefix("polymarketChance/") {
109            if let Ok(id) = rest.parse::<i64>() {
110                return Topic::PolymarketChance { market_id: id };
111            }
112        }
113        if let Some(rest) = s.strip_prefix("kalshiChance/") {
114            if let Ok(id) = rest.parse::<i64>() {
115                return Topic::KalshiChance { market_id: id };
116            }
117        }
118        if let Some(rest) = s.strip_prefix("predictWalletEvents/") {
119            return Topic::WalletEvents { jwt: rest.to_string() };
120        }
121        Topic::Raw(s.to_string())
122    }
123}
124
125impl std::fmt::Display for Topic {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        write!(f, "{}", self.to_topic_string())
128    }
129}
130
131// ── Parsed message types ──
132
133/// Orderbook level: `(price, size)`.
134pub type Level = (f64, f64);
135
136/// Last settled order info.
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct LastOrderSettled {
139    pub id: String,
140    pub kind: String,
141    #[serde(rename = "marketId")]
142    pub market_id: i64,
143    pub outcome: String,
144    pub price: String,
145    pub side: String,
146}
147
148/// Full orderbook snapshot pushed on every change.
149#[derive(Debug, Clone)]
150pub struct OrderbookSnapshot {
151    pub market_id: i64,
152    pub bids: Vec<Level>,
153    pub asks: Vec<Level>,
154    pub version: u64,
155    pub update_timestamp_ms: u64,
156    pub order_count: u64,
157    pub last_order_settled: Option<LastOrderSettled>,
158}
159
160impl OrderbookSnapshot {
161    pub fn best_bid(&self) -> Option<f64> {
162        self.bids.first().map(|(p, _)| *p)
163    }
164    pub fn best_ask(&self) -> Option<f64> {
165        self.asks.first().map(|(p, _)| *p)
166    }
167    pub fn mid(&self) -> Option<f64> {
168        match (self.best_bid(), self.best_ask()) {
169            (Some(b), Some(a)) => Some((b + a) / 2.0),
170            _ => None,
171        }
172    }
173    pub fn spread(&self) -> Option<f64> {
174        match (self.best_bid(), self.best_ask()) {
175            (Some(b), Some(a)) => Some(a - b),
176            _ => None,
177        }
178    }
179}
180
181/// Oracle price update.
182#[derive(Debug, Clone)]
183pub struct AssetPriceUpdate {
184    pub feed_id: i64,
185    pub price: f64,
186    pub publish_time: u64,
187    pub timestamp: u64,
188}
189
190/// Cross-venue chance data (Polymarket or Kalshi).
191#[derive(Debug, Clone)]
192pub struct CrossVenueChance {
193    pub source: CrossVenueSource,
194    pub market_id: i64,
195    pub data: Value,
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum CrossVenueSource {
200    Polymarket,
201    Kalshi,
202}
203
204/// Wallet event notification (fills, settlements).
205#[derive(Debug, Clone)]
206pub struct WalletEvent {
207    pub data: Value,
208}
209
210/// Parsed WebSocket message.
211#[derive(Debug, Clone)]
212pub enum PredictWsMessage {
213    Orderbook(OrderbookSnapshot),
214    AssetPrice(AssetPriceUpdate),
215    CrossVenueChance(CrossVenueChance),
216    WalletEvent(WalletEvent),
217    /// Unparsed message (unknown topic or parse failure).
218    Raw { topic: String, data: Value },
219}
220
221// ── Wire protocol ──
222
223#[derive(Serialize)]
224struct WsRequest {
225    #[serde(rename = "requestId")]
226    request_id: u64,
227    method: String,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    params: Option<Vec<String>>,
230    #[serde(skip_serializing_if = "Option::is_none")]
231    data: Option<Value>,
232}
233
234#[derive(Deserialize)]
235struct WsRawMessage {
236    #[serde(rename = "type")]
237    msg_type: String,
238    #[serde(rename = "requestId")]
239    request_id: Option<i64>,
240    success: Option<bool>,
241    error: Option<WsError>,
242    topic: Option<String>,
243    data: Option<Value>,
244}
245
246#[derive(Deserialize, Debug)]
247struct WsError {
248    code: String,
249    message: Option<String>,
250}
251
252impl std::fmt::Display for WsError {
253    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254        write!(f, "{}", self.code)?;
255        if let Some(msg) = &self.message {
256            write!(f, ": {}", msg)?;
257        }
258        Ok(())
259    }
260}
261
262// ── Client ──
263
264type WsSink = futures_util::stream::SplitSink<
265    tokio_tungstenite::WebSocketStream<
266        tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
267    >,
268    WsMessage,
269>;
270type WsStream = futures_util::stream::SplitStream<
271    tokio_tungstenite::WebSocketStream<
272        tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
273    >,
274>;
275
276type PendingResponse = oneshot::Sender<Result<()>>;
277
278/// WebSocket connection config.
279#[derive(Debug, Clone)]
280pub struct PredictWsConfig {
281    pub url: String,
282    pub channel_buffer: usize,
283    pub heartbeat_timeout_secs: u64,
284    pub max_reconnect_attempts: u32,
285    pub max_reconnect_backoff_secs: u64,
286}
287
288impl Default for PredictWsConfig {
289    fn default() -> Self {
290        Self {
291            url: PREDICT_WS_MAINNET.to_string(),
292            channel_buffer: 1024,
293            heartbeat_timeout_secs: 60,
294            max_reconnect_attempts: 0, // infinite
295            max_reconnect_backoff_secs: 15,
296        }
297    }
298}
299
300impl PredictWsConfig {
301    pub fn mainnet() -> Self {
302        Self::default()
303    }
304    pub fn testnet() -> Self {
305        Self {
306            url: PREDICT_WS_TESTNET.to_string(),
307            ..Self::default()
308        }
309    }
310}
311
312/// Handle for interacting with the WebSocket connection.
313///
314/// Messages are received on the `mpsc::Receiver` returned from `connect`.
315#[derive(Clone)]
316pub struct PredictWsClient {
317    sink: Arc<Mutex<WsSink>>,
318    request_id: Arc<AtomicU64>,
319    pending: Arc<Mutex<HashMap<u64, PendingResponse>>>,
320    active_topics: Arc<Mutex<HashSet<String>>>,
321    config: PredictWsConfig,
322}
323
324impl PredictWsClient {
325    pub async fn connect_mainnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
326        Self::connect(PredictWsConfig::mainnet()).await
327    }
328
329    pub async fn connect_testnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
330        Self::connect(PredictWsConfig::testnet()).await
331    }
332
333    pub async fn connect(
334        config: PredictWsConfig,
335    ) -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
336        let (ws_stream, _) = tokio_tungstenite::connect_async(&config.url)
337            .await
338            .with_context(|| format!("failed to connect to {}", config.url))?;
339
340        info!("Connected to {}", config.url);
341
342        let (sink, stream) = ws_stream.split();
343        let (tx, rx) = mpsc::channel(config.channel_buffer);
344
345        let client = Self {
346            sink: Arc::new(Mutex::new(sink)),
347            request_id: Arc::new(AtomicU64::new(0)),
348            pending: Arc::new(Mutex::new(HashMap::new())),
349            active_topics: Arc::new(Mutex::new(HashSet::new())),
350            config,
351        };
352
353        let client_clone = client.clone();
354        tokio::spawn(async move {
355            client_clone.run_loop(stream, tx).await;
356        });
357
358        Ok((client, rx))
359    }
360
361    pub async fn subscribe(&self, topic: Topic) -> Result<()> {
362        let topic_str = topic.to_topic_string();
363        self.send_and_wait("subscribe", &topic_str).await?;
364
365        self.active_topics.lock().await.insert(topic_str.clone());
366        info!("Subscribed to {}", topic_str);
367        Ok(())
368    }
369
370    pub async fn unsubscribe(&self, topic: Topic) -> Result<()> {
371        let topic_str = topic.to_topic_string();
372        self.send_and_wait("unsubscribe", &topic_str).await?;
373
374        self.active_topics.lock().await.remove(&topic_str);
375        info!("Unsubscribed from {}", topic_str);
376        Ok(())
377    }
378
379    pub async fn active_topics(&self) -> Vec<String> {
380        self.active_topics.lock().await.iter().cloned().collect()
381    }
382
383    // ── Internal ──
384
385    async fn send_and_wait(&self, method: &str, topic: &str) -> Result<()> {
386        let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
387        let (resp_tx, resp_rx) = oneshot::channel();
388
389        self.pending.lock().await.insert(request_id, resp_tx);
390
391        let msg = WsRequest {
392            request_id,
393            method: method.to_string(),
394            params: Some(vec![topic.to_string()]),
395            data: None,
396        };
397        self.send_raw(&msg).await?;
398        debug!("{} {} (requestId={})", method, topic, request_id);
399
400        tokio::time::timeout(Duration::from_secs(10), resp_rx)
401            .await
402            .map_err(|_| anyhow!("{} timeout for {}", method, topic))?
403            .map_err(|_| anyhow!("{} channel closed for {}", method, topic))?
404    }
405
406    async fn send_heartbeat(&self, data: &Value) -> Result<()> {
407        let msg = WsRequest {
408            request_id: self.request_id.fetch_add(1, Ordering::Relaxed),
409            method: "heartbeat".to_string(),
410            params: None,
411            data: Some(data.clone()),
412        };
413        self.send_raw(&msg).await
414    }
415
416    async fn send_raw(&self, msg: &WsRequest) -> Result<()> {
417        let text = serde_json::to_string(msg).context("failed to serialize WS message")?;
418        self.sink
419            .lock()
420            .await
421            .send(WsMessage::Text(text))
422            .await
423            .context("failed to send WS message")
424    }
425
426    /// Main event loop: read messages, handle heartbeats, reconnect on disconnect.
427    /// Iterative (not recursive) — simpler and no stack growth.
428    async fn run_loop(&self, mut stream: WsStream, tx: mpsc::Sender<PredictWsMessage>) {
429        let heartbeat_timeout = Duration::from_secs(self.config.heartbeat_timeout_secs);
430        let mut attempt = 0u32;
431
432        loop {
433            // Read loop for current connection
434            let mut last_heartbeat = time::Instant::now();
435            let disconnected = loop {
436                tokio::select! {
437                    msg = stream.next() => {
438                        match msg {
439                            Some(Ok(WsMessage::Text(text))) => {
440                                if let Ok(raw) = serde_json::from_str::<WsRawMessage>(&text) {
441                                    self.handle_message(raw, &tx, &mut last_heartbeat).await;
442                                }
443                            }
444                            Some(Ok(WsMessage::Ping(data))) => {
445                                let _ = self.sink.lock().await.send(WsMessage::Pong(data)).await;
446                            }
447                            Some(Ok(WsMessage::Close(frame))) => {
448                                info!("WebSocket closed by server: {:?}", frame);
449                                break true;
450                            }
451                            Some(Err(e)) => {
452                                error!("WebSocket error: {}", e);
453                                break true;
454                            }
455                            None => {
456                                info!("WebSocket stream ended");
457                                break true;
458                            }
459                            _ => {} // Binary, Pong — ignore
460                        }
461                    }
462                    _ = time::sleep(heartbeat_timeout) => {
463                        if last_heartbeat.elapsed() > heartbeat_timeout {
464                            warn!("Heartbeat timeout ({}s)", self.config.heartbeat_timeout_secs);
465                            break true;
466                        }
467                    }
468                }
469            };
470
471            if !disconnected {
472                return;
473            }
474
475            // Reconnect with exponential backoff
476            let max = self.config.max_reconnect_attempts;
477            if max > 0 && attempt >= max {
478                error!("Max reconnect attempts ({}) reached", max);
479                return;
480            }
481
482            let backoff_secs = (2u64.pow(attempt.min(10))).min(self.config.max_reconnect_backoff_secs);
483            warn!("Reconnecting in {}s (attempt {})", backoff_secs, attempt + 1);
484            time::sleep(Duration::from_secs(backoff_secs)).await;
485            attempt += 1;
486
487            match tokio_tungstenite::connect_async(&self.config.url).await {
488                Ok((ws_stream, _)) => {
489                    info!("Reconnected to {}", self.config.url);
490                    let (new_sink, new_stream) = ws_stream.split();
491                    *self.sink.lock().await = new_sink;
492                    stream = new_stream;
493                    attempt = 0; // reset on success
494
495                    // Resubscribe to all active topics
496                    let topics: Vec<String> =
497                        self.active_topics.lock().await.iter().cloned().collect();
498                    for topic_str in &topics {
499                        let req_id = self.request_id.fetch_add(1, Ordering::Relaxed);
500                        let msg = WsRequest {
501                            request_id: req_id,
502                            method: "subscribe".to_string(),
503                            params: Some(vec![topic_str.clone()]),
504                            data: None,
505                        };
506                        if let Err(e) = self.send_raw(&msg).await {
507                            warn!("Failed to resubscribe to {}: {}", topic_str, e);
508                        }
509                    }
510                }
511                Err(e) => {
512                    error!("Reconnection failed: {}", e);
513                    // loop will retry
514                }
515            }
516        }
517    }
518
519    async fn handle_message(
520        &self,
521        raw: WsRawMessage,
522        tx: &mpsc::Sender<PredictWsMessage>,
523        last_heartbeat: &mut time::Instant,
524    ) {
525        match raw.msg_type.as_str() {
526            "R" => {
527                // Response to subscribe/unsubscribe
528                if let Some(req_id) = raw.request_id {
529                    if let Some(resp_tx) = self.pending.lock().await.remove(&(req_id as u64)) {
530                        let result = if raw.success.unwrap_or(false) {
531                            Ok(())
532                        } else {
533                            Err(anyhow!(
534                                "subscribe failed: {}",
535                                raw.error
536                                    .map(|e| e.to_string())
537                                    .unwrap_or_else(|| "unknown".into())
538                            ))
539                        };
540                        let _ = resp_tx.send(result);
541                    }
542                }
543            }
544            "M" => {
545                // Push message
546                let topic_str = match &raw.topic {
547                    Some(t) => t.as_str(),
548                    None => return,
549                };
550
551                // Heartbeat echo
552                if topic_str == "heartbeat" {
553                    *last_heartbeat = time::Instant::now();
554                    if let Some(data) = &raw.data {
555                        if let Err(e) = self.send_heartbeat(data).await {
556                            warn!("Heartbeat response failed: {}", e);
557                        }
558                    }
559                    return;
560                }
561
562                if let Some(data) = raw.data {
563                    let parsed = parse_push_message(topic_str, data);
564                    if tx.try_send(parsed).is_err() {
565                        warn!("Channel full, dropping message for {}", topic_str);
566                    }
567                }
568            }
569            _ => debug!("Unknown WS message type: {}", raw.msg_type),
570        }
571    }
572}
573
574// ── Message parsing ──
575
576fn parse_push_message(topic: &str, data: Value) -> PredictWsMessage {
577    if let Some(rest) = topic.strip_prefix("predictOrderbook/") {
578        if let Ok(market_id) = rest.parse::<i64>() {
579            if let Some(ob) = parse_orderbook(market_id, &data) {
580                return PredictWsMessage::Orderbook(ob);
581            }
582        }
583    }
584
585    if let Some(rest) = topic.strip_prefix("assetPriceUpdate/") {
586        if let Ok(feed_id) = rest.parse::<i64>() {
587            if let Some(price) = parse_asset_price(feed_id, &data) {
588                return PredictWsMessage::AssetPrice(price);
589            }
590        }
591    }
592
593    if let Some(rest) = topic.strip_prefix("polymarketChance/") {
594        if let Ok(id) = rest.parse::<i64>() {
595            return PredictWsMessage::CrossVenueChance(CrossVenueChance {
596                source: CrossVenueSource::Polymarket,
597                market_id: id,
598                data,
599            });
600        }
601    }
602
603    if let Some(rest) = topic.strip_prefix("kalshiChance/") {
604        if let Ok(id) = rest.parse::<i64>() {
605            return PredictWsMessage::CrossVenueChance(CrossVenueChance {
606                source: CrossVenueSource::Kalshi,
607                market_id: id,
608                data,
609            });
610        }
611    }
612
613    if topic.starts_with("predictWalletEvents/") {
614        return PredictWsMessage::WalletEvent(WalletEvent { data });
615    }
616
617    PredictWsMessage::Raw {
618        topic: topic.to_string(),
619        data,
620    }
621}
622
623fn parse_levels(val: &Value) -> Vec<Level> {
624    val.as_array()
625        .map(|arr| {
626            arr.iter()
627                .filter_map(|lvl| {
628                    let price = lvl.get(0)?.as_f64()?;
629                    let size = lvl.get(1)?.as_f64()?;
630                    Some((price, size))
631                })
632                .collect()
633        })
634        .unwrap_or_default()
635}
636
637fn parse_orderbook(market_id: i64, data: &Value) -> Option<OrderbookSnapshot> {
638    Some(OrderbookSnapshot {
639        market_id,
640        bids: parse_levels(data.get("bids")?),
641        asks: parse_levels(data.get("asks")?),
642        version: data.get("version")?.as_u64().unwrap_or(0),
643        update_timestamp_ms: data
644            .get("updateTimestampMs")
645            .and_then(|v| v.as_u64())
646            .unwrap_or(0),
647        order_count: data
648            .get("orderCount")
649            .and_then(|v| v.as_u64())
650            .unwrap_or(0),
651        last_order_settled: data
652            .get("lastOrderSettled")
653            .and_then(|v| serde_json::from_value(v.clone()).ok()),
654    })
655}
656
657fn parse_asset_price(feed_id: i64, data: &Value) -> Option<AssetPriceUpdate> {
658    Some(AssetPriceUpdate {
659        feed_id,
660        price: data.get("price")?.as_f64()?,
661        publish_time: data.get("publishTime").and_then(|v| v.as_u64()).unwrap_or(0),
662        timestamp: data.get("timestamp").and_then(|v| v.as_u64()).unwrap_or(0),
663    })
664}
665
666// ── Known asset feed IDs ──
667
668/// Known asset price feed IDs for `Topic::AssetPrice`.
669pub mod feeds {
670    pub const BTC: i64 = 1;
671    pub const ETH: i64 = 4;
672    /// Tentative — needs confirmation from predict.fun team.
673    pub const BNB: i64 = 2;
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679
680    #[test]
681    fn topic_roundtrip() {
682        let topics = vec![
683            Topic::Orderbook { market_id: 123 },
684            Topic::AssetPrice { feed_id: 1 },
685            Topic::PolymarketChance { market_id: 456 },
686            Topic::KalshiChance { market_id: 789 },
687            Topic::WalletEvents {
688                jwt: "abc123".to_string(),
689            },
690            Topic::Raw("custom/topic".to_string()),
691        ];
692        for topic in topics {
693            let s = topic.to_topic_string();
694            assert_eq!(topic, Topic::from_topic_string(&s), "Roundtrip failed: {}", s);
695        }
696    }
697
698    #[test]
699    fn topic_display() {
700        assert_eq!(Topic::Orderbook { market_id: 42 }.to_string(), "predictOrderbook/42");
701        assert_eq!(Topic::AssetPrice { feed_id: 1 }.to_string(), "assetPriceUpdate/1");
702    }
703
704    #[test]
705    fn parse_orderbook_snapshot() {
706        let data = serde_json::json!({
707            "asks": [[0.72, 15.0], [0.83, 5.88]],
708            "bids": [[0.57, 15.0], [0.38, 2.63]],
709            "version": 1,
710            "updateTimestampMs": 1772898630219u64,
711            "orderCount": 13,
712            "lastOrderSettled": {
713                "id": "20035648", "kind": "LIMIT", "marketId": 45532,
714                "outcome": "No", "price": "0.60", "side": "Bid"
715            }
716        });
717
718        let ob = parse_orderbook(45532, &data).unwrap();
719        assert_eq!(ob.market_id, 45532);
720        assert_eq!(ob.bids.len(), 2);
721        assert_eq!(ob.asks.len(), 2);
722        assert!((ob.best_bid().unwrap() - 0.57).abs() < 1e-10);
723        assert!((ob.best_ask().unwrap() - 0.72).abs() < 1e-10);
724        assert!((ob.mid().unwrap() - 0.645).abs() < 1e-10);
725        assert!((ob.spread().unwrap() - 0.15).abs() < 1e-10);
726        assert_eq!(ob.version, 1);
727        assert_eq!(ob.order_count, 13);
728        assert!(ob.last_order_settled.is_some());
729    }
730
731    #[test]
732    fn parse_asset_price_update() {
733        let data = serde_json::json!({
734            "price": 67853.57751504,
735            "publishTime": 1772898632u64,
736            "timestamp": 1772898633u64
737        });
738        let price = parse_asset_price(1, &data).unwrap();
739        assert_eq!(price.feed_id, 1);
740        assert!((price.price - 67853.577).abs() < 1.0);
741    }
742
743    #[test]
744    fn parse_push_dispatches() {
745        let ob = serde_json::json!({"asks": [], "bids": [], "version": 1, "updateTimestampMs": 0, "orderCount": 0});
746        assert!(matches!(parse_push_message("predictOrderbook/123", ob), PredictWsMessage::Orderbook(_)));
747
748        let p = serde_json::json!({"price": 100.0, "publishTime": 0, "timestamp": 0});
749        assert!(matches!(parse_push_message("assetPriceUpdate/1", p), PredictWsMessage::AssetPrice(_)));
750
751        let c = serde_json::json!({"chance": 0.5});
752        assert!(matches!(parse_push_message("polymarketChance/456", c), PredictWsMessage::CrossVenueChance(_)));
753
754        let k = serde_json::json!({"chance": 0.3});
755        assert!(matches!(parse_push_message("kalshiChance/789", k), PredictWsMessage::CrossVenueChance(_)));
756
757        let w = serde_json::json!({"event": "fill"});
758        assert!(matches!(parse_push_message("predictWalletEvents/jwt123", w), PredictWsMessage::WalletEvent(_)));
759
760        let u = serde_json::json!({"foo": "bar"});
761        assert!(matches!(parse_push_message("unknown/topic", u), PredictWsMessage::Raw { .. }));
762    }
763
764    #[test]
765    fn orderbook_helpers_empty() {
766        let ob = OrderbookSnapshot {
767            market_id: 1, bids: vec![], asks: vec![], version: 0,
768            update_timestamp_ms: 0, order_count: 0, last_order_settled: None,
769        };
770        assert!(ob.best_bid().is_none());
771        assert!(ob.mid().is_none());
772        assert!(ob.spread().is_none());
773    }
774
775    #[test]
776    fn feed_id_constants() {
777        assert_eq!(feeds::BTC, 1);
778        assert_eq!(feeds::ETH, 4);
779        assert_eq!(feeds::BNB, 2);
780    }
781
782    #[test]
783    fn config_defaults() {
784        let c = PredictWsConfig::default();
785        assert_eq!(c.url, PREDICT_WS_MAINNET);
786        assert_eq!(c.channel_buffer, 1024);
787        assert_eq!(c.heartbeat_timeout_secs, 60);
788    }
789
790    #[test]
791    fn config_testnet() {
792        assert_eq!(PredictWsConfig::testnet().url, PREDICT_WS_TESTNET);
793    }
794}