1use std::collections::HashMap;
48use std::sync::atomic::{AtomicU64, Ordering};
49use std::sync::Arc;
50
51use anyhow::{anyhow, Context, Result};
52use futures_util::{SinkExt, StreamExt};
53use serde::{Deserialize, Serialize};
54use serde_json::Value;
55use tokio::sync::{mpsc, oneshot, Mutex};
56use tokio::time::{self, Duration};
57use tokio_tungstenite::tungstenite::Message as WsMessage;
58use tracing::{debug, error, info, warn};
59
60pub const PREDICT_WS_MAINNET: &str = "wss://ws.predict.fun/ws";
62pub const PREDICT_WS_TESTNET: &str = "wss://ws.bnb.predict.fail/ws";
63
64pub const PREDICT_GQL_MAINNET: &str = "https://graphql.predict.fun/graphql";
66pub const PREDICT_GQL_TESTNET: &str = "https://graphql.bnb.predict.fail/graphql";
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72pub enum Topic {
73 Orderbook { market_id: i64 },
76
77 AssetPrice { feed_id: i64 },
81
82 PolymarketChance { market_id: i64 },
85
86 KalshiChance { market_id: i64 },
89
90 WalletEvents { jwt: String },
93
94 Raw(String),
96}
97
98impl Topic {
99 pub fn to_topic_string(&self) -> String {
101 match self {
102 Topic::Orderbook { market_id } => format!("predictOrderbook/{}", market_id),
103 Topic::AssetPrice { feed_id } => format!("assetPriceUpdate/{}", feed_id),
104 Topic::PolymarketChance { market_id } => format!("polymarketChance/{}", market_id),
105 Topic::KalshiChance { market_id } => format!("kalshiChance/{}", market_id),
106 Topic::WalletEvents { jwt } => format!("predictWalletEvents/{}", jwt),
107 Topic::Raw(s) => s.clone(),
108 }
109 }
110
111 pub fn from_topic_string(s: &str) -> Self {
113 if let Some(rest) = s.strip_prefix("predictOrderbook/") {
114 if let Ok(id) = rest.parse::<i64>() {
115 return Topic::Orderbook { market_id: id };
116 }
117 }
118 if let Some(rest) = s.strip_prefix("assetPriceUpdate/") {
119 if let Ok(id) = rest.parse::<i64>() {
120 return Topic::AssetPrice { feed_id: id };
121 }
122 }
123 if let Some(rest) = s.strip_prefix("polymarketChance/") {
124 if let Ok(id) = rest.parse::<i64>() {
125 return Topic::PolymarketChance { market_id: id };
126 }
127 }
128 if let Some(rest) = s.strip_prefix("kalshiChance/") {
129 if let Ok(id) = rest.parse::<i64>() {
130 return Topic::KalshiChance { market_id: id };
131 }
132 }
133 if let Some(rest) = s.strip_prefix("predictWalletEvents/") {
134 return Topic::WalletEvents {
135 jwt: rest.to_string(),
136 };
137 }
138 Topic::Raw(s.to_string())
139 }
140}
141
142impl std::fmt::Display for Topic {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 write!(f, "{}", self.to_topic_string())
145 }
146}
147
148pub type Level = (f64, f64);
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct LastOrderSettled {
156 pub id: String,
157 pub kind: String,
158 #[serde(rename = "marketId")]
159 pub market_id: i64,
160 pub outcome: String,
161 pub price: String,
162 pub side: String,
163}
164
165#[derive(Debug, Clone)]
167pub struct OrderbookSnapshot {
168 pub market_id: i64,
169 pub bids: Vec<Level>,
170 pub asks: Vec<Level>,
171 pub version: u64,
172 pub update_timestamp_ms: u64,
173 pub order_count: u64,
174 pub last_order_settled: Option<LastOrderSettled>,
175}
176
177impl OrderbookSnapshot {
178 pub fn best_bid(&self) -> Option<f64> {
180 self.bids.first().map(|(p, _)| *p)
181 }
182
183 pub fn best_ask(&self) -> Option<f64> {
185 self.asks.first().map(|(p, _)| *p)
186 }
187
188 pub fn mid(&self) -> Option<f64> {
190 match (self.best_bid(), self.best_ask()) {
191 (Some(b), Some(a)) => Some((b + a) / 2.0),
192 _ => None,
193 }
194 }
195
196 pub fn spread(&self) -> Option<f64> {
198 match (self.best_bid(), self.best_ask()) {
199 (Some(b), Some(a)) => Some(a - b),
200 _ => None,
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct AssetPriceUpdate {
208 pub feed_id: i64,
209 pub price: f64,
210 pub publish_time: u64,
211 pub timestamp: u64,
212}
213
214#[derive(Debug, Clone)]
216pub struct CrossVenueChance {
217 pub source: CrossVenueSource,
218 pub market_id: i64,
219 pub data: Value,
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum CrossVenueSource {
225 Polymarket,
226 Kalshi,
227}
228
229#[derive(Debug, Clone)]
231pub struct WalletEvent {
232 pub data: Value,
233}
234
235#[derive(Debug, Clone)]
237pub enum PredictWsMessage {
238 Orderbook(OrderbookSnapshot),
240 AssetPrice(AssetPriceUpdate),
242 CrossVenueChance(CrossVenueChance),
244 WalletEvent(WalletEvent),
246 Raw { topic: String, data: Value },
248}
249
250#[derive(Serialize)]
253struct WsRequest {
254 #[serde(rename = "requestId")]
255 request_id: u64,
256 method: String,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 params: Option<Vec<String>>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 data: Option<Value>,
261}
262
263#[derive(Deserialize)]
264struct WsRawMessage {
265 #[serde(rename = "type")]
266 msg_type: String,
267 #[serde(rename = "requestId")]
268 request_id: Option<i64>,
269 success: Option<bool>,
270 error: Option<WsError>,
271 topic: Option<String>,
272 data: Option<Value>,
273}
274
275#[derive(Deserialize, Debug, Clone)]
276struct WsError {
277 code: String,
278 message: Option<String>,
279}
280
281impl std::fmt::Display for WsError {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 write!(f, "{}", self.code)?;
284 if let Some(msg) = &self.message {
285 write!(f, ": {}", msg)?;
286 }
287 Ok(())
288 }
289}
290
291type WsSink = futures_util::stream::SplitSink<
294 tokio_tungstenite::WebSocketStream<
295 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
296 >,
297 WsMessage,
298>;
299
300type PendingResponse = oneshot::Sender<Result<()>>;
302
303#[derive(Debug, Clone)]
305pub struct PredictWsConfig {
306 pub url: String,
308 pub channel_buffer: usize,
310 pub heartbeat_timeout_secs: u64,
312 pub max_reconnect_attempts: u32,
314 pub max_reconnect_backoff_secs: u64,
316}
317
318impl Default for PredictWsConfig {
319 fn default() -> Self {
320 Self {
321 url: PREDICT_WS_MAINNET.to_string(),
322 channel_buffer: 1024,
323 heartbeat_timeout_secs: 60,
324 max_reconnect_attempts: 0,
325 max_reconnect_backoff_secs: 15,
326 }
327 }
328}
329
330impl PredictWsConfig {
331 pub fn mainnet() -> Self {
332 Self::default()
333 }
334
335 pub fn testnet() -> Self {
336 Self {
337 url: PREDICT_WS_TESTNET.to_string(),
338 ..Self::default()
339 }
340 }
341}
342
343#[derive(Clone)]
348pub struct PredictWsClient {
349 sink: Arc<Mutex<WsSink>>,
350 request_id: Arc<AtomicU64>,
351 pending: Arc<Mutex<HashMap<u64, PendingResponse>>>,
352 active_topics: Arc<Mutex<Vec<String>>>,
353 config: PredictWsConfig,
354}
355
356impl PredictWsClient {
357 pub async fn connect_mainnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
359 Self::connect(PredictWsConfig::mainnet()).await
360 }
361
362 pub async fn connect_testnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
364 Self::connect(PredictWsConfig::testnet()).await
365 }
366
367 pub async fn connect(
369 config: PredictWsConfig,
370 ) -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
371 let (ws_stream, _) = tokio_tungstenite::connect_async(&config.url)
372 .await
373 .with_context(|| format!("failed to connect to {}", config.url))?;
374
375 info!("Connected to {}", config.url);
376
377 let (sink, stream) = ws_stream.split();
378 let (tx, rx) = mpsc::channel(config.channel_buffer);
379
380 let client = Self {
381 sink: Arc::new(Mutex::new(sink)),
382 request_id: Arc::new(AtomicU64::new(0)),
383 pending: Arc::new(Mutex::new(HashMap::new())),
384 active_topics: Arc::new(Mutex::new(Vec::new())),
385 config: config.clone(),
386 };
387
388 let client_clone = client.clone();
390 tokio::spawn(async move {
391 client_clone.read_loop(stream, tx).await
392 });
393
394 Ok((client, rx))
395 }
396
397 pub async fn subscribe(&self, topic: Topic) -> Result<()> {
399 let topic_str = topic.to_topic_string();
400 let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
401
402 let (resp_tx, resp_rx) = oneshot::channel();
403 {
404 let mut pending = self.pending.lock().await;
405 pending.insert(request_id, resp_tx);
406 }
407
408 let msg = WsRequest {
409 request_id,
410 method: "subscribe".to_string(),
411 params: Some(vec![topic_str.clone()]),
412 data: None,
413 };
414
415 self.send_raw(&msg).await?;
416 debug!("Subscribing to {} (requestId={})", topic_str, request_id);
417
418 let result = tokio::time::timeout(Duration::from_secs(10), resp_rx)
420 .await
421 .map_err(|_| anyhow!("subscribe timeout for {}", topic_str))?
422 .map_err(|_| anyhow!("subscribe channel closed for {}", topic_str))??;
423
424 {
426 let mut topics = self.active_topics.lock().await;
427 if !topics.contains(&topic_str) {
428 topics.push(topic_str.clone());
429 }
430 }
431
432 info!("Subscribed to {}", topic_str);
433 Ok(result)
434 }
435
436 pub async fn unsubscribe(&self, topic: Topic) -> Result<()> {
438 let topic_str = topic.to_topic_string();
439 let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
440
441 let (resp_tx, resp_rx) = oneshot::channel();
442 {
443 let mut pending = self.pending.lock().await;
444 pending.insert(request_id, resp_tx);
445 }
446
447 let msg = WsRequest {
448 request_id,
449 method: "unsubscribe".to_string(),
450 params: Some(vec![topic_str.clone()]),
451 data: None,
452 };
453
454 self.send_raw(&msg).await?;
455
456 tokio::time::timeout(Duration::from_secs(10), resp_rx)
457 .await
458 .map_err(|_| anyhow!("unsubscribe timeout for {}", topic_str))?
459 .map_err(|_| anyhow!("unsubscribe channel closed for {}", topic_str))??;
460
461 {
463 let mut topics = self.active_topics.lock().await;
464 topics.retain(|t| t != &topic_str);
465 }
466
467 info!("Unsubscribed from {}", topic_str);
468 Ok(())
469 }
470
471 pub async fn active_topics(&self) -> Vec<String> {
473 self.active_topics.lock().await.clone()
474 }
475
476 async fn send_heartbeat(&self, data: &Value) -> Result<()> {
478 let msg = WsRequest {
479 request_id: self.request_id.fetch_add(1, Ordering::Relaxed),
480 method: "heartbeat".to_string(),
481 params: None,
482 data: Some(data.clone()),
483 };
484 self.send_raw(&msg).await
485 }
486
487 async fn send_raw(&self, msg: &WsRequest) -> Result<()> {
488 let text = serde_json::to_string(msg).context("failed to serialize WS message")?;
489 let mut sink = self.sink.lock().await;
490 sink.send(WsMessage::Text(text))
491 .await
492 .context("failed to send WS message")?;
493 Ok(())
494 }
495
496 fn read_loop(
498 &self,
499 mut stream: futures_util::stream::SplitStream<
500 tokio_tungstenite::WebSocketStream<
501 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
502 >,
503 >,
504 tx: mpsc::Sender<PredictWsMessage>,
505 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + '_>> {
506 Box::pin(async move {
507 let heartbeat_timeout = Duration::from_secs(self.config.heartbeat_timeout_secs);
508 let mut last_heartbeat = time::Instant::now();
509
510 loop {
511 tokio::select! {
512 msg = stream.next() => {
513 match msg {
514 Some(Ok(WsMessage::Text(text))) => {
515 match serde_json::from_str::<WsRawMessage>(&text) {
516 Ok(raw) => self.handle_message(raw, &tx, &mut last_heartbeat).await,
517 Err(e) => warn!("Failed to parse WS message: {} — raw: {}", e, &text[..text.len().min(200)]),
518 }
519 }
520 Some(Ok(WsMessage::Ping(data))) => {
521 let mut sink = self.sink.lock().await;
522 let _ = sink.send(WsMessage::Pong(data)).await;
523 }
524 Some(Ok(WsMessage::Close(frame))) => {
525 info!("WebSocket closed by server: {:?}", frame);
526 break;
527 }
528 Some(Err(e)) => {
529 error!("WebSocket error: {}", e);
530 break;
531 }
532 None => {
533 info!("WebSocket stream ended");
534 break;
535 }
536 _ => {} }
538 }
539 _ = time::sleep(heartbeat_timeout) => {
540 if last_heartbeat.elapsed() > heartbeat_timeout {
541 warn!("Heartbeat timeout ({}s), closing connection",
542 self.config.heartbeat_timeout_secs);
543 break;
544 }
545 }
546 }
547 }
548
549 self.try_reconnect(tx).await;
551 }) }
553
554 async fn handle_message(
555 &self,
556 raw: WsRawMessage,
557 tx: &mpsc::Sender<PredictWsMessage>,
558 last_heartbeat: &mut time::Instant,
559 ) {
560 match raw.msg_type.as_str() {
561 "R" => {
563 if let Some(req_id) = raw.request_id {
564 let mut pending = self.pending.lock().await;
565 if let Some(resp_tx) = pending.remove(&(req_id as u64)) {
566 let result = if raw.success.unwrap_or(false) {
567 Ok(())
568 } else {
569 let err_msg = raw
570 .error
571 .map(|e| e.to_string())
572 .unwrap_or_else(|| "unknown error".to_string());
573 Err(anyhow!("subscribe failed: {}", err_msg))
574 };
575 let _ = resp_tx.send(result);
576 }
577 }
578 }
579 "M" => {
581 let topic_str = match &raw.topic {
582 Some(t) => t.as_str(),
583 None => return,
584 };
585
586 if topic_str == "heartbeat" {
588 *last_heartbeat = time::Instant::now();
589 if let Some(data) = &raw.data {
590 if let Err(e) = self.send_heartbeat(data).await {
591 warn!("Failed to send heartbeat response: {}", e);
592 }
593 }
594 return;
595 }
596
597 let data = match raw.data {
599 Some(d) => d,
600 None => return,
601 };
602
603 let parsed = parse_push_message(topic_str, data);
604 if tx.try_send(parsed).is_err() {
605 warn!("Message channel full, dropping message for topic {}", topic_str);
606 }
607 }
608 other => {
609 debug!("Unknown WS message type: {}", other);
610 }
611 }
612 }
613
614 fn try_reconnect(
615 &self,
616 tx: mpsc::Sender<PredictWsMessage>,
617 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + '_>> {
618 Box::pin(async move {
619 let max_attempts = self.config.max_reconnect_attempts;
620 let max_backoff = self.config.max_reconnect_backoff_secs;
621 let mut attempt = 0u32;
622
623 loop {
624 if max_attempts > 0 && attempt >= max_attempts {
625 error!(
626 "Max reconnection attempts ({}) reached, giving up",
627 max_attempts
628 );
629 return;
630 }
631
632 let backoff = Duration::from_secs((2u64.pow(attempt.min(10))).min(max_backoff));
633 warn!(
634 "Reconnecting in {:?} (attempt {})",
635 backoff,
636 attempt + 1
637 );
638 time::sleep(backoff).await;
639 attempt += 1;
640
641 match tokio_tungstenite::connect_async(&self.config.url).await {
642 Ok((ws_stream, _)) => {
643 info!("Reconnected to {}", self.config.url);
644 let (new_sink, new_stream) = ws_stream.split();
645
646 {
648 let mut sink = self.sink.lock().await;
649 *sink = new_sink;
650 }
651
652 let topics = self.active_topics.lock().await.clone();
654 for topic_str in &topics {
655 let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
656 let msg = WsRequest {
657 request_id,
658 method: "subscribe".to_string(),
659 params: Some(vec![topic_str.clone()]),
660 data: None,
661 };
662 if let Err(e) = self.send_raw(&msg).await {
663 warn!("Failed to resubscribe to {}: {}", topic_str, e);
664 } else {
665 debug!("Resubscribed to {}", topic_str);
666 }
667 }
668
669 self.read_loop(new_stream, tx).await;
671 return;
672 }
673 Err(e) => {
674 error!("Reconnection failed: {}", e);
675 }
676 }
677 }
678 }) }
680}
681
682fn parse_push_message(topic: &str, data: Value) -> PredictWsMessage {
685 if topic.starts_with("predictOrderbook/") {
686 if let Some(ob) = parse_orderbook(topic, &data) {
687 return PredictWsMessage::Orderbook(ob);
688 }
689 }
690
691 if topic.starts_with("assetPriceUpdate/") {
692 if let Some(price) = parse_asset_price(topic, &data) {
693 return PredictWsMessage::AssetPrice(price);
694 }
695 }
696
697 if topic.starts_with("polymarketChance/") {
698 if let Ok(id) = topic.strip_prefix("polymarketChance/").unwrap_or("").parse::<i64>() {
699 return PredictWsMessage::CrossVenueChance(CrossVenueChance {
700 source: CrossVenueSource::Polymarket,
701 market_id: id,
702 data,
703 });
704 }
705 }
706
707 if topic.starts_with("kalshiChance/") {
708 if let Ok(id) = topic.strip_prefix("kalshiChance/").unwrap_or("").parse::<i64>() {
709 return PredictWsMessage::CrossVenueChance(CrossVenueChance {
710 source: CrossVenueSource::Kalshi,
711 market_id: id,
712 data,
713 });
714 }
715 }
716
717 if topic.starts_with("predictWalletEvents/") {
718 return PredictWsMessage::WalletEvent(WalletEvent { data });
719 }
720
721 PredictWsMessage::Raw {
722 topic: topic.to_string(),
723 data,
724 }
725}
726
727fn parse_levels(val: &Value) -> Vec<Level> {
728 val.as_array()
729 .map(|arr| {
730 arr.iter()
731 .filter_map(|lvl| {
732 let price = lvl.get(0).and_then(|v| v.as_f64())?;
733 let size = lvl.get(1).and_then(|v| v.as_f64())?;
734 Some((price, size))
735 })
736 .collect()
737 })
738 .unwrap_or_default()
739}
740
741fn parse_orderbook(topic: &str, data: &Value) -> Option<OrderbookSnapshot> {
742 let market_id = topic
743 .strip_prefix("predictOrderbook/")?
744 .parse::<i64>()
745 .ok()?;
746
747 let bids = parse_levels(data.get("bids")?);
748 let asks = parse_levels(data.get("asks")?);
749 let version = data.get("version").and_then(|v| v.as_u64()).unwrap_or(0);
750 let update_timestamp_ms = data
751 .get("updateTimestampMs")
752 .and_then(|v| v.as_u64())
753 .unwrap_or(0);
754 let order_count = data
755 .get("orderCount")
756 .and_then(|v| v.as_u64())
757 .unwrap_or(0);
758 let last_order_settled = data
759 .get("lastOrderSettled")
760 .and_then(|v| serde_json::from_value(v.clone()).ok());
761
762 Some(OrderbookSnapshot {
763 market_id,
764 bids,
765 asks,
766 version,
767 update_timestamp_ms,
768 order_count,
769 last_order_settled,
770 })
771}
772
773fn parse_asset_price(topic: &str, data: &Value) -> Option<AssetPriceUpdate> {
774 let feed_id = topic
775 .strip_prefix("assetPriceUpdate/")?
776 .parse::<i64>()
777 .ok()?;
778
779 let price = data.get("price").and_then(|v| v.as_f64())?;
780 let publish_time = data.get("publishTime").and_then(|v| v.as_u64()).unwrap_or(0);
781 let timestamp = data.get("timestamp").and_then(|v| v.as_u64()).unwrap_or(0);
782
783 Some(AssetPriceUpdate {
784 feed_id,
785 price,
786 publish_time,
787 timestamp,
788 })
789}
790
791pub mod feeds {
795 pub const BTC: i64 = 1;
797 pub const ETH: i64 = 4;
799 pub const BNB: i64 = 2;
801}
802
803#[cfg(test)]
804mod tests {
805 use super::*;
806
807 #[test]
808 fn topic_roundtrip() {
809 let topics = vec![
810 Topic::Orderbook { market_id: 123 },
811 Topic::AssetPrice { feed_id: 1 },
812 Topic::PolymarketChance { market_id: 456 },
813 Topic::KalshiChance { market_id: 789 },
814 Topic::WalletEvents {
815 jwt: "abc123".to_string(),
816 },
817 Topic::Raw("custom/topic".to_string()),
818 ];
819
820 for topic in topics {
821 let s = topic.to_topic_string();
822 let parsed = Topic::from_topic_string(&s);
823 assert_eq!(topic, parsed, "Roundtrip failed for {}", s);
824 }
825 }
826
827 #[test]
828 fn topic_display() {
829 assert_eq!(
830 Topic::Orderbook { market_id: 42 }.to_string(),
831 "predictOrderbook/42"
832 );
833 assert_eq!(
834 Topic::AssetPrice { feed_id: 1 }.to_string(),
835 "assetPriceUpdate/1"
836 );
837 }
838
839 #[test]
840 fn parse_orderbook_snapshot() {
841 let data = serde_json::json!({
842 "asks": [[0.72, 15.0], [0.83, 5.88]],
843 "bids": [[0.57, 15.0], [0.38, 2.63]],
844 "marketId": 45532,
845 "version": 1,
846 "updateTimestampMs": 1772898630219u64,
847 "orderCount": 13,
848 "lastOrderSettled": {
849 "id": "20035648",
850 "kind": "LIMIT",
851 "marketId": 45532,
852 "outcome": "No",
853 "price": "0.60",
854 "side": "Bid"
855 }
856 });
857
858 let ob = parse_orderbook("predictOrderbook/45532", &data).unwrap();
859 assert_eq!(ob.market_id, 45532);
860 assert_eq!(ob.bids.len(), 2);
861 assert_eq!(ob.asks.len(), 2);
862 assert!((ob.bids[0].0 - 0.57).abs() < 1e-10);
863 assert!((ob.asks[0].0 - 0.72).abs() < 1e-10);
864 assert_eq!(ob.version, 1);
865 assert_eq!(ob.order_count, 13);
866 assert!(ob.last_order_settled.is_some());
867 assert!((ob.best_bid().unwrap() - 0.57).abs() < 1e-10);
868 assert!((ob.best_ask().unwrap() - 0.72).abs() < 1e-10);
869 assert!((ob.mid().unwrap() - 0.645).abs() < 1e-10);
870 assert!((ob.spread().unwrap() - 0.15).abs() < 1e-10);
871 }
872
873 #[test]
874 fn parse_asset_price_update() {
875 let data = serde_json::json!({
876 "price": 67853.57751504,
877 "publishTime": 1772898632u64,
878 "timestamp": 1772898633u64
879 });
880
881 let price = parse_asset_price("assetPriceUpdate/1", &data).unwrap();
882 assert_eq!(price.feed_id, 1);
883 assert!((price.price - 67853.577).abs() < 1.0);
884 assert_eq!(price.publish_time, 1772898632);
885 assert_eq!(price.timestamp, 1772898633);
886 }
887
888 #[test]
889 fn parse_push_message_dispatches_correctly() {
890 let ob_data = serde_json::json!({"asks": [], "bids": [], "version": 1, "updateTimestampMs": 0, "orderCount": 0});
892 assert!(matches!(
893 parse_push_message("predictOrderbook/123", ob_data),
894 PredictWsMessage::Orderbook(_)
895 ));
896
897 let price_data = serde_json::json!({"price": 100.0, "publishTime": 0, "timestamp": 0});
899 assert!(matches!(
900 parse_push_message("assetPriceUpdate/1", price_data),
901 PredictWsMessage::AssetPrice(_)
902 ));
903
904 let chance_data = serde_json::json!({"chance": 0.5});
906 assert!(matches!(
907 parse_push_message("polymarketChance/456", chance_data),
908 PredictWsMessage::CrossVenueChance(_)
909 ));
910
911 let kalshi_data = serde_json::json!({"chance": 0.3});
913 assert!(matches!(
914 parse_push_message("kalshiChance/789", kalshi_data),
915 PredictWsMessage::CrossVenueChance(_)
916 ));
917
918 let wallet_data = serde_json::json!({"event": "fill"});
920 assert!(matches!(
921 parse_push_message("predictWalletEvents/jwt123", wallet_data),
922 PredictWsMessage::WalletEvent(_)
923 ));
924
925 let unknown_data = serde_json::json!({"foo": "bar"});
927 assert!(matches!(
928 parse_push_message("unknown/topic", unknown_data),
929 PredictWsMessage::Raw { .. }
930 ));
931 }
932
933 #[test]
934 fn orderbook_snapshot_helpers_empty() {
935 let ob = OrderbookSnapshot {
936 market_id: 1,
937 bids: vec![],
938 asks: vec![],
939 version: 0,
940 update_timestamp_ms: 0,
941 order_count: 0,
942 last_order_settled: None,
943 };
944 assert!(ob.best_bid().is_none());
945 assert!(ob.best_ask().is_none());
946 assert!(ob.mid().is_none());
947 assert!(ob.spread().is_none());
948 }
949
950 #[test]
951 fn feed_id_constants() {
952 assert_eq!(feeds::BTC, 1);
953 assert_eq!(feeds::ETH, 4);
954 assert_eq!(feeds::BNB, 2);
955 }
956
957 #[test]
958 fn ws_endpoint_constants() {
959 assert_eq!(PREDICT_WS_MAINNET, "wss://ws.predict.fun/ws");
960 assert_eq!(PREDICT_WS_TESTNET, "wss://ws.bnb.predict.fail/ws");
961 }
962
963 #[test]
964 fn config_defaults() {
965 let config = PredictWsConfig::default();
966 assert_eq!(config.url, PREDICT_WS_MAINNET);
967 assert_eq!(config.channel_buffer, 1024);
968 assert_eq!(config.heartbeat_timeout_secs, 60);
969 assert_eq!(config.max_reconnect_attempts, 0);
970 }
971
972 #[test]
973 fn config_testnet() {
974 let config = PredictWsConfig::testnet();
975 assert_eq!(config.url, PREDICT_WS_TESTNET);
976 }
977}