Skip to main content

predict_sdk/websocket/
client.rs

1//! WebSocket client for Predict.fun real-time data
2//!
3//! Handles connection, subscriptions, heartbeat, and message parsing.
4//! Uses `ws-reconnect-client` for low-level connection management.
5
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9use dashmap::DashSet;
10use futures_util::{SinkExt, StreamExt};
11use tokio::sync::Mutex;
12use tracing::{debug, error, info, warn};
13use ws_reconnect_client::{connect_with_retry, Message, WsConnectionConfig, WsReader, WsWriter};
14
15use super::types::{AssetPriceData, OrderbookData, PushMessage, RawWsMessage, WsMessage, WsRequest};
16use crate::api_types::PredictWalletEvent;
17use crate::errors::{Error, Result};
18
19/// WebSocket client for Predict.fun
20pub struct PredictWebSocket {
21    config: WsConnectionConfig,
22    subscribed_markets: DashSet<u64>,
23    writer: Arc<Mutex<Option<WsWriter>>>,
24    next_request_id: AtomicU64,
25}
26
27impl PredictWebSocket {
28    /// Create a new PredictWebSocket client
29    pub fn new(ws_url: String) -> Self {
30        // Configure for Predict.fun:
31        // - Disable automatic ping (Predict uses custom heartbeat)
32        // - Use reasonable retry settings
33        let config = WsConnectionConfig::new(ws_url)
34            .with_ping_interval(0) // Disable - we handle heartbeat manually
35            .with_retries(10)
36            .with_backoff(1000, 30_000);
37
38        Self {
39            config,
40            subscribed_markets: DashSet::new(),
41            writer: Arc::new(Mutex::new(None)),
42            next_request_id: AtomicU64::new(1),
43        }
44    }
45
46    /// Get the next unique request ID
47    fn next_id(&self) -> u64 {
48        self.next_request_id.fetch_add(1, Ordering::SeqCst)
49    }
50
51    /// Connect to the WebSocket and return a message stream
52    ///
53    /// The returned stream yields parsed WsMessage items.
54    /// Heartbeat messages are handled automatically.
55    pub async fn connect(&self) -> Result<PredictWsStream> {
56        info!("Connecting to Predict WebSocket: {}", self.config.url);
57
58        let (writer, reader) = connect_with_retry(&self.config)
59            .await
60            .map_err(|e| Error::Other(format!("WebSocket connection failed: {}", e)))?;
61
62        // Store the writer for sending messages
63        {
64            let mut w = self.writer.lock().await;
65            *w = Some(writer);
66        }
67
68        info!("Connected to Predict WebSocket");
69
70        Ok(PredictWsStream {
71            reader,
72            writer: self.writer.clone(),
73        })
74    }
75
76    /// Subscribe to orderbook updates for a market
77    pub async fn subscribe_orderbook(&self, market_id: u64) -> Result<()> {
78        let topic = format!("predictOrderbook/{}", market_id);
79        let request_id = self.next_id();
80
81        let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
82        self.send_request(&request).await?;
83
84        self.subscribed_markets.insert(market_id);
85        info!("Subscribed to orderbook for market {}", market_id);
86
87        Ok(())
88    }
89
90    /// Unsubscribe from orderbook updates for a market
91    pub async fn unsubscribe_orderbook(&self, market_id: u64) -> Result<()> {
92        let topic = format!("predictOrderbook/{}", market_id);
93        let request_id = self.next_id();
94
95        let request = WsRequest::unsubscribe(request_id, vec![topic]);
96        self.send_request(&request).await?;
97
98        self.subscribed_markets.remove(&market_id);
99        info!("Unsubscribed from orderbook for market {}", market_id);
100
101        Ok(())
102    }
103
104    /// Subscribe to asset price updates for a price feed
105    ///
106    /// The price_feed_id is typically a Pyth price feed ID (hex string).
107    /// Common feeds:
108    /// - BTC/USD: 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43
109    /// - ETH/USD: 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace
110    pub async fn subscribe_asset_price(&self, price_feed_id: &str) -> Result<()> {
111        let topic = format!("assetPriceUpdate/{}", price_feed_id);
112        let request_id = self.next_id();
113
114        let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
115        self.send_request(&request).await?;
116
117        info!("Subscribed to asset price for feed {}", price_feed_id);
118
119        Ok(())
120    }
121
122    /// Unsubscribe from asset price updates for a price feed
123    pub async fn unsubscribe_asset_price(&self, price_feed_id: &str) -> Result<()> {
124        let topic = format!("assetPriceUpdate/{}", price_feed_id);
125        let request_id = self.next_id();
126
127        let request = WsRequest::unsubscribe(request_id, vec![topic]);
128        self.send_request(&request).await?;
129
130        info!("Unsubscribed from asset price for feed {}", price_feed_id);
131
132        Ok(())
133    }
134
135    /// Subscribe to Polymarket chance updates for a market
136    ///
137    /// Receives chance/price data from Polymarket for cross-platform comparison.
138    pub async fn subscribe_polymarket_chance(&self, market_id: u64) -> Result<()> {
139        let topic = format!("polymarketChance/{}", market_id);
140        let request_id = self.next_id();
141
142        let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
143        self.send_request(&request).await?;
144
145        info!("Subscribed to Polymarket chance for market {}", market_id);
146
147        Ok(())
148    }
149
150    /// Subscribe to Kalshi chance updates for a market
151    ///
152    /// Receives chance/price data from Kalshi for cross-platform comparison.
153    pub async fn subscribe_kalshi_chance(&self, market_id: u64) -> Result<()> {
154        let topic = format!("kalshiChance/{}", market_id);
155        let request_id = self.next_id();
156
157        let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
158        self.send_request(&request).await?;
159
160        info!("Subscribed to Kalshi chance for market {}", market_id);
161
162        Ok(())
163    }
164
165    /// Subscribe to wallet events (order fills, cancellations, etc.)
166    ///
167    /// Requires a JWT token obtained from `PredictClient::authenticate()`.
168    /// Topic: `predictWalletEvents/{jwt}`
169    ///
170    /// Events include:
171    /// - orderAccepted: Order placed in orderbook
172    /// - orderTransactionSuccess: Order filled on-chain
173    /// - orderCancelled: Order cancelled
174    /// - orderTransactionFailed: On-chain transaction failed
175    pub async fn subscribe_wallet_events(&self, jwt: &str) -> Result<()> {
176        let topic = format!("predictWalletEvents/{}", jwt);
177        let request_id = self.next_id();
178
179        let request = WsRequest::subscribe(request_id, vec![topic]);
180        self.send_request(&request).await?;
181
182        info!("Subscribed to wallet events");
183
184        Ok(())
185    }
186
187    /// Unsubscribe from wallet events
188    pub async fn unsubscribe_wallet_events(&self, jwt: &str) -> Result<()> {
189        let topic = format!("predictWalletEvents/{}", jwt);
190        let request_id = self.next_id();
191
192        let request = WsRequest::unsubscribe(request_id, vec![topic]);
193        self.send_request(&request).await?;
194
195        info!("Unsubscribed from wallet events");
196
197        Ok(())
198    }
199
200    /// Send a heartbeat response with the given timestamp
201    pub async fn send_heartbeat(&self, timestamp: u64) -> Result<()> {
202        let request = WsRequest::heartbeat(timestamp);
203        self.send_request(&request).await?;
204        debug!("Sent heartbeat response: {}", timestamp);
205        Ok(())
206    }
207
208    /// Send a request to the WebSocket
209    async fn send_request(&self, request: &WsRequest) -> Result<()> {
210        let json = serde_json::to_string(request)
211            .map_err(|e| Error::Other(format!("Failed to serialize request: {}", e)))?;
212
213        let mut writer_guard = self.writer.lock().await;
214        let writer = writer_guard
215            .as_mut()
216            .ok_or_else(|| Error::Other("WebSocket not connected".to_string()))?;
217
218        writer
219            .send(Message::Text(json.into()))
220            .await
221            .map_err(|e| Error::Other(format!("Failed to send message: {}", e)))?;
222
223        Ok(())
224    }
225
226    /// Reconnect to the WebSocket server
227    ///
228    /// Clears the old writer and establishes a fresh connection.
229    /// Subscriptions must be re-sent after reconnecting.
230    pub async fn reconnect(&self) -> Result<PredictWsStream> {
231        {
232            let mut w = self.writer.lock().await;
233            *w = None;
234        }
235        self.connect().await
236    }
237
238    /// Get the connection config (for backoff settings)
239    pub fn config(&self) -> &WsConnectionConfig {
240        &self.config
241    }
242
243    /// Check if connected
244    pub async fn is_connected(&self) -> bool {
245        self.writer.lock().await.is_some()
246    }
247
248    /// Get list of subscribed market IDs
249    pub fn subscribed_markets(&self) -> Vec<u64> {
250        self.subscribed_markets.iter().map(|r| *r).collect()
251    }
252
253    /// Get a clone of the writer Arc for external heartbeat handling
254    pub fn writer(&self) -> Arc<Mutex<Option<WsWriter>>> {
255        self.writer.clone()
256    }
257}
258
259/// WebSocket message stream that yields parsed messages
260pub struct PredictWsStream {
261    reader: WsReader,
262    writer: Arc<Mutex<Option<WsWriter>>>,
263}
264
265impl PredictWsStream {
266    /// Get the next message from the stream
267    ///
268    /// Returns None if the connection is closed.
269    /// Automatically responds to heartbeat messages.
270    pub async fn next(&mut self) -> Option<Result<WsMessage>> {
271        loop {
272            match self.reader.next().await {
273                Some(Ok(Message::Text(text))) => {
274                    match self.parse_message(&text).await {
275                        Ok(Some(msg)) => return Some(Ok(msg)),
276                        Ok(None) => continue, // Heartbeat was handled, get next message
277                        Err(e) => return Some(Err(e)),
278                    }
279                }
280                Some(Ok(Message::Ping(data))) => {
281                    // Respond to ping with pong
282                    if let Err(e) = self.send_pong(data.to_vec()).await {
283                        warn!("Failed to send pong: {}", e);
284                    }
285                    continue;
286                }
287                Some(Ok(Message::Pong(_))) => {
288                    // Ignore pong messages
289                    continue;
290                }
291                Some(Ok(Message::Close(frame))) => {
292                    info!("WebSocket closed: {:?}", frame);
293                    return None;
294                }
295                Some(Ok(Message::Binary(_))) => {
296                    warn!("Received unexpected binary message");
297                    continue;
298                }
299                Some(Ok(Message::Frame(_))) => {
300                    // Raw frame, ignore
301                    continue;
302                }
303                Some(Err(e)) => {
304                    error!("WebSocket error: {}", e);
305                    return Some(Err(Error::Other(format!("WebSocket error: {}", e))));
306                }
307                None => {
308                    info!("WebSocket stream ended");
309                    return None;
310                }
311            }
312        }
313    }
314
315    /// Parse a text message and handle heartbeats automatically
316    async fn parse_message(&mut self, text: &str) -> Result<Option<WsMessage>> {
317        let raw: RawWsMessage = serde_json::from_str(text)
318            .map_err(|e| Error::Other(format!("Failed to parse message: {} - {}", e, text)))?;
319
320        let msg = WsMessage::try_from(raw)
321            .map_err(|e| Error::Other(format!("Failed to convert message: {}", e)))?;
322
323        // Handle heartbeat automatically
324        if let WsMessage::PushMessage(ref push) = msg {
325            if let Some(timestamp) = push.heartbeat_timestamp() {
326                self.send_heartbeat(timestamp).await?;
327                return Ok(None); // Don't yield heartbeat messages
328            }
329        }
330
331        Ok(Some(msg))
332    }
333
334    /// Send a heartbeat response
335    async fn send_heartbeat(&mut self, timestamp: u64) -> Result<()> {
336        let request = WsRequest::heartbeat(timestamp);
337        let json = serde_json::to_string(&request)
338            .map_err(|e| Error::Other(format!("Failed to serialize heartbeat: {}", e)))?;
339
340        let mut writer_guard = self.writer.lock().await;
341        if let Some(writer) = writer_guard.as_mut() {
342            writer
343                .send(Message::Text(json.into()))
344                .await
345                .map_err(|e| Error::Other(format!("Failed to send heartbeat: {}", e)))?;
346            debug!("Sent heartbeat response: {}", timestamp);
347        }
348
349        Ok(())
350    }
351
352    /// Send a pong response
353    async fn send_pong(&mut self, data: Vec<u8>) -> Result<()> {
354        let mut writer_guard = self.writer.lock().await;
355        if let Some(writer) = writer_guard.as_mut() {
356            writer
357                .send(Message::Pong(data.into()))
358                .await
359                .map_err(|e| Error::Other(format!("Failed to send pong: {}", e)))?;
360        }
361        Ok(())
362    }
363}
364
365/// Parse orderbook data from a push message
366pub fn parse_orderbook_update(push: &PushMessage) -> Result<OrderbookData> {
367    if !push.is_orderbook() {
368        return Err(Error::Other("Not an orderbook message".to_string()));
369    }
370
371    serde_json::from_value(push.data.clone())
372        .map_err(|e| Error::Other(format!("Failed to parse orderbook data: {}", e)))
373}
374
375/// Parse asset price data from a push message
376pub fn parse_asset_price_update(push: &PushMessage) -> Result<AssetPriceData> {
377    if !push.is_asset_price() {
378        return Err(Error::Other("Not an asset price message".to_string()));
379    }
380
381    serde_json::from_value(push.data.clone())
382        .map_err(|e| Error::Other(format!("Failed to parse asset price data: {}", e)))
383}
384
385/// Parse wallet event data from a push message
386///
387/// Wallet events are received on the `predictWalletEvents/{jwt}` topic.
388/// The data payload contains `type` (event type string) and event-specific fields.
389pub fn parse_wallet_event(push: &PushMessage) -> Result<PredictWalletEvent> {
390    if !push.is_wallet_event() {
391        return Err(Error::Other("Not a wallet event message".to_string()));
392    }
393
394    let event_type = push
395        .data
396        .get("type")
397        .and_then(|v| v.as_str())
398        .unwrap_or("")
399        .to_string();
400
401    let order_hash = push
402        .data
403        .get("orderHash")
404        .and_then(|v| v.as_str())
405        .unwrap_or("")
406        .to_string();
407
408    // orderId can be string or number in the WS payload.
409    // Predict WS sends BigInt format with trailing "n" (e.g., "4175379n") — strip it.
410    let order_id = push
411        .data
412        .get("orderId")
413        .map(|v| match v {
414            serde_json::Value::String(s) => s.strip_suffix('n').unwrap_or(s).to_string(),
415            serde_json::Value::Number(n) => n.to_string(),
416            _ => String::new(),
417        })
418        .unwrap_or_default();
419
420    let tx_hash = push
421        .data
422        .get("txHash")
423        .and_then(|v| v.as_str())
424        .map(|s| s.to_string());
425
426    let reason = push
427        .data
428        .get("reason")
429        .and_then(|v| v.as_str())
430        .map(|s| s.to_string());
431
432    // Parse details object (present on transaction events)
433    let details = push.data.get("details").map(|d| {
434        use crate::WalletEventDetails;
435        WalletEventDetails {
436            price: d.get("price").and_then(|v| v.as_str()).map(|s| s.to_string()),
437            quantity: d.get("quantity").and_then(|v| v.as_str()).map(|s| s.to_string()),
438            quantity_filled: d.get("quantityFilled").and_then(|v| v.as_str()).map(|s| s.to_string()),
439            outcome: d.get("outcome").and_then(|v| v.as_str()).map(|s| s.to_string()),
440            quote_type: d.get("quoteType").and_then(|v| v.as_str()).map(|s| s.to_string()),
441        }
442    }).unwrap_or_default();
443
444    match event_type.as_str() {
445        "orderAccepted" => Ok(PredictWalletEvent::OrderAccepted { order_hash, order_id }),
446        "orderNotAccepted" => Ok(PredictWalletEvent::OrderNotAccepted {
447            order_hash,
448            order_id,
449            reason,
450        }),
451        "orderExpired" => Ok(PredictWalletEvent::OrderExpired { order_hash, order_id }),
452        "orderCancelled" => Ok(PredictWalletEvent::OrderCancelled { order_hash, order_id }),
453        "orderTransactionSubmitted" => Ok(PredictWalletEvent::OrderTransactionSubmitted {
454            order_hash,
455            order_id,
456            tx_hash,
457            details,
458        }),
459        "orderTransactionSuccess" => Ok(PredictWalletEvent::OrderTransactionSuccess {
460            order_hash,
461            order_id,
462            tx_hash,
463            details,
464        }),
465        "orderTransactionFailed" => Ok(PredictWalletEvent::OrderTransactionFailed {
466            order_hash,
467            order_id,
468            tx_hash,
469            details,
470        }),
471        _ => Ok(PredictWalletEvent::Unknown {
472            event_type,
473            data: push.data.clone(),
474        }),
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_client_creation() {
484        let client = PredictWebSocket::new("wss://ws.predict.fun/ws".to_string());
485        assert!(client.subscribed_markets().is_empty());
486    }
487
488    #[test]
489    fn test_request_id_increment() {
490        let client = PredictWebSocket::new("wss://ws.predict.fun/ws".to_string());
491        assert_eq!(client.next_id(), 1);
492        assert_eq!(client.next_id(), 2);
493        assert_eq!(client.next_id(), 3);
494    }
495
496    fn wallet_push(data: serde_json::Value) -> PushMessage {
497        PushMessage {
498            topic: "predictWalletEvents/jwt123".to_string(),
499            data,
500        }
501    }
502
503    #[test]
504    fn test_parse_order_accepted() {
505        let push = wallet_push(serde_json::json!({
506            "type": "orderAccepted",
507            "orderId": "4170746",
508            "orderHash": "0xb5b5b676abcd"
509        }));
510        let event = parse_wallet_event(&push).unwrap();
511        match event {
512            PredictWalletEvent::OrderAccepted { order_hash, order_id } => {
513                assert_eq!(order_hash, "0xb5b5b676abcd");
514                assert_eq!(order_id, "4170746");
515            }
516            other => panic!("Expected OrderAccepted, got {:?}", other),
517        }
518    }
519
520    #[test]
521    fn test_parse_order_transaction_submitted() {
522        let push = wallet_push(serde_json::json!({
523            "type": "orderTransactionSubmitted",
524            "orderId": 4170746,
525            "orderHash": "0xb5b5b676abcd",
526            "txHash": "0xdeadbeef"
527        }));
528        let event = parse_wallet_event(&push).unwrap();
529        match event {
530            PredictWalletEvent::OrderTransactionSubmitted { order_hash, order_id, tx_hash, .. } => {
531                assert_eq!(order_hash, "0xb5b5b676abcd");
532                assert_eq!(order_id, "4170746");
533                assert_eq!(tx_hash, Some("0xdeadbeef".to_string()));
534            }
535            other => panic!("Expected OrderTransactionSubmitted, got {:?}", other),
536        }
537    }
538
539    #[test]
540    fn test_parse_order_transaction_success() {
541        let push = wallet_push(serde_json::json!({
542            "type": "orderTransactionSuccess",
543            "orderId": "4170746",
544            "txHash": "0xdeadbeef"
545        }));
546        let event = parse_wallet_event(&push).unwrap();
547        match event {
548            PredictWalletEvent::OrderTransactionSuccess { order_hash, order_id, tx_hash, .. } => {
549                assert_eq!(order_hash, ""); // no orderHash in payload
550                assert_eq!(order_id, "4170746");
551                assert_eq!(tx_hash, Some("0xdeadbeef".to_string()));
552            }
553            other => panic!("Expected OrderTransactionSuccess, got {:?}", other),
554        }
555    }
556
557    #[test]
558    fn test_parse_order_not_accepted() {
559        let push = wallet_push(serde_json::json!({
560            "type": "orderNotAccepted",
561            "orderId": "123",
562            "orderHash": "0xabc",
563            "reason": "insufficient balance"
564        }));
565        let event = parse_wallet_event(&push).unwrap();
566        match event {
567            PredictWalletEvent::OrderNotAccepted { order_hash, order_id, reason } => {
568                assert_eq!(order_hash, "0xabc");
569                assert_eq!(order_id, "123");
570                assert_eq!(reason, Some("insufficient balance".to_string()));
571            }
572            other => panic!("Expected OrderNotAccepted, got {:?}", other),
573        }
574    }
575
576    #[test]
577    fn test_parse_unknown_event_type() {
578        let push = wallet_push(serde_json::json!({
579            "type": "newEventType",
580            "foo": "bar"
581        }));
582        let event = parse_wallet_event(&push).unwrap();
583        match event {
584            PredictWalletEvent::Unknown { event_type, .. } => {
585                assert_eq!(event_type, "newEventType");
586            }
587            other => panic!("Expected Unknown, got {:?}", other),
588        }
589    }
590
591    #[test]
592    fn test_parse_missing_type_field() {
593        // If data has no "type" field at all, should produce Unknown with empty event_type
594        let push = wallet_push(serde_json::json!({
595            "orderId": "123"
596        }));
597        let event = parse_wallet_event(&push).unwrap();
598        match event {
599            PredictWalletEvent::Unknown { event_type, .. } => {
600                assert_eq!(event_type, "");
601            }
602            other => panic!("Expected Unknown, got {:?}", other),
603        }
604    }
605
606    #[test]
607    fn test_bigint_order_id_suffix_stripped() {
608        // Predict WS sends orderId as BigInt string with trailing "n"
609        let push = wallet_push(serde_json::json!({
610            "type": "orderAccepted",
611            "orderId": "4175379n"
612        }));
613        let event = parse_wallet_event(&push).unwrap();
614        match event {
615            PredictWalletEvent::OrderAccepted { order_id, order_hash } => {
616                assert_eq!(order_id, "4175379"); // "n" stripped
617                assert_eq!(order_hash, ""); // not in WS payload
618            }
619            other => panic!("Expected OrderAccepted, got {:?}", other),
620        }
621    }
622
623    #[test]
624    fn test_parse_details_from_production_payload() {
625        // Real production payload structure from log1.log
626        let push = wallet_push(serde_json::json!({
627            "type": "orderTransactionSuccess",
628            "orderId": "4170746n",
629            "timestamp": 1769952855099u64,
630            "details": {
631                "categorySlug": "btc-usd-up-down-2026-02-01-08-30-15-minutes",
632                "marketQuestion": "BTC/USD Up or Down - February 1, 8:30-8:45AM ET",
633                "outcome": "YES",
634                "price": "0.290",
635                "quantity": "5.000",
636                "quantityFilled": "5.000",
637                "quoteType": "ASK",
638                "strategyType": "LIMIT",
639                "value": "1.45",
640                "valueFilled": "1.45"
641            }
642        }));
643        let event = parse_wallet_event(&push).unwrap();
644        match event {
645            PredictWalletEvent::OrderTransactionSuccess { order_id, details, .. } => {
646                assert_eq!(order_id, "4170746");
647                assert_eq!(details.price.as_deref(), Some("0.290"));
648                assert_eq!(details.quantity.as_deref(), Some("5.000"));
649                assert_eq!(details.quantity_filled.as_deref(), Some("5.000"));
650                assert_eq!(details.outcome.as_deref(), Some("YES"));
651                assert_eq!(details.quote_type.as_deref(), Some("ASK"));
652            }
653            other => panic!("Expected OrderTransactionSuccess, got {:?}", other),
654        }
655    }
656
657    #[test]
658    fn test_non_wallet_event_rejected() {
659        let push = PushMessage {
660            topic: "predictOrderbook/123".to_string(),
661            data: serde_json::json!({"type": "orderAccepted"}),
662        };
663        assert!(parse_wallet_event(&push).is_err());
664    }
665}