price_adapter/sources/binance/
websocket.rs

1use crate::mappers::BandStaticMapper;
2use crate::sources::BandStableCoin;
3use crate::types::{Mapper, SettingResponse, Source, WebSocketSource, WebsocketMessage};
4use crate::{error::Error, types::PriceInfo};
5use futures_util::{stream::FusedStream, Stream, StreamExt};
6use price_adapter_raw::{
7    types::WebsocketMessage as WebsocketMessageRaw, BinanceWebsocket as BinanceWebsocketRaw,
8};
9use std::time::Duration;
10use std::{
11    collections::HashMap,
12    pin::Pin,
13    sync::Arc,
14    task::{Context, Poll},
15};
16use tokio::sync::Mutex;
17use tokio::time::sleep;
18
19pub type DefaultBinanceWebsocket = BinanceWebsocket<BandStaticMapper, BandStableCoin>;
20
21/// A generic struct `BinanceWebsocket` parameterized over `Mapper` and `Source` types.
22pub struct BinanceWebsocket<M: Mapper, S: Source> {
23    mapper: M,
24    usdt_source: Arc<S>,
25    usdt_interval: Duration,
26
27    usdt_price: Arc<Mutex<Option<PriceInfo>>>,
28    raw: Arc<Mutex<BinanceWebsocketRaw>>,
29    mapping_back: HashMap<String, String>,
30    ended: bool,
31}
32
33impl<M: Mapper, S: Source> BinanceWebsocket<M, S> {
34    /// Constructor for the `BinanceWebsocket` struct.
35    pub fn new(mapper: M, usdt_source: S, usdt_interval: Duration) -> Self {
36        Self {
37            mapper,
38            usdt_source: Arc::new(usdt_source),
39            usdt_interval,
40            usdt_price: Arc::new(Mutex::new(None)),
41            raw: Arc::new(Mutex::new(BinanceWebsocketRaw::new(
42                "wss://stream.binance.com:9443",
43            ))),
44            mapping_back: HashMap::new(),
45            ended: false,
46        }
47    }
48}
49
50// Implementing the WebSocketSource trait for BinanceWebsocket.
51#[async_trait::async_trait]
52impl<M: Mapper, S: Source> WebSocketSource for BinanceWebsocket<M, S> {
53    /// Asynchronous function to connect to the WebSocket.
54    async fn connect(&mut self) -> Result<(), Error> {
55        let mut locked_raw = self.raw.lock().await;
56        if !locked_raw.is_connected() {
57            locked_raw.connect().await?;
58        }
59        drop(locked_raw);
60
61        let cloned_usdt_source = Arc::clone(&self.usdt_source);
62        let cloned_usdt_price = Arc::clone(&self.usdt_price);
63        let cloned_usdt_interval = self.usdt_interval;
64
65        tokio::spawn(async move {
66            loop {
67                let price_info = cloned_usdt_source.get_price("USDT").await;
68                let mut locked_usdt_price = cloned_usdt_price.lock().await;
69                if let Ok(price) = price_info {
70                    *locked_usdt_price = Some(price);
71                } else {
72                    *locked_usdt_price = None;
73                }
74                drop(locked_usdt_price);
75
76                sleep(cloned_usdt_interval).await;
77            }
78        });
79
80        Ok(())
81    }
82
83    /// Asynchronous function to subscribe to symbols.
84    async fn subscribe(&mut self, symbols: &[&str]) -> Result<u32, Error> {
85        // Retrieve the symbol-to-id mapping from the provided mapper.
86        let mapping = self.mapper.get_mapping().await?;
87
88        for (key, value) in mapping {
89            if let Some(pair) = value.as_str() {
90                self.mapping_back
91                    .insert(pair.to_string().to_uppercase(), key.to_string());
92            }
93        }
94
95        let ids: Vec<&str> = symbols
96            .iter()
97            .filter_map(|&symbol| mapping.get(symbol))
98            .filter_map(|val| val.as_str())
99            .collect();
100
101        if ids.len() != symbols.len() {
102            return Err(Error::UnsupportedSymbol);
103        }
104
105        let mut locked_raw = self.raw.lock().await;
106        locked_raw
107            .subscribe(ids.as_slice())
108            .await
109            .map_err(Error::PriceAdapterRawError)
110    }
111
112    /// Asynchronous function to unsubscribe from symbols.
113    async fn unsubscribe(&mut self, symbols: &[&str]) -> Result<u32, Error> {
114        let ids: Vec<&str> = symbols
115            .iter()
116            .filter_map(|&symbol| self.mapping_back.get(symbol))
117            .map(|string_ref| string_ref.as_str())
118            .collect();
119
120        if ids.len() != symbols.len() {
121            return Err(Error::UnsupportedSymbol);
122        }
123
124        let mut locked_raw = self.raw.lock().await;
125        locked_raw
126            .unsubscribe(ids.as_slice())
127            .await
128            .map_err(Error::PriceAdapterRawError)
129    }
130
131    /// Check if the WebSocket is connected.
132    async fn is_connected(&self) -> bool {
133        let locked_raw = self.raw.lock().await;
134        locked_raw.is_connected()
135    }
136}
137
138// Implementing BinanceWebsocket for specific types (BandStaticMapper, BandStableCoin).
139impl DefaultBinanceWebsocket {
140    /// Constructor for creating a new BinanceWebsocket with default settings.
141    pub fn new_with_default() -> Result<Self, Error> {
142        let mapper = BandStaticMapper::from_source("binance")?;
143        let band_stable_coin = BandStableCoin::new();
144        Ok(Self::new(mapper, band_stable_coin, Duration::from_secs(5)))
145    }
146}
147
148// Implementing Stream for BinanceWebsocket.
149impl<M: Mapper, S: Source> Stream for BinanceWebsocket<M, S> {
150    type Item = Result<WebsocketMessage, Error>;
151
152    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153        if self.ended {
154            return Poll::Ready(None);
155        }
156
157        let Ok(mut locked_raw) = self.raw.try_lock() else {
158            cx.waker().wake_by_ref();
159            return Poll::Pending;
160        };
161
162        let clone_usdt_price = Arc::clone(&self.usdt_price);
163        let Ok(locked_usdt_price) = clone_usdt_price.try_lock() else {
164            cx.waker().wake_by_ref();
165            return Poll::Pending;
166        };
167
168        let Some(usdt_price) = &*locked_usdt_price else {
169            cx.waker().wake_by_ref();
170            return Poll::Pending;
171        };
172
173        match locked_raw.poll_next_unpin(cx) {
174            Poll::Ready(Some(message)) => match message {
175                Ok(WebsocketMessageRaw::PriceInfo(price_info_raw)) => {
176                    tracing::trace!("received price info raw: {}", price_info_raw);
177                    if let Some(symbol) = self.mapping_back.get(&price_info_raw.id) {
178                        Poll::Ready(Some(Ok(WebsocketMessage::PriceInfo(PriceInfo {
179                            symbol: symbol.to_string(),
180                            price: price_info_raw.price / usdt_price.price,
181                            timestamp: price_info_raw.timestamp,
182                        }))))
183                    } else {
184                        // If symbol not found, wake up the waker and return Pending.
185                        tracing::trace!("received symbol doesn't match");
186                        cx.waker().wake_by_ref();
187                        Poll::Pending
188                    }
189                }
190                Ok(WebsocketMessageRaw::SettingResponse(response)) => {
191                    tracing::trace!("received setting response raw: {:?}", response);
192                    Poll::Ready(Some(Ok(WebsocketMessage::SettingResponse(
193                        SettingResponse {
194                            data: response.data,
195                        },
196                    ))))
197                }
198                Err(err) => Poll::Ready(Some(Err(err.into()))),
199            },
200            Poll::Ready(None) => {
201                drop(locked_raw);
202                self.ended = true;
203                Poll::Ready(None)
204            }
205            Poll::Pending => Poll::Pending,
206        }
207    }
208}
209
210// Implementing FusedStream for BinanceWebsocket.
211impl<M: Mapper, S: Source> FusedStream for BinanceWebsocket<M, S> {
212    /// Check if the stream is terminated.
213    fn is_terminated(&self) -> bool {
214        self.ended
215    }
216}