Skip to main content

bullet_rust_sdk/ws/
client.rs

1//! WebSocket client for real-time market data and order updates.
2//!
3//! This module provides a WebSocket client for connecting to the trading API's
4//! real-time data streams.
5//!
6//! # Features
7//!
8//! - **Protocol-level keepalive**: The server handles keepalive via WebSocket protocol-level
9//!   ping/pong frames (managed automatically by the transport).
10//! - **Cross-platform**: Works on both native Rust and WASM targets.
11//! - **Graceful error handling**: Parse failures return `ServerMessage::Unknown` with the error and
12//!   raw message text for debugging.
13//!
14//! # Example
15//!
16//! ```no_run
17//! use bullet_rust_sdk::Client;
18//! use bullet_rust_sdk::errors::WSErrors;
19//! use bullet_rust_sdk::types::ClientMessage;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! let api = Client::mainnet().await?;
23//!
24//! 'reconnect: loop {
25//!     let mut ws = api.connect_ws().call().await?;
26//!
27//!     ws.send(ClientMessage::Subscribe {
28//!         id: Some(1.into()),
29//!         params: vec!["BTC-USD@aggTrade".to_string()],
30//!     })
31//!     .await?;
32//!
33//!     loop {
34//!         match ws.recv().await {
35//!             Ok(msg) => println!("Received: {:?}", msg),
36//!             Err(WSErrors::WsClosed { code, reason }) => {
37//!                 eprintln!("Closed ({code:?}): {reason}");
38//!                 continue 'reconnect;
39//!             }
40//!             Err(WSErrors::WsStreamEnded) => {
41//!                 eprintln!("Connection lost");
42//!                 continue 'reconnect;
43//!             }
44//!             Err(e) => return Err(e.into()),
45//!         }
46//!     }
47//! }
48//! # }
49//! ```
50//!
51//! # Keepalive Behavior
52//!
53//! The server handles keepalive automatically using WebSocket protocol-level
54//! ping/pong frames. No application-level pings are needed.
55
56use std::ops::Deref;
57
58use bon::{Builder, bon};
59use futures::{FutureExt, SinkExt, StreamExt, select};
60use futures_timer::Delay;
61use tracing::{debug, warn};
62use web_time::Duration;
63
64use super::models::{ServerMessage, TaggedMessage};
65use super::topics::Topic;
66use crate::errors::WSErrors;
67use crate::types::{ClientMessage, OrderParams, RequestId};
68use crate::{Client, SDKResult};
69
70/// Default connection timeout in seconds.
71const DEFAULT_CONNECTION_TIMEOUT_SECS: u64 = 10;
72
73/// Handle to an active WebSocket connection.
74///
75/// Provides methods to send messages and receive responses.
76///
77/// # Thread Safety
78///
79/// `WebsocketHandle` is `Send` but **not `Sync`**. The underlying WebSocket
80/// transport contains non-thread-safe internal buffers.
81///
82/// If you need to share the handle across async tasks, wrap it in a
83/// [`tokio::sync::Mutex`]:
84///
85/// ```ignore
86/// use std::sync::Arc;
87/// use tokio::sync::Mutex;
88///
89/// let ws = Arc::new(Mutex::new(client.connect_ws().call().await?));
90///
91/// // Receiving task
92/// let ws_recv = ws.clone();
93/// tokio::spawn(async move {
94///     loop {
95///         let msg = ws_recv.lock().await.recv().await;
96///         // handle msg ...
97///     }
98/// });
99///
100/// // Sending task
101/// ws.lock().await.subscribe([Topic::agg_trade("BTC-USD")], None).await?;
102/// ```
103///
104/// For high-throughput bots, a common pattern is to dedicate one task to the
105/// WebSocket and use [`tokio::sync::mpsc`] channels to communicate with other
106/// tasks — this avoids lock contention on the hot path.
107///
108/// # Extracting the Inner Socket
109///
110/// If you need direct access to the underlying `reqwest_websocket::WebSocket`,
111/// use the `Deref` implementation.
112pub struct WebsocketHandle {
113    socket: reqwest_websocket::WebSocket,
114}
115
116impl WebsocketHandle {
117    /// Connect to a WebSocket endpoint and wait for the server's handshake.
118    ///
119    /// Shared by `Client::connect_ws` and `ManagedWsClient::connect`.
120    pub(crate) async fn connect(
121        ws_client: &reqwest::Client,
122        ws_url: &str,
123        timeout: web_time::Duration,
124    ) -> SDKResult<Self, WSErrors> {
125        use reqwest_websocket::Upgrade;
126
127        let response: reqwest_websocket::UpgradeResponse =
128            ws_client.clone().get(ws_url).upgrade().send().await?;
129        let websocket = response.into_websocket().await?;
130        let mut handle = Self { socket: websocket };
131        handle.wait_for_connected(timeout).await?;
132        Ok(handle)
133    }
134}
135
136/// Configuration for WebSocket connection behavior.
137///
138/// # Example
139///
140/// ```no_run
141/// use bullet_rust_sdk::Client;
142/// use bullet_rust_sdk::ws::WebsocketConfig;
143/// use web_time::Duration;
144///
145/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
146/// let api = Client::mainnet().await?;
147///
148/// // Use a longer connection timeout
149/// let config = WebsocketConfig { connection_timeout: Duration::from_secs(30) };
150/// let mut ws = api.connect_ws().config(config).call().await?;
151/// # Ok(())
152/// # }
153/// ```
154#[derive(Builder, Clone, Debug)]
155pub struct WebsocketConfig {
156    /// How long to wait for the server's "connected" message during handshake.
157    ///
158    /// Default: 10 seconds
159    #[builder(default = Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS))]
160    pub connection_timeout: Duration,
161}
162
163impl Default for WebsocketConfig {
164    fn default() -> Self {
165        Self::builder().build()
166    }
167}
168
169impl Deref for WebsocketHandle {
170    type Target = reqwest_websocket::WebSocket;
171
172    fn deref(&self) -> &Self::Target {
173        &self.socket
174    }
175}
176
177#[bon]
178impl Client {
179    /// Open a raw WebSocket connection.
180    ///
181    /// For production bots, prefer [`connect_ws_managed`](Client::connect_ws_managed)
182    /// which handles reconnection automatically.
183    #[builder]
184    pub async fn connect_ws(
185        &self,
186        config: Option<WebsocketConfig>,
187    ) -> SDKResult<WebsocketHandle, WSErrors> {
188        let config = config.unwrap_or_default();
189        WebsocketHandle::connect(&self.ws_client, self.ws_url(), config.connection_timeout).await
190    }
191}
192
193impl WebsocketHandle {
194    /// Wait for the server's "connected" status message.
195    ///
196    /// Called automatically during connection. Times out if no message received
197    /// within the specified timeout.
198    async fn wait_for_connected(&mut self, timeout: Duration) -> SDKResult<(), WSErrors> {
199        // Note: web_time::Duration is std::time::Duration on native, but different on WASM.
200        // The try_into() is needed for WASM compatibility.
201        #[allow(clippy::useless_conversion)]
202        let timeout = Delay::new(
203            timeout
204                .try_into()
205                .unwrap_or(std::time::Duration::from_secs(DEFAULT_CONNECTION_TIMEOUT_SECS)),
206        );
207
208        debug!("Waiting for connected message from websocket.");
209
210        select! {
211            result = self.recv().fuse() => {
212                match result? {
213                    ServerMessage::Tagged(TaggedMessage::Status(status))
214                        if status.status == "connected" =>
215                    {
216                        debug!("Successfully got connected message, continuing");
217                        Ok(())
218                    }
219                    other => Err(WSErrors::WsHandshakeFailed(format!("{other:?}"))),
220                }
221            }
222            _ = timeout.fuse() => {
223                Err(WSErrors::WsConnectionTimeout)
224            }
225        }
226    }
227
228    /// Send a message to the server.
229    ///
230    /// # Available Message Types
231    ///
232    /// - `ClientMessage::Subscribe` - Subscribe to market data streams
233    /// - `ClientMessage::Unsubscribe` - Unsubscribe from streams
234    /// - `ClientMessage::ListSubscriptions` - List active subscriptions
235    /// - `ClientMessage::Ping` - Manual ping (not needed for keepalive)
236    /// - `ClientMessage::OrderPlace` - Place an order
237    /// - `ClientMessage::OrderCancel` - Cancel an order
238    ///
239    /// # Example
240    ///
241    /// ```no_run
242    /// use bullet_rust_sdk::types::ClientMessage;
243    /// # use bullet_rust_sdk::Client;
244    ///
245    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
246    /// # let api = Client::mainnet().await?;
247    /// # let mut ws = api.connect_ws().call().await?;
248    /// // Subscribe to aggregated trades
249    /// ws.send(ClientMessage::Subscribe {
250    ///     id: Some(1.into()),
251    ///     params: vec!["BTC-USD@aggTrade".to_string()],
252    /// })
253    /// .await?;
254    ///
255    /// // Unsubscribe later
256    /// ws.send(ClientMessage::Unsubscribe {
257    ///     id: Some(2.into()),
258    ///     params: vec!["BTC-USD@aggTrade".to_string()],
259    /// })
260    /// .await?;
261    /// # Ok(())
262    /// # }
263    /// ```
264    pub async fn send(&mut self, msg: ClientMessage) -> SDKResult<(), WSErrors> {
265        let string_msg = serde_json::to_string(&msg)?;
266        self.socket.send(reqwest_websocket::Message::Text(string_msg)).await?;
267        Ok(())
268    }
269
270    /// Receive the next message from the server.
271    ///
272    /// # Errors
273    ///
274    /// - [`WSErrors::WsClosed`] - Server closed the connection (includes close code and reason)
275    /// - [`WSErrors::WsStreamEnded`] - Connection ended unexpectedly without a close frame
276    /// - [`WSErrors::WsUpgradeError`] - WebSocket protocol error
277    ///
278    /// # Parse Errors
279    ///
280    /// If a message cannot be parsed into a known [`ServerMessage`] variant,
281    /// it returns `ServerMessage::Unknown(error, raw_text)` instead of failing.
282    /// This allows you to log or debug unexpected message formats.
283    ///
284    /// # Example
285    ///
286    /// ```no_run
287    /// use bullet_rust_sdk::errors::WSErrors;
288    /// use bullet_rust_sdk::types::ClientMessage;
289    /// use bullet_rust_sdk::ws::models::{ServerMessage, TaggedMessage};
290    /// # use bullet_rust_sdk::Client;
291    ///
292    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
293    /// let api = Client::mainnet().await?;
294    ///
295    /// 'reconnect: loop {
296    ///     let mut ws = api.connect_ws().call().await?;
297    ///
298    ///     ws.send(ClientMessage::Subscribe {
299    ///         id: Some(1.into()),
300    ///         params: vec!["BTC-USD@aggTrade".to_string()],
301    ///     })
302    ///     .await?;
303    ///
304    ///     loop {
305    ///         match ws.recv().await {
306    ///             Ok(msg) => match msg {
307    ///                 ServerMessage::AggTrade(trade) => {
308    ///                     println!("Trade: {} @ {}", trade.symbol, trade.price);
309    ///                 }
310    ///                 ServerMessage::Tagged(TaggedMessage::Pong(_)) => {}
311    ///                 ServerMessage::Tagged(TaggedMessage::Error(err)) => {
312    ///                     eprintln!("Server error: {:?}", err);
313    ///                 }
314    ///                 _ => {}
315    ///             },
316    ///             Err(WSErrors::WsClosed { code, reason }) => {
317    ///                 eprintln!("Connection closed (code {:?}): {}", code, reason);
318    ///                 continue 'reconnect;
319    ///             }
320    ///             Err(WSErrors::WsStreamEnded) => {
321    ///                 eprintln!("Connection lost unexpectedly");
322    ///                 continue 'reconnect;
323    ///             }
324    ///             Err(e) => return Err(e.into()),
325    ///         }
326    ///     }
327    /// }
328    /// # }
329    /// ```
330    pub async fn recv(&mut self) -> SDKResult<ServerMessage, WSErrors> {
331        while let Some(msg) = self.socket.next().await {
332            let msg = msg?;
333
334            match msg {
335                reqwest_websocket::Message::Text(text) => {
336                    let server_msg = match serde_json::from_str::<ServerMessage>(&text) {
337                        Ok(v) => v,
338                        Err(e) => {
339                            warn!(?e, "Failed to parse ServerMessage, returning Unknown");
340                            ServerMessage::Unknown(e.to_string(), text)
341                        }
342                    };
343                    return Ok(server_msg);
344                }
345                reqwest_websocket::Message::Binary(data) => {
346                    let text = String::from_utf8_lossy(&data).to_string();
347                    let server_msg = match serde_json::from_slice::<ServerMessage>(&data) {
348                        Ok(v) => v,
349                        Err(e) => {
350                            warn!(?e, "Failed to parse ServerMessage, returning Unknown");
351                            ServerMessage::Unknown(e.to_string(), text)
352                        }
353                    };
354                    return Ok(server_msg);
355                }
356                reqwest_websocket::Message::Close { code, reason } => {
357                    return Err(WSErrors::WsClosed { code, reason });
358                }
359                _ => continue,
360            }
361        }
362
363        Err(WSErrors::WsStreamEnded)
364    }
365
366    /// Subscribe to one or more topics.
367    ///
368    /// # Arguments
369    ///
370    /// * `topics` - Topics to subscribe to
371    /// * `id` - Optional request ID for matching the server's response
372    ///
373    /// # Example
374    ///
375    /// ```no_run
376    /// use bullet_rust_sdk::types::RequestId;
377    /// use bullet_rust_sdk::{Client, OrderbookDepth, Topic};
378    ///
379    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
380    /// let api = Client::mainnet().await?;
381    /// let mut ws = api.connect_ws().call().await?;
382    ///
383    /// // Subscribe to multiple topics using type-safe builders
384    /// ws.subscribe(
385    ///     [
386    ///         Topic::agg_trade("BTC-USD"),
387    ///         Topic::depth("ETH-USD", OrderbookDepth::D10),
388    ///         Topic::book_ticker("SOL-USD"),
389    ///     ],
390    ///     Some(RequestId::new(1)),
391    /// )
392    /// .await?;
393    ///
394    /// // Now receive market data
395    /// loop {
396    ///     let msg = ws.recv().await?;
397    ///     println!("{:?}", msg);
398    /// }
399    /// # }
400    /// ```
401    pub async fn subscribe(
402        &mut self,
403        topics: impl IntoIterator<Item = Topic>,
404        id: Option<RequestId>,
405    ) -> SDKResult<(), WSErrors> {
406        self.send(ClientMessage::Subscribe {
407            id,
408            params: topics.into_iter().map(|t| t.to_string()).collect(),
409        })
410        .await
411    }
412
413    /// Unsubscribe from one or more topics.
414    ///
415    /// Unsubscribe is idempotent - unsubscribing from topics you're not
416    /// subscribed to will still succeed.
417    ///
418    /// # Arguments
419    ///
420    /// * `topics` - Topics to unsubscribe from
421    /// * `id` - Optional request ID for matching the server's response
422    ///
423    /// # Example
424    ///
425    /// ```no_run
426    /// use bullet_rust_sdk::Client;
427    /// use bullet_rust_sdk::types::RequestId;
428    ///
429    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
430    /// let api = Client::mainnet().await?;
431    /// let mut ws = api.connect_ws().call().await?;
432    ///
433    /// ws.list_subscriptions(Some(RequestId::new(1))).await?;
434    /// // Match response by request_id
435    /// # Ok(())
436    /// # }
437    /// ```
438    pub async fn list_subscriptions(&mut self, id: Option<RequestId>) -> SDKResult<(), WSErrors> {
439        self.send(ClientMessage::ListSubscriptions { id }).await
440    }
441
442    /// Place an order via WebSocket.
443    ///
444    /// # Arguments
445    ///
446    /// * `tx` - Base64-encoded raw transaction bytes
447    /// * `id` - Optional request ID for matching the server's response
448    ///
449    /// # Example
450    ///
451    /// ```no_run
452    /// use bullet_rust_sdk::Client;
453    /// use bullet_rust_sdk::types::RequestId;
454    ///
455    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
456    /// let api = Client::mainnet().await?;
457    /// let mut ws = api.connect_ws().call().await?;
458    ///
459    /// let tx_bytes = "base64_encoded_transaction";
460    /// ws.order_place(tx_bytes, Some(RequestId::new(1))).await?;
461    /// // Match response by request_id
462    /// # Ok(())
463    /// # }
464    /// ```
465    pub async fn order_place(
466        &mut self,
467        tx: impl Into<String>,
468        id: Option<RequestId>,
469    ) -> SDKResult<(), WSErrors> {
470        self.send(ClientMessage::OrderPlace { id, params: OrderParams { tx: tx.into() } }).await
471    }
472
473    /// Cancel an order via WebSocket.
474    ///
475    /// # Arguments
476    ///
477    /// * `tx` - Base64-encoded raw transaction bytes
478    /// * `id` - Optional request ID for matching the server's response
479    ///
480    /// # Example
481    ///
482    /// ```no_run
483    /// use bullet_rust_sdk::Client;
484    /// use bullet_rust_sdk::types::RequestId;
485    ///
486    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
487    /// let api = Client::mainnet().await?;
488    /// let mut ws = api.connect_ws().call().await?;
489    ///
490    /// let tx_bytes = "base64_encoded_cancel_transaction";
491    /// ws.order_cancel(tx_bytes, Some(RequestId::new(1))).await?;
492    /// // Match response by request_id
493    /// # Ok(())
494    /// # }
495    /// ```
496    pub async fn order_cancel(
497        &mut self,
498        tx: impl Into<String>,
499        id: Option<RequestId>,
500    ) -> SDKResult<(), WSErrors> {
501        self.send(ClientMessage::OrderCancel { id, params: OrderParams { tx: tx.into() } }).await
502    }
503
504    /// Place an order via WebSocket using a signed transaction.
505    ///
506    /// This is a convenience wrapper around [`order_place`](Self::order_place) that
507    /// handles base64 encoding internally.
508    ///
509    /// # Example
510    ///
511    /// ```ignore
512    /// use bullet_rust_sdk::{Client, Transaction};
513    ///
514    /// let signed = Transaction::builder()
515    ///     .call_message(call_msg)
516    ///     .client(&client)
517    ///     .build()?;
518    ///
519    /// ws.place_order(&signed, None).await?;
520    /// ```
521    pub async fn place_order(
522        &mut self,
523        signed: &bullet_exchange_interface::transaction::Transaction,
524        id: Option<RequestId>,
525    ) -> SDKResult<(), WSErrors> {
526        let base64 =
527            crate::Transaction::to_base64(signed).map_err(|e| WSErrors::WsError(e.to_string()))?;
528        self.order_place(base64, id).await
529    }
530
531    /// Cancel an order via WebSocket using a signed transaction.
532    ///
533    /// This is a convenience wrapper around [`order_cancel`](Self::order_cancel) that
534    /// handles base64 encoding internally.
535    ///
536    /// # Example
537    ///
538    /// ```ignore
539    /// use bullet_rust_sdk::{Client, Transaction};
540    ///
541    /// let signed = Transaction::builder()
542    ///     .call_message(cancel_msg)
543    ///     .client(&client)
544    ///     .build()?;
545    ///
546    /// ws.cancel_order(&signed, None).await?;
547    /// ```
548    pub async fn cancel_order(
549        &mut self,
550        signed: &bullet_exchange_interface::transaction::Transaction,
551        id: Option<RequestId>,
552    ) -> SDKResult<(), WSErrors> {
553        let base64 =
554            crate::Transaction::to_base64(signed).map_err(|e| WSErrors::WsError(e.to_string()))?;
555        self.order_cancel(base64, id).await
556    }
557}