price_adapter/sources/binance/
websocket.rs1use crate::mappers::BandStaticMapper;
2use crate::sources::BandStableCoin;
3use crate::types::{Mapper, SettingResponse, Source, WebSocketSource, WebsocketMessage};
4use crate::{error::Error, types::PriceInfo};
5use futures_util::{stream::FusedStream, Stream, StreamExt};
6use price_adapter_raw::{
7 types::WebsocketMessage as WebsocketMessageRaw, BinanceWebsocket as BinanceWebsocketRaw,
8};
9use std::time::Duration;
10use std::{
11 collections::HashMap,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15};
16use tokio::sync::Mutex;
17use tokio::time::sleep;
18
19pub type DefaultBinanceWebsocket = BinanceWebsocket<BandStaticMapper, BandStableCoin>;
20
21pub struct BinanceWebsocket<M: Mapper, S: Source> {
23 mapper: M,
24 usdt_source: Arc<S>,
25 usdt_interval: Duration,
26
27 usdt_price: Arc<Mutex<Option<PriceInfo>>>,
28 raw: Arc<Mutex<BinanceWebsocketRaw>>,
29 mapping_back: HashMap<String, String>,
30 ended: bool,
31}
32
33impl<M: Mapper, S: Source> BinanceWebsocket<M, S> {
34 pub fn new(mapper: M, usdt_source: S, usdt_interval: Duration) -> Self {
36 Self {
37 mapper,
38 usdt_source: Arc::new(usdt_source),
39 usdt_interval,
40 usdt_price: Arc::new(Mutex::new(None)),
41 raw: Arc::new(Mutex::new(BinanceWebsocketRaw::new(
42 "wss://stream.binance.com:9443",
43 ))),
44 mapping_back: HashMap::new(),
45 ended: false,
46 }
47 }
48}
49
50#[async_trait::async_trait]
52impl<M: Mapper, S: Source> WebSocketSource for BinanceWebsocket<M, S> {
53 async fn connect(&mut self) -> Result<(), Error> {
55 let mut locked_raw = self.raw.lock().await;
56 if !locked_raw.is_connected() {
57 locked_raw.connect().await?;
58 }
59 drop(locked_raw);
60
61 let cloned_usdt_source = Arc::clone(&self.usdt_source);
62 let cloned_usdt_price = Arc::clone(&self.usdt_price);
63 let cloned_usdt_interval = self.usdt_interval;
64
65 tokio::spawn(async move {
66 loop {
67 let price_info = cloned_usdt_source.get_price("USDT").await;
68 let mut locked_usdt_price = cloned_usdt_price.lock().await;
69 if let Ok(price) = price_info {
70 *locked_usdt_price = Some(price);
71 } else {
72 *locked_usdt_price = None;
73 }
74 drop(locked_usdt_price);
75
76 sleep(cloned_usdt_interval).await;
77 }
78 });
79
80 Ok(())
81 }
82
83 async fn subscribe(&mut self, symbols: &[&str]) -> Result<u32, Error> {
85 let mapping = self.mapper.get_mapping().await?;
87
88 for (key, value) in mapping {
89 if let Some(pair) = value.as_str() {
90 self.mapping_back
91 .insert(pair.to_string().to_uppercase(), key.to_string());
92 }
93 }
94
95 let ids: Vec<&str> = symbols
96 .iter()
97 .filter_map(|&symbol| mapping.get(symbol))
98 .filter_map(|val| val.as_str())
99 .collect();
100
101 if ids.len() != symbols.len() {
102 return Err(Error::UnsupportedSymbol);
103 }
104
105 let mut locked_raw = self.raw.lock().await;
106 locked_raw
107 .subscribe(ids.as_slice())
108 .await
109 .map_err(Error::PriceAdapterRawError)
110 }
111
112 async fn unsubscribe(&mut self, symbols: &[&str]) -> Result<u32, Error> {
114 let ids: Vec<&str> = symbols
115 .iter()
116 .filter_map(|&symbol| self.mapping_back.get(symbol))
117 .map(|string_ref| string_ref.as_str())
118 .collect();
119
120 if ids.len() != symbols.len() {
121 return Err(Error::UnsupportedSymbol);
122 }
123
124 let mut locked_raw = self.raw.lock().await;
125 locked_raw
126 .unsubscribe(ids.as_slice())
127 .await
128 .map_err(Error::PriceAdapterRawError)
129 }
130
131 async fn is_connected(&self) -> bool {
133 let locked_raw = self.raw.lock().await;
134 locked_raw.is_connected()
135 }
136}
137
138impl DefaultBinanceWebsocket {
140 pub fn new_with_default() -> Result<Self, Error> {
142 let mapper = BandStaticMapper::from_source("binance")?;
143 let band_stable_coin = BandStableCoin::new();
144 Ok(Self::new(mapper, band_stable_coin, Duration::from_secs(5)))
145 }
146}
147
148impl<M: Mapper, S: Source> Stream for BinanceWebsocket<M, S> {
150 type Item = Result<WebsocketMessage, Error>;
151
152 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
153 if self.ended {
154 return Poll::Ready(None);
155 }
156
157 let Ok(mut locked_raw) = self.raw.try_lock() else {
158 cx.waker().wake_by_ref();
159 return Poll::Pending;
160 };
161
162 let clone_usdt_price = Arc::clone(&self.usdt_price);
163 let Ok(locked_usdt_price) = clone_usdt_price.try_lock() else {
164 cx.waker().wake_by_ref();
165 return Poll::Pending;
166 };
167
168 let Some(usdt_price) = &*locked_usdt_price else {
169 cx.waker().wake_by_ref();
170 return Poll::Pending;
171 };
172
173 match locked_raw.poll_next_unpin(cx) {
174 Poll::Ready(Some(message)) => match message {
175 Ok(WebsocketMessageRaw::PriceInfo(price_info_raw)) => {
176 tracing::trace!("received price info raw: {}", price_info_raw);
177 if let Some(symbol) = self.mapping_back.get(&price_info_raw.id) {
178 Poll::Ready(Some(Ok(WebsocketMessage::PriceInfo(PriceInfo {
179 symbol: symbol.to_string(),
180 price: price_info_raw.price / usdt_price.price,
181 timestamp: price_info_raw.timestamp,
182 }))))
183 } else {
184 tracing::trace!("received symbol doesn't match");
186 cx.waker().wake_by_ref();
187 Poll::Pending
188 }
189 }
190 Ok(WebsocketMessageRaw::SettingResponse(response)) => {
191 tracing::trace!("received setting response raw: {:?}", response);
192 Poll::Ready(Some(Ok(WebsocketMessage::SettingResponse(
193 SettingResponse {
194 data: response.data,
195 },
196 ))))
197 }
198 Err(err) => Poll::Ready(Some(Err(err.into()))),
199 },
200 Poll::Ready(None) => {
201 drop(locked_raw);
202 self.ended = true;
203 Poll::Ready(None)
204 }
205 Poll::Pending => Poll::Pending,
206 }
207 }
208}
209
210impl<M: Mapper, S: Source> FusedStream for BinanceWebsocket<M, S> {
212 fn is_terminated(&self) -> bool {
214 self.ended
215 }
216}