1use std::sync::Arc;
2
3use futures::StreamExt;
4use tokio::sync::{broadcast, Mutex};
5use tokio_tungstenite::{connect_async, tungstenite::Message};
6use tracing::{debug, warn};
7
8use px_core::error::WebSocketError;
9use px_core::models::{CryptoPrice, CryptoPriceSource};
10use px_core::websocket::{
11 AtomicWebSocketState, CryptoPriceStream, WebSocketState, WS_CRYPTO_PING_INTERVAL,
12 WS_MAX_RECONNECT_ATTEMPTS, WS_RECONNECT_BASE_DELAY, WS_RECONNECT_MAX_DELAY,
13};
14
15const CRYPTO_WS_URL: &str = "wss://ws-live-data.polymarket.com";
16const BROADCAST_CAPACITY: usize = 16_384;
17
18#[derive(Debug, Clone)]
20struct Subscription {
21 source: CryptoPriceSource,
22 symbols: Vec<String>,
23}
24
25#[derive(serde::Deserialize)]
27struct Envelope {
28 topic: String,
29 #[serde(default)]
30 #[allow(dead_code)]
31 r#type: String,
32 #[allow(dead_code)]
33 #[serde(default)]
34 timestamp: Option<u64>,
35 payload: serde_json::Value,
36}
37
38#[derive(serde::Deserialize)]
40struct PricePayload {
41 symbol: String,
42 timestamp: u64,
43 value: f64,
44}
45
46fn topic_for_source(source: CryptoPriceSource) -> &'static str {
47 match source {
48 CryptoPriceSource::Binance => "crypto_prices",
49 CryptoPriceSource::Chainlink => "crypto_prices_chainlink",
50 }
51}
52
53fn source_from_topic(topic: &str) -> Option<CryptoPriceSource> {
54 match topic {
55 "crypto_prices" => Some(CryptoPriceSource::Binance),
56 "crypto_prices_chainlink" => Some(CryptoPriceSource::Chainlink),
57 _ => None,
58 }
59}
60
61fn build_subscribe_msg(source: CryptoPriceSource, symbols: &[String]) -> String {
62 let topic = topic_for_source(source);
63 if symbols.is_empty() {
64 return serde_json::json!({
65 "action": "subscribe",
66 "subscriptions": [{
67 "topic": topic,
68 "type": "*",
69 "filters": "",
70 }]
71 })
72 .to_string();
73 }
74 let subs: Vec<serde_json::Value> = symbols
75 .iter()
76 .map(|sym| {
77 let filter = serde_json::json!({ "symbol": sym }).to_string();
78 serde_json::json!({
79 "topic": topic,
80 "type": "*",
81 "filters": filter,
82 })
83 })
84 .collect();
85 serde_json::json!({
86 "action": "subscribe",
87 "subscriptions": subs,
88 })
89 .to_string()
90}
91
92fn build_unsubscribe_msg(source: CryptoPriceSource, symbols: &[String]) -> String {
93 let topic = topic_for_source(source);
94 if symbols.is_empty() {
95 return serde_json::json!({
96 "action": "unsubscribe",
97 "subscriptions": [{
98 "topic": topic,
99 "type": "*",
100 "filters": "",
101 }]
102 })
103 .to_string();
104 }
105 let subs: Vec<serde_json::Value> = symbols
106 .iter()
107 .map(|sym| {
108 let filter = serde_json::json!({ "symbol": sym }).to_string();
109 serde_json::json!({
110 "topic": topic,
111 "type": "*",
112 "filters": filter,
113 })
114 })
115 .collect();
116 serde_json::json!({
117 "action": "unsubscribe",
118 "subscriptions": subs,
119 })
120 .to_string()
121}
122
123pub struct CryptoPriceWebSocket {
128 state: Arc<AtomicWebSocketState>,
129 sender: broadcast::Sender<Result<CryptoPrice, WebSocketError>>,
130 write_tx: Arc<Mutex<Option<futures::channel::mpsc::UnboundedSender<Message>>>>,
131 shutdown_tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
132 subscriptions: Arc<Mutex<Vec<Subscription>>>,
133}
134
135impl CryptoPriceWebSocket {
136 pub fn new() -> Self {
137 let (sender, _) = broadcast::channel(BROADCAST_CAPACITY);
138 Self {
139 state: Arc::new(AtomicWebSocketState::new(WebSocketState::Disconnected)),
140 sender,
141 write_tx: Arc::new(Mutex::new(None)),
142 shutdown_tx: Arc::new(Mutex::new(None)),
143 subscriptions: Arc::new(Mutex::new(Vec::new())),
144 }
145 }
146
147 pub fn state(&self) -> WebSocketState {
148 self.state.load()
149 }
150
151 pub fn stream(&self) -> CryptoPriceStream {
152 let rx = self.sender.subscribe();
153 Box::pin(
154 tokio_stream::wrappers::BroadcastStream::new(rx)
155 .filter_map(|result| async move { result.ok() }),
156 )
157 }
158
159 pub async fn subscribe(
162 &self,
163 source: CryptoPriceSource,
164 symbols: &[String],
165 ) -> Result<(), WebSocketError> {
166 let msg = build_subscribe_msg(source, symbols);
167 let write_tx = self.write_tx.lock().await;
168 if let Some(ref tx) = *write_tx {
169 tx.unbounded_send(Message::Text(msg))
170 .map_err(|e| WebSocketError::Connection(e.to_string()))?;
171 } else {
172 return Err(WebSocketError::Connection("not connected".to_string()));
173 }
174
175 let mut subs = self.subscriptions.lock().await;
176 subs.push(Subscription {
177 source,
178 symbols: symbols.to_vec(),
179 });
180
181 Ok(())
182 }
183
184 pub async fn unsubscribe(
186 &self,
187 source: CryptoPriceSource,
188 symbols: &[String],
189 ) -> Result<(), WebSocketError> {
190 let msg = build_unsubscribe_msg(source, symbols);
191 let write_tx = self.write_tx.lock().await;
192 if let Some(ref tx) = *write_tx {
193 tx.unbounded_send(Message::Text(msg))
194 .map_err(|e| WebSocketError::Connection(e.to_string()))?;
195 }
196
197 let mut subs = self.subscriptions.lock().await;
198 subs.retain(|s| !(s.source == source && s.symbols == symbols));
199
200 Ok(())
201 }
202
203 pub async fn connect(&mut self) -> Result<(), WebSocketError> {
204 self.state.store(WebSocketState::Connecting);
205
206 let (ws_stream, _) = connect_async(CRYPTO_WS_URL)
207 .await
208 .map_err(|e| WebSocketError::Connection(e.to_string()))?;
209
210 let (write, read) = ws_stream.split();
211 let (tx, rx) = futures::channel::mpsc::unbounded::<Message>();
212
213 {
214 let mut write_tx = self.write_tx.lock().await;
215 *write_tx = Some(tx);
216 }
217
218 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
219 {
220 let mut stx = self.shutdown_tx.lock().await;
221 *stx = Some(shutdown_tx);
222 }
223
224 let state = self.state.clone();
225 let sender = self.sender.clone();
226 let write_tx_clone = self.write_tx.clone();
227 let subscriptions = self.subscriptions.clone();
228
229 tokio::spawn(async move {
230 let write_future = rx.map(Ok).forward(write);
231
232 let read_future = {
233 let sender = sender.clone();
234 let write_tx_clone = write_tx_clone.clone();
235 async move {
236 let mut read = read;
237 while let Some(msg) = read.next().await {
238 handle_message(msg, &sender, &write_tx_clone).await;
239 }
240 }
241 };
242
243 let ping_future = {
244 let write_tx_clone = write_tx_clone.clone();
245 async move {
246 let mut interval = tokio::time::interval(WS_CRYPTO_PING_INTERVAL);
247 loop {
248 interval.tick().await;
249 let tx = write_tx_clone.lock().await;
250 if let Some(ref tx) = *tx {
251 if tx.unbounded_send(Message::Text("PING".into())).is_err() {
252 break;
253 }
254 } else {
255 break;
256 }
257 }
258 }
259 };
260
261 tokio::select! {
262 _ = write_future => {},
263 _ = read_future => {},
264 _ = ping_future => {},
265 _ = shutdown_rx => {},
266 }
267
268 if state.load() == WebSocketState::Closed {
269 return;
270 }
271 state.store(WebSocketState::Disconnected);
272
273 let mut attempt = 1u32;
275 while attempt <= WS_MAX_RECONNECT_ATTEMPTS {
276 state.store(WebSocketState::Reconnecting);
277
278 let delay = calculate_reconnect_delay(attempt);
279 warn!(
280 attempt,
281 delay_ms = delay.as_millis() as u64,
282 "reconnecting crypto websocket"
283 );
284 tokio::time::sleep(delay).await;
285
286 match connect_async(CRYPTO_WS_URL).await {
287 Ok((new_ws, _)) => {
288 let (new_write, new_read) = new_ws.split();
289 let (new_tx, new_rx) = futures::channel::mpsc::unbounded::<Message>();
290
291 {
292 let mut wtx = write_tx_clone.lock().await;
293 *wtx = Some(new_tx);
294 }
295
296 state.store(WebSocketState::Connected);
297 attempt = 0;
298
299 {
301 let subs = subscriptions.lock().await;
302 let wtx = write_tx_clone.lock().await;
303 if let Some(ref tx) = *wtx {
304 for sub in subs.iter() {
305 let msg = build_subscribe_msg(sub.source, &sub.symbols);
306 let _ = tx.unbounded_send(Message::Text(msg));
307 }
308 }
309 }
310
311 let sender_clone = sender.clone();
312 let wtx_clone = write_tx_clone.clone();
313
314 let write_future = new_rx.map(Ok).forward(new_write);
315
316 let read_future = {
317 let sender = sender_clone;
318 let write_tx = wtx_clone.clone();
319 async move {
320 let mut read = new_read;
321 while let Some(msg) = read.next().await {
322 handle_message(msg, &sender, &write_tx).await;
323 }
324 }
325 };
326
327 let ping_future = {
328 let write_tx = wtx_clone;
329 async move {
330 let mut interval = tokio::time::interval(WS_CRYPTO_PING_INTERVAL);
331 loop {
332 interval.tick().await;
333 let tx = write_tx.lock().await;
334 if let Some(ref tx) = *tx {
335 if tx.unbounded_send(Message::Text("PING".into())).is_err()
336 {
337 break;
338 }
339 } else {
340 break;
341 }
342 }
343 }
344 };
345
346 tokio::select! {
347 _ = write_future => {},
348 _ = read_future => {},
349 _ = ping_future => {},
350 }
351
352 if state.load() == WebSocketState::Closed {
353 return;
354 }
355
356 attempt += 1;
357 }
358 Err(_) => {
359 attempt += 1;
360 }
361 }
362 }
363
364 state.store(WebSocketState::Disconnected);
365 });
366
367 self.state.store(WebSocketState::Connected);
368 Ok(())
369 }
370
371 pub async fn disconnect(&mut self) -> Result<(), WebSocketError> {
372 self.state.store(WebSocketState::Closed);
373 if let Some(tx) = self.shutdown_tx.lock().await.take() {
374 let _ = tx.send(());
375 }
376 Ok(())
377 }
378}
379
380impl Default for CryptoPriceWebSocket {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386async fn handle_message(
387 msg: Result<Message, tokio_tungstenite::tungstenite::Error>,
388 sender: &broadcast::Sender<Result<CryptoPrice, WebSocketError>>,
389 write_tx: &Arc<Mutex<Option<futures::channel::mpsc::UnboundedSender<Message>>>>,
390) {
391 match msg {
392 Ok(Message::Text(text)) => {
393 if text == "PONG" {
395 return;
396 }
397
398 let envelope: Envelope = match serde_json::from_str(&text) {
399 Ok(e) => e,
400 Err(e) => {
401 debug!(raw = %text, error = %e, "skipping non-envelope message");
402 return;
403 }
404 };
405
406 let source = match source_from_topic(&envelope.topic) {
407 Some(s) => s,
408 None => {
409 debug!(topic = %envelope.topic, "skipping unknown topic");
410 return;
411 }
412 };
413
414 let payload: PricePayload = match serde_json::from_value(envelope.payload) {
415 Ok(p) => p,
416 Err(e) => {
417 debug!(error = %e, "skipping malformed price payload");
418 return;
419 }
420 };
421
422 let price = CryptoPrice {
423 symbol: payload.symbol,
424 timestamp: payload.timestamp,
425 value: payload.value,
426 source,
427 };
428
429 let _ = sender.send(Ok(price));
430 }
431 Ok(Message::Ping(data)) => {
432 if let Some(ref tx) = *write_tx.lock().await {
433 let _ = tx.unbounded_send(Message::Pong(data));
434 }
435 }
436 Ok(Message::Close(_)) | Err(_) => {}
437 _ => {}
438 }
439}
440
441fn calculate_reconnect_delay(attempt: u32) -> std::time::Duration {
442 let delay = WS_RECONNECT_BASE_DELAY.as_millis() as f64 * 1.5_f64.powi(attempt as i32);
443 let delay = delay.min(WS_RECONNECT_MAX_DELAY.as_millis() as f64) as u64;
444 std::time::Duration::from_millis(delay)
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use serde_json::json;
451
452 #[test]
453 fn deserialize_binance_envelope() {
454 let data = json!({
455 "topic": "crypto_prices",
456 "type": "update",
457 "timestamp": 1700000000,
458 "payload": {
459 "symbol": "btcusdt",
460 "timestamp": 1700000000u64,
461 "value": 43250.5
462 }
463 });
464
465 let envelope: Envelope = serde_json::from_value(data).expect("should deserialize");
466 assert_eq!(envelope.topic, "crypto_prices");
467
468 let source = source_from_topic(&envelope.topic).unwrap();
469 assert_eq!(source, CryptoPriceSource::Binance);
470
471 let payload: PricePayload =
472 serde_json::from_value(envelope.payload).expect("should deserialize payload");
473 assert_eq!(payload.symbol, "btcusdt");
474 assert_eq!(payload.timestamp, 1700000000);
475 assert!((payload.value - 43250.5).abs() < f64::EPSILON);
476 }
477
478 #[test]
479 fn deserialize_chainlink_envelope() {
480 let data = json!({
481 "topic": "crypto_prices_chainlink",
482 "type": "update",
483 "timestamp": 1700000001,
484 "payload": {
485 "symbol": "eth/usd",
486 "timestamp": 1700000001u64,
487 "value": 2250.75
488 }
489 });
490
491 let envelope: Envelope = serde_json::from_value(data).expect("should deserialize");
492 assert_eq!(envelope.topic, "crypto_prices_chainlink");
493
494 let source = source_from_topic(&envelope.topic).unwrap();
495 assert_eq!(source, CryptoPriceSource::Chainlink);
496
497 let payload: PricePayload =
498 serde_json::from_value(envelope.payload).expect("should deserialize payload");
499 assert_eq!(payload.symbol, "eth/usd");
500 assert!((payload.value - 2250.75).abs() < f64::EPSILON);
501 }
502
503 #[test]
504 fn serialize_binance_subscribe() {
505 let msg = build_subscribe_msg(
506 CryptoPriceSource::Binance,
507 &["btcusdt".into(), "ethusdt".into()],
508 );
509 let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
510 assert_eq!(parsed["action"], "subscribe");
511 assert_eq!(parsed["subscriptions"].as_array().unwrap().len(), 2);
513 assert_eq!(parsed["subscriptions"][0]["topic"], "crypto_prices");
514 assert_eq!(parsed["subscriptions"][0]["type"], "*");
515 let f0: serde_json::Value =
516 serde_json::from_str(parsed["subscriptions"][0]["filters"].as_str().unwrap())
517 .expect("filters should be valid JSON");
518 assert_eq!(f0["symbol"], "btcusdt");
519 let f1: serde_json::Value =
520 serde_json::from_str(parsed["subscriptions"][1]["filters"].as_str().unwrap())
521 .expect("filters should be valid JSON");
522 assert_eq!(f1["symbol"], "ethusdt");
523 }
524
525 #[test]
526 fn serialize_chainlink_subscribe() {
527 let msg = build_subscribe_msg(CryptoPriceSource::Chainlink, &["eth/usd".into()]);
528 let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
529 assert_eq!(parsed["action"], "subscribe");
530 assert_eq!(
531 parsed["subscriptions"][0]["topic"],
532 "crypto_prices_chainlink"
533 );
534 assert_eq!(parsed["subscriptions"][0]["type"], "*");
535 let filters: serde_json::Value =
536 serde_json::from_str(parsed["subscriptions"][0]["filters"].as_str().unwrap())
537 .expect("filters should be valid JSON");
538 assert_eq!(filters["symbol"], "eth/usd");
539 }
540
541 #[test]
542 fn serialize_binance_subscribe_all() {
543 let msg = build_subscribe_msg(CryptoPriceSource::Binance, &[]);
544 let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
545 assert_eq!(parsed["subscriptions"][0]["type"], "*");
546 assert_eq!(parsed["subscriptions"][0]["filters"], "");
547 }
548
549 #[test]
550 fn serialize_unsubscribe() {
551 let msg = build_unsubscribe_msg(CryptoPriceSource::Binance, &["btcusdt".into()]);
552 let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
553 assert_eq!(parsed["action"], "unsubscribe");
554 assert_eq!(parsed["subscriptions"][0]["topic"], "crypto_prices");
555 let filters: serde_json::Value =
556 serde_json::from_str(parsed["subscriptions"][0]["filters"].as_str().unwrap())
557 .expect("filters should be valid JSON");
558 assert_eq!(filters["symbol"], "btcusdt");
559 }
560
561 #[test]
562 fn ping_is_not_valid_price() {
563 let result = serde_json::from_str::<Envelope>("PING");
564 assert!(result.is_err());
565 }
566
567 #[test]
568 fn unknown_topic_returns_none() {
569 assert!(source_from_topic("unknown_topic").is_none());
570 }
571
572 #[test]
573 fn topic_round_trip() {
574 assert_eq!(
575 source_from_topic(topic_for_source(CryptoPriceSource::Binance)),
576 Some(CryptoPriceSource::Binance)
577 );
578 assert_eq!(
579 source_from_topic(topic_for_source(CryptoPriceSource::Chainlink)),
580 Some(CryptoPriceSource::Chainlink)
581 );
582 }
583}