Skip to main content

sandbox_quant/binance/
ws.rs

1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use std::time::Duration;
4use tokio::sync::{mpsc, watch};
5use tokio_tungstenite::tungstenite;
6use tungstenite::error::{Error as WsError, ProtocolError, UrlError};
7use tungstenite::protocol::frame::coding::CloseCode;
8
9use super::types::BinanceTradeEvent;
10use crate::event::{AppEvent, LogDomain, LogLevel, LogRecord, WsConnectionStatus};
11use crate::model::tick::Tick;
12
13/// Exponential backoff for reconnection.
14struct ExponentialBackoff {
15    current: Duration,
16    initial: Duration,
17    max: Duration,
18    factor: f64,
19}
20
21impl ExponentialBackoff {
22    fn new(initial: Duration, max: Duration, factor: f64) -> Self {
23        Self {
24            current: initial,
25            initial,
26            max,
27            factor,
28        }
29    }
30
31    fn next_delay(&mut self) -> Duration {
32        let delay = self.current;
33        self.current = Duration::from_secs_f64(
34            (self.current.as_secs_f64() * self.factor).min(self.max.as_secs_f64()),
35        );
36        delay
37    }
38
39    fn reset(&mut self) {
40        self.current = self.initial;
41    }
42}
43
44#[derive(Clone)]
45pub struct BinanceWsClient {
46    spot_url: String,
47    futures_url: String,
48}
49
50impl BinanceWsClient {
51    /// Create a new WebSocket client.
52    ///
53    /// `ws_base_url` — e.g. `wss://stream.testnet.binance.vision/ws`
54    pub fn new(ws_base_url: &str, futures_ws_base_url: &str) -> Self {
55        Self {
56            spot_url: ws_base_url.to_string(),
57            futures_url: futures_ws_base_url.to_string(),
58        }
59    }
60
61    /// Connect and run the WebSocket loop with automatic reconnection.
62    /// Sends WsStatus events through `status_tx` and ticks through `tick_tx`.
63    pub async fn connect_and_run(
64        &self,
65        tick_tx: mpsc::Sender<Tick>,
66        status_tx: mpsc::Sender<AppEvent>,
67        mut symbol_rx: watch::Receiver<String>,
68        mut shutdown: watch::Receiver<bool>,
69    ) -> Result<()> {
70        let mut backoff =
71            ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60), 2.0);
72        let mut attempt: u32 = 0;
73
74        loop {
75            attempt += 1;
76            let instrument = symbol_rx.borrow().clone();
77            let (symbol, is_futures) = parse_instrument_symbol(&instrument);
78            let streams = vec![format!("{}@trade", symbol.to_lowercase())];
79            let ws_url = if is_futures {
80                &self.futures_url
81            } else {
82                &self.spot_url
83            };
84            match self
85                .connect_once(
86                    ws_url,
87                    &streams,
88                    &instrument,
89                    &tick_tx,
90                    &status_tx,
91                    &mut symbol_rx,
92                    &mut shutdown,
93                )
94                .await
95            {
96                Ok(()) => {
97                    // Clean shutdown requested
98                    let _ = status_tx
99                        .send(AppEvent::WsStatus(WsConnectionStatus::Disconnected))
100                        .await;
101                    break;
102                }
103                Err(e) => {
104                    let _ = status_tx
105                        .send(AppEvent::WsStatus(WsConnectionStatus::Disconnected))
106                        .await;
107                    tracing::warn!(attempt, error = %e, "WS connection attempt failed");
108                    let _ = status_tx
109                        .send(AppEvent::LogRecord(
110                            ws_log(
111                                LogLevel::Warn,
112                                "connect.fail",
113                                &instrument,
114                                format!("attempt={} error={}", attempt, e),
115                            ),
116                        ))
117                        .await;
118
119                    let delay = backoff.next_delay();
120                    let _ = status_tx
121                        .send(AppEvent::WsStatus(WsConnectionStatus::Reconnecting {
122                            attempt,
123                            delay_ms: delay.as_millis() as u64,
124                        }))
125                        .await;
126
127                    tokio::select! {
128                        _ = tokio::time::sleep(delay) => continue,
129                        _ = shutdown.changed() => {
130                            let _ = status_tx
131                                .send(AppEvent::LogRecord(ws_log(
132                                    LogLevel::Info,
133                                    "shutdown.during_reconnect",
134                                    &instrument,
135                                    "shutdown signal received during reconnect wait".to_string(),
136                                )))
137                                .await;
138                            break;
139                        }
140                    }
141                }
142            }
143        }
144        Ok(())
145    }
146
147    async fn connect_once(
148        &self,
149        ws_url: &str,
150        streams: &[String],
151        display_symbol: &str,
152        tick_tx: &mpsc::Sender<Tick>,
153        status_tx: &mpsc::Sender<AppEvent>,
154        symbol_rx: &mut watch::Receiver<String>,
155        shutdown: &mut watch::Receiver<bool>,
156    ) -> Result<()> {
157        let _ = status_tx
158            .send(AppEvent::LogRecord(ws_log(
159                LogLevel::Info,
160                "connect.start",
161                display_symbol,
162                format!("url={}", ws_url),
163            )))
164            .await;
165
166        let (ws_stream, resp) = tokio_tungstenite::connect_async(ws_url)
167            .await
168            .map_err(|e| {
169                let detail = format_ws_error(&e);
170                let _ = status_tx.try_send(AppEvent::LogRecord(ws_log(
171                    LogLevel::Warn,
172                    "connect.detail",
173                    display_symbol,
174                    detail.clone(),
175                )));
176                anyhow::anyhow!("WebSocket connect failed: {}", detail)
177            })?;
178
179        tracing::debug!(status = %resp.status(), "WebSocket HTTP upgrade response");
180
181        let (mut write, mut read) = ws_stream.split();
182
183        // Send SUBSCRIBE message per Binance WebSocket API spec
184        let subscribe_msg = serde_json::json!({
185            "method": "SUBSCRIBE",
186            "params": streams,
187            "id": 1
188        });
189        write
190            .send(tungstenite::Message::Text(subscribe_msg.to_string()))
191            .await
192            .map_err(|e| {
193                let detail = format_ws_error(&e);
194                anyhow::anyhow!("Failed to send SUBSCRIBE: {}", detail)
195            })?;
196
197        let _ = status_tx
198            .send(AppEvent::LogRecord(ws_log(
199                LogLevel::Info,
200                "subscribe.ok",
201                display_symbol,
202                format!("streams={}", streams.join(",")),
203            )))
204            .await;
205
206        // Send Connected AFTER successful subscription
207        let _ = status_tx
208            .send(AppEvent::WsStatus(WsConnectionStatus::Connected))
209            .await;
210
211        loop {
212            tokio::select! {
213                msg = read.next() => {
214                    match msg {
215                        Some(Ok(tungstenite::Message::Text(text))) => {
216                            self.handle_text_message(&text, display_symbol, tick_tx, status_tx).await;
217                        }
218                        Some(Ok(tungstenite::Message::Ping(_))) => {
219                            // tokio-tungstenite handles pong automatically
220                        }
221                        Some(Ok(tungstenite::Message::Close(frame))) => {
222                            let detail = match &frame {
223                                Some(cf) => format!(
224                                    "Server closed: code={} reason=\"{}\"",
225                                    format_close_code(&cf.code),
226                                    cf.reason
227                                ),
228                                None => "Server closed: no close frame".to_string(),
229                            };
230                            let _ = status_tx
231                                .send(AppEvent::LogRecord(ws_log(
232                                    LogLevel::Warn,
233                                    "server.closed",
234                                    display_symbol,
235                                    detail.clone(),
236                                )))
237                                .await;
238                            return Err(anyhow::anyhow!("{}", detail));
239                        }
240                        Some(Ok(other)) => {
241                            tracing::trace!(msg_type = ?other, "Unhandled WS message type");
242                        }
243                        Some(Err(e)) => {
244                            let detail = format_ws_error(&e);
245                            let _ = status_tx
246                                .send(AppEvent::LogRecord(ws_log(
247                                    LogLevel::Warn,
248                                    "read.error",
249                                    display_symbol,
250                                    detail.clone(),
251                                )))
252                                .await;
253                            return Err(anyhow::anyhow!("WebSocket read error: {}", detail));
254                        }
255                        None => {
256                            return Err(anyhow::anyhow!(
257                                "WebSocket stream ended unexpectedly (connection dropped)"
258                            ));
259                        }
260                    }
261                }
262                _ = shutdown.changed() => {
263                    // Send UNSUBSCRIBE before closing
264                    let unsub_msg = serde_json::json!({
265                        "method": "UNSUBSCRIBE",
266                        "params": streams,
267                        "id": 2
268                    });
269                    let _ = write
270                        .send(tungstenite::Message::Text(unsub_msg.to_string()))
271                        .await;
272                    let _ = write.send(tungstenite::Message::Close(None)).await;
273                    return Ok(());
274                }
275                _ = symbol_rx.changed() => {
276                    let _ = write.send(tungstenite::Message::Close(None)).await;
277                    // In multi-worker mode, symbol channel closure means this worker is being retired.
278                    let _ = status_tx
279                        .send(AppEvent::LogRecord(ws_log(
280                            LogLevel::Info,
281                            "worker.retired",
282                            display_symbol,
283                            "symbol channel closed".to_string(),
284                        )))
285                        .await;
286                    return Ok(());
287                }
288            }
289        }
290    }
291
292    async fn handle_text_message(
293        &self,
294        text: &str,
295        display_symbol: &str,
296        tick_tx: &mpsc::Sender<Tick>,
297        status_tx: &mpsc::Sender<AppEvent>,
298    ) {
299        // Skip subscription confirmation responses like {"result":null,"id":1}
300        if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
301            if val.get("result").is_some() && val.get("id").is_some() {
302                tracing::debug!(id = %val["id"], "Subscription response received");
303                return;
304            }
305        }
306
307        match serde_json::from_str::<BinanceTradeEvent>(text) {
308            Ok(event) => {
309                let tick = Tick {
310                    symbol: display_symbol.to_string(),
311                    price: event.price,
312                    qty: event.qty,
313                    timestamp_ms: event.event_time,
314                    is_buyer_maker: event.is_buyer_maker,
315                    trade_id: event.trade_id,
316                };
317                if tick_tx.try_send(tick).is_err() {
318                    tracing::warn!("Tick channel full, dropping tick");
319                    let _ = status_tx.try_send(AppEvent::TickDropped);
320                }
321            }
322            Err(e) => {
323                tracing::debug!(error = %e, raw = %text, "Failed to parse WS message");
324                let _ = status_tx
325                    .send(AppEvent::LogRecord(ws_log(
326                        LogLevel::Debug,
327                        "parse.skip",
328                        display_symbol,
329                        format!("payload={}", &text[..text.len().min(80)]),
330                    )))
331                    .await;
332            }
333        }
334    }
335}
336
337fn parse_instrument_symbol(instrument: &str) -> (String, bool) {
338    let trimmed = instrument.trim();
339    if let Some(symbol) = trimmed.strip_suffix(" (FUT)") {
340        return (symbol.to_ascii_uppercase(), true);
341    }
342    (trimmed.to_ascii_uppercase(), false)
343}
344
345fn ws_log(level: LogLevel, event: &'static str, symbol: &str, msg: String) -> LogRecord {
346    let mut record = LogRecord::new(level, LogDomain::Ws, event, msg);
347    record.symbol = Some(symbol.to_string());
348    record
349}
350
351/// Format a tungstenite WebSocket error into a detailed, human-readable string.
352fn format_ws_error(err: &WsError) -> String {
353    match err {
354        WsError::ConnectionClosed => "Connection closed normally".to_string(),
355        WsError::AlreadyClosed => "Attempted operation on already-closed connection".to_string(),
356        WsError::Io(io_err) => {
357            format!("IO error [kind={}]: {}", io_err.kind(), io_err)
358        }
359        WsError::Tls(tls_err) => format!("TLS error: {}", tls_err),
360        WsError::Capacity(cap_err) => format!("Capacity error: {}", cap_err),
361        WsError::Protocol(proto_err) => {
362            let detail = match proto_err {
363                ProtocolError::ResetWithoutClosingHandshake => {
364                    "connection reset without closing handshake (server may have dropped)"
365                }
366                ProtocolError::SendAfterClosing => "tried to send after close frame",
367                ProtocolError::ReceivedAfterClosing => "received data after close frame",
368                ProtocolError::HandshakeIncomplete => "handshake incomplete",
369                _ => "",
370            };
371            if detail.is_empty() {
372                format!("Protocol error: {}", proto_err)
373            } else {
374                format!("Protocol error: {} ({})", proto_err, detail)
375            }
376        }
377        WsError::WriteBufferFull(_) => "Write buffer full (backpressure)".to_string(),
378        WsError::Utf8 => "UTF-8 encoding error in frame data".to_string(),
379        WsError::AttackAttempt => "Attack attempt detected by WebSocket library".to_string(),
380        WsError::Url(url_err) => {
381            let hint = match url_err {
382                UrlError::TlsFeatureNotEnabled => "TLS feature not compiled in",
383                UrlError::NoHostName => "no host name in URL",
384                UrlError::UnableToConnect(addr) => {
385                    return format!(
386                        "URL error: unable to connect to {} (DNS/network failure?)",
387                        addr
388                    );
389                }
390                UrlError::UnsupportedUrlScheme => "only ws:// or wss:// are supported",
391                UrlError::EmptyHostName => "empty host name in URL",
392                UrlError::NoPathOrQuery => "no path/query in URL",
393            };
394            format!("URL error: {} — {}", url_err, hint)
395        }
396        WsError::Http(resp) => {
397            let status = resp.status();
398            let body_preview = resp
399                .body()
400                .as_ref()
401                .and_then(|b| std::str::from_utf8(b).ok())
402                .unwrap_or("")
403                .chars()
404                .take(200)
405                .collect::<String>();
406            format!(
407                "HTTP error: status={} ({}), body=\"{}\"",
408                status.as_u16(),
409                status.canonical_reason().unwrap_or("unknown"),
410                body_preview
411            )
412        }
413        WsError::HttpFormat(e) => format!("HTTP format error: {}", e),
414    }
415}
416
417/// Format a WebSocket close code into a readable string with numeric value.
418fn format_close_code(code: &CloseCode) -> String {
419    let (num, label) = match code {
420        CloseCode::Normal => (1000, "Normal"),
421        CloseCode::Away => (1001, "Going Away"),
422        CloseCode::Protocol => (1002, "Protocol Error"),
423        CloseCode::Unsupported => (1003, "Unsupported Data"),
424        CloseCode::Status => (1005, "No Status"),
425        CloseCode::Abnormal => (1006, "Abnormal Closure"),
426        CloseCode::Invalid => (1007, "Invalid Payload"),
427        CloseCode::Policy => (1008, "Policy Violation"),
428        CloseCode::Size => (1009, "Message Too Big"),
429        CloseCode::Extension => (1010, "Extension Required"),
430        CloseCode::Error => (1011, "Internal Error"),
431        CloseCode::Restart => (1012, "Service Restart"),
432        CloseCode::Again => (1013, "Try Again Later"),
433        CloseCode::Tls => (1015, "TLS Handshake Failure"),
434        CloseCode::Reserved(n) => (*n, "Reserved"),
435        CloseCode::Iana(n) => (*n, "IANA"),
436        CloseCode::Library(n) => (*n, "Library"),
437        CloseCode::Bad(n) => (*n, "Bad"),
438    };
439    format!("{} ({})", num, label)
440}