hyperliquid_rust_sdk_abrkn/ws/robust/
subs.rs1use super::stream::Stream;
2use crate::{BaseUrl, Message, Subscription, SubscriptionSendData};
3use anyhow::Result;
4use log::{debug, error, trace};
5use serde::Serialize;
6use std::sync::{atomic::AtomicU32, Arc};
7use tokio::{
8 spawn,
9 sync::{mpsc, oneshot, RwLock},
10 task::JoinHandle,
11};
12
13type Topic = super::super::ws_manager::Subscription;
14
15pub type SubId = u32;
18
19pub struct Sub {
20 pub id: SubId,
21 pub topic_key: String,
22 pub topic: Topic,
23 pub tx: mpsc::UnboundedSender<Message>,
24}
25
26#[derive(Serialize, Debug)]
27pub struct Unsubscribe {
28 pub method: String,
29 pub subscription: Topic,
30}
31
32enum Command {
33 Subscribe {
34 subscription: Subscription,
35 tx: mpsc::UnboundedSender<Message>,
36 reply_tx: oneshot::Sender<SubId>,
37 },
38 Unsubscribe(SubId),
39}
40
41#[derive(Clone)]
42pub struct State {
43 id_counter: Arc<AtomicU32>,
44 subs: Arc<RwLock<Vec<Sub>>>,
45}
46
47fn get_topic_key_for_subscription(topic: &Topic) -> String {
48 match topic {
49 Subscription::UserEvents { user: _ } => "userEvents".to_string(),
50 Subscription::OrderUpdates { user: _ } => "orderUpdates".to_string(),
51 Subscription::UserFills { user: _ } => "userFills".to_string(),
52 _ => serde_json::to_string(topic).expect("Failed to convert subscription to string"),
53 }
54}
55
56async fn run(
57 outbox_tx: mpsc::Sender<serde_json::Value>,
58 mut inbox_rx: mpsc::Receiver<Message>,
59 mut command_rx: mpsc::Receiver<Command>,
60) -> Result<()> {
61 let state = State {
62 subs: Arc::new(RwLock::new(Vec::new())),
63 id_counter: Arc::new(AtomicU32::new(0)),
64 };
65
66 loop {
67 tokio::select! {
68 message = inbox_rx.recv() => {
69 match message {
70 Some(message) => {
71 let topic = super::super::WsManager::get_identifier(&message)?;
72 debug!("Received message for topic: {}", topic);
73
74 for sub in
75 state.subs.read().await
76 .iter()
77 .filter(|s| s.topic_key == topic)
78 {
79 trace!("Sending message to sub ID={}", sub.id);
80
81 if let Err(e) = sub.tx.send(message.clone()) {
82 error!(
83 "Failed to send message for topic {} to sub {}: {}",
84 topic, sub.id, e
85 );
86 }
87 }
88 }
89 None => {
90 trace!("Inbox receiver closed");
91 break Ok(());
92 }
93 }
94 },
95 command = command_rx.recv() => {
96 match command {
97 Some(Command::Subscribe { subscription, tx, reply_tx }) => {
98 trace!("Received subscribe command for topic: {:?}", &subscription);
99 let id = add(&state, outbox_tx.clone(), subscription, tx).await?;
100
101 if let Err(e) = reply_tx.send(id) {
102 trace!("Failed to send reply for subscribe command: {}", e);
103 }
104 },
105 Some(Command::Unsubscribe(id)) => {
106 remove(&state, outbox_tx.clone(), id).await?;
107 },
108 None => {}
109 }
110 },
111 }
112 }
113}
114
115async fn add(
116 state: &State,
117 outbox_tx: mpsc::Sender<serde_json::Value>,
118 topic: Topic,
119 tx: mpsc::UnboundedSender<Message>,
120) -> Result<SubId> {
121 let id = state
122 .id_counter
123 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
124
125 let topic_key = get_topic_key_for_subscription(&topic);
126
127 let sub = Sub {
128 id,
129 topic: topic.clone(),
130 topic_key: topic_key.clone(),
131 tx,
132 };
133
134 let mut subs = state.subs.write().await;
136
137 debug!("Adding sub with id: {} ({})", id, topic_key);
138
139 if !subs.iter().any(|s| s.topic_key == topic_key) {
140 debug!("First subscription for this topic, sending subscribe command");
141
142 outbox_tx
143 .send(
144 serde_json::to_value(SubscriptionSendData {
145 method: "subscribe",
146 subscription: &serde_json::to_value(topic).unwrap(),
147 })
148 .unwrap(),
149 )
150 .await?;
151 }
152
153 subs.push(sub);
154
155 Ok(id)
156}
157
158async fn remove(
159 state: &State,
160 outbox_tx: mpsc::Sender<serde_json::Value>,
161 sub_id: SubId,
162) -> Result<()> {
163 let mut subs = state.subs.write().await;
165
166 let (topic, topic_key) = subs
167 .iter()
168 .find(|s| s.id == sub_id)
169 .map(|s| (s.topic.clone(), s.topic_key.clone()))
170 .unwrap();
171
172 debug!("Removing sub with id: {} ({})", sub_id, topic_key);
173
174 subs.retain(|s| s.id != sub_id);
175
176 if !subs.iter().any(|s| s.topic_key == topic_key) {
178 debug!(
179 "Last subscriber removed. Sending unsubscribe for topic: {}",
180 topic_key
181 );
182
183 outbox_tx
184 .send(
185 serde_json::to_value(Unsubscribe {
186 method: "unsubscribe".to_string(),
187 subscription: topic,
188 })
189 .unwrap(),
190 )
191 .await?;
192 }
193
194 Ok(())
195}
196
197pub struct Subs {
198 stream: Stream,
199 command_tx: mpsc::Sender<Command>,
200}
201
202pub struct Token {
203 id: SubId,
204 command_tx: mpsc::Sender<Command>,
205}
206
207impl Drop for Token {
208 fn drop(&mut self) {
209 let (id, command_tx) = (self.id, self.command_tx.clone());
210
211 trace!("Dropping Token with id: {}", self.id);
212
213 spawn(async move {
214 let _ = command_tx.send(Command::Unsubscribe(id)).await;
215 });
216 }
217}
218
219impl Subs {
220 pub fn start(base_url: &BaseUrl) -> (Self, JoinHandle<Result<()>>) {
221 let (inbox_tx, inbox_rx) = mpsc::channel(100);
222 let (command_tx, command_rx) = mpsc::channel(100);
223
224 let (stream, stream_handle) = Stream::connect(base_url, inbox_tx);
225
226 let run_handle = run(stream.outbox_tx.clone(), inbox_rx, command_rx);
227
228 let handle = spawn(async {
229 tokio::select! {
230 result = stream_handle => result.unwrap(),
231 result = run_handle => result,
232 }
233 });
234
235 (Self { stream, command_tx }, handle)
236 }
237
238 pub async fn add(&self, topic: Topic, tx: mpsc::UnboundedSender<Message>) -> Result<Token> {
239 let (reply_tx, reply_rx) = oneshot::channel();
240
241 self.command_tx
242 .send(Command::Subscribe {
243 subscription: topic,
244 tx,
245 reply_tx,
246 })
247 .await?;
248
249 let id = reply_rx.await.map_err(|e| anyhow::anyhow!(e))?;
250
251 Ok(Token {
252 id,
253 command_tx: self.command_tx.clone(),
254 })
255 }
256
257 pub async fn remove(&self, sub_id: SubId) -> Result<()> {
258 self.command_tx.send(Command::Unsubscribe(sub_id)).await?;
259
260 Ok(())
261 }
262
263 pub async fn cancel(&self) {
264 self.stream.cancel().await
265 }
266}