Skip to main content

sandbox_quant/binance/
ws.rs

1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use std::time::Duration;
4use tokio::sync::{mpsc, watch};
5use tokio_tungstenite::tungstenite;
6use tungstenite::error::{Error as WsError, ProtocolError, UrlError};
7use tungstenite::protocol::frame::coding::CloseCode;
8
9use super::types::BinanceTradeEvent;
10use crate::event::{AppEvent, LogDomain, LogLevel, LogRecord, WsConnectionStatus};
11use crate::model::tick::Tick;
12
13/// Exponential backoff for reconnection.
14struct ExponentialBackoff {
15    current: Duration,
16    initial: Duration,
17    max: Duration,
18    factor: f64,
19}
20
21impl ExponentialBackoff {
22    fn new(initial: Duration, max: Duration, factor: f64) -> Self {
23        Self {
24            current: initial,
25            initial,
26            max,
27            factor,
28        }
29    }
30
31    fn next_delay(&mut self) -> Duration {
32        let delay = self.current;
33        self.current = Duration::from_secs_f64(
34            (self.current.as_secs_f64() * self.factor).min(self.max.as_secs_f64()),
35        );
36        delay
37    }
38
39    fn reset(&mut self) {
40        self.current = self.initial;
41    }
42}
43
44#[derive(Clone)]
45pub struct BinanceWsClient {
46    spot_url: String,
47    futures_url: String,
48}
49
50impl BinanceWsClient {
51    /// Create a new WebSocket client.
52    ///
53    /// `ws_base_url` — e.g. `wss://stream.testnet.binance.vision/ws`
54    pub fn new(ws_base_url: &str, futures_ws_base_url: &str) -> Self {
55        Self {
56            spot_url: ws_base_url.to_string(),
57            futures_url: futures_ws_base_url.to_string(),
58        }
59    }
60
61    /// Connect and run the WebSocket loop with automatic reconnection.
62    /// Sends WsStatus events through `status_tx` and ticks through `tick_tx`.
63    pub async fn connect_and_run(
64        &self,
65        tick_tx: mpsc::Sender<Tick>,
66        status_tx: mpsc::Sender<AppEvent>,
67        mut symbol_rx: watch::Receiver<String>,
68        mut shutdown: watch::Receiver<bool>,
69    ) -> Result<()> {
70        let mut backoff =
71            ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60), 2.0);
72        let mut attempt: u32 = 0;
73
74        loop {
75            attempt += 1;
76            let instrument = symbol_rx.borrow().clone();
77            let (symbol, is_futures) = parse_instrument_symbol(&instrument);
78            let streams = vec![format!("{}@trade", symbol.to_lowercase())];
79            let ws_url = if is_futures {
80                &self.futures_url
81            } else {
82                &self.spot_url
83            };
84            match self
85                .connect_once(
86                    ws_url,
87                    &streams,
88                    &instrument,
89                    &tick_tx,
90                    &status_tx,
91                    &mut symbol_rx,
92                    &mut shutdown,
93                )
94                .await
95            {
96                Ok(()) => {
97                    // Clean shutdown requested
98                    let _ = status_tx
99                        .send(AppEvent::WsStatus(WsConnectionStatus::Disconnected))
100                        .await;
101                    break;
102                }
103                Err(e) => {
104                    let _ = status_tx
105                        .send(AppEvent::WsStatus(WsConnectionStatus::Disconnected))
106                        .await;
107                    tracing::warn!(attempt, error = %e, "WS connection attempt failed");
108                    let _ = status_tx
109                        .send(AppEvent::LogRecord(ws_log(
110                            LogLevel::Warn,
111                            "connect.fail",
112                            &instrument,
113                            format!("attempt={} error={}", attempt, e),
114                        )))
115                        .await;
116
117                    let delay = backoff.next_delay();
118                    let _ = status_tx
119                        .send(AppEvent::WsStatus(WsConnectionStatus::Reconnecting {
120                            attempt,
121                            delay_ms: delay.as_millis() as u64,
122                        }))
123                        .await;
124
125                    tokio::select! {
126                        _ = tokio::time::sleep(delay) => continue,
127                        _ = shutdown.changed() => {
128                            let _ = status_tx
129                                .send(AppEvent::LogRecord(ws_log(
130                                    LogLevel::Info,
131                                    "shutdown.during_reconnect",
132                                    &instrument,
133                                    "shutdown signal received during reconnect wait".to_string(),
134                                )))
135                                .await;
136                            break;
137                        }
138                    }
139                }
140            }
141        }
142        Ok(())
143    }
144
145    async fn connect_once(
146        &self,
147        ws_url: &str,
148        streams: &[String],
149        display_symbol: &str,
150        tick_tx: &mpsc::Sender<Tick>,
151        status_tx: &mpsc::Sender<AppEvent>,
152        symbol_rx: &mut watch::Receiver<String>,
153        shutdown: &mut watch::Receiver<bool>,
154    ) -> Result<()> {
155        let _ = status_tx
156            .send(AppEvent::LogRecord(ws_log(
157                LogLevel::Info,
158                "connect.start",
159                display_symbol,
160                format!("url={}", ws_url),
161            )))
162            .await;
163
164        let (ws_stream, resp) = tokio_tungstenite::connect_async(ws_url)
165            .await
166            .map_err(|e| {
167                let detail = format_ws_error(&e);
168                let _ = status_tx.try_send(AppEvent::LogRecord(ws_log(
169                    LogLevel::Warn,
170                    "connect.detail",
171                    display_symbol,
172                    detail.clone(),
173                )));
174                anyhow::anyhow!("WebSocket connect failed: {}", detail)
175            })?;
176
177        tracing::debug!(status = %resp.status(), "WebSocket HTTP upgrade response");
178
179        let (mut write, mut read) = ws_stream.split();
180
181        // Send SUBSCRIBE message per Binance WebSocket API spec
182        let subscribe_msg = serde_json::json!({
183            "method": "SUBSCRIBE",
184            "params": streams,
185            "id": 1
186        });
187        write
188            .send(tungstenite::Message::Text(subscribe_msg.to_string()))
189            .await
190            .map_err(|e| {
191                let detail = format_ws_error(&e);
192                anyhow::anyhow!("Failed to send SUBSCRIBE: {}", detail)
193            })?;
194
195        let _ = status_tx
196            .send(AppEvent::LogRecord(ws_log(
197                LogLevel::Info,
198                "subscribe.ok",
199                display_symbol,
200                format!("streams={}", streams.join(",")),
201            )))
202            .await;
203
204        // Send Connected AFTER successful subscription
205        let _ = status_tx
206            .send(AppEvent::WsStatus(WsConnectionStatus::Connected))
207            .await;
208
209        loop {
210            tokio::select! {
211                msg = read.next() => {
212                    match msg {
213                        Some(Ok(tungstenite::Message::Text(text))) => {
214                            self.handle_text_message(&text, display_symbol, tick_tx, status_tx).await;
215                        }
216                        Some(Ok(tungstenite::Message::Ping(_))) => {
217                            // tokio-tungstenite handles pong automatically
218                        }
219                        Some(Ok(tungstenite::Message::Close(frame))) => {
220                            let detail = match &frame {
221                                Some(cf) => format!(
222                                    "Server closed: code={} reason=\"{}\"",
223                                    format_close_code(&cf.code),
224                                    cf.reason
225                                ),
226                                None => "Server closed: no close frame".to_string(),
227                            };
228                            let _ = status_tx
229                                .send(AppEvent::LogRecord(ws_log(
230                                    LogLevel::Warn,
231                                    "server.closed",
232                                    display_symbol,
233                                    detail.clone(),
234                                )))
235                                .await;
236                            return Err(anyhow::anyhow!("{}", detail));
237                        }
238                        Some(Ok(other)) => {
239                            tracing::trace!(msg_type = ?other, "Unhandled WS message type");
240                        }
241                        Some(Err(e)) => {
242                            let detail = format_ws_error(&e);
243                            let _ = status_tx
244                                .send(AppEvent::LogRecord(ws_log(
245                                    LogLevel::Warn,
246                                    "read.error",
247                                    display_symbol,
248                                    detail.clone(),
249                                )))
250                                .await;
251                            return Err(anyhow::anyhow!("WebSocket read error: {}", detail));
252                        }
253                        None => {
254                            return Err(anyhow::anyhow!(
255                                "WebSocket stream ended unexpectedly (connection dropped)"
256                            ));
257                        }
258                    }
259                }
260                _ = shutdown.changed() => {
261                    // Send UNSUBSCRIBE before closing
262                    let unsub_msg = serde_json::json!({
263                        "method": "UNSUBSCRIBE",
264                        "params": streams,
265                        "id": 2
266                    });
267                    let _ = write
268                        .send(tungstenite::Message::Text(unsub_msg.to_string()))
269                        .await;
270                    let _ = write.send(tungstenite::Message::Close(None)).await;
271                    return Ok(());
272                }
273                _ = symbol_rx.changed() => {
274                    let _ = write.send(tungstenite::Message::Close(None)).await;
275                    // In multi-worker mode, symbol channel closure means this worker is being retired.
276                    let _ = status_tx
277                        .send(AppEvent::LogRecord(ws_log(
278                            LogLevel::Info,
279                            "worker.retired",
280                            display_symbol,
281                            "symbol channel closed".to_string(),
282                        )))
283                        .await;
284                    return Ok(());
285                }
286            }
287        }
288    }
289
290    async fn handle_text_message(
291        &self,
292        text: &str,
293        display_symbol: &str,
294        tick_tx: &mpsc::Sender<Tick>,
295        status_tx: &mpsc::Sender<AppEvent>,
296    ) {
297        // Skip subscription confirmation responses like {"result":null,"id":1}
298        if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
299            if val.get("result").is_some() && val.get("id").is_some() {
300                tracing::debug!(id = %val["id"], "Subscription response received");
301                return;
302            }
303        }
304
305        match serde_json::from_str::<BinanceTradeEvent>(text) {
306            Ok(event) => {
307                let tick = Tick {
308                    symbol: display_symbol.to_string(),
309                    price: event.price,
310                    qty: event.qty,
311                    timestamp_ms: event.event_time,
312                    is_buyer_maker: event.is_buyer_maker,
313                    trade_id: event.trade_id,
314                };
315                if tick_tx.try_send(tick).is_err() {
316                    tracing::warn!("Tick channel full, dropping tick");
317                    let _ = status_tx.try_send(AppEvent::TickDropped);
318                }
319            }
320            Err(e) => {
321                tracing::debug!(error = %e, raw = %text, "Failed to parse WS message");
322                let _ = status_tx
323                    .send(AppEvent::LogRecord(ws_log(
324                        LogLevel::Debug,
325                        "parse.skip",
326                        display_symbol,
327                        format!("payload={}", &text[..text.len().min(80)]),
328                    )))
329                    .await;
330            }
331        }
332    }
333}
334
335fn parse_instrument_symbol(instrument: &str) -> (String, bool) {
336    let trimmed = instrument.trim();
337    if let Some(symbol) = trimmed.strip_suffix(" (FUT)") {
338        return (symbol.to_ascii_uppercase(), true);
339    }
340    (trimmed.to_ascii_uppercase(), false)
341}
342
343fn ws_log(level: LogLevel, event: &'static str, symbol: &str, msg: String) -> LogRecord {
344    let mut record = LogRecord::new(level, LogDomain::Ws, event, msg);
345    record.symbol = Some(symbol.to_string());
346    record
347}
348
349/// Format a tungstenite WebSocket error into a detailed, human-readable string.
350fn format_ws_error(err: &WsError) -> String {
351    match err {
352        WsError::ConnectionClosed => "Connection closed normally".to_string(),
353        WsError::AlreadyClosed => "Attempted operation on already-closed connection".to_string(),
354        WsError::Io(io_err) => {
355            format!("IO error [kind={}]: {}", io_err.kind(), io_err)
356        }
357        WsError::Tls(tls_err) => format!("TLS error: {}", tls_err),
358        WsError::Capacity(cap_err) => format!("Capacity error: {}", cap_err),
359        WsError::Protocol(proto_err) => {
360            let detail = match proto_err {
361                ProtocolError::ResetWithoutClosingHandshake => {
362                    "connection reset without closing handshake (server may have dropped)"
363                }
364                ProtocolError::SendAfterClosing => "tried to send after close frame",
365                ProtocolError::ReceivedAfterClosing => "received data after close frame",
366                ProtocolError::HandshakeIncomplete => "handshake incomplete",
367                _ => "",
368            };
369            if detail.is_empty() {
370                format!("Protocol error: {}", proto_err)
371            } else {
372                format!("Protocol error: {} ({})", proto_err, detail)
373            }
374        }
375        WsError::WriteBufferFull(_) => "Write buffer full (backpressure)".to_string(),
376        WsError::Utf8 => "UTF-8 encoding error in frame data".to_string(),
377        WsError::AttackAttempt => "Attack attempt detected by WebSocket library".to_string(),
378        WsError::Url(url_err) => {
379            let hint = match url_err {
380                UrlError::TlsFeatureNotEnabled => "TLS feature not compiled in",
381                UrlError::NoHostName => "no host name in URL",
382                UrlError::UnableToConnect(addr) => {
383                    return format!(
384                        "URL error: unable to connect to {} (DNS/network failure?)",
385                        addr
386                    );
387                }
388                UrlError::UnsupportedUrlScheme => "only ws:// or wss:// are supported",
389                UrlError::EmptyHostName => "empty host name in URL",
390                UrlError::NoPathOrQuery => "no path/query in URL",
391            };
392            format!("URL error: {} — {}", url_err, hint)
393        }
394        WsError::Http(resp) => {
395            let status = resp.status();
396            let body_preview = resp
397                .body()
398                .as_ref()
399                .and_then(|b| std::str::from_utf8(b).ok())
400                .unwrap_or("")
401                .chars()
402                .take(200)
403                .collect::<String>();
404            format!(
405                "HTTP error: status={} ({}), body=\"{}\"",
406                status.as_u16(),
407                status.canonical_reason().unwrap_or("unknown"),
408                body_preview
409            )
410        }
411        WsError::HttpFormat(e) => format!("HTTP format error: {}", e),
412    }
413}
414
415/// Format a WebSocket close code into a readable string with numeric value.
416fn format_close_code(code: &CloseCode) -> String {
417    let (num, label) = match code {
418        CloseCode::Normal => (1000, "Normal"),
419        CloseCode::Away => (1001, "Going Away"),
420        CloseCode::Protocol => (1002, "Protocol Error"),
421        CloseCode::Unsupported => (1003, "Unsupported Data"),
422        CloseCode::Status => (1005, "No Status"),
423        CloseCode::Abnormal => (1006, "Abnormal Closure"),
424        CloseCode::Invalid => (1007, "Invalid Payload"),
425        CloseCode::Policy => (1008, "Policy Violation"),
426        CloseCode::Size => (1009, "Message Too Big"),
427        CloseCode::Extension => (1010, "Extension Required"),
428        CloseCode::Error => (1011, "Internal Error"),
429        CloseCode::Restart => (1012, "Service Restart"),
430        CloseCode::Again => (1013, "Try Again Later"),
431        CloseCode::Tls => (1015, "TLS Handshake Failure"),
432        CloseCode::Reserved(n) => (*n, "Reserved"),
433        CloseCode::Iana(n) => (*n, "IANA"),
434        CloseCode::Library(n) => (*n, "Library"),
435        CloseCode::Bad(n) => (*n, "Bad"),
436    };
437    format!("{} ({})", num, label)
438}