Skip to main content

px_crypto/
lib.rs

1use std::sync::Arc;
2
3use futures::StreamExt;
4use tokio::sync::{broadcast, Mutex};
5use tokio_tungstenite::{connect_async, tungstenite::Message};
6use tracing::{debug, warn};
7
8use px_core::error::WebSocketError;
9use px_core::models::{CryptoPrice, CryptoPriceSource};
10use px_core::websocket::{
11    AtomicWebSocketState, CryptoPriceStream, WebSocketState, WS_CRYPTO_PING_INTERVAL,
12    WS_MAX_RECONNECT_ATTEMPTS, WS_RECONNECT_BASE_DELAY, WS_RECONNECT_MAX_DELAY,
13};
14
15const CRYPTO_WS_URL: &str = "wss://ws-live-data.polymarket.com";
16const BROADCAST_CAPACITY: usize = 16_384;
17
18/// A stored subscription for replay on reconnect.
19#[derive(Debug, Clone)]
20struct Subscription {
21    source: CryptoPriceSource,
22    symbols: Vec<String>,
23}
24
25/// Outer envelope from the WebSocket.
26#[derive(serde::Deserialize)]
27struct Envelope {
28    topic: String,
29    #[serde(default)]
30    #[allow(dead_code)]
31    r#type: String,
32    #[allow(dead_code)]
33    #[serde(default)]
34    timestamp: Option<u64>,
35    payload: serde_json::Value,
36}
37
38/// Payload shape: `{ symbol, timestamp, value }`.
39#[derive(serde::Deserialize)]
40struct PricePayload {
41    symbol: String,
42    timestamp: u64,
43    value: f64,
44}
45
46fn topic_for_source(source: CryptoPriceSource) -> &'static str {
47    match source {
48        CryptoPriceSource::Binance => "crypto_prices",
49        CryptoPriceSource::Chainlink => "crypto_prices_chainlink",
50    }
51}
52
53fn source_from_topic(topic: &str) -> Option<CryptoPriceSource> {
54    match topic {
55        "crypto_prices" => Some(CryptoPriceSource::Binance),
56        "crypto_prices_chainlink" => Some(CryptoPriceSource::Chainlink),
57        _ => None,
58    }
59}
60
61fn build_subscribe_msg(source: CryptoPriceSource, symbols: &[String]) -> String {
62    let topic = topic_for_source(source);
63    if symbols.is_empty() {
64        return serde_json::json!({
65            "action": "subscribe",
66            "subscriptions": [{
67                "topic": topic,
68                "type": "*",
69                "filters": "",
70            }]
71        })
72        .to_string();
73    }
74    let subs: Vec<serde_json::Value> = symbols
75        .iter()
76        .map(|sym| {
77            let filter = serde_json::json!({ "symbol": sym }).to_string();
78            serde_json::json!({
79                "topic": topic,
80                "type": "*",
81                "filters": filter,
82            })
83        })
84        .collect();
85    serde_json::json!({
86        "action": "subscribe",
87        "subscriptions": subs,
88    })
89    .to_string()
90}
91
92fn build_unsubscribe_msg(source: CryptoPriceSource, symbols: &[String]) -> String {
93    let topic = topic_for_source(source);
94    if symbols.is_empty() {
95        return serde_json::json!({
96            "action": "unsubscribe",
97            "subscriptions": [{
98                "topic": topic,
99                "type": "*",
100                "filters": "",
101            }]
102        })
103        .to_string();
104    }
105    let subs: Vec<serde_json::Value> = symbols
106        .iter()
107        .map(|sym| {
108            let filter = serde_json::json!({ "symbol": sym }).to_string();
109            serde_json::json!({
110                "topic": topic,
111                "type": "*",
112                "filters": filter,
113            })
114        })
115        .collect();
116    serde_json::json!({
117        "action": "unsubscribe",
118        "subscriptions": subs,
119    })
120    .to_string()
121}
122
123/// Streams real-time crypto prices from a WebSocket feed.
124///
125/// Supports Binance and Chainlink price sources. Requires explicit subscribe/unsubscribe
126/// messages and client-initiated PING every 5 seconds.
127pub struct CryptoPriceWebSocket {
128    state: Arc<AtomicWebSocketState>,
129    sender: broadcast::Sender<Result<CryptoPrice, WebSocketError>>,
130    write_tx: Arc<Mutex<Option<futures::channel::mpsc::UnboundedSender<Message>>>>,
131    shutdown_tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
132    subscriptions: Arc<Mutex<Vec<Subscription>>>,
133}
134
135impl CryptoPriceWebSocket {
136    pub fn new() -> Self {
137        let (sender, _) = broadcast::channel(BROADCAST_CAPACITY);
138        Self {
139            state: Arc::new(AtomicWebSocketState::new(WebSocketState::Disconnected)),
140            sender,
141            write_tx: Arc::new(Mutex::new(None)),
142            shutdown_tx: Arc::new(Mutex::new(None)),
143            subscriptions: Arc::new(Mutex::new(Vec::new())),
144        }
145    }
146
147    pub fn state(&self) -> WebSocketState {
148        self.state.load()
149    }
150
151    pub fn stream(&self) -> CryptoPriceStream {
152        let rx = self.sender.subscribe();
153        Box::pin(
154            tokio_stream::wrappers::BroadcastStream::new(rx)
155                .filter_map(|result| async move { result.ok() }),
156        )
157    }
158
159    /// Subscribe to crypto prices for the given source and symbols.
160    /// Empty symbols subscribes to all available symbols.
161    pub async fn subscribe(
162        &self,
163        source: CryptoPriceSource,
164        symbols: &[String],
165    ) -> Result<(), WebSocketError> {
166        let msg = build_subscribe_msg(source, symbols);
167        let write_tx = self.write_tx.lock().await;
168        if let Some(ref tx) = *write_tx {
169            tx.unbounded_send(Message::Text(msg))
170                .map_err(|e| WebSocketError::Connection(e.to_string()))?;
171        } else {
172            return Err(WebSocketError::Connection("not connected".to_string()));
173        }
174
175        let mut subs = self.subscriptions.lock().await;
176        subs.push(Subscription {
177            source,
178            symbols: symbols.to_vec(),
179        });
180
181        Ok(())
182    }
183
184    /// Unsubscribe from crypto prices for the given source and symbols.
185    pub async fn unsubscribe(
186        &self,
187        source: CryptoPriceSource,
188        symbols: &[String],
189    ) -> Result<(), WebSocketError> {
190        let msg = build_unsubscribe_msg(source, symbols);
191        let write_tx = self.write_tx.lock().await;
192        if let Some(ref tx) = *write_tx {
193            tx.unbounded_send(Message::Text(msg))
194                .map_err(|e| WebSocketError::Connection(e.to_string()))?;
195        }
196
197        let mut subs = self.subscriptions.lock().await;
198        subs.retain(|s| !(s.source == source && s.symbols == symbols));
199
200        Ok(())
201    }
202
203    pub async fn connect(&mut self) -> Result<(), WebSocketError> {
204        self.state.store(WebSocketState::Connecting);
205
206        let (ws_stream, _) = connect_async(CRYPTO_WS_URL)
207            .await
208            .map_err(|e| WebSocketError::Connection(e.to_string()))?;
209
210        let (write, read) = ws_stream.split();
211        let (tx, rx) = futures::channel::mpsc::unbounded::<Message>();
212
213        {
214            let mut write_tx = self.write_tx.lock().await;
215            *write_tx = Some(tx);
216        }
217
218        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
219        {
220            let mut stx = self.shutdown_tx.lock().await;
221            *stx = Some(shutdown_tx);
222        }
223
224        let state = self.state.clone();
225        let sender = self.sender.clone();
226        let write_tx_clone = self.write_tx.clone();
227        let subscriptions = self.subscriptions.clone();
228
229        tokio::spawn(async move {
230            let write_future = rx.map(Ok).forward(write);
231
232            let read_future = {
233                let sender = sender.clone();
234                let write_tx_clone = write_tx_clone.clone();
235                async move {
236                    let mut read = read;
237                    while let Some(msg) = read.next().await {
238                        handle_message(msg, &sender, &write_tx_clone).await;
239                    }
240                }
241            };
242
243            let ping_future = {
244                let write_tx_clone = write_tx_clone.clone();
245                async move {
246                    let mut interval = tokio::time::interval(WS_CRYPTO_PING_INTERVAL);
247                    loop {
248                        interval.tick().await;
249                        let tx = write_tx_clone.lock().await;
250                        if let Some(ref tx) = *tx {
251                            if tx.unbounded_send(Message::Text("PING".into())).is_err() {
252                                break;
253                            }
254                        } else {
255                            break;
256                        }
257                    }
258                }
259            };
260
261            tokio::select! {
262                _ = write_future => {},
263                _ = read_future => {},
264                _ = ping_future => {},
265                _ = shutdown_rx => {},
266            }
267
268            if state.load() == WebSocketState::Closed {
269                return;
270            }
271            state.store(WebSocketState::Disconnected);
272
273            // Auto-reconnect with exponential backoff
274            let mut attempt = 1u32;
275            while attempt <= WS_MAX_RECONNECT_ATTEMPTS {
276                state.store(WebSocketState::Reconnecting);
277
278                let delay = calculate_reconnect_delay(attempt);
279                warn!(
280                    attempt,
281                    delay_ms = delay.as_millis() as u64,
282                    "reconnecting crypto websocket"
283                );
284                tokio::time::sleep(delay).await;
285
286                match connect_async(CRYPTO_WS_URL).await {
287                    Ok((new_ws, _)) => {
288                        let (new_write, new_read) = new_ws.split();
289                        let (new_tx, new_rx) = futures::channel::mpsc::unbounded::<Message>();
290
291                        {
292                            let mut wtx = write_tx_clone.lock().await;
293                            *wtx = Some(new_tx);
294                        }
295
296                        state.store(WebSocketState::Connected);
297                        attempt = 0;
298
299                        // Replay stored subscriptions
300                        {
301                            let subs = subscriptions.lock().await;
302                            let wtx = write_tx_clone.lock().await;
303                            if let Some(ref tx) = *wtx {
304                                for sub in subs.iter() {
305                                    let msg = build_subscribe_msg(sub.source, &sub.symbols);
306                                    let _ = tx.unbounded_send(Message::Text(msg));
307                                }
308                            }
309                        }
310
311                        let sender_clone = sender.clone();
312                        let wtx_clone = write_tx_clone.clone();
313
314                        let write_future = new_rx.map(Ok).forward(new_write);
315
316                        let read_future = {
317                            let sender = sender_clone;
318                            let write_tx = wtx_clone.clone();
319                            async move {
320                                let mut read = new_read;
321                                while let Some(msg) = read.next().await {
322                                    handle_message(msg, &sender, &write_tx).await;
323                                }
324                            }
325                        };
326
327                        let ping_future = {
328                            let write_tx = wtx_clone;
329                            async move {
330                                let mut interval = tokio::time::interval(WS_CRYPTO_PING_INTERVAL);
331                                loop {
332                                    interval.tick().await;
333                                    let tx = write_tx.lock().await;
334                                    if let Some(ref tx) = *tx {
335                                        if tx.unbounded_send(Message::Text("PING".into())).is_err()
336                                        {
337                                            break;
338                                        }
339                                    } else {
340                                        break;
341                                    }
342                                }
343                            }
344                        };
345
346                        tokio::select! {
347                            _ = write_future => {},
348                            _ = read_future => {},
349                            _ = ping_future => {},
350                        }
351
352                        if state.load() == WebSocketState::Closed {
353                            return;
354                        }
355
356                        attempt += 1;
357                    }
358                    Err(_) => {
359                        attempt += 1;
360                    }
361                }
362            }
363
364            state.store(WebSocketState::Disconnected);
365        });
366
367        self.state.store(WebSocketState::Connected);
368        Ok(())
369    }
370
371    pub async fn disconnect(&mut self) -> Result<(), WebSocketError> {
372        self.state.store(WebSocketState::Closed);
373        if let Some(tx) = self.shutdown_tx.lock().await.take() {
374            let _ = tx.send(());
375        }
376        Ok(())
377    }
378}
379
380impl Default for CryptoPriceWebSocket {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386async fn handle_message(
387    msg: Result<Message, tokio_tungstenite::tungstenite::Error>,
388    sender: &broadcast::Sender<Result<CryptoPrice, WebSocketError>>,
389    write_tx: &Arc<Mutex<Option<futures::channel::mpsc::UnboundedSender<Message>>>>,
390) {
391    match msg {
392        Ok(Message::Text(text)) => {
393            // Ignore PONG responses
394            if text == "PONG" {
395                return;
396            }
397
398            let envelope: Envelope = match serde_json::from_str(&text) {
399                Ok(e) => e,
400                Err(e) => {
401                    debug!(raw = %text, error = %e, "skipping non-envelope message");
402                    return;
403                }
404            };
405
406            let source = match source_from_topic(&envelope.topic) {
407                Some(s) => s,
408                None => {
409                    debug!(topic = %envelope.topic, "skipping unknown topic");
410                    return;
411                }
412            };
413
414            let payload: PricePayload = match serde_json::from_value(envelope.payload) {
415                Ok(p) => p,
416                Err(e) => {
417                    debug!(error = %e, "skipping malformed price payload");
418                    return;
419                }
420            };
421
422            let price = CryptoPrice {
423                symbol: payload.symbol,
424                timestamp: payload.timestamp,
425                value: payload.value,
426                source,
427            };
428
429            let _ = sender.send(Ok(price));
430        }
431        Ok(Message::Ping(data)) => {
432            if let Some(ref tx) = *write_tx.lock().await {
433                let _ = tx.unbounded_send(Message::Pong(data));
434            }
435        }
436        Ok(Message::Close(_)) | Err(_) => {}
437        _ => {}
438    }
439}
440
441fn calculate_reconnect_delay(attempt: u32) -> std::time::Duration {
442    let delay = WS_RECONNECT_BASE_DELAY.as_millis() as f64 * 1.5_f64.powi(attempt as i32);
443    let delay = delay.min(WS_RECONNECT_MAX_DELAY.as_millis() as f64) as u64;
444    std::time::Duration::from_millis(delay)
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use serde_json::json;
451
452    #[test]
453    fn deserialize_binance_envelope() {
454        let data = json!({
455            "topic": "crypto_prices",
456            "type": "update",
457            "timestamp": 1700000000,
458            "payload": {
459                "symbol": "btcusdt",
460                "timestamp": 1700000000u64,
461                "value": 43250.5
462            }
463        });
464
465        let envelope: Envelope = serde_json::from_value(data).expect("should deserialize");
466        assert_eq!(envelope.topic, "crypto_prices");
467
468        let source = source_from_topic(&envelope.topic).unwrap();
469        assert_eq!(source, CryptoPriceSource::Binance);
470
471        let payload: PricePayload =
472            serde_json::from_value(envelope.payload).expect("should deserialize payload");
473        assert_eq!(payload.symbol, "btcusdt");
474        assert_eq!(payload.timestamp, 1700000000);
475        assert!((payload.value - 43250.5).abs() < f64::EPSILON);
476    }
477
478    #[test]
479    fn deserialize_chainlink_envelope() {
480        let data = json!({
481            "topic": "crypto_prices_chainlink",
482            "type": "update",
483            "timestamp": 1700000001,
484            "payload": {
485                "symbol": "eth/usd",
486                "timestamp": 1700000001u64,
487                "value": 2250.75
488            }
489        });
490
491        let envelope: Envelope = serde_json::from_value(data).expect("should deserialize");
492        assert_eq!(envelope.topic, "crypto_prices_chainlink");
493
494        let source = source_from_topic(&envelope.topic).unwrap();
495        assert_eq!(source, CryptoPriceSource::Chainlink);
496
497        let payload: PricePayload =
498            serde_json::from_value(envelope.payload).expect("should deserialize payload");
499        assert_eq!(payload.symbol, "eth/usd");
500        assert!((payload.value - 2250.75).abs() < f64::EPSILON);
501    }
502
503    #[test]
504    fn serialize_binance_subscribe() {
505        let msg = build_subscribe_msg(
506            CryptoPriceSource::Binance,
507            &["btcusdt".into(), "ethusdt".into()],
508        );
509        let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
510        assert_eq!(parsed["action"], "subscribe");
511        // One subscription entry per symbol
512        assert_eq!(parsed["subscriptions"].as_array().unwrap().len(), 2);
513        assert_eq!(parsed["subscriptions"][0]["topic"], "crypto_prices");
514        assert_eq!(parsed["subscriptions"][0]["type"], "*");
515        let f0: serde_json::Value =
516            serde_json::from_str(parsed["subscriptions"][0]["filters"].as_str().unwrap())
517                .expect("filters should be valid JSON");
518        assert_eq!(f0["symbol"], "btcusdt");
519        let f1: serde_json::Value =
520            serde_json::from_str(parsed["subscriptions"][1]["filters"].as_str().unwrap())
521                .expect("filters should be valid JSON");
522        assert_eq!(f1["symbol"], "ethusdt");
523    }
524
525    #[test]
526    fn serialize_chainlink_subscribe() {
527        let msg = build_subscribe_msg(CryptoPriceSource::Chainlink, &["eth/usd".into()]);
528        let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
529        assert_eq!(parsed["action"], "subscribe");
530        assert_eq!(
531            parsed["subscriptions"][0]["topic"],
532            "crypto_prices_chainlink"
533        );
534        assert_eq!(parsed["subscriptions"][0]["type"], "*");
535        let filters: serde_json::Value =
536            serde_json::from_str(parsed["subscriptions"][0]["filters"].as_str().unwrap())
537                .expect("filters should be valid JSON");
538        assert_eq!(filters["symbol"], "eth/usd");
539    }
540
541    #[test]
542    fn serialize_binance_subscribe_all() {
543        let msg = build_subscribe_msg(CryptoPriceSource::Binance, &[]);
544        let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
545        assert_eq!(parsed["subscriptions"][0]["type"], "*");
546        assert_eq!(parsed["subscriptions"][0]["filters"], "");
547    }
548
549    #[test]
550    fn serialize_unsubscribe() {
551        let msg = build_unsubscribe_msg(CryptoPriceSource::Binance, &["btcusdt".into()]);
552        let parsed: serde_json::Value = serde_json::from_str(&msg).expect("valid JSON");
553        assert_eq!(parsed["action"], "unsubscribe");
554        assert_eq!(parsed["subscriptions"][0]["topic"], "crypto_prices");
555        let filters: serde_json::Value =
556            serde_json::from_str(parsed["subscriptions"][0]["filters"].as_str().unwrap())
557                .expect("filters should be valid JSON");
558        assert_eq!(filters["symbol"], "btcusdt");
559    }
560
561    #[test]
562    fn ping_is_not_valid_price() {
563        let result = serde_json::from_str::<Envelope>("PING");
564        assert!(result.is_err());
565    }
566
567    #[test]
568    fn unknown_topic_returns_none() {
569        assert!(source_from_topic("unknown_topic").is_none());
570    }
571
572    #[test]
573    fn topic_round_trip() {
574        assert_eq!(
575            source_from_topic(topic_for_source(CryptoPriceSource::Binance)),
576            Some(CryptoPriceSource::Binance)
577        );
578        assert_eq!(
579            source_from_topic(topic_for_source(CryptoPriceSource::Chainlink)),
580            Some(CryptoPriceSource::Chainlink)
581        );
582    }
583}