1use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9use dashmap::DashSet;
10use futures_util::{SinkExt, StreamExt};
11use tokio::sync::Mutex;
12use tracing::{debug, error, info, warn};
13use ws_reconnect_client::{connect_with_retry, Message, WsConnectionConfig, WsReader, WsWriter};
14
15use super::types::{AssetPriceData, OrderbookData, PushMessage, RawWsMessage, WsMessage, WsRequest};
16use crate::api_types::PredictWalletEvent;
17use crate::errors::{Error, Result};
18
19pub struct PredictWebSocket {
21 config: WsConnectionConfig,
22 subscribed_markets: DashSet<u64>,
23 writer: Arc<Mutex<Option<WsWriter>>>,
24 next_request_id: AtomicU64,
25}
26
27impl PredictWebSocket {
28 pub fn new(ws_url: String) -> Self {
30 let config = WsConnectionConfig::new(ws_url)
34 .with_ping_interval(0) .with_retries(10)
36 .with_backoff(1000, 30_000);
37
38 Self {
39 config,
40 subscribed_markets: DashSet::new(),
41 writer: Arc::new(Mutex::new(None)),
42 next_request_id: AtomicU64::new(1),
43 }
44 }
45
46 fn next_id(&self) -> u64 {
48 self.next_request_id.fetch_add(1, Ordering::SeqCst)
49 }
50
51 pub async fn connect(&self) -> Result<PredictWsStream> {
56 info!("Connecting to Predict WebSocket: {}", self.config.url);
57
58 let (writer, reader) = connect_with_retry(&self.config)
59 .await
60 .map_err(|e| Error::Other(format!("WebSocket connection failed: {}", e)))?;
61
62 {
64 let mut w = self.writer.lock().await;
65 *w = Some(writer);
66 }
67
68 info!("Connected to Predict WebSocket");
69
70 Ok(PredictWsStream {
71 reader,
72 writer: self.writer.clone(),
73 })
74 }
75
76 pub async fn subscribe_orderbook(&self, market_id: u64) -> Result<()> {
78 let topic = format!("predictOrderbook/{}", market_id);
79 let request_id = self.next_id();
80
81 let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
82 self.send_request(&request).await?;
83
84 self.subscribed_markets.insert(market_id);
85 info!("Subscribed to orderbook for market {}", market_id);
86
87 Ok(())
88 }
89
90 pub async fn unsubscribe_orderbook(&self, market_id: u64) -> Result<()> {
92 let topic = format!("predictOrderbook/{}", market_id);
93 let request_id = self.next_id();
94
95 let request = WsRequest::unsubscribe(request_id, vec![topic]);
96 self.send_request(&request).await?;
97
98 self.subscribed_markets.remove(&market_id);
99 info!("Unsubscribed from orderbook for market {}", market_id);
100
101 Ok(())
102 }
103
104 pub async fn subscribe_asset_price(&self, price_feed_id: &str) -> Result<()> {
111 let topic = format!("assetPriceUpdate/{}", price_feed_id);
112 let request_id = self.next_id();
113
114 let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
115 self.send_request(&request).await?;
116
117 info!("Subscribed to asset price for feed {}", price_feed_id);
118
119 Ok(())
120 }
121
122 pub async fn unsubscribe_asset_price(&self, price_feed_id: &str) -> Result<()> {
124 let topic = format!("assetPriceUpdate/{}", price_feed_id);
125 let request_id = self.next_id();
126
127 let request = WsRequest::unsubscribe(request_id, vec![topic]);
128 self.send_request(&request).await?;
129
130 info!("Unsubscribed from asset price for feed {}", price_feed_id);
131
132 Ok(())
133 }
134
135 pub async fn subscribe_polymarket_chance(&self, market_id: u64) -> Result<()> {
139 let topic = format!("polymarketChance/{}", market_id);
140 let request_id = self.next_id();
141
142 let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
143 self.send_request(&request).await?;
144
145 info!("Subscribed to Polymarket chance for market {}", market_id);
146
147 Ok(())
148 }
149
150 pub async fn subscribe_kalshi_chance(&self, market_id: u64) -> Result<()> {
154 let topic = format!("kalshiChance/{}", market_id);
155 let request_id = self.next_id();
156
157 let request = WsRequest::subscribe(request_id, vec![topic.clone()]);
158 self.send_request(&request).await?;
159
160 info!("Subscribed to Kalshi chance for market {}", market_id);
161
162 Ok(())
163 }
164
165 pub async fn subscribe_wallet_events(&self, jwt: &str) -> Result<()> {
176 let topic = format!("predictWalletEvents/{}", jwt);
177 let request_id = self.next_id();
178
179 let request = WsRequest::subscribe(request_id, vec![topic]);
180 self.send_request(&request).await?;
181
182 info!("Subscribed to wallet events");
183
184 Ok(())
185 }
186
187 pub async fn unsubscribe_wallet_events(&self, jwt: &str) -> Result<()> {
189 let topic = format!("predictWalletEvents/{}", jwt);
190 let request_id = self.next_id();
191
192 let request = WsRequest::unsubscribe(request_id, vec![topic]);
193 self.send_request(&request).await?;
194
195 info!("Unsubscribed from wallet events");
196
197 Ok(())
198 }
199
200 pub async fn send_heartbeat(&self, timestamp: u64) -> Result<()> {
202 let request = WsRequest::heartbeat(timestamp);
203 self.send_request(&request).await?;
204 debug!("Sent heartbeat response: {}", timestamp);
205 Ok(())
206 }
207
208 async fn send_request(&self, request: &WsRequest) -> Result<()> {
210 let json = serde_json::to_string(request)
211 .map_err(|e| Error::Other(format!("Failed to serialize request: {}", e)))?;
212
213 let mut writer_guard = self.writer.lock().await;
214 let writer = writer_guard
215 .as_mut()
216 .ok_or_else(|| Error::Other("WebSocket not connected".to_string()))?;
217
218 writer
219 .send(Message::Text(json.into()))
220 .await
221 .map_err(|e| Error::Other(format!("Failed to send message: {}", e)))?;
222
223 Ok(())
224 }
225
226 pub async fn reconnect(&self) -> Result<PredictWsStream> {
231 {
232 let mut w = self.writer.lock().await;
233 *w = None;
234 }
235 self.connect().await
236 }
237
238 pub fn config(&self) -> &WsConnectionConfig {
240 &self.config
241 }
242
243 pub async fn is_connected(&self) -> bool {
245 self.writer.lock().await.is_some()
246 }
247
248 pub fn subscribed_markets(&self) -> Vec<u64> {
250 self.subscribed_markets.iter().map(|r| *r).collect()
251 }
252
253 pub fn writer(&self) -> Arc<Mutex<Option<WsWriter>>> {
255 self.writer.clone()
256 }
257}
258
259pub struct PredictWsStream {
261 reader: WsReader,
262 writer: Arc<Mutex<Option<WsWriter>>>,
263}
264
265impl PredictWsStream {
266 pub async fn next(&mut self) -> Option<Result<WsMessage>> {
271 loop {
272 match self.reader.next().await {
273 Some(Ok(Message::Text(text))) => {
274 match self.parse_message(&text).await {
275 Ok(Some(msg)) => return Some(Ok(msg)),
276 Ok(None) => continue, Err(e) => return Some(Err(e)),
278 }
279 }
280 Some(Ok(Message::Ping(data))) => {
281 if let Err(e) = self.send_pong(data.to_vec()).await {
283 warn!("Failed to send pong: {}", e);
284 }
285 continue;
286 }
287 Some(Ok(Message::Pong(_))) => {
288 continue;
290 }
291 Some(Ok(Message::Close(frame))) => {
292 info!("WebSocket closed: {:?}", frame);
293 return None;
294 }
295 Some(Ok(Message::Binary(_))) => {
296 warn!("Received unexpected binary message");
297 continue;
298 }
299 Some(Ok(Message::Frame(_))) => {
300 continue;
302 }
303 Some(Err(e)) => {
304 error!("WebSocket error: {}", e);
305 return Some(Err(Error::Other(format!("WebSocket error: {}", e))));
306 }
307 None => {
308 info!("WebSocket stream ended");
309 return None;
310 }
311 }
312 }
313 }
314
315 async fn parse_message(&mut self, text: &str) -> Result<Option<WsMessage>> {
317 let raw: RawWsMessage = serde_json::from_str(text)
318 .map_err(|e| Error::Other(format!("Failed to parse message: {} - {}", e, text)))?;
319
320 let msg = WsMessage::try_from(raw)
321 .map_err(|e| Error::Other(format!("Failed to convert message: {}", e)))?;
322
323 if let WsMessage::PushMessage(ref push) = msg {
325 if let Some(timestamp) = push.heartbeat_timestamp() {
326 self.send_heartbeat(timestamp).await?;
327 return Ok(None); }
329 }
330
331 Ok(Some(msg))
332 }
333
334 async fn send_heartbeat(&mut self, timestamp: u64) -> Result<()> {
336 let request = WsRequest::heartbeat(timestamp);
337 let json = serde_json::to_string(&request)
338 .map_err(|e| Error::Other(format!("Failed to serialize heartbeat: {}", e)))?;
339
340 let mut writer_guard = self.writer.lock().await;
341 if let Some(writer) = writer_guard.as_mut() {
342 writer
343 .send(Message::Text(json.into()))
344 .await
345 .map_err(|e| Error::Other(format!("Failed to send heartbeat: {}", e)))?;
346 debug!("Sent heartbeat response: {}", timestamp);
347 }
348
349 Ok(())
350 }
351
352 async fn send_pong(&mut self, data: Vec<u8>) -> Result<()> {
354 let mut writer_guard = self.writer.lock().await;
355 if let Some(writer) = writer_guard.as_mut() {
356 writer
357 .send(Message::Pong(data.into()))
358 .await
359 .map_err(|e| Error::Other(format!("Failed to send pong: {}", e)))?;
360 }
361 Ok(())
362 }
363}
364
365pub fn parse_orderbook_update(push: &PushMessage) -> Result<OrderbookData> {
367 if !push.is_orderbook() {
368 return Err(Error::Other("Not an orderbook message".to_string()));
369 }
370
371 serde_json::from_value(push.data.clone())
372 .map_err(|e| Error::Other(format!("Failed to parse orderbook data: {}", e)))
373}
374
375pub fn parse_asset_price_update(push: &PushMessage) -> Result<AssetPriceData> {
377 if !push.is_asset_price() {
378 return Err(Error::Other("Not an asset price message".to_string()));
379 }
380
381 serde_json::from_value(push.data.clone())
382 .map_err(|e| Error::Other(format!("Failed to parse asset price data: {}", e)))
383}
384
385pub fn parse_wallet_event(push: &PushMessage) -> Result<PredictWalletEvent> {
390 if !push.is_wallet_event() {
391 return Err(Error::Other("Not a wallet event message".to_string()));
392 }
393
394 let event_type = push
395 .data
396 .get("type")
397 .and_then(|v| v.as_str())
398 .unwrap_or("")
399 .to_string();
400
401 let order_hash = push
402 .data
403 .get("orderHash")
404 .and_then(|v| v.as_str())
405 .unwrap_or("")
406 .to_string();
407
408 let order_id = push
411 .data
412 .get("orderId")
413 .map(|v| match v {
414 serde_json::Value::String(s) => s.strip_suffix('n').unwrap_or(s).to_string(),
415 serde_json::Value::Number(n) => n.to_string(),
416 _ => String::new(),
417 })
418 .unwrap_or_default();
419
420 let tx_hash = push
421 .data
422 .get("txHash")
423 .and_then(|v| v.as_str())
424 .map(|s| s.to_string());
425
426 let reason = push
427 .data
428 .get("reason")
429 .and_then(|v| v.as_str())
430 .map(|s| s.to_string());
431
432 let details = push.data.get("details").map(|d| {
434 use crate::WalletEventDetails;
435 WalletEventDetails {
436 price: d.get("price").and_then(|v| v.as_str()).map(|s| s.to_string()),
437 quantity: d.get("quantity").and_then(|v| v.as_str()).map(|s| s.to_string()),
438 quantity_filled: d.get("quantityFilled").and_then(|v| v.as_str()).map(|s| s.to_string()),
439 outcome: d.get("outcome").and_then(|v| v.as_str()).map(|s| s.to_string()),
440 quote_type: d.get("quoteType").and_then(|v| v.as_str()).map(|s| s.to_string()),
441 }
442 }).unwrap_or_default();
443
444 match event_type.as_str() {
445 "orderAccepted" => Ok(PredictWalletEvent::OrderAccepted { order_hash, order_id }),
446 "orderNotAccepted" => Ok(PredictWalletEvent::OrderNotAccepted {
447 order_hash,
448 order_id,
449 reason,
450 }),
451 "orderExpired" => Ok(PredictWalletEvent::OrderExpired { order_hash, order_id }),
452 "orderCancelled" => Ok(PredictWalletEvent::OrderCancelled { order_hash, order_id }),
453 "orderTransactionSubmitted" => Ok(PredictWalletEvent::OrderTransactionSubmitted {
454 order_hash,
455 order_id,
456 tx_hash,
457 details,
458 }),
459 "orderTransactionSuccess" => Ok(PredictWalletEvent::OrderTransactionSuccess {
460 order_hash,
461 order_id,
462 tx_hash,
463 details,
464 }),
465 "orderTransactionFailed" => Ok(PredictWalletEvent::OrderTransactionFailed {
466 order_hash,
467 order_id,
468 tx_hash,
469 details,
470 }),
471 _ => Ok(PredictWalletEvent::Unknown {
472 event_type,
473 data: push.data.clone(),
474 }),
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_client_creation() {
484 let client = PredictWebSocket::new("wss://ws.predict.fun/ws".to_string());
485 assert!(client.subscribed_markets().is_empty());
486 }
487
488 #[test]
489 fn test_request_id_increment() {
490 let client = PredictWebSocket::new("wss://ws.predict.fun/ws".to_string());
491 assert_eq!(client.next_id(), 1);
492 assert_eq!(client.next_id(), 2);
493 assert_eq!(client.next_id(), 3);
494 }
495
496 fn wallet_push(data: serde_json::Value) -> PushMessage {
497 PushMessage {
498 topic: "predictWalletEvents/jwt123".to_string(),
499 data,
500 }
501 }
502
503 #[test]
504 fn test_parse_order_accepted() {
505 let push = wallet_push(serde_json::json!({
506 "type": "orderAccepted",
507 "orderId": "4170746",
508 "orderHash": "0xb5b5b676abcd"
509 }));
510 let event = parse_wallet_event(&push).unwrap();
511 match event {
512 PredictWalletEvent::OrderAccepted { order_hash, order_id } => {
513 assert_eq!(order_hash, "0xb5b5b676abcd");
514 assert_eq!(order_id, "4170746");
515 }
516 other => panic!("Expected OrderAccepted, got {:?}", other),
517 }
518 }
519
520 #[test]
521 fn test_parse_order_transaction_submitted() {
522 let push = wallet_push(serde_json::json!({
523 "type": "orderTransactionSubmitted",
524 "orderId": 4170746,
525 "orderHash": "0xb5b5b676abcd",
526 "txHash": "0xdeadbeef"
527 }));
528 let event = parse_wallet_event(&push).unwrap();
529 match event {
530 PredictWalletEvent::OrderTransactionSubmitted { order_hash, order_id, tx_hash, .. } => {
531 assert_eq!(order_hash, "0xb5b5b676abcd");
532 assert_eq!(order_id, "4170746");
533 assert_eq!(tx_hash, Some("0xdeadbeef".to_string()));
534 }
535 other => panic!("Expected OrderTransactionSubmitted, got {:?}", other),
536 }
537 }
538
539 #[test]
540 fn test_parse_order_transaction_success() {
541 let push = wallet_push(serde_json::json!({
542 "type": "orderTransactionSuccess",
543 "orderId": "4170746",
544 "txHash": "0xdeadbeef"
545 }));
546 let event = parse_wallet_event(&push).unwrap();
547 match event {
548 PredictWalletEvent::OrderTransactionSuccess { order_hash, order_id, tx_hash, .. } => {
549 assert_eq!(order_hash, ""); assert_eq!(order_id, "4170746");
551 assert_eq!(tx_hash, Some("0xdeadbeef".to_string()));
552 }
553 other => panic!("Expected OrderTransactionSuccess, got {:?}", other),
554 }
555 }
556
557 #[test]
558 fn test_parse_order_not_accepted() {
559 let push = wallet_push(serde_json::json!({
560 "type": "orderNotAccepted",
561 "orderId": "123",
562 "orderHash": "0xabc",
563 "reason": "insufficient balance"
564 }));
565 let event = parse_wallet_event(&push).unwrap();
566 match event {
567 PredictWalletEvent::OrderNotAccepted { order_hash, order_id, reason } => {
568 assert_eq!(order_hash, "0xabc");
569 assert_eq!(order_id, "123");
570 assert_eq!(reason, Some("insufficient balance".to_string()));
571 }
572 other => panic!("Expected OrderNotAccepted, got {:?}", other),
573 }
574 }
575
576 #[test]
577 fn test_parse_unknown_event_type() {
578 let push = wallet_push(serde_json::json!({
579 "type": "newEventType",
580 "foo": "bar"
581 }));
582 let event = parse_wallet_event(&push).unwrap();
583 match event {
584 PredictWalletEvent::Unknown { event_type, .. } => {
585 assert_eq!(event_type, "newEventType");
586 }
587 other => panic!("Expected Unknown, got {:?}", other),
588 }
589 }
590
591 #[test]
592 fn test_parse_missing_type_field() {
593 let push = wallet_push(serde_json::json!({
595 "orderId": "123"
596 }));
597 let event = parse_wallet_event(&push).unwrap();
598 match event {
599 PredictWalletEvent::Unknown { event_type, .. } => {
600 assert_eq!(event_type, "");
601 }
602 other => panic!("Expected Unknown, got {:?}", other),
603 }
604 }
605
606 #[test]
607 fn test_bigint_order_id_suffix_stripped() {
608 let push = wallet_push(serde_json::json!({
610 "type": "orderAccepted",
611 "orderId": "4175379n"
612 }));
613 let event = parse_wallet_event(&push).unwrap();
614 match event {
615 PredictWalletEvent::OrderAccepted { order_id, order_hash } => {
616 assert_eq!(order_id, "4175379"); assert_eq!(order_hash, ""); }
619 other => panic!("Expected OrderAccepted, got {:?}", other),
620 }
621 }
622
623 #[test]
624 fn test_parse_details_from_production_payload() {
625 let push = wallet_push(serde_json::json!({
627 "type": "orderTransactionSuccess",
628 "orderId": "4170746n",
629 "timestamp": 1769952855099u64,
630 "details": {
631 "categorySlug": "btc-usd-up-down-2026-02-01-08-30-15-minutes",
632 "marketQuestion": "BTC/USD Up or Down - February 1, 8:30-8:45AM ET",
633 "outcome": "YES",
634 "price": "0.290",
635 "quantity": "5.000",
636 "quantityFilled": "5.000",
637 "quoteType": "ASK",
638 "strategyType": "LIMIT",
639 "value": "1.45",
640 "valueFilled": "1.45"
641 }
642 }));
643 let event = parse_wallet_event(&push).unwrap();
644 match event {
645 PredictWalletEvent::OrderTransactionSuccess { order_id, details, .. } => {
646 assert_eq!(order_id, "4170746");
647 assert_eq!(details.price.as_deref(), Some("0.290"));
648 assert_eq!(details.quantity.as_deref(), Some("5.000"));
649 assert_eq!(details.quantity_filled.as_deref(), Some("5.000"));
650 assert_eq!(details.outcome.as_deref(), Some("YES"));
651 assert_eq!(details.quote_type.as_deref(), Some("ASK"));
652 }
653 other => panic!("Expected OrderTransactionSuccess, got {:?}", other),
654 }
655 }
656
657 #[test]
658 fn test_non_wallet_event_rejected() {
659 let push = PushMessage {
660 topic: "predictOrderbook/123".to_string(),
661 data: serde_json::json!({"type": "orderAccepted"}),
662 };
663 assert!(parse_wallet_event(&push).is_err());
664 }
665}