use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time::{self, Duration};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tracing::{debug, error, info, warn};
pub const PREDICT_WS_MAINNET: &str = "wss://ws.predict.fun/ws";
pub const PREDICT_WS_TESTNET: &str = "wss://ws.bnb.predict.fail/ws";
pub const PREDICT_GQL_MAINNET: &str = "https://graphql.predict.fun/graphql";
pub const PREDICT_GQL_TESTNET: &str = "https://graphql.bnb.predict.fail/graphql";
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Topic {
Orderbook { market_id: i64 },
AssetPrice { feed_id: i64 },
PolymarketChance { market_id: i64 },
KalshiChance { market_id: i64 },
WalletEvents { jwt: String },
Raw(String),
}
impl Topic {
pub fn to_topic_string(&self) -> String {
match self {
Topic::Orderbook { market_id } => format!("predictOrderbook/{}", market_id),
Topic::AssetPrice { feed_id } => format!("assetPriceUpdate/{}", feed_id),
Topic::PolymarketChance { market_id } => format!("polymarketChance/{}", market_id),
Topic::KalshiChance { market_id } => format!("kalshiChance/{}", market_id),
Topic::WalletEvents { jwt } => format!("predictWalletEvents/{}", jwt),
Topic::Raw(s) => s.clone(),
}
}
pub fn from_topic_string(s: &str) -> Self {
if let Some(rest) = s.strip_prefix("predictOrderbook/") {
if let Ok(id) = rest.parse::<i64>() {
return Topic::Orderbook { market_id: id };
}
}
if let Some(rest) = s.strip_prefix("assetPriceUpdate/") {
if let Ok(id) = rest.parse::<i64>() {
return Topic::AssetPrice { feed_id: id };
}
}
if let Some(rest) = s.strip_prefix("polymarketChance/") {
if let Ok(id) = rest.parse::<i64>() {
return Topic::PolymarketChance { market_id: id };
}
}
if let Some(rest) = s.strip_prefix("kalshiChance/") {
if let Ok(id) = rest.parse::<i64>() {
return Topic::KalshiChance { market_id: id };
}
}
if let Some(rest) = s.strip_prefix("predictWalletEvents/") {
return Topic::WalletEvents { jwt: rest.to_string() };
}
Topic::Raw(s.to_string())
}
}
impl std::fmt::Display for Topic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.to_topic_string())
}
}
pub type Level = (f64, f64);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LastOrderSettled {
pub id: String,
pub kind: String,
#[serde(rename = "marketId")]
pub market_id: i64,
pub outcome: String,
pub price: String,
pub side: String,
}
#[derive(Debug, Clone)]
pub struct OrderbookSnapshot {
pub market_id: i64,
pub bids: Vec<Level>,
pub asks: Vec<Level>,
pub version: u64,
pub update_timestamp_ms: u64,
pub order_count: u64,
pub last_order_settled: Option<LastOrderSettled>,
}
impl OrderbookSnapshot {
pub fn best_bid(&self) -> Option<f64> {
self.bids.first().map(|(p, _)| *p)
}
pub fn best_ask(&self) -> Option<f64> {
self.asks.first().map(|(p, _)| *p)
}
pub fn mid(&self) -> Option<f64> {
match (self.best_bid(), self.best_ask()) {
(Some(b), Some(a)) => Some((b + a) / 2.0),
_ => None,
}
}
pub fn spread(&self) -> Option<f64> {
match (self.best_bid(), self.best_ask()) {
(Some(b), Some(a)) => Some(a - b),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct AssetPriceUpdate {
pub feed_id: i64,
pub price: f64,
pub publish_time: u64,
pub timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct CrossVenueChance {
pub source: CrossVenueSource,
pub market_id: i64,
pub data: Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CrossVenueSource {
Polymarket,
Kalshi,
}
#[derive(Debug, Clone)]
pub struct WalletEvent {
pub data: Value,
}
#[derive(Debug, Clone)]
pub enum PredictWsMessage {
Orderbook(OrderbookSnapshot),
AssetPrice(AssetPriceUpdate),
CrossVenueChance(CrossVenueChance),
WalletEvent(WalletEvent),
Raw { topic: String, data: Value },
}
#[derive(Serialize)]
struct WsRequest {
#[serde(rename = "requestId")]
request_id: u64,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
}
#[derive(Deserialize)]
struct WsRawMessage {
#[serde(rename = "type")]
msg_type: String,
#[serde(rename = "requestId")]
request_id: Option<i64>,
success: Option<bool>,
error: Option<WsError>,
topic: Option<String>,
data: Option<Value>,
}
#[derive(Deserialize, Debug)]
struct WsError {
code: String,
message: Option<String>,
}
impl std::fmt::Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.code)?;
if let Some(msg) = &self.message {
write!(f, ": {}", msg)?;
}
Ok(())
}
}
type WsSink = futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
WsMessage,
>;
type WsStream = futures_util::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>;
type PendingResponse = oneshot::Sender<Result<()>>;
#[derive(Debug, Clone)]
pub struct PredictWsConfig {
pub url: String,
pub channel_buffer: usize,
pub heartbeat_timeout_secs: u64,
pub max_reconnect_attempts: u32,
pub max_reconnect_backoff_secs: u64,
}
impl Default for PredictWsConfig {
fn default() -> Self {
Self {
url: PREDICT_WS_MAINNET.to_string(),
channel_buffer: 1024,
heartbeat_timeout_secs: 60,
max_reconnect_attempts: 0, max_reconnect_backoff_secs: 15,
}
}
}
impl PredictWsConfig {
pub fn mainnet() -> Self {
Self::default()
}
pub fn testnet() -> Self {
Self {
url: PREDICT_WS_TESTNET.to_string(),
..Self::default()
}
}
}
#[derive(Clone)]
pub struct PredictWsClient {
sink: Arc<Mutex<WsSink>>,
request_id: Arc<AtomicU64>,
pending: Arc<Mutex<HashMap<u64, PendingResponse>>>,
active_topics: Arc<Mutex<HashSet<String>>>,
config: PredictWsConfig,
}
impl PredictWsClient {
pub async fn connect_mainnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
Self::connect(PredictWsConfig::mainnet()).await
}
pub async fn connect_testnet() -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
Self::connect(PredictWsConfig::testnet()).await
}
pub async fn connect(
config: PredictWsConfig,
) -> Result<(Self, mpsc::Receiver<PredictWsMessage>)> {
let (ws_stream, _) = tokio_tungstenite::connect_async(&config.url)
.await
.with_context(|| format!("failed to connect to {}", config.url))?;
info!("Connected to {}", config.url);
let (sink, stream) = ws_stream.split();
let (tx, rx) = mpsc::channel(config.channel_buffer);
let client = Self {
sink: Arc::new(Mutex::new(sink)),
request_id: Arc::new(AtomicU64::new(0)),
pending: Arc::new(Mutex::new(HashMap::new())),
active_topics: Arc::new(Mutex::new(HashSet::new())),
config,
};
let client_clone = client.clone();
tokio::spawn(async move {
client_clone.run_loop(stream, tx).await;
});
Ok((client, rx))
}
pub async fn subscribe(&self, topic: Topic) -> Result<()> {
let topic_str = topic.to_topic_string();
self.send_and_wait("subscribe", &topic_str).await?;
self.active_topics.lock().await.insert(topic_str.clone());
info!("Subscribed to {}", topic_str);
Ok(())
}
pub async fn unsubscribe(&self, topic: Topic) -> Result<()> {
let topic_str = topic.to_topic_string();
self.send_and_wait("unsubscribe", &topic_str).await?;
self.active_topics.lock().await.remove(&topic_str);
info!("Unsubscribed from {}", topic_str);
Ok(())
}
pub async fn active_topics(&self) -> Vec<String> {
self.active_topics.lock().await.iter().cloned().collect()
}
async fn send_and_wait(&self, method: &str, topic: &str) -> Result<()> {
let request_id = self.request_id.fetch_add(1, Ordering::Relaxed);
let (resp_tx, resp_rx) = oneshot::channel();
self.pending.lock().await.insert(request_id, resp_tx);
let msg = WsRequest {
request_id,
method: method.to_string(),
params: Some(vec![topic.to_string()]),
data: None,
};
self.send_raw(&msg).await?;
debug!("{} {} (requestId={})", method, topic, request_id);
tokio::time::timeout(Duration::from_secs(10), resp_rx)
.await
.map_err(|_| anyhow!("{} timeout for {}", method, topic))?
.map_err(|_| anyhow!("{} channel closed for {}", method, topic))?
}
async fn send_heartbeat(&self, data: &Value) -> Result<()> {
let msg = WsRequest {
request_id: self.request_id.fetch_add(1, Ordering::Relaxed),
method: "heartbeat".to_string(),
params: None,
data: Some(data.clone()),
};
self.send_raw(&msg).await
}
async fn send_raw(&self, msg: &WsRequest) -> Result<()> {
let text = serde_json::to_string(msg).context("failed to serialize WS message")?;
self.sink
.lock()
.await
.send(WsMessage::Text(text))
.await
.context("failed to send WS message")
}
async fn run_loop(&self, mut stream: WsStream, tx: mpsc::Sender<PredictWsMessage>) {
let heartbeat_timeout = Duration::from_secs(self.config.heartbeat_timeout_secs);
let mut attempt = 0u32;
loop {
let mut last_heartbeat = time::Instant::now();
let disconnected = loop {
tokio::select! {
msg = stream.next() => {
match msg {
Some(Ok(WsMessage::Text(text))) => {
if let Ok(raw) = serde_json::from_str::<WsRawMessage>(&text) {
self.handle_message(raw, &tx, &mut last_heartbeat).await;
}
}
Some(Ok(WsMessage::Ping(data))) => {
let _ = self.sink.lock().await.send(WsMessage::Pong(data)).await;
}
Some(Ok(WsMessage::Close(frame))) => {
info!("WebSocket closed by server: {:?}", frame);
break true;
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
break true;
}
None => {
info!("WebSocket stream ended");
break true;
}
_ => {} }
}
_ = time::sleep(heartbeat_timeout) => {
if last_heartbeat.elapsed() > heartbeat_timeout {
warn!("Heartbeat timeout ({}s)", self.config.heartbeat_timeout_secs);
break true;
}
}
}
};
if !disconnected {
return;
}
let max = self.config.max_reconnect_attempts;
if max > 0 && attempt >= max {
error!("Max reconnect attempts ({}) reached", max);
return;
}
let backoff_secs = (2u64.pow(attempt.min(10))).min(self.config.max_reconnect_backoff_secs);
warn!("Reconnecting in {}s (attempt {})", backoff_secs, attempt + 1);
time::sleep(Duration::from_secs(backoff_secs)).await;
attempt += 1;
match tokio_tungstenite::connect_async(&self.config.url).await {
Ok((ws_stream, _)) => {
info!("Reconnected to {}", self.config.url);
let (new_sink, new_stream) = ws_stream.split();
*self.sink.lock().await = new_sink;
stream = new_stream;
attempt = 0;
let topics: Vec<String> =
self.active_topics.lock().await.iter().cloned().collect();
for topic_str in &topics {
let req_id = self.request_id.fetch_add(1, Ordering::Relaxed);
let msg = WsRequest {
request_id: req_id,
method: "subscribe".to_string(),
params: Some(vec![topic_str.clone()]),
data: None,
};
if let Err(e) = self.send_raw(&msg).await {
warn!("Failed to resubscribe to {}: {}", topic_str, e);
}
}
}
Err(e) => {
error!("Reconnection failed: {}", e);
}
}
}
}
async fn handle_message(
&self,
raw: WsRawMessage,
tx: &mpsc::Sender<PredictWsMessage>,
last_heartbeat: &mut time::Instant,
) {
match raw.msg_type.as_str() {
"R" => {
if let Some(req_id) = raw.request_id {
if let Some(resp_tx) = self.pending.lock().await.remove(&(req_id as u64)) {
let result = if raw.success.unwrap_or(false) {
Ok(())
} else {
Err(anyhow!(
"subscribe failed: {}",
raw.error
.map(|e| e.to_string())
.unwrap_or_else(|| "unknown".into())
))
};
let _ = resp_tx.send(result);
}
}
}
"M" => {
let topic_str = match &raw.topic {
Some(t) => t.as_str(),
None => return,
};
if topic_str == "heartbeat" {
*last_heartbeat = time::Instant::now();
if let Some(data) = &raw.data {
if let Err(e) = self.send_heartbeat(data).await {
warn!("Heartbeat response failed: {}", e);
}
}
return;
}
if let Some(data) = raw.data {
let parsed = parse_push_message(topic_str, data);
if tx.try_send(parsed).is_err() {
warn!("Channel full, dropping message for {}", topic_str);
}
}
}
_ => debug!("Unknown WS message type: {}", raw.msg_type),
}
}
}
fn parse_push_message(topic: &str, data: Value) -> PredictWsMessage {
if let Some(rest) = topic.strip_prefix("predictOrderbook/") {
if let Ok(market_id) = rest.parse::<i64>() {
if let Some(ob) = parse_orderbook(market_id, &data) {
return PredictWsMessage::Orderbook(ob);
}
}
}
if let Some(rest) = topic.strip_prefix("assetPriceUpdate/") {
if let Ok(feed_id) = rest.parse::<i64>() {
if let Some(price) = parse_asset_price(feed_id, &data) {
return PredictWsMessage::AssetPrice(price);
}
}
}
if let Some(rest) = topic.strip_prefix("polymarketChance/") {
if let Ok(id) = rest.parse::<i64>() {
return PredictWsMessage::CrossVenueChance(CrossVenueChance {
source: CrossVenueSource::Polymarket,
market_id: id,
data,
});
}
}
if let Some(rest) = topic.strip_prefix("kalshiChance/") {
if let Ok(id) = rest.parse::<i64>() {
return PredictWsMessage::CrossVenueChance(CrossVenueChance {
source: CrossVenueSource::Kalshi,
market_id: id,
data,
});
}
}
if topic.starts_with("predictWalletEvents/") {
return PredictWsMessage::WalletEvent(WalletEvent { data });
}
PredictWsMessage::Raw {
topic: topic.to_string(),
data,
}
}
fn parse_levels(val: &Value) -> Vec<Level> {
val.as_array()
.map(|arr| {
arr.iter()
.filter_map(|lvl| {
let price = lvl.get(0)?.as_f64()?;
let size = lvl.get(1)?.as_f64()?;
Some((price, size))
})
.collect()
})
.unwrap_or_default()
}
fn parse_orderbook(market_id: i64, data: &Value) -> Option<OrderbookSnapshot> {
Some(OrderbookSnapshot {
market_id,
bids: parse_levels(data.get("bids")?),
asks: parse_levels(data.get("asks")?),
version: data.get("version")?.as_u64().unwrap_or(0),
update_timestamp_ms: data
.get("updateTimestampMs")
.and_then(|v| v.as_u64())
.unwrap_or(0),
order_count: data
.get("orderCount")
.and_then(|v| v.as_u64())
.unwrap_or(0),
last_order_settled: data
.get("lastOrderSettled")
.and_then(|v| serde_json::from_value(v.clone()).ok()),
})
}
fn parse_asset_price(feed_id: i64, data: &Value) -> Option<AssetPriceUpdate> {
Some(AssetPriceUpdate {
feed_id,
price: data.get("price")?.as_f64()?,
publish_time: data.get("publishTime").and_then(|v| v.as_u64()).unwrap_or(0),
timestamp: data.get("timestamp").and_then(|v| v.as_u64()).unwrap_or(0),
})
}
pub mod feeds {
pub const BTC: i64 = 1;
pub const ETH: i64 = 4;
pub const BNB: i64 = 2;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn topic_roundtrip() {
let topics = vec![
Topic::Orderbook { market_id: 123 },
Topic::AssetPrice { feed_id: 1 },
Topic::PolymarketChance { market_id: 456 },
Topic::KalshiChance { market_id: 789 },
Topic::WalletEvents {
jwt: "abc123".to_string(),
},
Topic::Raw("custom/topic".to_string()),
];
for topic in topics {
let s = topic.to_topic_string();
assert_eq!(topic, Topic::from_topic_string(&s), "Roundtrip failed: {}", s);
}
}
#[test]
fn topic_display() {
assert_eq!(Topic::Orderbook { market_id: 42 }.to_string(), "predictOrderbook/42");
assert_eq!(Topic::AssetPrice { feed_id: 1 }.to_string(), "assetPriceUpdate/1");
}
#[test]
fn parse_orderbook_snapshot() {
let data = serde_json::json!({
"asks": [[0.72, 15.0], [0.83, 5.88]],
"bids": [[0.57, 15.0], [0.38, 2.63]],
"version": 1,
"updateTimestampMs": 1772898630219u64,
"orderCount": 13,
"lastOrderSettled": {
"id": "20035648", "kind": "LIMIT", "marketId": 45532,
"outcome": "No", "price": "0.60", "side": "Bid"
}
});
let ob = parse_orderbook(45532, &data).unwrap();
assert_eq!(ob.market_id, 45532);
assert_eq!(ob.bids.len(), 2);
assert_eq!(ob.asks.len(), 2);
assert!((ob.best_bid().unwrap() - 0.57).abs() < 1e-10);
assert!((ob.best_ask().unwrap() - 0.72).abs() < 1e-10);
assert!((ob.mid().unwrap() - 0.645).abs() < 1e-10);
assert!((ob.spread().unwrap() - 0.15).abs() < 1e-10);
assert_eq!(ob.version, 1);
assert_eq!(ob.order_count, 13);
assert!(ob.last_order_settled.is_some());
}
#[test]
fn parse_asset_price_update() {
let data = serde_json::json!({
"price": 67853.57751504,
"publishTime": 1772898632u64,
"timestamp": 1772898633u64
});
let price = parse_asset_price(1, &data).unwrap();
assert_eq!(price.feed_id, 1);
assert!((price.price - 67853.577).abs() < 1.0);
}
#[test]
fn parse_push_dispatches() {
let ob = serde_json::json!({"asks": [], "bids": [], "version": 1, "updateTimestampMs": 0, "orderCount": 0});
assert!(matches!(parse_push_message("predictOrderbook/123", ob), PredictWsMessage::Orderbook(_)));
let p = serde_json::json!({"price": 100.0, "publishTime": 0, "timestamp": 0});
assert!(matches!(parse_push_message("assetPriceUpdate/1", p), PredictWsMessage::AssetPrice(_)));
let c = serde_json::json!({"chance": 0.5});
assert!(matches!(parse_push_message("polymarketChance/456", c), PredictWsMessage::CrossVenueChance(_)));
let k = serde_json::json!({"chance": 0.3});
assert!(matches!(parse_push_message("kalshiChance/789", k), PredictWsMessage::CrossVenueChance(_)));
let w = serde_json::json!({"event": "fill"});
assert!(matches!(parse_push_message("predictWalletEvents/jwt123", w), PredictWsMessage::WalletEvent(_)));
let u = serde_json::json!({"foo": "bar"});
assert!(matches!(parse_push_message("unknown/topic", u), PredictWsMessage::Raw { .. }));
}
#[test]
fn orderbook_helpers_empty() {
let ob = OrderbookSnapshot {
market_id: 1, bids: vec![], asks: vec![], version: 0,
update_timestamp_ms: 0, order_count: 0, last_order_settled: None,
};
assert!(ob.best_bid().is_none());
assert!(ob.mid().is_none());
assert!(ob.spread().is_none());
}
#[test]
fn feed_id_constants() {
assert_eq!(feeds::BTC, 1);
assert_eq!(feeds::ETH, 4);
assert_eq!(feeds::BNB, 2);
}
#[test]
fn config_defaults() {
let c = PredictWsConfig::default();
assert_eq!(c.url, PREDICT_WS_MAINNET);
assert_eq!(c.channel_buffer, 1024);
assert_eq!(c.heartbeat_timeout_secs, 60);
}
#[test]
fn config_testnet() {
assert_eq!(PredictWsConfig::testnet().url, PREDICT_WS_TESTNET);
}
}