hyperliquid_rust_sdk_abrkn/ws/
ws_manager.rs1use crate::{
2 prelude::*,
3 ws::message_types::{AllMids, Candle, L2Book, OrderUpdates, Trades, User},
4 Error, Notification, UserFills, UserFundings, UserNonFundingLedgerUpdates,
5};
6use futures_util::{stream::SplitSink, SinkExt, StreamExt};
7use log::{error, warn};
8use serde::{Deserialize, Serialize};
9use std::{
10 collections::HashMap,
11 sync::{
12 atomic::{AtomicBool, Ordering},
13 Arc,
14 },
15 time::Duration,
16};
17use tokio::{
18 net::TcpStream,
19 runtime::Runtime,
20 spawn,
21 sync::{mpsc::UnboundedSender, Mutex},
22 task::JoinHandle,
23 time,
24};
25use tokio_tungstenite::{
26 connect_async,
27 tungstenite::{self, protocol},
28 MaybeTlsStream, WebSocketStream,
29};
30
31use ethers::types::H160;
32
33#[derive(Debug)]
34struct SubscriptionData {
35 sending_channel: UnboundedSender<Message>,
36 subscription_id: u32,
37}
38pub(crate) struct WsManager {
39 stop_flag: Arc<AtomicBool>,
40 reader_handle: Option<JoinHandle<()>>,
41 ping_handle: Option<JoinHandle<()>>,
42 writer: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, protocol::Message>>>,
43 subscriptions: Arc<Mutex<HashMap<String, Vec<SubscriptionData>>>>,
44 subscription_id: u32,
45 subscription_identifiers: HashMap<u32, String>,
46}
47
48#[derive(Serialize, Deserialize, Debug, Clone)]
49#[serde(tag = "type")]
50#[serde(rename_all = "camelCase")]
51pub enum Subscription {
52 AllMids,
53 Trades { coin: String },
54 L2Book { coin: String },
55 UserEvents { user: H160 },
56 UserFills { user: H160 },
57 Candle { coin: String, interval: String },
58 OrderUpdates { user: H160 },
59 UserFundings { user: H160 },
60 UserNonFundingLedgerUpdates { user: H160 },
61 Notification { user: H160 },
62}
63
64#[derive(Deserialize, Clone, Debug)]
65#[serde(tag = "channel")]
66#[serde(rename_all = "camelCase")]
67pub enum Message {
68 NoData,
69 HyperliquidError(String),
70 AllMids(AllMids),
71 Trades(Trades),
72 L2Book(L2Book),
73 User(User),
74 UserFills(UserFills),
75 Candle(Candle),
76 SubscriptionResponse,
77 OrderUpdates(OrderUpdates),
78 UserFundings(UserFundings),
79 UserNonFundingLedgerUpdates(UserNonFundingLedgerUpdates),
80 Notification(Notification),
81 Pong,
82}
83
84#[derive(Serialize)]
85pub struct SubscriptionSendData<'a> {
86 pub method: &'static str,
87 pub subscription: &'a serde_json::Value,
88}
89
90#[derive(Serialize)]
91pub(crate) struct Ping {
92 pub(crate) method: &'static str,
93}
94
95impl WsManager {
96 const SEND_PING_INTERVAL: u64 = 50;
97
98 pub(crate) async fn new(url: String) -> Result<WsManager> {
99 let stop_flag = Arc::new(AtomicBool::new(false));
100
101 let (ws_stream, _) = connect_async(url.clone())
102 .await
103 .map_err(|e| Error::Websocket(e.to_string()))?;
104
105 let (writer, mut reader) = ws_stream.split();
106 let writer = Arc::new(Mutex::new(writer));
107
108 let subscriptions_map: HashMap<String, Vec<SubscriptionData>> = HashMap::new();
109 let subscriptions = Arc::new(Mutex::new(subscriptions_map));
110 let subscriptions_copy = Arc::clone(&subscriptions);
111
112 let reader_handle = {
113 let stop_flag = Arc::clone(&stop_flag);
114 let reader_fut = async move {
115 while !stop_flag.load(Ordering::Relaxed) {
117 let data = reader.next().await;
118 if let Err(err) =
119 WsManager::parse_and_send_data(data, &subscriptions_copy).await
120 {
121 error!("Error processing data received by WS manager reader: {err}");
122 }
123 }
124 warn!("ws message reader task stopped");
125 };
126 spawn(reader_fut)
127 };
128
129 let ping_handle = {
130 let stop_flag = Arc::clone(&stop_flag);
131 let writer = Arc::clone(&writer);
132 let ping_fut = async move {
133 while !stop_flag.load(Ordering::Relaxed) {
134 match serde_json::to_string(&Ping { method: "ping" }) {
135 Ok(payload) => {
136 let mut writer = writer.lock().await;
137 if let Err(err) = writer.send(protocol::Message::Text(payload)).await {
138 error!("Error pinging server: {err}")
139 }
140 }
141 Err(err) => error!("Error serializing ping message: {err}"),
142 }
143 time::sleep(Duration::from_secs(Self::SEND_PING_INTERVAL)).await;
144 }
145 warn!("ws ping task stopped");
146 };
147 spawn(ping_fut)
148 };
149
150 Ok(WsManager {
151 stop_flag,
152 reader_handle: Some(reader_handle),
153 ping_handle: Some(ping_handle),
154 writer,
155 subscriptions,
156 subscription_id: 0,
157 subscription_identifiers: HashMap::new(),
158 })
159 }
160
161 pub(crate) fn get_identifier(message: &Message) -> Result<String> {
162 match message {
163 Message::AllMids(_) => serde_json::to_string(&Subscription::AllMids)
164 .map_err(|e| Error::JsonParse(e.to_string())),
165 Message::User(_) => Ok("userEvents".to_string()),
166 Message::UserFills(_) => Ok("userFills".to_string()),
167 Message::Trades(trades) => {
168 if trades.data.is_empty() {
169 Ok(String::default())
170 } else {
171 serde_json::to_string(&Subscription::Trades {
172 coin: trades.data[0].coin.clone(),
173 })
174 .map_err(|e| Error::JsonParse(e.to_string()))
175 }
176 }
177 Message::L2Book(l2_book) => serde_json::to_string(&Subscription::L2Book {
178 coin: l2_book.data.coin.clone(),
179 })
180 .map_err(|e| Error::JsonParse(e.to_string())),
181 Message::Candle(candle) => serde_json::to_string(&Subscription::Candle {
182 coin: candle.data.coin.clone(),
183 interval: candle.data.interval.clone(),
184 })
185 .map_err(|e| Error::JsonParse(e.to_string())),
186 Message::OrderUpdates(_) => Ok("orderUpdates".to_string()),
187 Message::UserFundings(_) => Ok("userFundings".to_string()),
188 Message::UserNonFundingLedgerUpdates(user_non_funding_ledger_updates) => {
189 serde_json::to_string(&Subscription::UserNonFundingLedgerUpdates {
190 user: user_non_funding_ledger_updates.data.user,
191 })
192 .map_err(|e| Error::JsonParse(e.to_string()))
193 }
194 Message::Notification(_) => Ok("notification".to_string()),
195 Message::SubscriptionResponse | Message::Pong => Ok(String::default()),
196 Message::NoData => Ok("".to_string()),
197 Message::HyperliquidError(err) => Ok(format!("hyperliquid error: {err:?}")),
198 }
199 }
200
201 async fn parse_and_send_data(
202 data: Option<std::result::Result<protocol::Message, tungstenite::Error>>,
203 subscriptions: &Arc<Mutex<HashMap<String, Vec<SubscriptionData>>>>,
204 ) -> Result<()> {
205 let Some(data) = data else {
206 return WsManager::send_to_all_subscriptions(subscriptions, Message::NoData).await;
207 };
208
209 match data {
210 Ok(data) => match data.into_text() {
211 Ok(data) => {
212 if !data.starts_with('{') {
213 return Ok(());
214 }
215 let message = serde_json::from_str::<Message>(&data)
216 .map_err(|e| Error::JsonParse(e.to_string()))?;
217 let identifier = WsManager::get_identifier(&message)?;
218 if identifier.is_empty() {
219 return Ok(());
220 }
221
222 let mut subscriptions = subscriptions.lock().await;
223 let mut res = Ok(());
224 if let Some(subscription_datas) = subscriptions.get_mut(&identifier) {
225 for subscription_data in subscription_datas {
226 if let Err(e) = subscription_data
227 .sending_channel
228 .send(message.clone())
229 .map_err(|e| Error::WsSend(e.to_string()))
230 {
231 res = Err(e);
232 }
233 }
234 }
235 res
236 }
237 Err(err) => {
238 let error = Error::ReaderTextConversion(err.to_string());
239 Ok(WsManager::send_to_all_subscriptions(
240 subscriptions,
241 Message::HyperliquidError(error.to_string()),
242 )
243 .await?)
244 }
245 },
246 Err(err) => {
247 let error = Error::GenericReader(err.to_string());
248 Ok(WsManager::send_to_all_subscriptions(
249 subscriptions,
250 Message::HyperliquidError(error.to_string()),
251 )
252 .await?)
253 }
254 }
255 }
256
257 async fn send_to_all_subscriptions(
258 subscriptions: &Arc<Mutex<HashMap<String, Vec<SubscriptionData>>>>,
259 message: Message,
260 ) -> Result<()> {
261 let mut subscriptions = subscriptions.lock().await;
262 let mut res = Ok(());
263 for subscription_datas in subscriptions.values_mut() {
264 for subscription_data in subscription_datas {
265 if let Err(e) = subscription_data
266 .sending_channel
267 .send(message.clone())
268 .map_err(|e| Error::WsSend(e.to_string()))
269 {
270 res = Err(e);
271 }
272 }
273 }
274 res
275 }
276
277 pub(crate) async fn add_subscription(
278 &mut self,
279 identifier: String,
280 sending_channel: UnboundedSender<Message>,
281 ) -> Result<u32> {
282 let mut subscriptions = self.subscriptions.lock().await;
283
284 let identifier_entry = if let Subscription::UserEvents { user: _ } =
285 serde_json::from_str::<Subscription>(&identifier)
286 .map_err(|e| Error::JsonParse(e.to_string()))?
287 {
288 "userEvents".to_string()
289 } else if let Subscription::OrderUpdates { user: _ } =
290 serde_json::from_str::<Subscription>(&identifier)
291 .map_err(|e| Error::JsonParse(e.to_string()))?
292 {
293 "orderUpdates".to_string()
294 } else {
295 identifier.clone()
296 };
297 let subscriptions = subscriptions
298 .entry(identifier_entry.clone())
299 .or_insert(Vec::new());
300
301 if !subscriptions.is_empty() && identifier_entry.eq("userEvents") {
302 return Err(Error::UserEvents);
303 }
304
305 if subscriptions.is_empty() {
306 let payload = serde_json::to_string(&SubscriptionSendData {
307 method: "subscribe",
308 subscription: &serde_json::from_str::<serde_json::Value>(&identifier)
309 .map_err(|e| Error::JsonParse(e.to_string()))?,
310 })
311 .map_err(|e| Error::JsonParse(e.to_string()))?;
312
313 let mut writer = self.writer.lock().await;
314 writer
315 .send(protocol::Message::Text(payload))
316 .await
317 .map_err(|e| Error::Websocket(e.to_string()))?;
318 }
319
320 let subscription_id = self.subscription_id;
321 self.subscription_identifiers
322 .insert(subscription_id, identifier.clone());
323 subscriptions.push(SubscriptionData {
324 sending_channel,
325 subscription_id,
326 });
327
328 self.subscription_id += 1;
329 Ok(subscription_id)
330 }
331
332 pub(crate) async fn remove_subscription(&mut self, subscription_id: u32) -> Result<()> {
333 let identifier = self
334 .subscription_identifiers
335 .get(&subscription_id)
336 .ok_or(Error::SubscriptionNotFound)?
337 .clone();
338
339 let identifier_entry = if let Subscription::UserEvents { user: _ } =
340 serde_json::from_str::<Subscription>(&identifier)
341 .map_err(|e| Error::JsonParse(e.to_string()))?
342 {
343 "userEvents".to_string()
344 } else if let Subscription::OrderUpdates { user: _ } =
345 serde_json::from_str::<Subscription>(&identifier)
346 .map_err(|e| Error::JsonParse(e.to_string()))?
347 {
348 "orderUpdates".to_string()
349 } else {
350 identifier.clone()
351 };
352
353 self.subscription_identifiers.remove(&subscription_id);
354
355 let mut subscriptions = self.subscriptions.lock().await;
356
357 let subscriptions = subscriptions
358 .get_mut(&identifier_entry)
359 .ok_or(Error::SubscriptionNotFound)?;
360 let index = subscriptions
361 .iter()
362 .position(|subscription_data| subscription_data.subscription_id == subscription_id)
363 .ok_or(Error::SubscriptionNotFound)?;
364 subscriptions.remove(index);
365
366 if subscriptions.is_empty() {
367 let payload = serde_json::to_string(&SubscriptionSendData {
368 method: "unsubscribe",
369 subscription: &serde_json::from_str::<serde_json::Value>(&identifier)
370 .map_err(|e| Error::JsonParse(e.to_string()))?,
371 })
372 .map_err(|e| Error::JsonParse(e.to_string()))?;
373
374 let mut writer = self.writer.lock().await;
375 writer
376 .send(protocol::Message::Text(payload))
377 .await
378 .map_err(|e| Error::Websocket(e.to_string()))?;
379 }
380 Ok(())
381 }
382}
383
384impl Drop for WsManager {
385 fn drop(&mut self) {
386 self.stop_flag.store(true, Ordering::Relaxed);
387
388 let rt = Runtime::new().unwrap();
389
390 if let Some(task) = self.reader_handle.take() {
391 rt.block_on(task).unwrap();
392 }
393
394 if let Some(task) = self.ping_handle.take() {
395 rt.block_on(task).unwrap();
396 }
397 }
398}