1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4 time::Duration,
5};
6
7use futures_util::{SinkExt, Stream, StreamExt};
8use tokio::{net::TcpStream, time::interval};
9use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
10
11use super::{
12 auth::ApiCredentials,
13 error::WebSocketError,
14 market::MarketMessage,
15 subscription::{ChannelType, MarketSubscription, UserSubscription, WS_MARKET_URL, WS_USER_URL},
16 user::UserMessage,
17 Channel,
18};
19
20pub struct WebSocket {
43 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
44 channel_type: ChannelType,
45}
46
47impl WebSocket {
48 pub async fn connect_market(asset_ids: Vec<String>) -> Result<Self, WebSocketError> {
69 let (mut ws, _) = connect_async(WS_MARKET_URL).await?;
70
71 let subscription = MarketSubscription::new(asset_ids);
72 let msg = serde_json::to_string(&subscription)?;
73 ws.send(Message::Text(msg.into())).await?;
74
75 Ok(Self {
76 inner: ws,
77 channel_type: ChannelType::Market,
78 })
79 }
80
81 pub async fn connect_user(
104 market_ids: Vec<String>,
105 credentials: ApiCredentials,
106 ) -> Result<Self, WebSocketError> {
107 let (mut ws, _) = connect_async(WS_USER_URL).await?;
108
109 let subscription = UserSubscription::new(market_ids, credentials);
110 let msg = serde_json::to_string(&subscription)?;
111 ws.send(Message::Text(msg.into())).await?;
112
113 Ok(Self {
114 inner: ws,
115 channel_type: ChannelType::User,
116 })
117 }
118
119 pub async fn ping(&mut self) -> Result<(), WebSocketError> {
123 self.inner.send(Message::Text("PING".into())).await?;
124 Ok(())
125 }
126
127 pub async fn close(&mut self) -> Result<(), WebSocketError> {
129 self.inner.close(None).await?;
130 Ok(())
131 }
132
133 pub fn channel_type(&self) -> ChannelType {
135 self.channel_type
136 }
137
138 fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
140 if text == "PONG" || text == "{}" || text.is_empty() {
142 return Ok(None);
143 }
144
145 if !text.contains("event_type") {
147 tracing::debug!("Skipping non-event message: {}", text);
148 return Ok(None);
149 }
150
151 match self.channel_type {
152 ChannelType::Market => {
153 let msg = MarketMessage::from_json(text)?;
154 Ok(Some(Channel::Market(msg)))
155 }
156 ChannelType::User => {
157 let msg = UserMessage::from_json(text)?;
158 Ok(Some(Channel::User(msg)))
159 }
160 }
161 }
162}
163
164impl Stream for WebSocket {
165 type Item = Result<Channel, WebSocketError>;
166
167 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
168 loop {
169 match Pin::new(&mut self.inner).poll_next(cx) {
170 Poll::Ready(Some(Ok(msg))) => match msg {
171 Message::Text(text) => match self.parse_message(&text) {
172 Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
173 Ok(None) => continue, Err(e) => return Poll::Ready(Some(Err(e))),
175 },
176 Message::Binary(data) => {
177 if let Ok(text) = String::from_utf8(data.to_vec()) {
179 match self.parse_message(&text) {
180 Ok(Some(channel)) => return Poll::Ready(Some(Ok(channel))),
181 Ok(None) => continue,
182 Err(e) => return Poll::Ready(Some(Err(e))),
183 }
184 }
185 continue;
186 }
187 Message::Ping(_) | Message::Pong(_) => continue,
188 Message::Close(_) => return Poll::Ready(None),
189 Message::Frame(_) => continue,
190 },
191 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e.into()))),
192 Poll::Ready(None) => return Poll::Ready(None),
193 Poll::Pending => return Poll::Pending,
194 }
195 }
196 }
197}
198
199pub struct WebSocketBuilder {
201 market_url: String,
202 user_url: String,
203 ping_interval: Option<Duration>,
204}
205
206impl Default for WebSocketBuilder {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl WebSocketBuilder {
213 pub fn new() -> Self {
215 Self {
216 market_url: WS_MARKET_URL.to_string(),
217 user_url: WS_USER_URL.to_string(),
218 ping_interval: None,
219 }
220 }
221
222 pub fn market_url(mut self, url: impl Into<String>) -> Self {
224 self.market_url = url.into();
225 self
226 }
227
228 pub fn user_url(mut self, url: impl Into<String>) -> Self {
230 self.user_url = url.into();
231 self
232 }
233
234 pub fn ping_interval(mut self, interval: Duration) -> Self {
239 self.ping_interval = Some(interval);
240 self
241 }
242
243 pub async fn connect_market(
245 self,
246 asset_ids: Vec<String>,
247 ) -> Result<WebSocketWithPing, WebSocketError> {
248 let (mut ws, _) = connect_async(&self.market_url).await?;
249
250 let subscription = MarketSubscription::new(asset_ids);
251 let msg = serde_json::to_string(&subscription)?;
252 ws.send(Message::Text(msg.into())).await?;
253
254 Ok(WebSocketWithPing {
255 inner: ws,
256 channel_type: ChannelType::Market,
257 ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
258 })
259 }
260
261 pub async fn connect_user(
263 self,
264 market_ids: Vec<String>,
265 credentials: ApiCredentials,
266 ) -> Result<WebSocketWithPing, WebSocketError> {
267 let (mut ws, _) = connect_async(&self.user_url).await?;
268
269 let subscription = UserSubscription::new(market_ids, credentials);
270 let msg = serde_json::to_string(&subscription)?;
271 ws.send(Message::Text(msg.into())).await?;
272
273 Ok(WebSocketWithPing {
274 inner: ws,
275 channel_type: ChannelType::User,
276 ping_interval: self.ping_interval.unwrap_or(Duration::from_secs(10)),
277 })
278 }
279}
280
281pub struct WebSocketWithPing {
286 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
287 channel_type: ChannelType,
288 ping_interval: Duration,
289}
290
291impl WebSocketWithPing {
292 pub async fn run<F, Fut>(mut self, mut handler: F) -> Result<(), WebSocketError>
325 where
326 F: FnMut(Channel) -> Fut,
327 Fut: std::future::Future<Output = Result<(), WebSocketError>>,
328 {
329 let mut ping_interval = interval(self.ping_interval);
330
331 loop {
332 tokio::select! {
333 _ = ping_interval.tick() => {
334 self.inner.send(Message::Text("PING".into())).await?;
335 }
336 msg = self.inner.next() => {
337 match msg {
338 Some(Ok(Message::Text(text))) => {
339 if text.as_str() == "PONG" {
340 continue;
341 }
342 let channel = self.parse_message(&text)?;
343 if let Some(channel) = channel {
344 handler(channel).await?;
345 }
346 }
347 Some(Ok(Message::Binary(data))) => {
348 if let Ok(text) = String::from_utf8(data.to_vec()) {
349 if text == "PONG" {
350 continue;
351 }
352 let channel = self.parse_message(&text)?;
353 if let Some(channel) = channel {
354 handler(channel).await?;
355 }
356 }
357 }
358 Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) | Some(Ok(Message::Frame(_))) => continue,
359 Some(Ok(Message::Close(_))) => return Ok(()),
360 Some(Err(e)) => return Err(e.into()),
361 None => return Ok(()),
362 }
363 }
364 }
365 }
366 }
367
368 pub fn channel_type(&self) -> ChannelType {
370 self.channel_type
371 }
372
373 fn parse_message(&self, text: &str) -> Result<Option<Channel>, WebSocketError> {
375 if text == "PONG" || text == "{}" || text.is_empty() {
377 return Ok(None);
378 }
379
380 if !text.contains("event_type") {
382 tracing::debug!("Skipping non-event message: {}", text);
383 return Ok(None);
384 }
385
386 match self.channel_type {
387 ChannelType::Market => {
388 let msg = MarketMessage::from_json(text)?;
389 Ok(Some(Channel::Market(msg)))
390 }
391 ChannelType::User => {
392 let msg = UserMessage::from_json(text)?;
393 Ok(Some(Channel::User(msg)))
394 }
395 }
396 }
397}