Skip to main content

predict_fun_sdk/
ws.rs

1//! Predict.fun WebSocket client for real-time market data.
2//!
3//! Connects to `wss://ws.predict.fun/ws` and provides:
4//! - Orderbook snapshots (`predictOrderbook/{marketId}`)
5//! - Asset price updates (`assetPriceUpdate/{feedId}`)
6//! - Cross-venue chance data (`polymarketChance/{marketId}`, `kalshiChance/{marketId}`)
7//! - Wallet event notifications (`predictWalletEvents/{jwt}`)
8//!
9//! # Protocol
10//!
11//! Custom JSON RPC over WebSocket:
12//! ```text
13//! → {"requestId": 0, "method": "subscribe", "params": ["predictOrderbook/123"]}
14//! ← {"type": "R", "requestId": 0, "success": true}
15//! ← {"type": "M", "topic": "predictOrderbook/123", "data": {...}}
16//! ```
17//!
18//! Message types: `R` = Response (to a request), `M` = Push message (subscription data).
19//!
20//! # Example
21//!
22//! ```rust,no_run
23//! use predict_fun_sdk::ws::{PredictWsClient, PredictWsMessage, Topic};
24//!
25//! # async fn example() -> anyhow::Result<()> {
26//! let (client, mut rx) = PredictWsClient::connect_mainnet().await?;
27//!
28//! client.subscribe(Topic::Orderbook { market_id: 45532 }).await?;
29//! client.subscribe(Topic::AssetPrice { feed_id: 1 }).await?;
30//!
31//! while let Some(msg) = rx.recv().await {
32//!     match msg {
33//!         PredictWsMessage::Orderbook(ob) => {
34//!             println!("OB market={}: {} bids, {} asks",
35//!                 ob.market_id, ob.bids.len(), ob.asks.len());
36//!         }
37//!         PredictWsMessage::AssetPrice(p) => {
38//!             println!("Price feed {}: ${:.2}", p.feed_id, p.price);
39//!         }
40//!         _ => {}
41//!     }
42//! }
43//! # Ok(())
44//! # }
45//! ```
46
47use std::collections::HashMap;
48use std::sync::atomic::{AtomicU64, Ordering};
49use std::sync::Arc;
50
51use anyhow::{anyhow, Context, Result};
52use futures_util::{SinkExt, StreamExt};
53use serde::{Deserialize, Serialize};
54use serde_json::Value;
55use tokio::sync::{mpsc, oneshot, Mutex};
56use tokio::time::{self, Duration};
57use tokio_tungstenite::tungstenite::Message as WsMessage;
58use tracing::{debug, error, info, warn};
59
60/// WebSocket endpoints.
61pub const PREDICT_WS_MAINNET: &str = "wss://ws.predict.fun/ws";
62pub const PREDICT_WS_TESTNET: &str = "wss://ws.bnb.predict.fail/ws";
63
64/// GraphQL endpoints (for reference / future use).
65pub const PREDICT_GQL_MAINNET: &str = "https://graphql.predict.fun/graphql";
66pub const PREDICT_GQL_TESTNET: &str = "https://graphql.bnb.predict.fail/graphql";
67
68// ── Topic types ──
69
70/// Subscription topic for the predict.fun WebSocket feed.
71#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72pub enum Topic {
73    /// Real-time orderbook snapshots for a market.
74    /// Topic string: `predictOrderbook/{market_id}`
75    Orderbook { market_id: i64 },
76
77    /// Asset price updates from oracle feeds.
78    /// Topic string: `assetPriceUpdate/{feed_id}`
79    /// Known feeds: 1=BTC, 4=ETH
80    AssetPrice { feed_id: i64 },
81
82    /// Polymarket chance/probability for a market (cross-venue reference).
83    /// Topic string: `polymarketChance/{market_id}`
84    PolymarketChance { market_id: i64 },
85
86    /// Kalshi chance/probability for a market (cross-venue reference).
87    /// Topic string: `kalshiChance/{market_id}`
88    KalshiChance { market_id: i64 },
89
90    /// Wallet events (fills, settlements). Requires JWT auth token.
91    /// Topic string: `predictWalletEvents/{jwt}`
92    WalletEvents { jwt: String },
93
94    /// Raw topic string for undocumented/custom topics.
95    Raw(String),
96}
97
98impl Topic {
99    /// Convert to the wire-format topic string.
100    pub fn to_topic_string(&self) -> String {
101        match self {
102            Topic::Orderbook { market_id } => format!("predictOrderbook/{}", market_id),
103            Topic::AssetPrice { feed_id } => format!("assetPriceUpdate/{}", feed_id),
104            Topic::PolymarketChance { market_id } => format!("polymarketChance/{}", market_id),
105            Topic::KalshiChance { market_id } => format!("kalshiChance/{}", market_id),
106            Topic::WalletEvents { jwt } => format!("predictWalletEvents/{}", jwt),
107            Topic::Raw(s) => s.clone(),
108        }
109    }
110
111    /// Parse a wire-format topic string back into a Topic.
112    pub fn from_topic_string(s: &str) -> Self {
113        if let Some(rest) = s.strip_prefix("predictOrderbook/") {
114            if let Ok(id) = rest.parse::<i64>() {
115                return Topic::Orderbook { market_id: id };
116            }
117        }
118        if let Some(rest) = s.strip_prefix("assetPriceUpdate/") {
119            if let Ok(id) = rest.parse::<i64>() {
120                return Topic::AssetPrice { feed_id: id };
121            }
122        }
123        if let Some(rest) = s.strip_prefix("polymarketChance/") {
124            if let Ok(id) = rest.parse::<i64>() {
125                return Topic::PolymarketChance { market_id: id };
126            }
127        }
128        if let Some(rest) = s.strip_prefix("kalshiChance/") {
129            if let Ok(id) = rest.parse::<i64>() {
130                return Topic::KalshiChance { market_id: id };
131            }
132        }
133        if let Some(rest) = s.strip_prefix("predictWalletEvents/") {
134            return Topic::WalletEvents {
135                jwt: rest.to_string(),
136            };
137        }
138        Topic::Raw(s.to_string())
139    }
140}
141
142impl std::fmt::Display for Topic {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        write!(f, "{}", self.to_topic_string())
145    }
146}
147
148// ── Parsed message types ──
149
150/// Orderbook level: `[price, size]`.
151pub type Level = (f64, f64);
152
153/// Last settled order info from the orderbook snapshot.
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct LastOrderSettled {
156    pub id: String,
157    pub kind: String,
158    #[serde(rename = "marketId")]
159    pub market_id: i64,
160    pub outcome: String,
161    pub price: String,
162    pub side: String,
163}
164
165/// Full orderbook snapshot pushed on every change.
166#[derive(Debug, Clone)]
167pub struct OrderbookSnapshot {
168    pub market_id: i64,
169    pub bids: Vec<Level>,
170    pub asks: Vec<Level>,
171    pub version: u64,
172    pub update_timestamp_ms: u64,
173    pub order_count: u64,
174    pub last_order_settled: Option<LastOrderSettled>,
175}
176
177impl OrderbookSnapshot {
178    /// Best bid price, or None if empty.
179    pub fn best_bid(&self) -> Option<f64> {
180        self.bids.first().map(|(p, _)| *p)
181    }
182
183    /// Best ask price, or None if empty.
184    pub fn best_ask(&self) -> Option<f64> {
185        self.asks.first().map(|(p, _)| *p)
186    }
187
188    /// Midpoint price, or None if either side is empty.
189    pub fn mid(&self) -> Option<f64> {
190        match (self.best_bid(), self.best_ask()) {
191            (Some(b), Some(a)) => Some((b + a) / 2.0),
192            _ => None,
193        }
194    }
195
196    /// Spread (ask - bid), or None if either side is empty.
197    pub fn spread(&self) -> Option<f64> {
198        match (self.best_bid(), self.best_ask()) {
199            (Some(b), Some(a)) => Some(a - b),
200            _ => None,
201        }
202    }
203}
204
205/// Asset price update from oracle feed.
206#[derive(Debug, Clone)]
207pub struct AssetPriceUpdate {
208    pub feed_id: i64,
209    pub price: f64,
210    pub publish_time: u64,
211    pub timestamp: u64,
212}
213
214/// Cross-venue chance data (Polymarket or Kalshi).
215#[derive(Debug, Clone)]
216pub struct CrossVenueChance {
217    pub source: CrossVenueSource,
218    pub market_id: i64,
219    pub data: Value,
220}
221
222/// Source of cross-venue data.
223#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum CrossVenueSource {
225    Polymarket,
226    Kalshi,
227}
228
229/// Wallet event notification.
230#[derive(Debug, Clone)]
231pub struct WalletEvent {
232    pub data: Value,
233}
234
235/// Parsed WebSocket message from predict.fun.
236#[derive(Debug, Clone)]
237pub enum PredictWsMessage {
238    /// Full orderbook snapshot.
239    Orderbook(OrderbookSnapshot),
240    /// Asset price update from oracle.
241    AssetPrice(AssetPriceUpdate),
242    /// Cross-venue chance data.
243    CrossVenueChance(CrossVenueChance),
244    /// Wallet event (fills, settlements).
245    WalletEvent(WalletEvent),
246    /// Unparsed push message (unknown topic or parse failure).
247    Raw { topic: String, data: Value },
248}
249
250// ── Wire protocol types ──
251
252#[derive(Serialize)]
253struct WsRequest {
254    #[serde(rename = "requestId")]
255    request_id: u64,
256    method: String,
257    #[serde(skip_serializing_if = "Option::is_none")]
258    params: Option<Vec<String>>,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    data: Option<Value>,
261}
262
263#[derive(Deserialize)]
264struct WsRawMessage {
265    #[serde(rename = "type")]
266    msg_type: String,
267    #[serde(rename = "requestId")]
268    request_id: Option<i64>,
269    success: Option<bool>,
270    error: Option<WsError>,
271    topic: Option<String>,
272    data: Option<Value>,
273}
274
275#[derive(Deserialize, Debug, Clone)]
276struct WsError {
277    code: String,
278    message: Option<String>,
279}
280
281impl std::fmt::Display for WsError {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        write!(f, "{}", self.code)?;
284        if let Some(msg) = &self.message {
285            write!(f, ": {}", msg)?;
286        }
287        Ok(())
288    }
289}
290
291// ── Client ──
292
293type WsSink = futures_util::stream::SplitSink<
294    tokio_tungstenite::WebSocketStream<
295        tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
296    >,
297    WsMessage,
298>;
299
300/// Pending subscribe/unsubscribe response channel.
301type PendingResponse = oneshot::Sender<Result<()>>;
302
303/// Configuration for the WebSocket client.
304#[derive(Debug, Clone)]
305pub struct PredictWsConfig {
306    /// WebSocket URL. Defaults to mainnet.
307    pub url: String,
308    /// Channel buffer size for outgoing messages. Default: 1024.
309    pub channel_buffer: usize,
310    /// Heartbeat timeout in seconds. Default: 60.
311    pub heartbeat_timeout_secs: u64,
312    /// Maximum reconnection attempts. 0 = infinite. Default: 0.
313    pub max_reconnect_attempts: u32,
314    /// Maximum reconnection backoff in seconds. Default: 15.
315    pub max_reconnect_backoff_secs: u64,
316}
317
318impl Default for PredictWsConfig {
319    fn default() -> Self {
320        Self {
321            url: PREDICT_WS_MAINNET.to_string(),
322            channel_buffer: 1024,
323            heartbeat_timeout_secs: 60,
324            max_reconnect_attempts: 0,
325            max_reconnect_backoff_secs: 15,
326        }
327    }
328}
329
330impl PredictWsConfig {
331    pub fn mainnet() -> Self {
332        Self::default()
333    }
334
335    pub fn testnet() -> Self {
336        Self {
337            url: PREDICT_WS_TESTNET.to_string(),
338            ..Self::default()
339        }
340    }
341}
342
343/// Handle for interacting with the WebSocket connection.
344///
345/// Subscribe/unsubscribe to topics. Messages are received on the `mpsc::Receiver`
346/// returned from [`PredictWsClient::connect`].
347#[derive(Clone)]
348pub struct PredictWsClient {
349    sink: Arc<Mutex<WsSink>>,
350    request_id: Arc<AtomicU64>,
351    pending: Arc<Mutex<HashMap<u64, PendingResponse>>>,
352    active_topics: Arc<Mutex<Vec<String>>>,
353    config: PredictWsConfig,
354}
355
356impl PredictWsClient {
357    /// Connect to mainnet WebSocket and return (client_handle, message_receiver).
358    pub async fn connect_mainnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
359        Self::connect(PredictWsConfig::mainnet()).await
360    }
361
362    /// Connect to testnet WebSocket and return (client_handle, message_receiver).
363    pub async fn connect_testnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
364        Self::connect(PredictWsConfig::testnet()).await
365    }
366
367    /// Connect with custom configuration.
368    pub async fn connect(
369        config: PredictWsConfig,
370    ) -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
371        let (ws_stream, _) = tokio_tungstenite::connect_async(&config.url)
372            .await
373            .with_context(|| format!("failed to connect to {}", config.url))?;
374
375        info!("Connected to {}", config.url);
376
377        let (sink, stream) = ws_stream.split();
378        let (tx, rx) = mpsc::channel(config.channel_buffer);
379
380        let client = Self {
381            sink: Arc::new(Mutex::new(sink)),
382            request_id: Arc::new(AtomicU64::new(0)),
383            pending: Arc::new(Mutex::new(HashMap::new())),
384            active_topics: Arc::new(Mutex::new(Vec::new())),
385            config: config.clone(),
386        };
387
388        // Spawn the read loop
389        let client_clone = client.clone();
390        tokio::spawn(async move {
391            client_clone.read_loop(stream, tx).await
392        });
393
394        Ok((client, rx))
395    }
396
397    /// Subscribe to a topic. Returns Ok(()) when the server acknowledges.
398    pub async fn subscribe(&self, topic: Topic) -> Result<()> {
399        let topic_str = topic.to_topic_string();
400        let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
401
402        let (resp_tx, resp_rx) = oneshot::channel();
403        {
404            let mut pending = self.pending.lock().await;
405            pending.insert(request_id, resp_tx);
406        }
407
408        let msg = WsRequest {
409            request_id,
410            method: "subscribe".to_string(),
411            params: Some(vec![topic_str.clone()]),
412            data: None,
413        };
414
415        self.send_raw(&msg).await?;
416        debug!("Subscribing to {} (requestId={})", topic_str, request_id);
417
418        // Wait for server response
419        let result = tokio::time::timeout(Duration::from_secs(10), resp_rx)
420            .await
421            .map_err(|_| anyhow!("subscribe timeout for {}", topic_str))?
422            .map_err(|_| anyhow!("subscribe channel closed for {}", topic_str))??;
423
424        // Track active topic for reconnection
425        {
426            let mut topics = self.active_topics.lock().await;
427            if !topics.contains(&topic_str) {
428                topics.push(topic_str.clone());
429            }
430        }
431
432        info!("Subscribed to {}", topic_str);
433        Ok(result)
434    }
435
436    /// Unsubscribe from a topic.
437    pub async fn unsubscribe(&self, topic: Topic) -> Result<()> {
438        let topic_str = topic.to_topic_string();
439        let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
440
441        let (resp_tx, resp_rx) = oneshot::channel();
442        {
443            let mut pending = self.pending.lock().await;
444            pending.insert(request_id, resp_tx);
445        }
446
447        let msg = WsRequest {
448            request_id,
449            method: "unsubscribe".to_string(),
450            params: Some(vec![topic_str.clone()]),
451            data: None,
452        };
453
454        self.send_raw(&msg).await?;
455
456        tokio::time::timeout(Duration::from_secs(10), resp_rx)
457            .await
458            .map_err(|_| anyhow!("unsubscribe timeout for {}", topic_str))?
459            .map_err(|_| anyhow!("unsubscribe channel closed for {}", topic_str))??;
460
461        // Remove from active topics
462        {
463            let mut topics = self.active_topics.lock().await;
464            topics.retain(|t| t != &topic_str);
465        }
466
467        info!("Unsubscribed from {}", topic_str);
468        Ok(())
469    }
470
471    /// Get list of currently subscribed topics.
472    pub async fn active_topics(&self) -> Vec<String> {
473        self.active_topics.lock().await.clone()
474    }
475
476    /// Send a raw heartbeat response (called internally by the read loop).
477    async fn send_heartbeat(&self, data: &Value) -> Result<()> {
478        let msg = WsRequest {
479            request_id: self.request_id.fetch_add(1, Ordering::Relaxed),
480            method: "heartbeat".to_string(),
481            params: None,
482            data: Some(data.clone()),
483        };
484        self.send_raw(&msg).await
485    }
486
487    async fn send_raw(&self, msg: &WsRequest) -> Result<()> {
488        let text = serde_json::to_string(msg).context("failed to serialize WS message")?;
489        let mut sink = self.sink.lock().await;
490        sink.send(WsMessage::Text(text))
491            .await
492            .context("failed to send WS message")?;
493        Ok(())
494    }
495
496    /// Main read loop — processes incoming messages, dispatches to channel.
497    fn read_loop(
498        &self,
499        mut stream: futures_util::stream::SplitStream<
500            tokio_tungstenite::WebSocketStream<
501                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
502            >,
503        >,
504        tx: mpsc::Sender<PredictWsMessage>,
505    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + '_>> {
506        Box::pin(async move {
507        let heartbeat_timeout = Duration::from_secs(self.config.heartbeat_timeout_secs);
508        let mut last_heartbeat = time::Instant::now();
509
510        loop {
511            tokio::select! {
512                msg = stream.next() => {
513                    match msg {
514                        Some(Ok(WsMessage::Text(text))) => {
515                            match serde_json::from_str::<WsRawMessage>(&text) {
516                                Ok(raw) => self.handle_message(raw, &tx, &mut last_heartbeat).await,
517                                Err(e) => warn!("Failed to parse WS message: {} — raw: {}", e, &text[..text.len().min(200)]),
518                            }
519                        }
520                        Some(Ok(WsMessage::Ping(data))) => {
521                            let mut sink = self.sink.lock().await;
522                            let _ = sink.send(WsMessage::Pong(data)).await;
523                        }
524                        Some(Ok(WsMessage::Close(frame))) => {
525                            info!("WebSocket closed by server: {:?}", frame);
526                            break;
527                        }
528                        Some(Err(e)) => {
529                            error!("WebSocket error: {}", e);
530                            break;
531                        }
532                        None => {
533                            info!("WebSocket stream ended");
534                            break;
535                        }
536                        _ => {} // Binary, Pong, Frame — ignore
537                    }
538                }
539                _ = time::sleep(heartbeat_timeout) => {
540                    if last_heartbeat.elapsed() > heartbeat_timeout {
541                        warn!("Heartbeat timeout ({}s), closing connection",
542                            self.config.heartbeat_timeout_secs);
543                        break;
544                    }
545                }
546            }
547        }
548
549        // Attempt reconnection
550        self.try_reconnect(tx).await;
551        }) // end Box::pin
552    }
553
554    async fn handle_message(
555        &self,
556        raw: WsRawMessage,
557        tx: &mpsc::Sender<PredictWsMessage>,
558        last_heartbeat: &mut time::Instant,
559    ) {
560        match raw.msg_type.as_str() {
561            // Response to a request (subscribe/unsubscribe)
562            "R" => {
563                if let Some(req_id) = raw.request_id {
564                    let mut pending = self.pending.lock().await;
565                    if let Some(resp_tx) = pending.remove(&(req_id as u64)) {
566                        let result = if raw.success.unwrap_or(false) {
567                            Ok(())
568                        } else {
569                            let err_msg = raw
570                                .error
571                                .map(|e| e.to_string())
572                                .unwrap_or_else(|| "unknown error".to_string());
573                            Err(anyhow!("subscribe failed: {}", err_msg))
574                        };
575                        let _ = resp_tx.send(result);
576                    }
577                }
578            }
579            // Push message (subscription data)
580            "M" => {
581                let topic_str = match &raw.topic {
582                    Some(t) => t.as_str(),
583                    None => return,
584                };
585
586                // Handle heartbeat
587                if topic_str == "heartbeat" {
588                    *last_heartbeat = time::Instant::now();
589                    if let Some(data) = &raw.data {
590                        if let Err(e) = self.send_heartbeat(data).await {
591                            warn!("Failed to send heartbeat response: {}", e);
592                        }
593                    }
594                    return;
595                }
596
597                // Parse and dispatch
598                let data = match raw.data {
599                    Some(d) => d,
600                    None => return,
601                };
602
603                let parsed = parse_push_message(topic_str, data);
604                if tx.try_send(parsed).is_err() {
605                    warn!("Message channel full, dropping message for topic {}", topic_str);
606                }
607            }
608            other => {
609                debug!("Unknown WS message type: {}", other);
610            }
611        }
612    }
613
614    fn try_reconnect(
615        &self,
616        tx: mpsc::Sender<PredictWsMessage>,
617    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + '_>> {
618        Box::pin(async move {
619        let max_attempts = self.config.max_reconnect_attempts;
620        let max_backoff = self.config.max_reconnect_backoff_secs;
621        let mut attempt = 0u32;
622
623        loop {
624            if max_attempts > 0 && attempt >= max_attempts {
625                error!(
626                    "Max reconnection attempts ({}) reached, giving up",
627                    max_attempts
628                );
629                return;
630            }
631
632            let backoff = Duration::from_secs((2u64.pow(attempt.min(10))).min(max_backoff));
633            warn!(
634                "Reconnecting in {:?} (attempt {})",
635                backoff,
636                attempt + 1
637            );
638            time::sleep(backoff).await;
639            attempt += 1;
640
641            match tokio_tungstenite::connect_async(&self.config.url).await {
642                Ok((ws_stream, _)) => {
643                    info!("Reconnected to {}", self.config.url);
644                    let (new_sink, new_stream) = ws_stream.split();
645
646                    // Replace sink
647                    {
648                        let mut sink = self.sink.lock().await;
649                        *sink = new_sink;
650                    }
651
652                    // Resubscribe to all active topics
653                    let topics = self.active_topics.lock().await.clone();
654                    for topic_str in &topics {
655                        let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
656                        let msg = WsRequest {
657                            request_id,
658                            method: "subscribe".to_string(),
659                            params: Some(vec![topic_str.clone()]),
660                            data: None,
661                        };
662                        if let Err(e) = self.send_raw(&msg).await {
663                            warn!("Failed to resubscribe to {}: {}", topic_str, e);
664                        } else {
665                            debug!("Resubscribed to {}", topic_str);
666                        }
667                    }
668
669                    // Restart read loop
670                    self.read_loop(new_stream, tx).await;
671                    return;
672                }
673                Err(e) => {
674                    error!("Reconnection failed: {}", e);
675                }
676            }
677        }
678        }) // end Box::pin
679    }
680}
681
682// ── Message parsing ──
683
684fn parse_push_message(topic: &str, data: Value) -> PredictWsMessage {
685    if topic.starts_with("predictOrderbook/") {
686        if let Some(ob) = parse_orderbook(topic, &data) {
687            return PredictWsMessage::Orderbook(ob);
688        }
689    }
690
691    if topic.starts_with("assetPriceUpdate/") {
692        if let Some(price) = parse_asset_price(topic, &data) {
693            return PredictWsMessage::AssetPrice(price);
694        }
695    }
696
697    if topic.starts_with("polymarketChance/") {
698        if let Ok(id) = topic.strip_prefix("polymarketChance/").unwrap_or("").parse::<i64>() {
699            return PredictWsMessage::CrossVenueChance(CrossVenueChance {
700                source: CrossVenueSource::Polymarket,
701                market_id: id,
702                data,
703            });
704        }
705    }
706
707    if topic.starts_with("kalshiChance/") {
708        if let Ok(id) = topic.strip_prefix("kalshiChance/").unwrap_or("").parse::<i64>() {
709            return PredictWsMessage::CrossVenueChance(CrossVenueChance {
710                source: CrossVenueSource::Kalshi,
711                market_id: id,
712                data,
713            });
714        }
715    }
716
717    if topic.starts_with("predictWalletEvents/") {
718        return PredictWsMessage::WalletEvent(WalletEvent { data });
719    }
720
721    PredictWsMessage::Raw {
722        topic: topic.to_string(),
723        data,
724    }
725}
726
727fn parse_levels(val: &Value) -> Vec<Level> {
728    val.as_array()
729        .map(|arr| {
730            arr.iter()
731                .filter_map(|lvl| {
732                    let price = lvl.get(0).and_then(|v| v.as_f64())?;
733                    let size = lvl.get(1).and_then(|v| v.as_f64())?;
734                    Some((price, size))
735                })
736                .collect()
737        })
738        .unwrap_or_default()
739}
740
741fn parse_orderbook(topic: &str, data: &Value) -> Option<OrderbookSnapshot> {
742    let market_id = topic
743        .strip_prefix("predictOrderbook/")?
744        .parse::<i64>()
745        .ok()?;
746
747    let bids = parse_levels(data.get("bids")?);
748    let asks = parse_levels(data.get("asks")?);
749    let version = data.get("version").and_then(|v| v.as_u64()).unwrap_or(0);
750    let update_timestamp_ms = data
751        .get("updateTimestampMs")
752        .and_then(|v| v.as_u64())
753        .unwrap_or(0);
754    let order_count = data
755        .get("orderCount")
756        .and_then(|v| v.as_u64())
757        .unwrap_or(0);
758    let last_order_settled = data
759        .get("lastOrderSettled")
760        .and_then(|v| serde_json::from_value(v.clone()).ok());
761
762    Some(OrderbookSnapshot {
763        market_id,
764        bids,
765        asks,
766        version,
767        update_timestamp_ms,
768        order_count,
769        last_order_settled,
770    })
771}
772
773fn parse_asset_price(topic: &str, data: &Value) -> Option<AssetPriceUpdate> {
774    let feed_id = topic
775        .strip_prefix("assetPriceUpdate/")?
776        .parse::<i64>()
777        .ok()?;
778
779    let price = data.get("price").and_then(|v| v.as_f64())?;
780    let publish_time = data.get("publishTime").and_then(|v| v.as_u64()).unwrap_or(0);
781    let timestamp = data.get("timestamp").and_then(|v| v.as_u64()).unwrap_or(0);
782
783    Some(AssetPriceUpdate {
784        feed_id,
785        price,
786        publish_time,
787        timestamp,
788    })
789}
790
791// ── Known asset feed IDs ──
792
793/// Known asset price feed IDs. Use with `Topic::AssetPrice { feed_id }`.
794pub mod feeds {
795    /// Bitcoin price feed.
796    pub const BTC: i64 = 1;
797    /// Ethereum price feed.
798    pub const ETH: i64 = 4;
799    /// BNB price feed (tentative — needs confirmation).
800    pub const BNB: i64 = 2;
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806
807    #[test]
808    fn topic_roundtrip() {
809        let topics = vec![
810            Topic::Orderbook { market_id: 123 },
811            Topic::AssetPrice { feed_id: 1 },
812            Topic::PolymarketChance { market_id: 456 },
813            Topic::KalshiChance { market_id: 789 },
814            Topic::WalletEvents {
815                jwt: "abc123".to_string(),
816            },
817            Topic::Raw("custom/topic".to_string()),
818        ];
819
820        for topic in topics {
821            let s = topic.to_topic_string();
822            let parsed = Topic::from_topic_string(&s);
823            assert_eq!(topic, parsed, "Roundtrip failed for {}", s);
824        }
825    }
826
827    #[test]
828    fn topic_display() {
829        assert_eq!(
830            Topic::Orderbook { market_id: 42 }.to_string(),
831            "predictOrderbook/42"
832        );
833        assert_eq!(
834            Topic::AssetPrice { feed_id: 1 }.to_string(),
835            "assetPriceUpdate/1"
836        );
837    }
838
839    #[test]
840    fn parse_orderbook_snapshot() {
841        let data = serde_json::json!({
842            "asks": [[0.72, 15.0], [0.83, 5.88]],
843            "bids": [[0.57, 15.0], [0.38, 2.63]],
844            "marketId": 45532,
845            "version": 1,
846            "updateTimestampMs": 1772898630219u64,
847            "orderCount": 13,
848            "lastOrderSettled": {
849                "id": "20035648",
850                "kind": "LIMIT",
851                "marketId": 45532,
852                "outcome": "No",
853                "price": "0.60",
854                "side": "Bid"
855            }
856        });
857
858        let ob = parse_orderbook("predictOrderbook/45532", &data).unwrap();
859        assert_eq!(ob.market_id, 45532);
860        assert_eq!(ob.bids.len(), 2);
861        assert_eq!(ob.asks.len(), 2);
862        assert!((ob.bids[0].0 - 0.57).abs() < 1e-10);
863        assert!((ob.asks[0].0 - 0.72).abs() < 1e-10);
864        assert_eq!(ob.version, 1);
865        assert_eq!(ob.order_count, 13);
866        assert!(ob.last_order_settled.is_some());
867        assert!((ob.best_bid().unwrap() - 0.57).abs() < 1e-10);
868        assert!((ob.best_ask().unwrap() - 0.72).abs() < 1e-10);
869        assert!((ob.mid().unwrap() - 0.645).abs() < 1e-10);
870        assert!((ob.spread().unwrap() - 0.15).abs() < 1e-10);
871    }
872
873    #[test]
874    fn parse_asset_price_update() {
875        let data = serde_json::json!({
876            "price": 67853.57751504,
877            "publishTime": 1772898632u64,
878            "timestamp": 1772898633u64
879        });
880
881        let price = parse_asset_price("assetPriceUpdate/1", &data).unwrap();
882        assert_eq!(price.feed_id, 1);
883        assert!((price.price - 67853.577).abs() < 1.0);
884        assert_eq!(price.publish_time, 1772898632);
885        assert_eq!(price.timestamp, 1772898633);
886    }
887
888    #[test]
889    fn parse_push_message_dispatches_correctly() {
890        // Orderbook
891        let ob_data = serde_json::json!({"asks": [], "bids": [], "version": 1, "updateTimestampMs": 0, "orderCount": 0});
892        assert!(matches!(
893            parse_push_message("predictOrderbook/123", ob_data),
894            PredictWsMessage::Orderbook(_)
895        ));
896
897        // Asset price
898        let price_data = serde_json::json!({"price": 100.0, "publishTime": 0, "timestamp": 0});
899        assert!(matches!(
900            parse_push_message("assetPriceUpdate/1", price_data),
901            PredictWsMessage::AssetPrice(_)
902        ));
903
904        // Polymarket chance
905        let chance_data = serde_json::json!({"chance": 0.5});
906        assert!(matches!(
907            parse_push_message("polymarketChance/456", chance_data),
908            PredictWsMessage::CrossVenueChance(_)
909        ));
910
911        // Kalshi chance
912        let kalshi_data = serde_json::json!({"chance": 0.3});
913        assert!(matches!(
914            parse_push_message("kalshiChance/789", kalshi_data),
915            PredictWsMessage::CrossVenueChance(_)
916        ));
917
918        // Wallet events
919        let wallet_data = serde_json::json!({"event": "fill"});
920        assert!(matches!(
921            parse_push_message("predictWalletEvents/jwt123", wallet_data),
922            PredictWsMessage::WalletEvent(_)
923        ));
924
925        // Unknown
926        let unknown_data = serde_json::json!({"foo": "bar"});
927        assert!(matches!(
928            parse_push_message("unknown/topic", unknown_data),
929            PredictWsMessage::Raw { .. }
930        ));
931    }
932
933    #[test]
934    fn orderbook_snapshot_helpers_empty() {
935        let ob = OrderbookSnapshot {
936            market_id: 1,
937            bids: vec![],
938            asks: vec![],
939            version: 0,
940            update_timestamp_ms: 0,
941            order_count: 0,
942            last_order_settled: None,
943        };
944        assert!(ob.best_bid().is_none());
945        assert!(ob.best_ask().is_none());
946        assert!(ob.mid().is_none());
947        assert!(ob.spread().is_none());
948    }
949
950    #[test]
951    fn feed_id_constants() {
952        assert_eq!(feeds::BTC, 1);
953        assert_eq!(feeds::ETH, 4);
954        assert_eq!(feeds::BNB, 2);
955    }
956
957    #[test]
958    fn ws_endpoint_constants() {
959        assert_eq!(PREDICT_WS_MAINNET, "wss://ws.predict.fun/ws");
960        assert_eq!(PREDICT_WS_TESTNET, "wss://ws.bnb.predict.fail/ws");
961    }
962
963    #[test]
964    fn config_defaults() {
965        let config = PredictWsConfig::default();
966        assert_eq!(config.url, PREDICT_WS_MAINNET);
967        assert_eq!(config.channel_buffer, 1024);
968        assert_eq!(config.heartbeat_timeout_secs, 60);
969        assert_eq!(config.max_reconnect_attempts, 0);
970    }
971
972    #[test]
973    fn config_testnet() {
974        let config = PredictWsConfig::testnet();
975        assert_eq!(config.url, PREDICT_WS_TESTNET);
976    }
977}