1use rust_decimal::Decimal;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize)]
10#[serde(rename_all = "camelCase")]
11pub struct WsRequest {
12 pub method: String,
14 #[serde(skip_serializing_if = "Option::is_none")]
16 pub request_id: Option<u64>,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub params: Option<Vec<String>>,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub data: Option<serde_json::Value>,
23}
24
25impl WsRequest {
26 pub fn subscribe(request_id: u64, topics: Vec<String>) -> Self {
28 Self {
29 method: "subscribe".to_string(),
30 request_id: Some(request_id),
31 params: Some(topics),
32 data: None,
33 }
34 }
35
36 pub fn unsubscribe(request_id: u64, topics: Vec<String>) -> Self {
38 Self {
39 method: "unsubscribe".to_string(),
40 request_id: Some(request_id),
41 params: Some(topics),
42 data: None,
43 }
44 }
45
46 pub fn heartbeat(timestamp: u64) -> Self {
48 Self {
49 method: "heartbeat".to_string(),
50 request_id: None,
51 params: None,
52 data: Some(serde_json::Value::Number(timestamp.into())),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Deserialize)]
59pub struct RawWsMessage {
60 #[serde(rename = "type")]
62 pub msg_type: String,
63 #[serde(rename = "requestId")]
65 pub request_id: Option<u64>,
66 pub success: Option<bool>,
68 pub topic: Option<String>,
70 pub data: Option<serde_json::Value>,
72 pub error: Option<WsError>,
74}
75
76#[derive(Debug, Clone)]
78pub enum WsMessage {
79 RequestResponse(RequestResponse),
81 PushMessage(PushMessage),
83}
84
85impl TryFrom<RawWsMessage> for WsMessage {
86 type Error = String;
87
88 fn try_from(raw: RawWsMessage) -> Result<Self, Self::Error> {
89 match raw.msg_type.as_str() {
90 "R" => Ok(WsMessage::RequestResponse(RequestResponse {
91 request_id: raw.request_id.ok_or("Missing request_id for type R")?,
92 success: raw.success.unwrap_or(false),
93 data: raw.data,
94 error: raw.error,
95 })),
96 "M" => Ok(WsMessage::PushMessage(PushMessage {
97 topic: raw.topic.ok_or("Missing topic for type M")?,
98 data: raw.data.unwrap_or(serde_json::Value::Null),
99 })),
100 other => Err(format!("Unknown message type: {}", other)),
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct RequestResponse {
108 pub request_id: u64,
110 pub success: bool,
112 pub data: Option<serde_json::Value>,
114 pub error: Option<WsError>,
116}
117
118#[derive(Debug, Clone)]
120pub struct PushMessage {
121 pub topic: String,
123 pub data: serde_json::Value,
125}
126
127impl PushMessage {
128 pub fn is_heartbeat(&self) -> bool {
130 self.topic == "heartbeat"
131 }
132
133 pub fn heartbeat_timestamp(&self) -> Option<u64> {
135 if self.is_heartbeat() {
136 self.data.as_u64()
137 } else {
138 None
139 }
140 }
141
142 pub fn is_orderbook(&self) -> bool {
144 self.topic.starts_with("predictOrderbook/")
145 }
146
147 pub fn orderbook_market_id(&self) -> Option<u64> {
149 if self.is_orderbook() {
150 self.topic
151 .strip_prefix("predictOrderbook/")
152 .and_then(|s| s.parse().ok())
153 } else {
154 None
155 }
156 }
157
158 pub fn is_asset_price(&self) -> bool {
160 self.topic.starts_with("assetPriceUpdate/")
161 }
162
163 pub fn asset_price_feed_id(&self) -> Option<&str> {
165 if self.is_asset_price() {
166 self.topic.strip_prefix("assetPriceUpdate/")
167 } else {
168 None
169 }
170 }
171
172 pub fn is_polymarket_chance(&self) -> bool {
174 self.topic.starts_with("polymarketChance/")
175 }
176
177 pub fn polymarket_chance_market_id(&self) -> Option<u64> {
179 if self.is_polymarket_chance() {
180 self.topic
181 .strip_prefix("polymarketChance/")
182 .and_then(|s| s.parse().ok())
183 } else {
184 None
185 }
186 }
187
188 pub fn is_kalshi_chance(&self) -> bool {
190 self.topic.starts_with("kalshiChance/")
191 }
192
193 pub fn kalshi_chance_market_id(&self) -> Option<u64> {
195 if self.is_kalshi_chance() {
196 self.topic
197 .strip_prefix("kalshiChance/")
198 .and_then(|s| s.parse().ok())
199 } else {
200 None
201 }
202 }
203
204 pub fn is_wallet_event(&self) -> bool {
206 self.topic.starts_with("predictWalletEvents/")
207 }
208}
209
210#[derive(Debug, Clone, Deserialize)]
212pub struct WsError {
213 pub code: String,
215 pub message: String,
217}
218
219#[derive(Debug, Clone, Deserialize)]
221#[serde(rename_all = "camelCase")]
222pub struct OrderbookData {
223 pub market_id: u64,
225 pub bids: Vec<PriceLevel>,
227 pub asks: Vec<PriceLevel>,
229 #[serde(default)]
231 pub timestamp: Option<u64>,
232}
233
234#[derive(Debug, Clone, Deserialize)]
236#[serde(rename_all = "camelCase")]
237pub struct AssetPriceData {
238 pub price: f64,
240 pub publish_time: u64,
242 pub timestamp: u64,
244}
245
246#[derive(Debug, Clone, Deserialize)]
248pub struct PriceLevel {
249 #[serde(deserialize_with = "deserialize_decimal")]
251 pub price: Decimal,
252 #[serde(deserialize_with = "deserialize_decimal")]
254 pub size: Decimal,
255}
256
257fn deserialize_decimal<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
259where
260 D: serde::Deserializer<'de>,
261{
262 use serde::de::Error;
263
264 #[derive(Deserialize)]
265 #[serde(untagged)]
266 enum StringOrNumber {
267 String(String),
268 Number(f64),
269 }
270
271 match StringOrNumber::deserialize(deserializer)? {
272 StringOrNumber::String(s) => s.parse().map_err(D::Error::custom),
273 StringOrNumber::Number(n) => Decimal::try_from(n).map_err(D::Error::custom),
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_subscribe_request() {
283 let req = WsRequest::subscribe(1, vec!["predictOrderbook/123".to_string()]);
284 let json = serde_json::to_string(&req).unwrap();
285 assert!(json.contains("\"method\":\"subscribe\""));
286 assert!(json.contains("\"requestId\":1"));
287 assert!(json.contains("predictOrderbook/123"));
288 }
289
290 #[test]
291 fn test_heartbeat_request() {
292 let req = WsRequest::heartbeat(1736696400000);
293 let json = serde_json::to_string(&req).unwrap();
294 assert!(json.contains("\"method\":\"heartbeat\""));
295 assert!(json.contains("1736696400000"));
296 }
297
298 #[test]
299 fn test_parse_request_response() {
300 let json = r#"{"type":"R","requestId":1,"success":true,"data":null}"#;
301 let raw: RawWsMessage = serde_json::from_str(json).unwrap();
302 let msg = WsMessage::try_from(raw).unwrap();
303 match msg {
304 WsMessage::RequestResponse(resp) => {
305 assert_eq!(resp.request_id, 1);
306 assert!(resp.success);
307 }
308 _ => panic!("Expected RequestResponse"),
309 }
310 }
311
312 #[test]
313 fn test_parse_heartbeat_message() {
314 let json = r#"{"type":"M","topic":"heartbeat","data":1736696400000}"#;
315 let raw: RawWsMessage = serde_json::from_str(json).unwrap();
316 let msg = WsMessage::try_from(raw).unwrap();
317 match msg {
318 WsMessage::PushMessage(push) => {
319 assert!(push.is_heartbeat());
320 assert_eq!(push.heartbeat_timestamp(), Some(1736696400000));
321 }
322 _ => panic!("Expected PushMessage"),
323 }
324 }
325
326 #[test]
327 fn test_parse_orderbook_topic() {
328 let push = PushMessage {
329 topic: "predictOrderbook/5614".to_string(),
330 data: serde_json::Value::Null,
331 };
332 assert!(push.is_orderbook());
333 assert_eq!(push.orderbook_market_id(), Some(5614));
334 }
335}