Skip to main content

metaflux_client/ws/
client.rs

1//! WS client core — connect, send subscribe frames, dispatch inbound messages.
2//!
3//! The connection is managed by a background tokio task spawned by
4//! [`WsClient::connect`]. The task:
5//!
6//! 1. Opens a `wss://` connection.
7//! 2. Re-issues every active subscription on reconnect.
8//! 3. Sends `ping` frames at the configured interval.
9//! 4. Forwards inbound channel frames to the user via the
10//!    `tokio::sync::broadcast` channel exposed by [`WsClient::messages`].
11//!
12//! On disconnect it reconnects with exponential backoff (capped). The user
13//! task continues to consume the broadcast — they will see new frames once
14//! reconnection succeeds.
15
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
19use std::time::Duration;
20
21use futures_util::{SinkExt, StreamExt};
22use serde_json::{Value, json};
23use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
24use tokio::task::JoinHandle;
25use tokio_tungstenite::tungstenite::Message;
26
27use crate::error::ClientError;
28use crate::types::order::{CancelOrder, Order, OrderResponse};
29use crate::wallet::{TypedTradingAction, TypedTradingDigest, Wallet};
30use crate::ws::subscriptions::{Subscription, WsMessage};
31
32/// Tunable WS configuration.
33#[derive(Clone, Debug)]
34pub struct WsConfig {
35    /// Heartbeat interval. Default: 30 seconds.
36    pub ping_interval: Duration,
37    /// Initial backoff after first disconnect. Default: 250 ms.
38    pub initial_backoff: Duration,
39    /// Cap on backoff between reconnect attempts. Default: 30 seconds.
40    pub max_backoff: Duration,
41    /// Capacity of the inbound message broadcast channel. Default: 1024.
42    pub channel_capacity: usize,
43    /// How long a `post` request waits for its correlated response before
44    /// failing with [`ClientError::WebSocket`]. Default: 10 seconds.
45    pub post_timeout: Duration,
46}
47
48impl Default for WsConfig {
49    fn default() -> Self {
50        Self {
51            ping_interval: Duration::from_secs(30),
52            initial_backoff: Duration::from_millis(250),
53            max_backoff: Duration::from_secs(30),
54            channel_capacity: 1024,
55            post_timeout: Duration::from_secs(10),
56        }
57    }
58}
59
60/// Internal control-plane commands to the background task.
61#[derive(Debug)]
62enum Command {
63    Subscribe(Subscription),
64    Unsubscribe(Subscription),
65    /// A correlated `post` request: the pre-serialized frame plus a one-shot
66    /// channel the background task completes with the matching `response`
67    /// object (`{type, payload}`) once the `{channel:"post"}` frame arrives.
68    Post {
69        id: u64,
70        frame: String,
71        reply: oneshot::Sender<Value>,
72    },
73    /// Drop a pending `post` whose caller gave up (timed out) so its entry
74    /// doesn't linger in the correlation map for the life of the connection.
75    CancelPost {
76        id: u64,
77    },
78    Shutdown,
79}
80
81/// Connected WebSocket client.
82///
83/// Cheap to clone — wraps `Arc`/channels internally. Drop the last clone to
84/// trigger shutdown.
85#[derive(Debug, Clone)]
86pub struct WsClient {
87    /// Inbound message broadcast.
88    inbound_tx: broadcast::Sender<WsMessage>,
89    /// Control-plane channel to the background task.
90    cmd_tx: mpsc::UnboundedSender<Command>,
91    /// Connection state flag (true while the background loop is running).
92    alive: Arc<AtomicBool>,
93    /// Active subscriptions; replayed on reconnect.
94    active: Arc<Mutex<Vec<Subscription>>>,
95    /// Monotonic id source for `post` request/response correlation.
96    post_id: Arc<AtomicU64>,
97    /// Per-request timeout for `post` calls.
98    post_timeout: Duration,
99}
100
101impl WsClient {
102    /// Connect to a WS endpoint with the default configuration.
103    ///
104    /// `url` should be a `wss://...` URL. Returns a [`WsClient`] handle as
105    /// soon as the initial connect succeeds.
106    ///
107    /// # Errors
108    /// [`ClientError::WebSocket`] on initial connect failure.
109    pub async fn connect(url: impl Into<String>) -> Result<Self, ClientError> {
110        Self::connect_with(url, WsConfig::default()).await
111    }
112
113    /// Connect with a custom [`WsConfig`].
114    ///
115    /// # Errors
116    /// See [`WsClient::connect`].
117    pub async fn connect_with(
118        url: impl Into<String>,
119        config: WsConfig,
120    ) -> Result<Self, ClientError> {
121        let url = url.into();
122        let (inbound_tx, _) = broadcast::channel(config.channel_capacity);
123        let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
124        let alive = Arc::new(AtomicBool::new(true));
125        let active: Arc<Mutex<Vec<Subscription>>> = Arc::new(Mutex::new(Vec::new()));
126        let post_timeout = config.post_timeout;
127
128        // Quick connect-then-drop to validate the URL up front; the
129        // background task will reconnect from scratch.
130        let (probe, _) = tokio_tungstenite::connect_async(&url).await?;
131        drop(probe);
132
133        let task_state = TaskState {
134            url,
135            config,
136            inbound_tx: inbound_tx.clone(),
137            cmd_rx,
138            alive: alive.clone(),
139            active: active.clone(),
140        };
141        let _handle: JoinHandle<()> = tokio::spawn(run_background(task_state));
142
143        Ok(Self {
144            inbound_tx,
145            cmd_tx,
146            alive,
147            active,
148            post_id: Arc::new(AtomicU64::new(1)),
149            post_timeout,
150        })
151    }
152
153    /// Subscribe a stream. The channel is replayed on reconnect.
154    ///
155    /// # Errors
156    /// [`ClientError::WebSocket`] if the background task is gone.
157    pub async fn subscribe(&self, sub: Subscription) -> Result<(), ClientError> {
158        {
159            let mut g = self.active.lock().await;
160            if !g.contains(&sub) {
161                g.push(sub.clone());
162            }
163        }
164        self.cmd_tx
165            .send(Command::Subscribe(sub))
166            .map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
167        Ok(())
168    }
169
170    /// Unsubscribe a stream.
171    ///
172    /// # Errors
173    /// [`ClientError::WebSocket`] if the background task is gone.
174    pub async fn unsubscribe(&self, sub: Subscription) -> Result<(), ClientError> {
175        {
176            let mut g = self.active.lock().await;
177            g.retain(|s| s != &sub);
178        }
179        self.cmd_tx
180            .send(Command::Unsubscribe(sub))
181            .map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
182        Ok(())
183    }
184
185    /// Subscribe to L2 book updates for a market. Convenience wrapper.
186    ///
187    /// # Errors
188    /// See [`WsClient::subscribe`].
189    pub async fn subscribe_l2_book(
190        &self,
191        market: crate::types::MarketId,
192    ) -> Result<(), ClientError> {
193        self.subscribe(Subscription::L2Book {
194            coin: market.0.to_string(),
195        })
196        .await
197    }
198
199    /// Subscribe to public trades for a market.
200    ///
201    /// # Errors
202    /// See [`WsClient::subscribe`].
203    pub async fn subscribe_trades(
204        &self,
205        market: crate::types::MarketId,
206    ) -> Result<(), ClientError> {
207        self.subscribe(Subscription::Trades {
208            coin: market.0.to_string(),
209        })
210        .await
211    }
212
213    /// Subscribe to best-bid-best-offer ticks for a market.
214    ///
215    /// # Errors
216    /// See [`WsClient::subscribe`].
217    pub async fn subscribe_bbo(&self, market: crate::types::MarketId) -> Result<(), ClientError> {
218        self.subscribe(Subscription::Bbo {
219            coin: market.0.to_string(),
220        })
221        .await
222    }
223
224    /// Subscribe to per-market mark / oracle / funding context.
225    ///
226    /// # Errors
227    /// See [`WsClient::subscribe`].
228    pub async fn subscribe_active_asset_ctx(
229        &self,
230        market: crate::types::MarketId,
231    ) -> Result<(), ClientError> {
232        self.subscribe(Subscription::ActiveAssetCtx {
233            coin: market.0.to_string(),
234        })
235        .await
236    }
237
238    /// Subscribe to OHLCV candles for a market + interval token
239    /// (`"1m"`/`"5m"`/`"15m"`/`"1h"`/`"4h"`/`"1d"`).
240    ///
241    /// # Errors
242    /// See [`WsClient::subscribe`].
243    pub async fn subscribe_candles(
244        &self,
245        market: crate::types::MarketId,
246        interval: impl Into<String>,
247    ) -> Result<(), ClientError> {
248        self.subscribe(Subscription::Candles {
249            coin: market.0.to_string(),
250            interval: interval.into(),
251        })
252        .await
253    }
254
255    /// Subscribe to the global all-market mids stream.
256    ///
257    /// # Errors
258    /// See [`WsClient::subscribe`].
259    pub async fn subscribe_all_mids(&self) -> Result<(), ClientError> {
260        self.subscribe(Subscription::AllMids).await
261    }
262
263    /// Subscribe to per-user fills.
264    ///
265    /// # Errors
266    /// See [`WsClient::subscribe`].
267    pub async fn subscribe_fills(&self, user: crate::wallet::Address) -> Result<(), ClientError> {
268        self.subscribe(Subscription::Fills { user }).await
269    }
270
271    /// Subscribe to per-user order lifecycle updates.
272    ///
273    /// # Errors
274    /// See [`WsClient::subscribe`].
275    pub async fn subscribe_order_updates(
276        &self,
277        user: crate::wallet::Address,
278    ) -> Result<(), ClientError> {
279        self.subscribe(Subscription::OrderUpdates { user }).await
280    }
281
282    /// Subscribe to per-user account / margin events.
283    ///
284    /// # Errors
285    /// See [`WsClient::subscribe`].
286    pub async fn subscribe_user_events(
287        &self,
288        user: crate::wallet::Address,
289    ) -> Result<(), ClientError> {
290        self.subscribe(Subscription::UserEvents { user }).await
291    }
292
293    /// Subscribe to the per-user live account-state stream.
294    ///
295    /// # Errors
296    /// See [`WsClient::subscribe`].
297    pub async fn subscribe_account_state(
298        &self,
299        user: crate::wallet::Address,
300    ) -> Result<(), ClientError> {
301        self.subscribe(Subscription::AccountState { user }).await
302    }
303
304    /// Receive inbound channel frames.
305    ///
306    /// Each call returns a fresh [`broadcast::Receiver`] so multiple consumers
307    /// can subscribe to the same stream. Returns `None` once the task has
308    /// shut down.
309    #[must_use]
310    pub fn messages(&self) -> broadcast::Receiver<WsMessage> {
311        self.inbound_tx.subscribe()
312    }
313
314    // ---- `post` request/response (HL `post` method) ----
315
316    /// Issue a signed exchange action over the WebSocket `post` channel,
317    /// returning the node's action response payload.
318    ///
319    /// This is the WS analogue of [`crate::rest::exchange::Exchange::post_signed`]: the
320    /// action is signed with the SAME EIP-712 digest (recovered over the
321    /// compact JSON of the action object), wrapped as
322    /// `{"method":"post","id":N,"request":{"type":"action","payload":{signature,nonce,action}}}`,
323    /// and sent over the existing connection. The returned `Value` is the
324    /// `payload` of the node's `action` response (e.g. `{"accepted":true,…}`);
325    /// a malformed-request rejection surfaces as [`ClientError::WebSocket`].
326    ///
327    /// # Errors
328    /// - [`ClientError::Signature`] on signing failure.
329    /// - [`ClientError::WebSocket`] if the socket is down, the post times out,
330    ///   or the node returns a post-level error frame.
331    pub async fn post_action(&self, wallet: &Wallet, action: Value) -> Result<Value, ClientError> {
332        let (nonce, signature) = crate::rest::exchange::sign_action(wallet, &action)?;
333        let payload = json!({ "signature": signature, "nonce": nonce, "action": action });
334        self.post_request("action", payload).await
335    }
336
337    /// Issue a TRADING action (order / cancel / …) over the WS `post` channel,
338    /// signed under the typed scheme. The 12 trading actions migrated to the
339    /// typed scheme (the node rejects them under the opaque envelope), so the WS
340    /// `post` path carries `sig_scheme:"typed"` alongside the structured digest.
341    async fn post_typed_trade(
342        &self,
343        wallet: &Wallet,
344        action: Value,
345        typed: TypedTradingAction<'_>,
346    ) -> Result<Value, ClientError> {
347        let nonce = crate::rest::exchange::next_nonce();
348        let digest =
349            TypedTradingDigest::new(typed, crate::rest::exchange::MTF_CHAIN_ID, nonce).digest()?;
350        let signature = wallet.sign_digest(&digest)?.to_hex();
351        let payload = json!({
352            "signature": signature,
353            "nonce": nonce,
354            "action": action,
355            "sig_scheme": "typed",
356        });
357        self.post_request("action", payload).await
358    }
359
360    /// Issue an `info` read over the WebSocket `post` channel, returning the
361    /// info response payload.
362    ///
363    /// The WS analogue of a `POST /info` call: `payload` is the usual
364    /// `{"type":"<info>",…}` body. Lets a subscriber multiplex one-off reads
365    /// over the same socket instead of opening a REST connection.
366    ///
367    /// # Errors
368    /// [`ClientError::WebSocket`] if the socket is down, the post times out, or
369    /// the node returns a post-level error frame.
370    pub async fn post_info(&self, payload: Value) -> Result<Value, ClientError> {
371        self.post_request("info", payload).await
372    }
373
374    /// Submit a limit / market / trigger order over the WS `post` channel,
375    /// decoding the typed [`OrderResponse`].
376    ///
377    /// Convenience wrapper over [`Self::post_action`] mirroring
378    /// [`crate::rest::exchange::Exchange::submit_order`]. The order's `owner` MUST equal
379    /// the wallet address.
380    ///
381    /// # Errors
382    /// - [`ClientError::Validation`] if `order.owner != wallet.address()`.
383    /// - [`ClientError::Decode`] if the response payload is not an
384    ///   [`OrderResponse`].
385    /// - WebSocket / signature errors per [`Self::post_action`].
386    pub async fn submit_order(
387        &self,
388        wallet: &Wallet,
389        order: &Order,
390    ) -> Result<OrderResponse, ClientError> {
391        if order.owner != wallet.address() {
392            return Err(ClientError::Validation(format!(
393                "order.owner {} != wallet address {}",
394                order.owner,
395                wallet.address()
396            )));
397        }
398        let action = json!({ "type": "submit_order", "order": order });
399        let payload = self
400            .post_typed_trade(wallet, action, TypedTradingAction::SubmitOrder(order))
401            .await?;
402        Ok(serde_json::from_value(payload)?)
403    }
404
405    /// Cancel an order over the WS `post` channel.
406    ///
407    /// Convenience wrapper over [`Self::post_action`] mirroring
408    /// [`crate::rest::exchange::Exchange::cancel_order`].
409    ///
410    /// # Errors
411    /// - [`ClientError::Validation`] if `cancel.owner != wallet.address()`.
412    /// - WebSocket / signature errors per [`Self::post_action`].
413    pub async fn cancel_order(
414        &self,
415        wallet: &Wallet,
416        cancel: &CancelOrder,
417    ) -> Result<Value, ClientError> {
418        if cancel.owner != wallet.address() {
419            return Err(ClientError::Validation(format!(
420                "cancel.owner {} != wallet address {}",
421                cancel.owner,
422                wallet.address()
423            )));
424        }
425        let action = json!({ "type": "cancel_order", "cancel": cancel });
426        self.post_typed_trade(wallet, action, TypedTradingAction::CancelOrder(cancel))
427            .await
428    }
429
430    /// Core `post` machinery: assign a correlation id, ship the frame to the
431    /// background task, and await the matching response. Maps a node
432    /// `{"type":"error",…}` response to [`ClientError::WebSocket`]; returns the
433    /// inner `payload` on success.
434    async fn post_request(&self, request_type: &str, payload: Value) -> Result<Value, ClientError> {
435        let id = self.post_id.fetch_add(1, Ordering::Relaxed);
436        let frame = json!({
437            "method": "post",
438            "id": id,
439            "request": { "type": request_type, "payload": payload },
440        })
441        .to_string();
442
443        let (reply_tx, reply_rx) = oneshot::channel();
444        self.cmd_tx
445            .send(Command::Post {
446                id,
447                frame,
448                reply: reply_tx,
449            })
450            .map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
451
452        let response = match tokio::time::timeout(self.post_timeout, reply_rx).await {
453            Ok(Ok(resp)) => resp,
454            // Sender dropped => the connection cycled before the response
455            // arrived. A signed action is one-shot, so we surface the failure
456            // rather than silently retrying (which could double-submit).
457            Ok(Err(_)) => {
458                return Err(ClientError::WebSocket(
459                    "ws post: connection closed before response".into(),
460                ));
461            }
462            Err(_) => {
463                // We gave up waiting; tell the background task to evict the
464                // pending entry so it can't leak on a long-lived connection.
465                // Best-effort: if the task is gone the entry dies with it.
466                let _ = self.cmd_tx.send(Command::CancelPost { id });
467                return Err(ClientError::WebSocket("ws post: timed out".into()));
468            }
469        };
470
471        // The node wraps every reply as `{type, payload}`; an error reply
472        // carries the message as a string payload.
473        if response.get("type").and_then(Value::as_str) == Some("error") {
474            let msg = response
475                .get("payload")
476                .and_then(Value::as_str)
477                .unwrap_or("unknown post error");
478            return Err(ClientError::WebSocket(format!("ws post error: {msg}")));
479        }
480        Ok(response.get("payload").cloned().unwrap_or(Value::Null))
481    }
482
483    /// True if the background reconnect task is still running.
484    #[must_use]
485    pub fn is_alive(&self) -> bool {
486        self.alive.load(Ordering::Acquire)
487    }
488
489    /// Initiate a graceful shutdown of the background task. Subsequent
490    /// `subscribe` calls will fail.
491    pub async fn shutdown(&self) {
492        let _ = self.cmd_tx.send(Command::Shutdown);
493        self.alive.store(false, Ordering::Release);
494    }
495}
496
497/// Internal task state.
498struct TaskState {
499    url: String,
500    config: WsConfig,
501    inbound_tx: broadcast::Sender<WsMessage>,
502    cmd_rx: mpsc::UnboundedReceiver<Command>,
503    alive: Arc<AtomicBool>,
504    active: Arc<Mutex<Vec<Subscription>>>,
505}
506
507/// The reconnect-with-backoff loop.
508async fn run_background(mut state: TaskState) {
509    let mut backoff = state.config.initial_backoff;
510    loop {
511        match run_connection(&mut state).await {
512            Ok(ConnectionExit::Shutdown) => break,
513            Ok(ConnectionExit::Recoverable) | Err(_) => {
514                tokio::time::sleep(backoff).await;
515                backoff = (backoff * 2).min(state.config.max_backoff);
516                // continue loop -> reconnect
517            }
518        }
519    }
520    state.alive.store(false, Ordering::Release);
521}
522
523/// Outcome of one connection's lifetime.
524#[derive(Debug)]
525enum ConnectionExit {
526    /// User asked to stop; do not reconnect.
527    Shutdown,
528    /// Connection dropped / errored; reconnect with backoff.
529    Recoverable,
530}
531
532async fn run_connection(state: &mut TaskState) -> Result<ConnectionExit, ClientError> {
533    let (stream, _) = tokio_tungstenite::connect_async(&state.url).await?;
534    let (mut sink, mut stream) = stream.split();
535
536    // Replay active subscriptions on (re)connect.
537    {
538        let subs = state.active.lock().await.clone();
539        for sub in &subs {
540            let frame = json!({"method": "subscribe", "subscription": sub});
541            sink.send(Message::Text(frame.to_string())).await?;
542        }
543    }
544
545    // In-flight `post` requests for this connection, keyed by correlation id.
546    // Dropped (with all reply senders) when the connection exits, so any
547    // caller awaiting a response on a dead socket unblocks with an error.
548    let mut pending: HashMap<u64, oneshot::Sender<Value>> = HashMap::new();
549
550    let mut ping_tick = tokio::time::interval(state.config.ping_interval);
551    ping_tick.tick().await; // consume the immediate first tick
552
553    loop {
554        tokio::select! {
555            cmd = state.cmd_rx.recv() => {
556                match cmd {
557                    Some(Command::Subscribe(sub)) => {
558                        let frame = json!({"method": "subscribe", "subscription": sub});
559                        sink.send(Message::Text(frame.to_string())).await?;
560                    }
561                    Some(Command::Unsubscribe(sub)) => {
562                        let frame = json!({"method": "unsubscribe", "subscription": sub});
563                        sink.send(Message::Text(frame.to_string())).await?;
564                    }
565                    Some(Command::Post { id, frame, reply }) => {
566                        // Send first; only track the reply once the frame is on
567                        // the wire. A send failure propagates `Err` out of
568                        // `run_connection` (which `run_background` treats as a
569                        // recoverable reconnect) and drops `reply`, surfacing a
570                        // disconnect to the caller.
571                        sink.send(Message::Text(frame)).await?;
572                        pending.insert(id, reply);
573                    }
574                    Some(Command::CancelPost { id }) => {
575                        // Caller timed out; drop the dangling reply sender.
576                        pending.remove(&id);
577                    }
578                    Some(Command::Shutdown) | None => {
579                        let _ = sink.send(Message::Close(None)).await;
580                        return Ok(ConnectionExit::Shutdown);
581                    }
582                }
583            }
584            _ = ping_tick.tick() => {
585                let ping = json!({"method": "ping"});
586                if sink.send(Message::Text(ping.to_string())).await.is_err() {
587                    return Ok(ConnectionExit::Recoverable);
588                }
589            }
590            frame = stream.next() => {
591                let Some(frame) = frame else {
592                    return Ok(ConnectionExit::Recoverable);
593                };
594                match frame {
595                    Ok(Message::Text(text)) => {
596                        // A `{channel:"post"}` frame correlates by id back to the
597                        // waiting caller; every other frame is a channel update
598                        // for the broadcast.
599                        match serde_json::from_str::<Value>(&text) {
600                            Ok(v)
601                                if v.get("channel").and_then(Value::as_str) == Some("post") =>
602                            {
603                                if let Some(id) =
604                                    v.pointer("/data/id").and_then(Value::as_u64)
605                                {
606                                    if let Some(reply) = pending.remove(&id) {
607                                        let resp = v
608                                            .pointer("/data/response")
609                                            .cloned()
610                                            .unwrap_or(Value::Null);
611                                        let _ = reply.send(resp);
612                                    }
613                                }
614                            }
615                            Ok(v) => {
616                                // Unknown / future channels (and any frame whose
617                                // `data` we can't type) fall back to `Unknown`
618                                // instead of being dropped, so a forward-compat
619                                // consumer still sees that a frame arrived.
620                                let msg = serde_json::from_value::<WsMessage>(v)
621                                    .unwrap_or(WsMessage::Unknown);
622                                let _ = state.inbound_tx.send(msg);
623                            }
624                            Err(_) => {}
625                        }
626                    }
627                    Ok(Message::Binary(_) | Message::Pong(_) | Message::Ping(_)) => {
628                        // Ignore non-text control frames; tungstenite handles
629                        // pong automatically for ping.
630                    }
631                    Ok(Message::Close(_)) => {
632                        return Ok(ConnectionExit::Recoverable);
633                    }
634                    Ok(Message::Frame(_)) => {
635                        // Raw frame — ignore.
636                    }
637                    Err(_) => return Ok(ConnectionExit::Recoverable),
638                }
639            }
640        }
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn ws_config_default_values() {
650        let c = WsConfig::default();
651        assert_eq!(c.ping_interval, Duration::from_secs(30));
652        assert_eq!(c.initial_backoff, Duration::from_millis(250));
653        assert_eq!(c.max_backoff, Duration::from_secs(30));
654        assert_eq!(c.channel_capacity, 1024);
655    }
656}