use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct WsRequest {
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl WsRequest {
pub fn subscribe(request_id: u64, topics: Vec<String>) -> Self {
Self {
method: "subscribe".to_string(),
request_id: Some(request_id),
params: Some(topics),
data: None,
}
}
pub fn unsubscribe(request_id: u64, topics: Vec<String>) -> Self {
Self {
method: "unsubscribe".to_string(),
request_id: Some(request_id),
params: Some(topics),
data: None,
}
}
pub fn heartbeat(timestamp: u64) -> Self {
Self {
method: "heartbeat".to_string(),
request_id: None,
params: None,
data: Some(serde_json::Value::Number(timestamp.into())),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct RawWsMessage {
#[serde(rename = "type")]
pub msg_type: String,
#[serde(rename = "requestId")]
pub request_id: Option<u64>,
pub success: Option<bool>,
pub topic: Option<String>,
pub data: Option<serde_json::Value>,
pub error: Option<WsError>,
}
#[derive(Debug, Clone)]
pub enum WsMessage {
RequestResponse(RequestResponse),
PushMessage(PushMessage),
}
impl TryFrom<RawWsMessage> for WsMessage {
type Error = String;
fn try_from(raw: RawWsMessage) -> Result<Self, Self::Error> {
match raw.msg_type.as_str() {
"R" => Ok(WsMessage::RequestResponse(RequestResponse {
request_id: raw.request_id.ok_or("Missing request_id for type R")?,
success: raw.success.unwrap_or(false),
data: raw.data,
error: raw.error,
})),
"M" => Ok(WsMessage::PushMessage(PushMessage {
topic: raw.topic.ok_or("Missing topic for type M")?,
data: raw.data.unwrap_or(serde_json::Value::Null),
})),
other => Err(format!("Unknown message type: {}", other)),
}
}
}
#[derive(Debug, Clone)]
pub struct RequestResponse {
pub request_id: u64,
pub success: bool,
pub data: Option<serde_json::Value>,
pub error: Option<WsError>,
}
#[derive(Debug, Clone)]
pub struct PushMessage {
pub topic: String,
pub data: serde_json::Value,
}
impl PushMessage {
pub fn is_heartbeat(&self) -> bool {
self.topic == "heartbeat"
}
pub fn heartbeat_timestamp(&self) -> Option<u64> {
if self.is_heartbeat() {
self.data.as_u64()
} else {
None
}
}
pub fn is_orderbook(&self) -> bool {
self.topic.starts_with("predictOrderbook/")
}
pub fn orderbook_market_id(&self) -> Option<u64> {
if self.is_orderbook() {
self.topic
.strip_prefix("predictOrderbook/")
.and_then(|s| s.parse().ok())
} else {
None
}
}
pub fn is_asset_price(&self) -> bool {
self.topic.starts_with("assetPriceUpdate/")
}
pub fn asset_price_feed_id(&self) -> Option<&str> {
if self.is_asset_price() {
self.topic.strip_prefix("assetPriceUpdate/")
} else {
None
}
}
pub fn is_polymarket_chance(&self) -> bool {
self.topic.starts_with("polymarketChance/")
}
pub fn polymarket_chance_market_id(&self) -> Option<u64> {
if self.is_polymarket_chance() {
self.topic
.strip_prefix("polymarketChance/")
.and_then(|s| s.parse().ok())
} else {
None
}
}
pub fn is_kalshi_chance(&self) -> bool {
self.topic.starts_with("kalshiChance/")
}
pub fn kalshi_chance_market_id(&self) -> Option<u64> {
if self.is_kalshi_chance() {
self.topic
.strip_prefix("kalshiChance/")
.and_then(|s| s.parse().ok())
} else {
None
}
}
pub fn is_wallet_event(&self) -> bool {
self.topic.starts_with("predictWalletEvents/")
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct WsError {
pub code: String,
pub message: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OrderbookData {
pub market_id: u64,
pub bids: Vec<PriceLevel>,
pub asks: Vec<PriceLevel>,
#[serde(default)]
pub timestamp: Option<u64>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AssetPriceData {
pub price: f64,
pub publish_time: u64,
pub timestamp: u64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct PriceLevel {
#[serde(deserialize_with = "deserialize_decimal")]
pub price: Decimal,
#[serde(deserialize_with = "deserialize_decimal")]
pub size: Decimal,
}
fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrNumber {
String(String),
Number(f64),
}
match StringOrNumber::deserialize(deserializer)? {
StringOrNumber::String(s) => s.parse().map_err(D::Error::custom),
StringOrNumber::Number(n) => Decimal::try_from(n).map_err(D::Error::custom),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subscribe_request() {
let req = WsRequest::subscribe(1, vec!["predictOrderbook/123".to_string()]);
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"method\":\"subscribe\""));
assert!(json.contains("\"requestId\":1"));
assert!(json.contains("predictOrderbook/123"));
}
#[test]
fn test_heartbeat_request() {
let req = WsRequest::heartbeat(1736696400000);
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"method\":\"heartbeat\""));
assert!(json.contains("1736696400000"));
}
#[test]
fn test_parse_request_response() {
let json = r#"{"type":"R","requestId":1,"success":true,"data":null}"#;
let raw: RawWsMessage = serde_json::from_str(json).unwrap();
let msg = WsMessage::try_from(raw).unwrap();
match msg {
WsMessage::RequestResponse(resp) => {
assert_eq!(resp.request_id, 1);
assert!(resp.success);
}
_ => panic!("Expected RequestResponse"),
}
}
#[test]
fn test_parse_heartbeat_message() {
let json = r#"{"type":"M","topic":"heartbeat","data":1736696400000}"#;
let raw: RawWsMessage = serde_json::from_str(json).unwrap();
let msg = WsMessage::try_from(raw).unwrap();
match msg {
WsMessage::PushMessage(push) => {
assert!(push.is_heartbeat());
assert_eq!(push.heartbeat_timestamp(), Some(1736696400000));
}
_ => panic!("Expected PushMessage"),
}
}
#[test]
fn test_parse_orderbook_topic() {
let push = PushMessage {
topic: "predictOrderbook/5614".to_string(),
data: serde_json::Value::Null,
};
assert!(push.is_orderbook());
assert_eq!(push.orderbook_market_id(), Some(5614));
}
}