Skip to main content

bybit_client/ws/
stream.rs

1//! Async Stream interface for WebSocket messages.
2//!
3//! This module provides a Stream-based interface for WebSocket messages,
4//! enabling use of stream combinators like `.filter()`, `.map()`, etc.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use std::pin::pin;
10//! use bybit_client::ws::{WsClient, WsChannel, WsStream};
11//! use futures::StreamExt;
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
15//!     let (client, receiver) = WsClient::connect_public(WsChannel::PublicLinear).await?;
16//!     client.subscribe(&["publicTrade.BTCUSDT", "tickers.BTCUSDT"]).await?;
17//!
18//!     // Create a stream from the receiver.
19//!     let stream = WsStream::new(receiver);
20//!
21//!     // Use stream combinators to filter trade messages.
22//!     // Note: Filtered streams need to be pinned to use `.next()`.
23//!     let mut trades = pin!(stream.trades());
24//!
25//!     while let Some(trade) = trades.next().await {
26//!         for t in &trade.data {
27//!             println!("Trade: {} {} @ {}", t.symbol, t.size, t.price);
28//!         }
29//!     }
30//!
31//!     Ok(())
32//! }
33//! ```
34
35use std::pin::Pin;
36use std::task::{Context, Poll};
37
38use futures::Stream;
39use tokio::sync::mpsc;
40use tokio_stream::wrappers::UnboundedReceiverStream;
41
42use crate::ws::types::*;
43
44/// A Stream wrapper for WebSocket messages.
45///
46/// This struct wraps an `mpsc::UnboundedReceiver<WsMessage>` and implements
47/// the `Stream` trait, enabling use of stream combinators.
48pub struct WsStream {
49    inner: UnboundedReceiverStream<WsMessage>,
50}
51
52impl WsStream {
53    /// Create a new WsStream from an unbounded receiver.
54    pub fn new(receiver: mpsc::UnboundedReceiver<WsMessage>) -> Self {
55        Self {
56            inner: UnboundedReceiverStream::new(receiver),
57        }
58    }
59
60    /// Create a filtered stream that only yields orderbook messages.
61    pub fn orderbooks(self) -> impl Stream<Item = Box<WsStreamMessage<OrderbookData>>> {
62        futures::StreamExt::filter_map(self, |msg| async {
63            match msg {
64                WsMessage::Orderbook(ob) => Some(ob),
65                _ => None,
66            }
67        })
68    }
69
70    /// Create a filtered stream that only yields trade messages.
71    pub fn trades(self) -> impl Stream<Item = Box<WsStreamMessage<Vec<TradeData>>>> {
72        futures::StreamExt::filter_map(self, |msg| async {
73            match msg {
74                WsMessage::Trade(t) => Some(t),
75                _ => None,
76            }
77        })
78    }
79
80    /// Create a filtered stream that only yields ticker messages.
81    pub fn tickers(self) -> impl Stream<Item = Box<WsStreamMessage<TickerData>>> {
82        futures::StreamExt::filter_map(self, |msg| async {
83            match msg {
84                WsMessage::Ticker(t) => Some(t),
85                _ => None,
86            }
87        })
88    }
89
90    /// Create a filtered stream that only yields kline messages.
91    pub fn klines(self) -> impl Stream<Item = Box<WsStreamMessage<Vec<KlineData>>>> {
92        futures::StreamExt::filter_map(self, |msg| async {
93            match msg {
94                WsMessage::Kline(k) => Some(k),
95                _ => None,
96            }
97        })
98    }
99
100    /// Create a filtered stream that only yields liquidation messages.
101    pub fn liquidations(self) -> impl Stream<Item = Box<WsStreamMessage<LiquidationData>>> {
102        futures::StreamExt::filter_map(self, |msg| async {
103            match msg {
104                WsMessage::Liquidation(l) => Some(l),
105                _ => None,
106            }
107        })
108    }
109
110    /// Create a filtered stream that only yields operation responses.
111    pub fn operation_responses(self) -> impl Stream<Item = WsOperationResponse> {
112        futures::StreamExt::filter_map(self, |msg| async {
113            match msg {
114                WsMessage::OperationResponse(r) => Some(r),
115                _ => None,
116            }
117        })
118    }
119
120    /// Create a filtered stream that only yields position updates (private stream).
121    pub fn positions(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<PositionData>>>> {
122        futures::StreamExt::filter_map(self, |msg| async {
123            match msg {
124                WsMessage::Position(p) => Some(p),
125                _ => None,
126            }
127        })
128    }
129
130    /// Create a filtered stream that only yields order updates (private stream).
131    pub fn orders(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<OrderData>>>> {
132        futures::StreamExt::filter_map(self, |msg| async {
133            match msg {
134                WsMessage::Order(o) => Some(o),
135                _ => None,
136            }
137        })
138    }
139
140    /// Create a filtered stream that only yields execution updates (private stream).
141    pub fn executions(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<ExecutionData>>>> {
142        futures::StreamExt::filter_map(self, |msg| async {
143            match msg {
144                WsMessage::Execution(e) => Some(e),
145                _ => None,
146            }
147        })
148    }
149
150    /// Create a filtered stream that only yields fast execution updates (private stream).
151    pub fn executions_fast(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<ExecutionFastData>>>> {
152        futures::StreamExt::filter_map(self, |msg| async {
153            match msg {
154                WsMessage::ExecutionFast(e) => Some(e),
155                _ => None,
156            }
157        })
158    }
159
160    /// Create a filtered stream that only yields wallet updates (private stream).
161    pub fn wallets(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<WalletData>>>> {
162        futures::StreamExt::filter_map(self, |msg| async {
163            match msg {
164                WsMessage::Wallet(w) => Some(w),
165                _ => None,
166            }
167        })
168    }
169
170    /// Create a filtered stream that only yields greeks updates (private stream, options).
171    pub fn greeks(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<GreeksData>>>> {
172        futures::StreamExt::filter_map(self, |msg| async {
173            match msg {
174                WsMessage::Greeks(g) => Some(g),
175                _ => None,
176            }
177        })
178    }
179
180    /// Filter messages for a specific symbol.
181    ///
182    /// This filters orderbook, trade, ticker, kline, and liquidation messages
183    /// that match the given symbol.
184    pub fn for_symbol(self, symbol: impl Into<String>) -> impl Stream<Item = WsMessage> {
185        let symbol = symbol.into();
186        futures::StreamExt::filter(self, move |msg| {
187            let matches = match msg {
188                WsMessage::Orderbook(ob) => ob.data.symbol == symbol,
189                WsMessage::Trade(t) => t.data.first().map(|d| d.symbol == symbol).unwrap_or(false),
190                WsMessage::Ticker(t) => t.data.symbol == symbol,
191                WsMessage::Kline(k) => k
192                    .topic
193                    .split('.')
194                    .last()
195                    .map(|s| s == symbol)
196                    .unwrap_or(false),
197                WsMessage::Liquidation(l) => l.data.symbol == symbol,
198                _ => true, // Pass through other message types
199            };
200            std::future::ready(matches)
201        })
202    }
203
204    /// Filter messages by topic prefix.
205    ///
206    /// This filters stream messages that have a topic starting with the given prefix.
207    pub fn for_topic_prefix(self, prefix: impl Into<String>) -> impl Stream<Item = WsMessage> {
208        let prefix = prefix.into();
209        futures::StreamExt::filter(self, move |msg| {
210            let matches = match msg {
211                WsMessage::Orderbook(ob) => ob.topic.starts_with(&prefix),
212                WsMessage::Trade(t) => t.topic.starts_with(&prefix),
213                WsMessage::Ticker(t) => t.topic.starts_with(&prefix),
214                WsMessage::Kline(k) => k.topic.starts_with(&prefix),
215                WsMessage::Liquidation(l) => l.topic.starts_with(&prefix),
216                _ => true, // Pass through other message types
217            };
218            std::future::ready(matches)
219        })
220    }
221}
222
223impl Stream for WsStream {
224    type Item = WsMessage;
225
226    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227        Pin::new(&mut self.inner).poll_next(cx)
228    }
229
230    fn size_hint(&self) -> (usize, Option<usize>) {
231        self.inner.size_hint()
232    }
233}
234
235/// Extension trait for converting receivers to streams.
236pub trait IntoWsStream {
237    /// Convert into a WsStream.
238    fn into_stream(self) -> WsStream;
239}
240
241impl IntoWsStream for mpsc::UnboundedReceiver<WsMessage> {
242    fn into_stream(self) -> WsStream {
243        WsStream::new(self)
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use futures::StreamExt;
251
252    fn make_trade_message(symbol: &str) -> WsMessage {
253        WsMessage::Trade(Box::new(WsStreamMessage {
254            topic: format!("publicTrade.{}", symbol),
255            update_type: "snapshot".to_string(),
256            ts: 1234567890000,
257            data: vec![TradeData {
258                timestamp: 1234567890000,
259                symbol: symbol.to_string(),
260                side: "Buy".to_string(),
261                size: "0.1".to_string(),
262                price: "50000".to_string(),
263                tick_direction: "ZeroPlusTick".to_string(),
264                trade_id: "test-123".to_string(),
265                is_block_trade: false,
266            }],
267            cts: None,
268        }))
269    }
270
271    fn make_orderbook_message(symbol: &str) -> WsMessage {
272        WsMessage::Orderbook(Box::new(WsStreamMessage {
273            topic: format!("orderbook.50.{}", symbol),
274            update_type: "snapshot".to_string(),
275            ts: 1234567890000,
276            data: OrderbookData {
277                symbol: symbol.to_string(),
278                bids: vec![],
279                asks: vec![],
280                update_id: 1,
281                seq: None,
282            },
283            cts: None,
284        }))
285    }
286
287    #[tokio::test]
288    async fn test_ws_stream_basic() {
289        let (tx, rx) = mpsc::unbounded_channel();
290        let mut stream = WsStream::new(rx);
291
292        if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
293            panic!("Failed to send trade message: {}", err);
294        }
295        if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
296            panic!("Failed to send orderbook message: {}", err);
297        }
298        drop(tx);
299
300        let mut count = 0;
301        while let Some(_msg) = stream.next().await {
302            count += 1;
303        }
304        assert_eq!(count, 2);
305    }
306
307    #[tokio::test]
308    async fn test_ws_stream_trades_filter() {
309        let (tx, rx) = mpsc::unbounded_channel();
310        let stream = WsStream::new(rx);
311
312        if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
313            panic!("Failed to send trade message: {}", err);
314        }
315        if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
316            panic!("Failed to send orderbook message: {}", err);
317        }
318        if let Err(err) = tx.send(make_trade_message("ETHUSDT")) {
319            panic!("Failed to send trade message: {}", err);
320        }
321        drop(tx);
322
323        let trades: Vec<_> = stream.trades().collect().await;
324        assert_eq!(trades.len(), 2);
325        assert_eq!(trades[0].data[0].symbol, "BTCUSDT");
326        assert_eq!(trades[1].data[0].symbol, "ETHUSDT");
327    }
328
329    #[tokio::test]
330    async fn test_ws_stream_orderbooks_filter() {
331        let (tx, rx) = mpsc::unbounded_channel();
332        let stream = WsStream::new(rx);
333
334        if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
335            panic!("Failed to send trade message: {}", err);
336        }
337        if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
338            panic!("Failed to send orderbook message: {}", err);
339        }
340        if let Err(err) = tx.send(make_orderbook_message("ETHUSDT")) {
341            panic!("Failed to send orderbook message: {}", err);
342        }
343        drop(tx);
344
345        let orderbooks: Vec<_> = stream.orderbooks().collect().await;
346        assert_eq!(orderbooks.len(), 2);
347    }
348
349    #[tokio::test]
350    async fn test_ws_stream_for_symbol() {
351        let (tx, rx) = mpsc::unbounded_channel();
352        let stream = WsStream::new(rx);
353
354        if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
355            panic!("Failed to send trade message: {}", err);
356        }
357        if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
358            panic!("Failed to send orderbook message: {}", err);
359        }
360        if let Err(err) = tx.send(make_trade_message("ETHUSDT")) {
361            panic!("Failed to send trade message: {}", err);
362        }
363        if let Err(err) = tx.send(make_orderbook_message("ETHUSDT")) {
364            panic!("Failed to send orderbook message: {}", err);
365        }
366        drop(tx);
367
368        let btc_messages: Vec<_> = stream.for_symbol("BTCUSDT").collect().await;
369        assert_eq!(btc_messages.len(), 2);
370    }
371
372    #[tokio::test]
373    async fn test_into_ws_stream() {
374        let (tx, rx) = mpsc::unbounded_channel();
375        let mut stream = rx.into_stream();
376
377        if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
378            panic!("Failed to send trade message: {}", err);
379        }
380        drop(tx);
381
382        let msg = stream.next().await;
383        assert!(matches!(msg, Some(WsMessage::Trade(_))));
384    }
385}