hyperliquid_rust_sdk_abrkn/ws/
ws_manager.rs

1use crate::{
2    prelude::*,
3    ws::message_types::{AllMids, Candle, L2Book, OrderUpdates, Trades, User},
4    Error, Notification, UserFills, UserFundings, UserNonFundingLedgerUpdates,
5};
6use futures_util::{stream::SplitSink, SinkExt, StreamExt};
7use log::{error, warn};
8use serde::{Deserialize, Serialize};
9use std::{
10    collections::HashMap,
11    sync::{
12        atomic::{AtomicBool, Ordering},
13        Arc,
14    },
15    time::Duration,
16};
17use tokio::{
18    net::TcpStream,
19    runtime::Runtime,
20    spawn,
21    sync::{mpsc::UnboundedSender, Mutex},
22    task::JoinHandle,
23    time,
24};
25use tokio_tungstenite::{
26    connect_async,
27    tungstenite::{self, protocol},
28    MaybeTlsStream, WebSocketStream,
29};
30
31use ethers::types::H160;
32
33#[derive(Debug)]
34struct SubscriptionData {
35    sending_channel: UnboundedSender<Message>,
36    subscription_id: u32,
37}
38pub(crate) struct WsManager {
39    stop_flag: Arc<AtomicBool>,
40    reader_handle: Option<JoinHandle<()>>,
41    ping_handle: Option<JoinHandle<()>>,
42    writer: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, protocol::Message>>>,
43    subscriptions: Arc<Mutex<HashMap<String, Vec<SubscriptionData>>>>,
44    subscription_id: u32,
45    subscription_identifiers: HashMap<u32, String>,
46}
47
48#[derive(Serialize, Deserialize, Debug, Clone)]
49#[serde(tag = "type")]
50#[serde(rename_all = "camelCase")]
51pub enum Subscription {
52    AllMids,
53    Trades { coin: String },
54    L2Book { coin: String },
55    UserEvents { user: H160 },
56    UserFills { user: H160 },
57    Candle { coin: String, interval: String },
58    OrderUpdates { user: H160 },
59    UserFundings { user: H160 },
60    UserNonFundingLedgerUpdates { user: H160 },
61    Notification { user: H160 },
62}
63
64#[derive(Deserialize, Clone, Debug)]
65#[serde(tag = "channel")]
66#[serde(rename_all = "camelCase")]
67pub enum Message {
68    NoData,
69    HyperliquidError(String),
70    AllMids(AllMids),
71    Trades(Trades),
72    L2Book(L2Book),
73    User(User),
74    UserFills(UserFills),
75    Candle(Candle),
76    SubscriptionResponse,
77    OrderUpdates(OrderUpdates),
78    UserFundings(UserFundings),
79    UserNonFundingLedgerUpdates(UserNonFundingLedgerUpdates),
80    Notification(Notification),
81    Pong,
82}
83
84#[derive(Serialize)]
85pub struct SubscriptionSendData<'a> {
86    pub method: &'static str,
87    pub subscription: &'a serde_json::Value,
88}
89
90#[derive(Serialize)]
91pub(crate) struct Ping {
92    pub(crate) method: &'static str,
93}
94
95impl WsManager {
96    const SEND_PING_INTERVAL: u64 = 50;
97
98    pub(crate) async fn new(url: String) -> Result<WsManager> {
99        let stop_flag = Arc::new(AtomicBool::new(false));
100
101        let (ws_stream, _) = connect_async(url.clone())
102            .await
103            .map_err(|e| Error::Websocket(e.to_string()))?;
104
105        let (writer, mut reader) = ws_stream.split();
106        let writer = Arc::new(Mutex::new(writer));
107
108        let subscriptions_map: HashMap<String, Vec<SubscriptionData>> = HashMap::new();
109        let subscriptions = Arc::new(Mutex::new(subscriptions_map));
110        let subscriptions_copy = Arc::clone(&subscriptions);
111
112        let reader_handle = {
113            let stop_flag = Arc::clone(&stop_flag);
114            let reader_fut = async move {
115                // TODO: reconnect
116                while !stop_flag.load(Ordering::Relaxed) {
117                    let data = reader.next().await;
118                    if let Err(err) =
119                        WsManager::parse_and_send_data(data, &subscriptions_copy).await
120                    {
121                        error!("Error processing data received by WS manager reader: {err}");
122                    }
123                }
124                warn!("ws message reader task stopped");
125            };
126            spawn(reader_fut)
127        };
128
129        let ping_handle = {
130            let stop_flag = Arc::clone(&stop_flag);
131            let writer = Arc::clone(&writer);
132            let ping_fut = async move {
133                while !stop_flag.load(Ordering::Relaxed) {
134                    match serde_json::to_string(&Ping { method: "ping" }) {
135                        Ok(payload) => {
136                            let mut writer = writer.lock().await;
137                            if let Err(err) = writer.send(protocol::Message::Text(payload)).await {
138                                error!("Error pinging server: {err}")
139                            }
140                        }
141                        Err(err) => error!("Error serializing ping message: {err}"),
142                    }
143                    time::sleep(Duration::from_secs(Self::SEND_PING_INTERVAL)).await;
144                }
145                warn!("ws ping task stopped");
146            };
147            spawn(ping_fut)
148        };
149
150        Ok(WsManager {
151            stop_flag,
152            reader_handle: Some(reader_handle),
153            ping_handle: Some(ping_handle),
154            writer,
155            subscriptions,
156            subscription_id: 0,
157            subscription_identifiers: HashMap::new(),
158        })
159    }
160
161    pub(crate) fn get_identifier(message: &Message) -> Result<String> {
162        match message {
163            Message::AllMids(_) => serde_json::to_string(&Subscription::AllMids)
164                .map_err(|e| Error::JsonParse(e.to_string())),
165            Message::User(_) => Ok("userEvents".to_string()),
166            Message::UserFills(_) => Ok("userFills".to_string()),
167            Message::Trades(trades) => {
168                if trades.data.is_empty() {
169                    Ok(String::default())
170                } else {
171                    serde_json::to_string(&Subscription::Trades {
172                        coin: trades.data[0].coin.clone(),
173                    })
174                    .map_err(|e| Error::JsonParse(e.to_string()))
175                }
176            }
177            Message::L2Book(l2_book) => serde_json::to_string(&Subscription::L2Book {
178                coin: l2_book.data.coin.clone(),
179            })
180            .map_err(|e| Error::JsonParse(e.to_string())),
181            Message::Candle(candle) => serde_json::to_string(&Subscription::Candle {
182                coin: candle.data.coin.clone(),
183                interval: candle.data.interval.clone(),
184            })
185            .map_err(|e| Error::JsonParse(e.to_string())),
186            Message::OrderUpdates(_) => Ok("orderUpdates".to_string()),
187            Message::UserFundings(_) => Ok("userFundings".to_string()),
188            Message::UserNonFundingLedgerUpdates(user_non_funding_ledger_updates) => {
189                serde_json::to_string(&Subscription::UserNonFundingLedgerUpdates {
190                    user: user_non_funding_ledger_updates.data.user,
191                })
192                .map_err(|e| Error::JsonParse(e.to_string()))
193            }
194            Message::Notification(_) => Ok("notification".to_string()),
195            Message::SubscriptionResponse | Message::Pong => Ok(String::default()),
196            Message::NoData => Ok("".to_string()),
197            Message::HyperliquidError(err) => Ok(format!("hyperliquid error: {err:?}")),
198        }
199    }
200
201    async fn parse_and_send_data(
202        data: Option<std::result::Result<protocol::Message, tungstenite::Error>>,
203        subscriptions: &Arc<Mutex<HashMap<String, Vec<SubscriptionData>>>>,
204    ) -> Result<()> {
205        let Some(data) = data else {
206            return WsManager::send_to_all_subscriptions(subscriptions, Message::NoData).await;
207        };
208
209        match data {
210            Ok(data) => match data.into_text() {
211                Ok(data) => {
212                    if !data.starts_with('{') {
213                        return Ok(());
214                    }
215                    let message = serde_json::from_str::<Message>(&data)
216                        .map_err(|e| Error::JsonParse(e.to_string()))?;
217                    let identifier = WsManager::get_identifier(&message)?;
218                    if identifier.is_empty() {
219                        return Ok(());
220                    }
221
222                    let mut subscriptions = subscriptions.lock().await;
223                    let mut res = Ok(());
224                    if let Some(subscription_datas) = subscriptions.get_mut(&identifier) {
225                        for subscription_data in subscription_datas {
226                            if let Err(e) = subscription_data
227                                .sending_channel
228                                .send(message.clone())
229                                .map_err(|e| Error::WsSend(e.to_string()))
230                            {
231                                res = Err(e);
232                            }
233                        }
234                    }
235                    res
236                }
237                Err(err) => {
238                    let error = Error::ReaderTextConversion(err.to_string());
239                    Ok(WsManager::send_to_all_subscriptions(
240                        subscriptions,
241                        Message::HyperliquidError(error.to_string()),
242                    )
243                    .await?)
244                }
245            },
246            Err(err) => {
247                let error = Error::GenericReader(err.to_string());
248                Ok(WsManager::send_to_all_subscriptions(
249                    subscriptions,
250                    Message::HyperliquidError(error.to_string()),
251                )
252                .await?)
253            }
254        }
255    }
256
257    async fn send_to_all_subscriptions(
258        subscriptions: &Arc<Mutex<HashMap<String, Vec<SubscriptionData>>>>,
259        message: Message,
260    ) -> Result<()> {
261        let mut subscriptions = subscriptions.lock().await;
262        let mut res = Ok(());
263        for subscription_datas in subscriptions.values_mut() {
264            for subscription_data in subscription_datas {
265                if let Err(e) = subscription_data
266                    .sending_channel
267                    .send(message.clone())
268                    .map_err(|e| Error::WsSend(e.to_string()))
269                {
270                    res = Err(e);
271                }
272            }
273        }
274        res
275    }
276
277    pub(crate) async fn add_subscription(
278        &mut self,
279        identifier: String,
280        sending_channel: UnboundedSender<Message>,
281    ) -> Result<u32> {
282        let mut subscriptions = self.subscriptions.lock().await;
283
284        let identifier_entry = if let Subscription::UserEvents { user: _ } =
285            serde_json::from_str::<Subscription>(&identifier)
286                .map_err(|e| Error::JsonParse(e.to_string()))?
287        {
288            "userEvents".to_string()
289        } else if let Subscription::OrderUpdates { user: _ } =
290            serde_json::from_str::<Subscription>(&identifier)
291                .map_err(|e| Error::JsonParse(e.to_string()))?
292        {
293            "orderUpdates".to_string()
294        } else {
295            identifier.clone()
296        };
297        let subscriptions = subscriptions
298            .entry(identifier_entry.clone())
299            .or_insert(Vec::new());
300
301        if !subscriptions.is_empty() && identifier_entry.eq("userEvents") {
302            return Err(Error::UserEvents);
303        }
304
305        if subscriptions.is_empty() {
306            let payload = serde_json::to_string(&SubscriptionSendData {
307                method: "subscribe",
308                subscription: &serde_json::from_str::<serde_json::Value>(&identifier)
309                    .map_err(|e| Error::JsonParse(e.to_string()))?,
310            })
311            .map_err(|e| Error::JsonParse(e.to_string()))?;
312
313            let mut writer = self.writer.lock().await;
314            writer
315                .send(protocol::Message::Text(payload))
316                .await
317                .map_err(|e| Error::Websocket(e.to_string()))?;
318        }
319
320        let subscription_id = self.subscription_id;
321        self.subscription_identifiers
322            .insert(subscription_id, identifier.clone());
323        subscriptions.push(SubscriptionData {
324            sending_channel,
325            subscription_id,
326        });
327
328        self.subscription_id += 1;
329        Ok(subscription_id)
330    }
331
332    pub(crate) async fn remove_subscription(&mut self, subscription_id: u32) -> Result<()> {
333        let identifier = self
334            .subscription_identifiers
335            .get(&subscription_id)
336            .ok_or(Error::SubscriptionNotFound)?
337            .clone();
338
339        let identifier_entry = if let Subscription::UserEvents { user: _ } =
340            serde_json::from_str::<Subscription>(&identifier)
341                .map_err(|e| Error::JsonParse(e.to_string()))?
342        {
343            "userEvents".to_string()
344        } else if let Subscription::OrderUpdates { user: _ } =
345            serde_json::from_str::<Subscription>(&identifier)
346                .map_err(|e| Error::JsonParse(e.to_string()))?
347        {
348            "orderUpdates".to_string()
349        } else {
350            identifier.clone()
351        };
352
353        self.subscription_identifiers.remove(&subscription_id);
354
355        let mut subscriptions = self.subscriptions.lock().await;
356
357        let subscriptions = subscriptions
358            .get_mut(&identifier_entry)
359            .ok_or(Error::SubscriptionNotFound)?;
360        let index = subscriptions
361            .iter()
362            .position(|subscription_data| subscription_data.subscription_id == subscription_id)
363            .ok_or(Error::SubscriptionNotFound)?;
364        subscriptions.remove(index);
365
366        if subscriptions.is_empty() {
367            let payload = serde_json::to_string(&SubscriptionSendData {
368                method: "unsubscribe",
369                subscription: &serde_json::from_str::<serde_json::Value>(&identifier)
370                    .map_err(|e| Error::JsonParse(e.to_string()))?,
371            })
372            .map_err(|e| Error::JsonParse(e.to_string()))?;
373
374            let mut writer = self.writer.lock().await;
375            writer
376                .send(protocol::Message::Text(payload))
377                .await
378                .map_err(|e| Error::Websocket(e.to_string()))?;
379        }
380        Ok(())
381    }
382}
383
384impl Drop for WsManager {
385    fn drop(&mut self) {
386        self.stop_flag.store(true, Ordering::Relaxed);
387
388        let rt = Runtime::new().unwrap();
389
390        if let Some(task) = self.reader_handle.take() {
391            rt.block_on(task).unwrap();
392        }
393
394        if let Some(task) = self.ping_handle.take() {
395            rt.block_on(task).unwrap();
396        }
397    }
398}