1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4 time::Duration,
5};
6
7use futures_util::{SinkExt, Stream, StreamExt};
8use tokio::{net::TcpStream, time::interval};
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11use super::{
12 auth::ApiCredentials,
13 error::WebSocketError,
14 market::MarketMessage,
15 subscription::{ChannelType, MarketSubscription, UserSubscription, WS_MARKET_URL, WS_USER_URL},
16 user::UserMessage,
17 Channel,
18};
19
20const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 500;
22
23pub struct WebSocket {
46 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
47 channel_type: ChannelType,
48}
49
50impl WebSocket {
51 pub async fn connect_market(asset_ids: Vec<String>) -> Result<Self, WebSocketError> {
72 if asset_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
73 return Err(WebSocketError::InvalidMessage(format!(
74 "Too many subscriptions ({}), max {}",
75 asset_ids.len(),
76 MAX_SUBSCRIPTIONS_PER_CONNECTION
77 )));
78 }
79 let (mut ws, _) = connect_async(WS_MARKET_URL).await?;
80
81 let subscription = MarketSubscription::new(asset_ids);
82 let msg = serde_json::to_string(&subscription)?;
83 ws.send(Message::Text(msg.into())).await?;
84
85 Ok(Self {
86 inner: ws,
87 channel_type: ChannelType::Market,
88 })
89 }
90
91 pub async fn connect_user(
114 market_ids: Vec<String>,
115 credentials: ApiCredentials,
116 ) -> Result<Self, WebSocketError> {
117 if market_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
118 return Err(WebSocketError::InvalidMessage(format!(
119 "Too many subscriptions ({}), max {}",
120 market_ids.len(),
121 MAX_SUBSCRIPTIONS_PER_CONNECTION
122 )));
123 }
124 let (mut ws, _) = connect_async(WS_USER_URL).await?;
125
126 let subscription = UserSubscription::new(market_ids, credentials);
127 let msg = serde_json::to_string(&subscription)?;
128 ws.send(Message::Text(msg.into())).await?;
129
130 Ok(Self {
131 inner: ws,
132 channel_type: ChannelType::User,
133 })
134 }
135
136 pub async fn ping(&mut self) -> Result<(), WebSocketError> {
140 self.inner.send(Message::Text("PING".into())).await?;
141 Ok(())
142 }
143
144 pub async fn close(&mut self) -> Result<(), WebSocketError> {
146 self.inner.close(None).await?;
147 Ok(())
148 }
149
150 pub fn channel_type(&self) -> ChannelType {
152 self.channel_type
153 }
154
155 fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
157 if text == "PONG" || text == "{}" || text.is_empty() {
159 return Ok(None);
160 }
161
162 if !text.contains("event_type") {
164 tracing::trace!("Skipping non-event message: {}", text);
165 return Ok(None);
166 }
167
168 match self.channel_type {
169 ChannelType::Market => {
170 let msg = MarketMessage::from_json(text)?;
171 Ok(Some(Channel::Market(msg)))
172 }
173 ChannelType::User => {
174 let msg = UserMessage::from_json(text)?;
175 Ok(Some(Channel::User(msg)))
176 }
177 }
178 }
179}
180
181impl Stream for WebSocket {
182 type Item = Result<Channel, WebSocketError>;
183
184 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185 loop {
186 match Pin::new(&mut self.inner).poll_next(cx) {
187 Poll::Ready(Some(Ok(msg))) => match msg {
188 Message::Text(text) => match self.parse_message(&text) {
189 Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
190 Ok(None) => continue, Err(e) => return Poll::Ready(Some(Err(e))),
192 },
193 Message::Binary(data) => {
194 if let Ok(text) = String::from_utf8(data.to_vec()) {
196 match self.parse_message(&text) {
197 Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
198 Ok(None) => continue,
199 Err(e) => return Poll::Ready(Some(Err(e))),
200 }
201 }
202 continue;
203 }
204 Message::Ping(_) | Message::Pong(_) => continue,
205 Message::Close(_) => return Poll::Ready(None),
206 Message::Frame(_) => continue,
207 },
208 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
209 Poll::Ready(None) => return Poll::Ready(None),
210 Poll::Pending => return Poll::Pending,
211 }
212 }
213 }
214}
215
216pub struct WebSocketBuilder {
218 market_url: String,
219 user_url: String,
220 ping_interval: Option<Duration>,
221}
222
223impl Default for WebSocketBuilder {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229impl WebSocketBuilder {
230 pub fn new() -> Self {
232 Self {
233 market_url: WS_MARKET_URL.to_string(),
234 user_url: WS_USER_URL.to_string(),
235 ping_interval: None,
236 }
237 }
238
239 pub fn market_url(mut self, url: impl Into<String>) -> Self {
241 self.market_url = url.into();
242 self
243 }
244
245 pub fn user_url(mut self, url: impl Into<String>) -> Self {
247 self.user_url = url.into();
248 self
249 }
250
251 pub fn ping_interval(mut self, interval: Duration) -> Self {
256 self.ping_interval = Some(interval);
257 self
258 }
259
260 pub async fn connect_market(
262 self,
263 asset_ids: Vec<String>,
264 ) -> Result<WebSocketWithPing, WebSocketError> {
265 if asset_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
266 return Err(WebSocketError::InvalidMessage(format!(
267 "Too many subscriptions ({}), max {}",
268 asset_ids.len(),
269 MAX_SUBSCRIPTIONS_PER_CONNECTION
270 )));
271 }
272 let (mut ws, _) = connect_async(&self.market_url).await?;
273
274 let subscription = MarketSubscription::new(asset_ids);
275 let msg = serde_json::to_string(&subscription)?;
276 ws.send(Message::Text(msg.into())).await?;
277
278 Ok(WebSocketWithPing {
279 inner: ws,
280 channel_type: ChannelType::Market,
281 ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
282 })
283 }
284
285 pub async fn connect_user(
287 self,
288 market_ids: Vec<String>,
289 credentials: ApiCredentials,
290 ) -> Result<WebSocketWithPing, WebSocketError> {
291 if market_ids.len() > MAX_SUBSCRIPTIONS_PER_CONNECTION {
292 return Err(WebSocketError::InvalidMessage(format!(
293 "Too many subscriptions ({}), max {}",
294 market_ids.len(),
295 MAX_SUBSCRIPTIONS_PER_CONNECTION
296 )));
297 }
298 let (mut ws, _) = connect_async(&self.user_url).await?;
299
300 let subscription = UserSubscription::new(market_ids, credentials);
301 let msg = serde_json::to_string(&subscription)?;
302 ws.send(Message::Text(msg.into())).await?;
303
304 Ok(WebSocketWithPing {
305 inner: ws,
306 channel_type: ChannelType::User,
307 ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
308 })
309 }
310}
311
312pub struct WebSocketWithPing {
317 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
318 channel_type: ChannelType,
319 ping_interval: Duration,
320}
321
322impl WebSocketWithPing {
323 pub async fn run<F, Fut>(mut self, mut handler: F) -> Result<(), WebSocketError>
356 where
357 F: FnMut(Channel) -> Fut,
358 Fut: std::future::Future<Output = Result<(), WebSocketError>>,
359 {
360 let mut ping_interval = interval(self.ping_interval);
361
362 loop {
363 tokio::select! {
364 _ = ping_interval.tick() => {
365 self.inner.send(Message::Text("PING".into())).await?;
366 }
367 msg = self.inner.next() => {
368 match msg {
369 Some(Ok(Message::Text(text))) => {
370 if text.as_str() == "PONG" {
371 continue;
372 }
373 let channel = self.parse_message(&text)?;
374 if let Some(channel) = channel {
375 handler(channel).await?;
376 }
377 }
378 Some(Ok(Message::Binary(data))) => {
379 if let Ok(text) = String::from_utf8(data.to_vec()) {
380 if text == "PONG" {
381 continue;
382 }
383 let channel = self.parse_message(&text)?;
384 if let Some(channel) = channel {
385 handler(channel).await?;
386 }
387 }
388 }
389 Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) | Some(Ok(Message::Frame(_))) => continue,
390 Some(Ok(Message::Close(_))) => return Ok(()),
391 Some(Err(e)) => return Err(e.into()),
392 None => return Ok(()),
393 }
394 }
395 }
396 }
397 }
398
399 pub fn channel_type(&self) -> ChannelType {
401 self.channel_type
402 }
403
404 fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
406 if text == "PONG" || text == "{}" || text.is_empty() {
408 return Ok(None);
409 }
410
411 if !text.contains("event_type") {
413 tracing::trace!("Skipping non-event message: {}", text);
414 return Ok(None);
415 }
416
417 match self.channel_type {
418 ChannelType::Market => {
419 let msg = MarketMessage::from_json(text)?;
420 Ok(Some(Channel::Market(msg)))
421 }
422 ChannelType::User => {
423 let msg = UserMessage::from_json(text)?;
424 Ok(Some(Channel::User(msg)))
425 }
426 }
427 }
428}