metaflux_client/ws/client.rs
1//! WS client core — connect, send subscribe frames, dispatch inbound messages.
2//!
3//! The connection is managed by a background tokio task spawned by
4//! [`WsClient::connect`]. The task:
5//!
6//! 1. Opens a `wss://` connection.
7//! 2. Re-issues every active subscription on reconnect.
8//! 3. Sends `ping` frames at the configured interval.
9//! 4. Forwards inbound channel frames to the user via the
10//! `tokio::sync::broadcast` channel exposed by [`WsClient::messages`].
11//!
12//! On disconnect it reconnects with exponential backoff (capped). The user
13//! task continues to consume the broadcast — they will see new frames once
14//! reconnection succeeds.
15
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
19use std::time::Duration;
20
21use futures_util::{SinkExt, StreamExt};
22use serde_json::{Value, json};
23use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
24use tokio::task::JoinHandle;
25use tokio_tungstenite::tungstenite::Message;
26
27use crate::error::ClientError;
28use crate::types::order::{CancelOrder, Order, OrderResponse};
29use crate::wallet::{TypedTradingAction, TypedTradingDigest, Wallet};
30use crate::ws::subscriptions::{Subscription, WsMessage};
31
32/// Tunable WS configuration.
33#[derive(Clone, Debug)]
34pub struct WsConfig {
35 /// Heartbeat interval. Default: 30 seconds.
36 pub ping_interval: Duration,
37 /// Initial backoff after first disconnect. Default: 250 ms.
38 pub initial_backoff: Duration,
39 /// Cap on backoff between reconnect attempts. Default: 30 seconds.
40 pub max_backoff: Duration,
41 /// Capacity of the inbound message broadcast channel. Default: 1024.
42 pub channel_capacity: usize,
43 /// How long a `post` request waits for its correlated response before
44 /// failing with [`ClientError::WebSocket`]. Default: 10 seconds.
45 pub post_timeout: Duration,
46}
47
48impl Default for WsConfig {
49 fn default() -> Self {
50 Self {
51 ping_interval: Duration::from_secs(30),
52 initial_backoff: Duration::from_millis(250),
53 max_backoff: Duration::from_secs(30),
54 channel_capacity: 1024,
55 post_timeout: Duration::from_secs(10),
56 }
57 }
58}
59
60/// Internal control-plane commands to the background task.
61#[derive(Debug)]
62enum Command {
63 Subscribe(Subscription),
64 Unsubscribe(Subscription),
65 /// A correlated `post` request: the pre-serialized frame plus a one-shot
66 /// channel the background task completes with the matching `response`
67 /// object (`{type, payload}`) once the `{channel:"post"}` frame arrives.
68 Post {
69 id: u64,
70 frame: String,
71 reply: oneshot::Sender<Value>,
72 },
73 /// Drop a pending `post` whose caller gave up (timed out) so its entry
74 /// doesn't linger in the correlation map for the life of the connection.
75 CancelPost {
76 id: u64,
77 },
78 Shutdown,
79}
80
81/// Connected WebSocket client.
82///
83/// Cheap to clone — wraps `Arc`/channels internally. Drop the last clone to
84/// trigger shutdown.
85#[derive(Debug, Clone)]
86pub struct WsClient {
87 /// Inbound message broadcast.
88 inbound_tx: broadcast::Sender<WsMessage>,
89 /// Control-plane channel to the background task.
90 cmd_tx: mpsc::UnboundedSender<Command>,
91 /// Connection state flag (true while the background loop is running).
92 alive: Arc<AtomicBool>,
93 /// Active subscriptions; replayed on reconnect.
94 active: Arc<Mutex<Vec<Subscription>>>,
95 /// Monotonic id source for `post` request/response correlation.
96 post_id: Arc<AtomicU64>,
97 /// Per-request timeout for `post` calls.
98 post_timeout: Duration,
99}
100
101impl WsClient {
102 /// Connect to a WS endpoint with the default configuration.
103 ///
104 /// `url` should be a `wss://...` URL. Returns a [`WsClient`] handle as
105 /// soon as the initial connect succeeds.
106 ///
107 /// # Errors
108 /// [`ClientError::WebSocket`] on initial connect failure.
109 pub async fn connect(url: impl Into<String>) -> Result<Self, ClientError> {
110 Self::connect_with(url, WsConfig::default()).await
111 }
112
113 /// Connect with a custom [`WsConfig`].
114 ///
115 /// # Errors
116 /// See [`WsClient::connect`].
117 pub async fn connect_with(
118 url: impl Into<String>,
119 config: WsConfig,
120 ) -> Result<Self, ClientError> {
121 let url = url.into();
122 let (inbound_tx, _) = broadcast::channel(config.channel_capacity);
123 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
124 let alive = Arc::new(AtomicBool::new(true));
125 let active: Arc<Mutex<Vec<Subscription>>> = Arc::new(Mutex::new(Vec::new()));
126 let post_timeout = config.post_timeout;
127
128 // Quick connect-then-drop to validate the URL up front; the
129 // background task will reconnect from scratch.
130 let (probe, _) = tokio_tungstenite::connect_async(&url).await?;
131 drop(probe);
132
133 let task_state = TaskState {
134 url,
135 config,
136 inbound_tx: inbound_tx.clone(),
137 cmd_rx,
138 alive: alive.clone(),
139 active: active.clone(),
140 };
141 let _handle: JoinHandle<()> = tokio::spawn(run_background(task_state));
142
143 Ok(Self {
144 inbound_tx,
145 cmd_tx,
146 alive,
147 active,
148 post_id: Arc::new(AtomicU64::new(1)),
149 post_timeout,
150 })
151 }
152
153 /// Subscribe a stream. The channel is replayed on reconnect.
154 ///
155 /// # Errors
156 /// [`ClientError::WebSocket`] if the background task is gone.
157 pub async fn subscribe(&self, sub: Subscription) -> Result<(), ClientError> {
158 {
159 let mut g = self.active.lock().await;
160 if !g.contains(&sub) {
161 g.push(sub.clone());
162 }
163 }
164 self.cmd_tx
165 .send(Command::Subscribe(sub))
166 .map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
167 Ok(())
168 }
169
170 /// Unsubscribe a stream.
171 ///
172 /// # Errors
173 /// [`ClientError::WebSocket`] if the background task is gone.
174 pub async fn unsubscribe(&self, sub: Subscription) -> Result<(), ClientError> {
175 {
176 let mut g = self.active.lock().await;
177 g.retain(|s| s != &sub);
178 }
179 self.cmd_tx
180 .send(Command::Unsubscribe(sub))
181 .map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
182 Ok(())
183 }
184
185 /// Subscribe to L2 book updates for a market. Convenience wrapper.
186 ///
187 /// # Errors
188 /// See [`WsClient::subscribe`].
189 pub async fn subscribe_l2_book(
190 &self,
191 market: crate::types::MarketId,
192 ) -> Result<(), ClientError> {
193 self.subscribe(Subscription::L2Book {
194 coin: market.0.to_string(),
195 })
196 .await
197 }
198
199 /// Subscribe to public trades for a market.
200 ///
201 /// # Errors
202 /// See [`WsClient::subscribe`].
203 pub async fn subscribe_trades(
204 &self,
205 market: crate::types::MarketId,
206 ) -> Result<(), ClientError> {
207 self.subscribe(Subscription::Trades {
208 coin: market.0.to_string(),
209 })
210 .await
211 }
212
213 /// Subscribe to best-bid-best-offer ticks for a market.
214 ///
215 /// # Errors
216 /// See [`WsClient::subscribe`].
217 pub async fn subscribe_bbo(&self, market: crate::types::MarketId) -> Result<(), ClientError> {
218 self.subscribe(Subscription::Bbo {
219 coin: market.0.to_string(),
220 })
221 .await
222 }
223
224 /// Subscribe to per-market mark / oracle / funding context.
225 ///
226 /// # Errors
227 /// See [`WsClient::subscribe`].
228 pub async fn subscribe_active_asset_ctx(
229 &self,
230 market: crate::types::MarketId,
231 ) -> Result<(), ClientError> {
232 self.subscribe(Subscription::ActiveAssetCtx {
233 coin: market.0.to_string(),
234 })
235 .await
236 }
237
238 /// Subscribe to OHLCV candles for a market + interval token
239 /// (`"1m"`/`"5m"`/`"15m"`/`"1h"`/`"4h"`/`"1d"`).
240 ///
241 /// # Errors
242 /// See [`WsClient::subscribe`].
243 pub async fn subscribe_candles(
244 &self,
245 market: crate::types::MarketId,
246 interval: impl Into<String>,
247 ) -> Result<(), ClientError> {
248 self.subscribe(Subscription::Candles {
249 coin: market.0.to_string(),
250 interval: interval.into(),
251 })
252 .await
253 }
254
255 /// Subscribe to the global all-market mids stream.
256 ///
257 /// # Errors
258 /// See [`WsClient::subscribe`].
259 pub async fn subscribe_all_mids(&self) -> Result<(), ClientError> {
260 self.subscribe(Subscription::AllMids).await
261 }
262
263 /// Subscribe to per-user fills.
264 ///
265 /// # Errors
266 /// See [`WsClient::subscribe`].
267 pub async fn subscribe_fills(&self, user: crate::wallet::Address) -> Result<(), ClientError> {
268 self.subscribe(Subscription::Fills { user }).await
269 }
270
271 /// Subscribe to per-user order lifecycle updates.
272 ///
273 /// # Errors
274 /// See [`WsClient::subscribe`].
275 pub async fn subscribe_order_updates(
276 &self,
277 user: crate::wallet::Address,
278 ) -> Result<(), ClientError> {
279 self.subscribe(Subscription::OrderUpdates { user }).await
280 }
281
282 /// Subscribe to per-user account / margin events.
283 ///
284 /// # Errors
285 /// See [`WsClient::subscribe`].
286 pub async fn subscribe_user_events(
287 &self,
288 user: crate::wallet::Address,
289 ) -> Result<(), ClientError> {
290 self.subscribe(Subscription::UserEvents { user }).await
291 }
292
293 /// Subscribe to the per-user live account-state stream.
294 ///
295 /// # Errors
296 /// See [`WsClient::subscribe`].
297 pub async fn subscribe_account_state(
298 &self,
299 user: crate::wallet::Address,
300 ) -> Result<(), ClientError> {
301 self.subscribe(Subscription::AccountState { user }).await
302 }
303
304 /// Receive inbound channel frames.
305 ///
306 /// Each call returns a fresh [`broadcast::Receiver`] so multiple consumers
307 /// can subscribe to the same stream. Returns `None` once the task has
308 /// shut down.
309 #[must_use]
310 pub fn messages(&self) -> broadcast::Receiver<WsMessage> {
311 self.inbound_tx.subscribe()
312 }
313
314 // ---- `post` request/response (HL `post` method) ----
315
316 /// Issue a signed exchange action over the WebSocket `post` channel,
317 /// returning the node's action response payload.
318 ///
319 /// This is the WS analogue of [`crate::rest::exchange::Exchange::post_signed`]: the
320 /// action is signed with the SAME EIP-712 digest (recovered over the
321 /// compact JSON of the action object), wrapped as
322 /// `{"method":"post","id":N,"request":{"type":"action","payload":{signature,nonce,action}}}`,
323 /// and sent over the existing connection. The returned `Value` is the
324 /// `payload` of the node's `action` response (e.g. `{"accepted":true,…}`);
325 /// a malformed-request rejection surfaces as [`ClientError::WebSocket`].
326 ///
327 /// # Errors
328 /// - [`ClientError::Signature`] on signing failure.
329 /// - [`ClientError::WebSocket`] if the socket is down, the post times out,
330 /// or the node returns a post-level error frame.
331 pub async fn post_action(&self, wallet: &Wallet, action: Value) -> Result<Value, ClientError> {
332 let (nonce, signature) = crate::rest::exchange::sign_action(wallet, &action)?;
333 let payload = json!({ "signature": signature, "nonce": nonce, "action": action });
334 self.post_request("action", payload).await
335 }
336
337 /// Issue a TRADING action (order / cancel / …) over the WS `post` channel,
338 /// signed under the typed scheme. The 12 trading actions migrated to the
339 /// typed scheme (the node rejects them under the opaque envelope), so the WS
340 /// `post` path carries `sig_scheme:"typed"` alongside the structured digest.
341 async fn post_typed_trade(
342 &self,
343 wallet: &Wallet,
344 action: Value,
345 typed: TypedTradingAction<'_>,
346 ) -> Result<Value, ClientError> {
347 let nonce = crate::rest::exchange::next_nonce();
348 let digest =
349 TypedTradingDigest::new(typed, crate::rest::exchange::MTF_CHAIN_ID, nonce).digest()?;
350 let signature = wallet.sign_digest(&digest)?.to_hex();
351 let payload = json!({
352 "signature": signature,
353 "nonce": nonce,
354 "action": action,
355 "sig_scheme": "typed",
356 });
357 self.post_request("action", payload).await
358 }
359
360 /// Issue an `info` read over the WebSocket `post` channel, returning the
361 /// info response payload.
362 ///
363 /// The WS analogue of a `POST /info` call: `payload` is the usual
364 /// `{"type":"<info>",…}` body. Lets a subscriber multiplex one-off reads
365 /// over the same socket instead of opening a REST connection.
366 ///
367 /// # Errors
368 /// [`ClientError::WebSocket`] if the socket is down, the post times out, or
369 /// the node returns a post-level error frame.
370 pub async fn post_info(&self, payload: Value) -> Result<Value, ClientError> {
371 self.post_request("info", payload).await
372 }
373
374 /// Submit a limit / market / trigger order over the WS `post` channel,
375 /// decoding the typed [`OrderResponse`].
376 ///
377 /// Convenience wrapper over [`Self::post_action`] mirroring
378 /// [`crate::rest::exchange::Exchange::submit_order`]. The order's `owner` MUST equal
379 /// the wallet address.
380 ///
381 /// # Errors
382 /// - [`ClientError::Validation`] if `order.owner != wallet.address()`.
383 /// - [`ClientError::Decode`] if the response payload is not an
384 /// [`OrderResponse`].
385 /// - WebSocket / signature errors per [`Self::post_action`].
386 pub async fn submit_order(
387 &self,
388 wallet: &Wallet,
389 order: &Order,
390 ) -> Result<OrderResponse, ClientError> {
391 if order.owner != wallet.address() {
392 return Err(ClientError::Validation(format!(
393 "order.owner {} != wallet address {}",
394 order.owner,
395 wallet.address()
396 )));
397 }
398 let action = json!({ "type": "submit_order", "order": order });
399 let payload = self
400 .post_typed_trade(wallet, action, TypedTradingAction::SubmitOrder(order))
401 .await?;
402 Ok(serde_json::from_value(payload)?)
403 }
404
405 /// Cancel an order over the WS `post` channel.
406 ///
407 /// Convenience wrapper over [`Self::post_action`] mirroring
408 /// [`crate::rest::exchange::Exchange::cancel_order`].
409 ///
410 /// # Errors
411 /// - [`ClientError::Validation`] if `cancel.owner != wallet.address()`.
412 /// - WebSocket / signature errors per [`Self::post_action`].
413 pub async fn cancel_order(
414 &self,
415 wallet: &Wallet,
416 cancel: &CancelOrder,
417 ) -> Result<Value, ClientError> {
418 if cancel.owner != wallet.address() {
419 return Err(ClientError::Validation(format!(
420 "cancel.owner {} != wallet address {}",
421 cancel.owner,
422 wallet.address()
423 )));
424 }
425 let action = json!({ "type": "cancel_order", "cancel": cancel });
426 self.post_typed_trade(wallet, action, TypedTradingAction::CancelOrder(cancel))
427 .await
428 }
429
430 /// Core `post` machinery: assign a correlation id, ship the frame to the
431 /// background task, and await the matching response. Maps a node
432 /// `{"type":"error",…}` response to [`ClientError::WebSocket`]; returns the
433 /// inner `payload` on success.
434 async fn post_request(&self, request_type: &str, payload: Value) -> Result<Value, ClientError> {
435 let id = self.post_id.fetch_add(1, Ordering::Relaxed);
436 let frame = json!({
437 "method": "post",
438 "id": id,
439 "request": { "type": request_type, "payload": payload },
440 })
441 .to_string();
442
443 let (reply_tx, reply_rx) = oneshot::channel();
444 self.cmd_tx
445 .send(Command::Post {
446 id,
447 frame,
448 reply: reply_tx,
449 })
450 .map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
451
452 let response = match tokio::time::timeout(self.post_timeout, reply_rx).await {
453 Ok(Ok(resp)) => resp,
454 // Sender dropped => the connection cycled before the response
455 // arrived. A signed action is one-shot, so we surface the failure
456 // rather than silently retrying (which could double-submit).
457 Ok(Err(_)) => {
458 return Err(ClientError::WebSocket(
459 "ws post: connection closed before response".into(),
460 ));
461 }
462 Err(_) => {
463 // We gave up waiting; tell the background task to evict the
464 // pending entry so it can't leak on a long-lived connection.
465 // Best-effort: if the task is gone the entry dies with it.
466 let _ = self.cmd_tx.send(Command::CancelPost { id });
467 return Err(ClientError::WebSocket("ws post: timed out".into()));
468 }
469 };
470
471 // The node wraps every reply as `{type, payload}`; an error reply
472 // carries the message as a string payload.
473 if response.get("type").and_then(Value::as_str) == Some("error") {
474 let msg = response
475 .get("payload")
476 .and_then(Value::as_str)
477 .unwrap_or("unknown post error");
478 return Err(ClientError::WebSocket(format!("ws post error: {msg}")));
479 }
480 Ok(response.get("payload").cloned().unwrap_or(Value::Null))
481 }
482
483 /// True if the background reconnect task is still running.
484 #[must_use]
485 pub fn is_alive(&self) -> bool {
486 self.alive.load(Ordering::Acquire)
487 }
488
489 /// Initiate a graceful shutdown of the background task. Subsequent
490 /// `subscribe` calls will fail.
491 pub async fn shutdown(&self) {
492 let _ = self.cmd_tx.send(Command::Shutdown);
493 self.alive.store(false, Ordering::Release);
494 }
495}
496
497/// Internal task state.
498struct TaskState {
499 url: String,
500 config: WsConfig,
501 inbound_tx: broadcast::Sender<WsMessage>,
502 cmd_rx: mpsc::UnboundedReceiver<Command>,
503 alive: Arc<AtomicBool>,
504 active: Arc<Mutex<Vec<Subscription>>>,
505}
506
507/// The reconnect-with-backoff loop.
508async fn run_background(mut state: TaskState) {
509 let mut backoff = state.config.initial_backoff;
510 loop {
511 match run_connection(&mut state).await {
512 Ok(ConnectionExit::Shutdown) => break,
513 Ok(ConnectionExit::Recoverable) | Err(_) => {
514 tokio::time::sleep(backoff).await;
515 backoff = (backoff * 2).min(state.config.max_backoff);
516 // continue loop -> reconnect
517 }
518 }
519 }
520 state.alive.store(false, Ordering::Release);
521}
522
523/// Outcome of one connection's lifetime.
524#[derive(Debug)]
525enum ConnectionExit {
526 /// User asked to stop; do not reconnect.
527 Shutdown,
528 /// Connection dropped / errored; reconnect with backoff.
529 Recoverable,
530}
531
532async fn run_connection(state: &mut TaskState) -> Result<ConnectionExit, ClientError> {
533 let (stream, _) = tokio_tungstenite::connect_async(&state.url).await?;
534 let (mut sink, mut stream) = stream.split();
535
536 // Replay active subscriptions on (re)connect.
537 {
538 let subs = state.active.lock().await.clone();
539 for sub in &subs {
540 let frame = json!({"method": "subscribe", "subscription": sub});
541 sink.send(Message::Text(frame.to_string())).await?;
542 }
543 }
544
545 // In-flight `post` requests for this connection, keyed by correlation id.
546 // Dropped (with all reply senders) when the connection exits, so any
547 // caller awaiting a response on a dead socket unblocks with an error.
548 let mut pending: HashMap<u64, oneshot::Sender<Value>> = HashMap::new();
549
550 let mut ping_tick = tokio::time::interval(state.config.ping_interval);
551 ping_tick.tick().await; // consume the immediate first tick
552
553 loop {
554 tokio::select! {
555 cmd = state.cmd_rx.recv() => {
556 match cmd {
557 Some(Command::Subscribe(sub)) => {
558 let frame = json!({"method": "subscribe", "subscription": sub});
559 sink.send(Message::Text(frame.to_string())).await?;
560 }
561 Some(Command::Unsubscribe(sub)) => {
562 let frame = json!({"method": "unsubscribe", "subscription": sub});
563 sink.send(Message::Text(frame.to_string())).await?;
564 }
565 Some(Command::Post { id, frame, reply }) => {
566 // Send first; only track the reply once the frame is on
567 // the wire. A send failure propagates `Err` out of
568 // `run_connection` (which `run_background` treats as a
569 // recoverable reconnect) and drops `reply`, surfacing a
570 // disconnect to the caller.
571 sink.send(Message::Text(frame)).await?;
572 pending.insert(id, reply);
573 }
574 Some(Command::CancelPost { id }) => {
575 // Caller timed out; drop the dangling reply sender.
576 pending.remove(&id);
577 }
578 Some(Command::Shutdown) | None => {
579 let _ = sink.send(Message::Close(None)).await;
580 return Ok(ConnectionExit::Shutdown);
581 }
582 }
583 }
584 _ = ping_tick.tick() => {
585 let ping = json!({"method": "ping"});
586 if sink.send(Message::Text(ping.to_string())).await.is_err() {
587 return Ok(ConnectionExit::Recoverable);
588 }
589 }
590 frame = stream.next() => {
591 let Some(frame) = frame else {
592 return Ok(ConnectionExit::Recoverable);
593 };
594 match frame {
595 Ok(Message::Text(text)) => {
596 // A `{channel:"post"}` frame correlates by id back to the
597 // waiting caller; every other frame is a channel update
598 // for the broadcast.
599 match serde_json::from_str::<Value>(&text) {
600 Ok(v)
601 if v.get("channel").and_then(Value::as_str) == Some("post") =>
602 {
603 if let Some(id) =
604 v.pointer("/data/id").and_then(Value::as_u64)
605 {
606 if let Some(reply) = pending.remove(&id) {
607 let resp = v
608 .pointer("/data/response")
609 .cloned()
610 .unwrap_or(Value::Null);
611 let _ = reply.send(resp);
612 }
613 }
614 }
615 Ok(v) => {
616 // Unknown / future channels (and any frame whose
617 // `data` we can't type) fall back to `Unknown`
618 // instead of being dropped, so a forward-compat
619 // consumer still sees that a frame arrived.
620 let msg = serde_json::from_value::<WsMessage>(v)
621 .unwrap_or(WsMessage::Unknown);
622 let _ = state.inbound_tx.send(msg);
623 }
624 Err(_) => {}
625 }
626 }
627 Ok(Message::Binary(_) | Message::Pong(_) | Message::Ping(_)) => {
628 // Ignore non-text control frames; tungstenite handles
629 // pong automatically for ping.
630 }
631 Ok(Message::Close(_)) => {
632 return Ok(ConnectionExit::Recoverable);
633 }
634 Ok(Message::Frame(_)) => {
635 // Raw frame — ignore.
636 }
637 Err(_) => return Ok(ConnectionExit::Recoverable),
638 }
639 }
640 }
641 }
642}
643
644#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn ws_config_default_values() {
650 let c = WsConfig::default();
651 assert_eq!(c.ping_interval, Duration::from_secs(30));
652 assert_eq!(c.initial_backoff, Duration::from_millis(250));
653 assert_eq!(c.max_backoff, Duration::from_secs(30));
654 assert_eq!(c.channel_capacity, 1024);
655 }
656}