use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::DashSet;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use ws_reconnect_client::{connect_with_retry, Message, WsConnectionConfig, WsReader, WsWriter};
use super::types::{AssetPriceData, OrderbookData, PushMessage, RawWsMessage, WsMessage, WsRequest};
use crate::api_types::PredictWalletEvent;
use crate::errors::{Error, Result};
pub struct PredictWebSocket {
config: WsConnectionConfig,
subscribed_markets: DashSet<u64>,
writer: Arc<Mutex<Option<WsWriter>>>,
next_request_id: AtomicU64,
}
impl PredictWebSocket {
pub fn new(ws_url: String) -> Self {
let config = WsConnectionConfig::new(ws_url)
.with_ping_interval(0) .with_retries(10)
.with_backoff(1000, 30_000);
Self {
config,
subscribed_markets: DashSet::new(),
writer: Arc::new(Mutex::new(None)),
next_request_id: AtomicU64::new(1),
}
}
fn next_id(&self) -> u64 {
self.next_request_id.fetch_add(1, Ordering::SeqCst)
}
pub async fn connect(&self) -> Result<PredictWsStream> {
info!("Connecting to Predict WebSocket: {}", self.config.url);
let (writer, reader) = connect_with_retry(&self.config)
.await
.map_err(|e| Error::Other(format!("WebSocket connection failed: {}", e)))?;
{
let mut w = self.writer.lock().await;
*w = Some(writer);
}
info!("Connected to Predict WebSocket");
Ok(PredictWsStream {
reader,
writer: self.writer.clone(),
})
}
pub async fn subscribe_orderbook(&self, market_id: u64) -> Result<()> {
let topic = format!("predictOrderbook/{}", market_id);
let request_id = self.next_id();
let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
self.send_request(&request).await?;
self.subscribed_markets.insert(market_id);
info!("Subscribed to orderbook for market {}", market_id);
Ok(())
}
pub async fn unsubscribe_orderbook(&self, market_id: u64) -> Result<()> {
let topic = format!("predictOrderbook/{}", market_id);
let request_id = self.next_id();
let request = WsRequest::unsubscribe(request_id, vec![topic]);
self.send_request(&request).await?;
self.subscribed_markets.remove(&market_id);
info!("Unsubscribed from orderbook for market {}", market_id);
Ok(())
}
pub async fn subscribe_asset_price(&self, price_feed_id: &str) -> Result<()> {
let topic = format!("assetPriceUpdate/{}", price_feed_id);
let request_id = self.next_id();
let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
self.send_request(&request).await?;
info!("Subscribed to asset price for feed {}", price_feed_id);
Ok(())
}
pub async fn unsubscribe_asset_price(&self, price_feed_id: &str) -> Result<()> {
let topic = format!("assetPriceUpdate/{}", price_feed_id);
let request_id = self.next_id();
let request = WsRequest::unsubscribe(request_id, vec![topic]);
self.send_request(&request).await?;
info!("Unsubscribed from asset price for feed {}", price_feed_id);
Ok(())
}
pub async fn subscribe_polymarket_chance(&self, market_id: u64) -> Result<()> {
let topic = format!("polymarketChance/{}", market_id);
let request_id = self.next_id();
let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
self.send_request(&request).await?;
info!("Subscribed to Polymarket chance for market {}", market_id);
Ok(())
}
pub async fn subscribe_kalshi_chance(&self, market_id: u64) -> Result<()> {
let topic = format!("kalshiChance/{}", market_id);
let request_id = self.next_id();
let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
self.send_request(&request).await?;
info!("Subscribed to Kalshi chance for market {}", market_id);
Ok(())
}
pub async fn subscribe_wallet_events(&self, jwt: &str) -> Result<()> {
let topic = format!("predictWalletEvents/{}", jwt);
let request_id = self.next_id();
let request = WsRequest::subscribe(request_id, vec![topic]);
self.send_request(&request).await?;
info!("Subscribed to wallet events");
Ok(())
}
pub async fn unsubscribe_wallet_events(&self, jwt: &str) -> Result<()> {
let topic = format!("predictWalletEvents/{}", jwt);
let request_id = self.next_id();
let request = WsRequest::unsubscribe(request_id, vec![topic]);
self.send_request(&request).await?;
info!("Unsubscribed from wallet events");
Ok(())
}
pub async fn send_heartbeat(&self, timestamp: u64) -> Result<()> {
let request = WsRequest::heartbeat(timestamp);
self.send_request(&request).await?;
debug!("Sent heartbeat response: {}", timestamp);
Ok(())
}
async fn send_request(&self, request: &WsRequest) -> Result<()> {
let json = serde_json::to_string(request)
.map_err(|e| Error::Other(format!("Failed to serialize request: {}", e)))?;
let mut writer_guard = self.writer.lock().await;
let writer = writer_guard
.as_mut()
.ok_or_else(|| Error::Other("WebSocket not connected".to_string()))?;
writer
.send(Message::Text(json.into()))
.await
.map_err(|e| Error::Other(format!("Failed to send message: {}", e)))?;
Ok(())
}
pub async fn reconnect(&self) -> Result<PredictWsStream> {
{
let mut w = self.writer.lock().await;
*w = None;
}
self.connect().await
}
pub fn config(&self) -> &WsConnectionConfig {
&self.config
}
pub async fn is_connected(&self) -> bool {
self.writer.lock().await.is_some()
}
pub fn subscribed_markets(&self) -> Vec<u64> {
self.subscribed_markets.iter().map(|r| *r).collect()
}
pub fn writer(&self) -> Arc<Mutex<Option<WsWriter>>> {
self.writer.clone()
}
}
pub struct PredictWsStream {
reader: WsReader,
writer: Arc<Mutex<Option<WsWriter>>>,
}
impl PredictWsStream {
pub async fn next(&mut self) -> Option<Result<WsMessage>> {
loop {
match self.reader.next().await {
Some(Ok(Message::Text(text))) => {
match self.parse_message(&text).await {
Ok(Some(msg)) => return Some(Ok(msg)),
Ok(None) => continue, Err(e) => return Some(Err(e)),
}
}
Some(Ok(Message::Ping(data))) => {
if let Err(e) = self.send_pong(data.to_vec()).await {
warn!("Failed to send pong: {}", e);
}
continue;
}
Some(Ok(Message::Pong(_))) => {
continue;
}
Some(Ok(Message::Close(frame))) => {
info!("WebSocket closed: {:?}", frame);
return None;
}
Some(Ok(Message::Binary(_))) => {
warn!("Received unexpected binary message");
continue;
}
Some(Ok(Message::Frame(_))) => {
continue;
}
Some(Err(e)) => {
error!("WebSocket error: {}", e);
return Some(Err(Error::Other(format!("WebSocket error: {}", e))));
}
None => {
info!("WebSocket stream ended");
return None;
}
}
}
}
async fn parse_message(&mut self, text: &str) -> Result<Option<WsMessage>> {
let raw: RawWsMessage = serde_json::from_str(text)
.map_err(|e| Error::Other(format!("Failed to parse message: {} - {}", e, text)))?;
let msg = WsMessage::try_from(raw)
.map_err(|e| Error::Other(format!("Failed to convert message: {}", e)))?;
if let WsMessage::PushMessage(ref push) = msg {
if let Some(timestamp) = push.heartbeat_timestamp() {
self.send_heartbeat(timestamp).await?;
return Ok(None); }
}
Ok(Some(msg))
}
async fn send_heartbeat(&mut self, timestamp: u64) -> Result<()> {
let request = WsRequest::heartbeat(timestamp);
let json = serde_json::to_string(&request)
.map_err(|e| Error::Other(format!("Failed to serialize heartbeat: {}", e)))?;
let mut writer_guard = self.writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
writer
.send(Message::Text(json.into()))
.await
.map_err(|e| Error::Other(format!("Failed to send heartbeat: {}", e)))?;
debug!("Sent heartbeat response: {}", timestamp);
}
Ok(())
}
async fn send_pong(&mut self, data: Vec<u8>) -> Result<()> {
let mut writer_guard = self.writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
writer
.send(Message::Pong(data.into()))
.await
.map_err(|e| Error::Other(format!("Failed to send pong: {}", e)))?;
}
Ok(())
}
}
pub fn parse_orderbook_update(push: &PushMessage) -> Result<OrderbookData> {
if !push.is_orderbook() {
return Err(Error::Other("Not an orderbook message".to_string()));
}
serde_json::from_value(push.data.clone())
.map_err(|e| Error::Other(format!("Failed to parse orderbook data: {}", e)))
}
pub fn parse_asset_price_update(push: &PushMessage) -> Result<AssetPriceData> {
if !push.is_asset_price() {
return Err(Error::Other("Not an asset price message".to_string()));
}
serde_json::from_value(push.data.clone())
.map_err(|e| Error::Other(format!("Failed to parse asset price data: {}", e)))
}
pub fn parse_wallet_event(push: &PushMessage) -> Result<PredictWalletEvent> {
if !push.is_wallet_event() {
return Err(Error::Other("Not a wallet event message".to_string()));
}
let event_type = push
.data
.get("type")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let order_hash = push
.data
.get("orderHash")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let order_id = push
.data
.get("orderId")
.map(|v| match v {
serde_json::Value::String(s) => s.strip_suffix('n').unwrap_or(s).to_string(),
serde_json::Value::Number(n) => n.to_string(),
_ => String::new(),
})
.unwrap_or_default();
let tx_hash = push
.data
.get("txHash")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let reason = push
.data
.get("reason")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let details = push.data.get("details").map(|d| {
use crate::WalletEventDetails;
WalletEventDetails {
price: d.get("price").and_then(|v| v.as_str()).map(|s| s.to_string()),
quantity: d.get("quantity").and_then(|v| v.as_str()).map(|s| s.to_string()),
quantity_filled: d.get("quantityFilled").and_then(|v| v.as_str()).map(|s| s.to_string()),
outcome: d.get("outcome").and_then(|v| v.as_str()).map(|s| s.to_string()),
quote_type: d.get("quoteType").and_then(|v| v.as_str()).map(|s| s.to_string()),
}
}).unwrap_or_default();
match event_type.as_str() {
"orderAccepted" => Ok(PredictWalletEvent::OrderAccepted { order_hash, order_id }),
"orderNotAccepted" => Ok(PredictWalletEvent::OrderNotAccepted {
order_hash,
order_id,
reason,
}),
"orderExpired" => Ok(PredictWalletEvent::OrderExpired { order_hash, order_id }),
"orderCancelled" => Ok(PredictWalletEvent::OrderCancelled { order_hash, order_id }),
"orderTransactionSubmitted" => Ok(PredictWalletEvent::OrderTransactionSubmitted {
order_hash,
order_id,
tx_hash,
details,
}),
"orderTransactionSuccess" => Ok(PredictWalletEvent::OrderTransactionSuccess {
order_hash,
order_id,
tx_hash,
details,
}),
"orderTransactionFailed" => Ok(PredictWalletEvent::OrderTransactionFailed {
order_hash,
order_id,
tx_hash,
details,
}),
_ => Ok(PredictWalletEvent::Unknown {
event_type,
data: push.data.clone(),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = PredictWebSocket::new("wss://ws.predict.fun/ws".to_string());
assert!(client.subscribed_markets().is_empty());
}
#[test]
fn test_request_id_increment() {
let client = PredictWebSocket::new("wss://ws.predict.fun/ws".to_string());
assert_eq!(client.next_id(), 1);
assert_eq!(client.next_id(), 2);
assert_eq!(client.next_id(), 3);
}
fn wallet_push(data: serde_json::Value) -> PushMessage {
PushMessage {
topic: "predictWalletEvents/jwt123".to_string(),
data,
}
}
#[test]
fn test_parse_order_accepted() {
let push = wallet_push(serde_json::json!({
"type": "orderAccepted",
"orderId": "4170746",
"orderHash": "0xb5b5b676abcd"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::OrderAccepted { order_hash, order_id } => {
assert_eq!(order_hash, "0xb5b5b676abcd");
assert_eq!(order_id, "4170746");
}
other => panic!("Expected OrderAccepted, got {:?}", other),
}
}
#[test]
fn test_parse_order_transaction_submitted() {
let push = wallet_push(serde_json::json!({
"type": "orderTransactionSubmitted",
"orderId": 4170746,
"orderHash": "0xb5b5b676abcd",
"txHash": "0xdeadbeef"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::OrderTransactionSubmitted { order_hash, order_id, tx_hash, .. } => {
assert_eq!(order_hash, "0xb5b5b676abcd");
assert_eq!(order_id, "4170746");
assert_eq!(tx_hash, Some("0xdeadbeef".to_string()));
}
other => panic!("Expected OrderTransactionSubmitted, got {:?}", other),
}
}
#[test]
fn test_parse_order_transaction_success() {
let push = wallet_push(serde_json::json!({
"type": "orderTransactionSuccess",
"orderId": "4170746",
"txHash": "0xdeadbeef"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::OrderTransactionSuccess { order_hash, order_id, tx_hash, .. } => {
assert_eq!(order_hash, ""); assert_eq!(order_id, "4170746");
assert_eq!(tx_hash, Some("0xdeadbeef".to_string()));
}
other => panic!("Expected OrderTransactionSuccess, got {:?}", other),
}
}
#[test]
fn test_parse_order_not_accepted() {
let push = wallet_push(serde_json::json!({
"type": "orderNotAccepted",
"orderId": "123",
"orderHash": "0xabc",
"reason": "insufficient balance"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::OrderNotAccepted { order_hash, order_id, reason } => {
assert_eq!(order_hash, "0xabc");
assert_eq!(order_id, "123");
assert_eq!(reason, Some("insufficient balance".to_string()));
}
other => panic!("Expected OrderNotAccepted, got {:?}", other),
}
}
#[test]
fn test_parse_unknown_event_type() {
let push = wallet_push(serde_json::json!({
"type": "newEventType",
"foo": "bar"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::Unknown { event_type, .. } => {
assert_eq!(event_type, "newEventType");
}
other => panic!("Expected Unknown, got {:?}", other),
}
}
#[test]
fn test_parse_missing_type_field() {
let push = wallet_push(serde_json::json!({
"orderId": "123"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::Unknown { event_type, .. } => {
assert_eq!(event_type, "");
}
other => panic!("Expected Unknown, got {:?}", other),
}
}
#[test]
fn test_bigint_order_id_suffix_stripped() {
let push = wallet_push(serde_json::json!({
"type": "orderAccepted",
"orderId": "4175379n"
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::OrderAccepted { order_id, order_hash } => {
assert_eq!(order_id, "4175379"); assert_eq!(order_hash, ""); }
other => panic!("Expected OrderAccepted, got {:?}", other),
}
}
#[test]
fn test_parse_details_from_production_payload() {
let push = wallet_push(serde_json::json!({
"type": "orderTransactionSuccess",
"orderId": "4170746n",
"timestamp": 1769952855099u64,
"details": {
"categorySlug": "btc-usd-up-down-2026-02-01-08-30-15-minutes",
"marketQuestion": "BTC/USD Up or Down - February 1, 8:30-8:45AM ET",
"outcome": "YES",
"price": "0.290",
"quantity": "5.000",
"quantityFilled": "5.000",
"quoteType": "ASK",
"strategyType": "LIMIT",
"value": "1.45",
"valueFilled": "1.45"
}
}));
let event = parse_wallet_event(&push).unwrap();
match event {
PredictWalletEvent::OrderTransactionSuccess { order_id, details, .. } => {
assert_eq!(order_id, "4170746");
assert_eq!(details.price.as_deref(), Some("0.290"));
assert_eq!(details.quantity.as_deref(), Some("5.000"));
assert_eq!(details.quantity_filled.as_deref(), Some("5.000"));
assert_eq!(details.outcome.as_deref(), Some("YES"));
assert_eq!(details.quote_type.as_deref(), Some("ASK"));
}
other => panic!("Expected OrderTransactionSuccess, got {:?}", other),
}
}
#[test]
fn test_non_wallet_event_rejected() {
let push = PushMessage {
topic: "predictOrderbook/123".to_string(),
data: serde_json::json!({"type": "orderAccepted"}),
};
assert!(parse_wallet_event(&push).is_err());
}
}