1use std::pin::Pin;
36use std::task::{Context, Poll};
37
38use futures::Stream;
39use tokio::sync::mpsc;
40use tokio_stream::wrappers::UnboundedReceiverStream;
41
42use crate::ws::types::*;
43
44pub struct WsStream {
49 inner: UnboundedReceiverStream<WsMessage>,
50}
51
52impl WsStream {
53 pub fn new(receiver: mpsc::UnboundedReceiver<WsMessage>) -> Self {
55 Self {
56 inner: UnboundedReceiverStream::new(receiver),
57 }
58 }
59
60 pub fn orderbooks(self) -> impl Stream<Item = Box<WsStreamMessage<OrderbookData>>> {
62 futures::StreamExt::filter_map(self, |msg| async {
63 match msg {
64 WsMessage::Orderbook(ob) => Some(ob),
65 _ => None,
66 }
67 })
68 }
69
70 pub fn trades(self) -> impl Stream<Item = Box<WsStreamMessage<Vec<TradeData>>>> {
72 futures::StreamExt::filter_map(self, |msg| async {
73 match msg {
74 WsMessage::Trade(t) => Some(t),
75 _ => None,
76 }
77 })
78 }
79
80 pub fn tickers(self) -> impl Stream<Item = Box<WsStreamMessage<TickerData>>> {
82 futures::StreamExt::filter_map(self, |msg| async {
83 match msg {
84 WsMessage::Ticker(t) => Some(t),
85 _ => None,
86 }
87 })
88 }
89
90 pub fn klines(self) -> impl Stream<Item = Box<WsStreamMessage<Vec<KlineData>>>> {
92 futures::StreamExt::filter_map(self, |msg| async {
93 match msg {
94 WsMessage::Kline(k) => Some(k),
95 _ => None,
96 }
97 })
98 }
99
100 pub fn liquidations(self) -> impl Stream<Item = Box<WsStreamMessage<LiquidationData>>> {
102 futures::StreamExt::filter_map(self, |msg| async {
103 match msg {
104 WsMessage::Liquidation(l) => Some(l),
105 _ => None,
106 }
107 })
108 }
109
110 pub fn operation_responses(self) -> impl Stream<Item = WsOperationResponse> {
112 futures::StreamExt::filter_map(self, |msg| async {
113 match msg {
114 WsMessage::OperationResponse(r) => Some(r),
115 _ => None,
116 }
117 })
118 }
119
120 pub fn positions(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<PositionData>>>> {
122 futures::StreamExt::filter_map(self, |msg| async {
123 match msg {
124 WsMessage::Position(p) => Some(p),
125 _ => None,
126 }
127 })
128 }
129
130 pub fn orders(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<OrderData>>>> {
132 futures::StreamExt::filter_map(self, |msg| async {
133 match msg {
134 WsMessage::Order(o) => Some(o),
135 _ => None,
136 }
137 })
138 }
139
140 pub fn executions(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<ExecutionData>>>> {
142 futures::StreamExt::filter_map(self, |msg| async {
143 match msg {
144 WsMessage::Execution(e) => Some(e),
145 _ => None,
146 }
147 })
148 }
149
150 pub fn executions_fast(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<ExecutionFastData>>>> {
152 futures::StreamExt::filter_map(self, |msg| async {
153 match msg {
154 WsMessage::ExecutionFast(e) => Some(e),
155 _ => None,
156 }
157 })
158 }
159
160 pub fn wallets(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<WalletData>>>> {
162 futures::StreamExt::filter_map(self, |msg| async {
163 match msg {
164 WsMessage::Wallet(w) => Some(w),
165 _ => None,
166 }
167 })
168 }
169
170 pub fn greeks(self) -> impl Stream<Item = Box<WsPrivateMessage<Vec<GreeksData>>>> {
172 futures::StreamExt::filter_map(self, |msg| async {
173 match msg {
174 WsMessage::Greeks(g) => Some(g),
175 _ => None,
176 }
177 })
178 }
179
180 pub fn for_symbol(self, symbol: impl Into<String>) -> impl Stream<Item = WsMessage> {
185 let symbol = symbol.into();
186 futures::StreamExt::filter(self, move |msg| {
187 let matches = match msg {
188 WsMessage::Orderbook(ob) => ob.data.symbol == symbol,
189 WsMessage::Trade(t) => t.data.first().map(|d| d.symbol == symbol).unwrap_or(false),
190 WsMessage::Ticker(t) => t.data.symbol == symbol,
191 WsMessage::Kline(k) => k
192 .topic
193 .split('.')
194 .last()
195 .map(|s| s == symbol)
196 .unwrap_or(false),
197 WsMessage::Liquidation(l) => l.data.symbol == symbol,
198 _ => true, };
200 std::future::ready(matches)
201 })
202 }
203
204 pub fn for_topic_prefix(self, prefix: impl Into<String>) -> impl Stream<Item = WsMessage> {
208 let prefix = prefix.into();
209 futures::StreamExt::filter(self, move |msg| {
210 let matches = match msg {
211 WsMessage::Orderbook(ob) => ob.topic.starts_with(&prefix),
212 WsMessage::Trade(t) => t.topic.starts_with(&prefix),
213 WsMessage::Ticker(t) => t.topic.starts_with(&prefix),
214 WsMessage::Kline(k) => k.topic.starts_with(&prefix),
215 WsMessage::Liquidation(l) => l.topic.starts_with(&prefix),
216 _ => true, };
218 std::future::ready(matches)
219 })
220 }
221}
222
223impl Stream for WsStream {
224 type Item = WsMessage;
225
226 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227 Pin::new(&mut self.inner).poll_next(cx)
228 }
229
230 fn size_hint(&self) -> (usize, Option<usize>) {
231 self.inner.size_hint()
232 }
233}
234
235pub trait IntoWsStream {
237 fn into_stream(self) -> WsStream;
239}
240
241impl IntoWsStream for mpsc::UnboundedReceiver<WsMessage> {
242 fn into_stream(self) -> WsStream {
243 WsStream::new(self)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use futures::StreamExt;
251
252 fn make_trade_message(symbol: &str) -> WsMessage {
253 WsMessage::Trade(Box::new(WsStreamMessage {
254 topic: format!("publicTrade.{}", symbol),
255 update_type: "snapshot".to_string(),
256 ts: 1234567890000,
257 data: vec![TradeData {
258 timestamp: 1234567890000,
259 symbol: symbol.to_string(),
260 side: "Buy".to_string(),
261 size: "0.1".to_string(),
262 price: "50000".to_string(),
263 tick_direction: "ZeroPlusTick".to_string(),
264 trade_id: "test-123".to_string(),
265 is_block_trade: false,
266 }],
267 cts: None,
268 }))
269 }
270
271 fn make_orderbook_message(symbol: &str) -> WsMessage {
272 WsMessage::Orderbook(Box::new(WsStreamMessage {
273 topic: format!("orderbook.50.{}", symbol),
274 update_type: "snapshot".to_string(),
275 ts: 1234567890000,
276 data: OrderbookData {
277 symbol: symbol.to_string(),
278 bids: vec![],
279 asks: vec![],
280 update_id: 1,
281 seq: None,
282 },
283 cts: None,
284 }))
285 }
286
287 #[tokio::test]
288 async fn test_ws_stream_basic() {
289 let (tx, rx) = mpsc::unbounded_channel();
290 let mut stream = WsStream::new(rx);
291
292 if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
293 panic!("Failed to send trade message: {}", err);
294 }
295 if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
296 panic!("Failed to send orderbook message: {}", err);
297 }
298 drop(tx);
299
300 let mut count = 0;
301 while let Some(_msg) = stream.next().await {
302 count += 1;
303 }
304 assert_eq!(count, 2);
305 }
306
307 #[tokio::test]
308 async fn test_ws_stream_trades_filter() {
309 let (tx, rx) = mpsc::unbounded_channel();
310 let stream = WsStream::new(rx);
311
312 if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
313 panic!("Failed to send trade message: {}", err);
314 }
315 if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
316 panic!("Failed to send orderbook message: {}", err);
317 }
318 if let Err(err) = tx.send(make_trade_message("ETHUSDT")) {
319 panic!("Failed to send trade message: {}", err);
320 }
321 drop(tx);
322
323 let trades: Vec<_> = stream.trades().collect().await;
324 assert_eq!(trades.len(), 2);
325 assert_eq!(trades[0].data[0].symbol, "BTCUSDT");
326 assert_eq!(trades[1].data[0].symbol, "ETHUSDT");
327 }
328
329 #[tokio::test]
330 async fn test_ws_stream_orderbooks_filter() {
331 let (tx, rx) = mpsc::unbounded_channel();
332 let stream = WsStream::new(rx);
333
334 if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
335 panic!("Failed to send trade message: {}", err);
336 }
337 if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
338 panic!("Failed to send orderbook message: {}", err);
339 }
340 if let Err(err) = tx.send(make_orderbook_message("ETHUSDT")) {
341 panic!("Failed to send orderbook message: {}", err);
342 }
343 drop(tx);
344
345 let orderbooks: Vec<_> = stream.orderbooks().collect().await;
346 assert_eq!(orderbooks.len(), 2);
347 }
348
349 #[tokio::test]
350 async fn test_ws_stream_for_symbol() {
351 let (tx, rx) = mpsc::unbounded_channel();
352 let stream = WsStream::new(rx);
353
354 if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
355 panic!("Failed to send trade message: {}", err);
356 }
357 if let Err(err) = tx.send(make_orderbook_message("BTCUSDT")) {
358 panic!("Failed to send orderbook message: {}", err);
359 }
360 if let Err(err) = tx.send(make_trade_message("ETHUSDT")) {
361 panic!("Failed to send trade message: {}", err);
362 }
363 if let Err(err) = tx.send(make_orderbook_message("ETHUSDT")) {
364 panic!("Failed to send orderbook message: {}", err);
365 }
366 drop(tx);
367
368 let btc_messages: Vec<_> = stream.for_symbol("BTCUSDT").collect().await;
369 assert_eq!(btc_messages.len(), 2);
370 }
371
372 #[tokio::test]
373 async fn test_into_ws_stream() {
374 let (tx, rx) = mpsc::unbounded_channel();
375 let mut stream = rx.into_stream();
376
377 if let Err(err) = tx.send(make_trade_message("BTCUSDT")) {
378 panic!("Failed to send trade message: {}", err);
379 }
380 drop(tx);
381
382 let msg = stream.next().await;
383 assert!(matches!(msg, Some(WsMessage::Trade(_))));
384 }
385}