hyperliquid_rust_sdk_abrkn/ws/robust/
subs.rs

1use super::stream::Stream;
2use crate::{BaseUrl, Message, Subscription, SubscriptionSendData};
3use anyhow::Result;
4use log::{debug, error, trace};
5use serde::Serialize;
6use std::sync::{atomic::AtomicU32, Arc};
7use tokio::{
8    spawn,
9    sync::{mpsc, oneshot, RwLock},
10    task::JoinHandle,
11};
12
13type Topic = super::super::ws_manager::Subscription;
14
15// NOTE: Leaking subs can be prevented here by implementing a drop that uses a channel
16// to notify the subs manager to remove the sub. This requires Subs to have a handle
17pub type SubId = u32;
18
19pub struct Sub {
20    pub id: SubId,
21    pub topic_key: String,
22    pub topic: Topic,
23    pub tx: mpsc::UnboundedSender<Message>,
24}
25
26#[derive(Serialize, Debug)]
27pub struct Unsubscribe {
28    pub method: String,
29    pub subscription: Topic,
30}
31
32enum Command {
33    Subscribe {
34        subscription: Subscription,
35        tx: mpsc::UnboundedSender<Message>,
36        reply_tx: oneshot::Sender<SubId>,
37    },
38    Unsubscribe(SubId),
39}
40
41#[derive(Clone)]
42pub struct State {
43    id_counter: Arc<AtomicU32>,
44    subs: Arc<RwLock<Vec<Sub>>>,
45}
46
47fn get_topic_key_for_subscription(topic: &Topic) -> String {
48    match topic {
49        Subscription::UserEvents { user: _ } => "userEvents".to_string(),
50        Subscription::OrderUpdates { user: _ } => "orderUpdates".to_string(),
51        Subscription::UserFills { user: _ } => "userFills".to_string(),
52        _ => serde_json::to_string(topic).expect("Failed to convert subscription to string"),
53    }
54}
55
56async fn run(
57    outbox_tx: mpsc::Sender<serde_json::Value>,
58    mut inbox_rx: mpsc::Receiver<Message>,
59    mut command_rx: mpsc::Receiver<Command>,
60) -> Result<()> {
61    let state = State {
62        subs: Arc::new(RwLock::new(Vec::new())),
63        id_counter: Arc::new(AtomicU32::new(0)),
64    };
65
66    loop {
67        tokio::select! {
68            message = inbox_rx.recv() => {
69                match message {
70                    Some(message) => {
71                        let topic = super::super::WsManager::get_identifier(&message)?;
72                            debug!("Received message for topic: {}", topic);
73
74                            for sub in
75                                state.subs.read().await
76                                .iter()
77                                .filter(|s| s.topic_key == topic)
78                            {
79                                trace!("Sending message to sub ID={}", sub.id);
80
81                                if let Err(e) = sub.tx.send(message.clone()) {
82                                    error!(
83                                        "Failed to send message for topic {} to sub {}: {}",
84                                        topic, sub.id, e
85                                    );
86                                }
87                        }
88                    }
89                    None => {
90                        trace!("Inbox receiver closed");
91                        break Ok(());
92                    }
93                }
94            },
95            command = command_rx.recv() => {
96                match command {
97                    Some(Command::Subscribe { subscription, tx, reply_tx }) => {
98                        trace!("Received subscribe command for topic: {:?}", &subscription);
99                        let id = add(&state, outbox_tx.clone(), subscription, tx).await?;
100
101                        if let Err(e) = reply_tx.send(id) {
102                            trace!("Failed to send reply for subscribe command: {}", e);
103                        }
104                    },
105                    Some(Command::Unsubscribe(id)) => {
106                        remove(&state, outbox_tx.clone(), id).await?;
107                    },
108                    None => {}
109                }
110            },
111        }
112    }
113}
114
115async fn add(
116    state: &State,
117    outbox_tx: mpsc::Sender<serde_json::Value>,
118    topic: Topic,
119    tx: mpsc::UnboundedSender<Message>,
120) -> Result<SubId> {
121    let id = state
122        .id_counter
123        .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
124
125    let topic_key = get_topic_key_for_subscription(&topic);
126
127    let sub = Sub {
128        id,
129        topic: topic.clone(),
130        topic_key: topic_key.clone(),
131        tx,
132    };
133
134    // NOTE: The mutex is held for the remainder of this function
135    let mut subs = state.subs.write().await;
136
137    debug!("Adding sub with id: {} ({})", id, topic_key);
138
139    if !subs.iter().any(|s| s.topic_key == topic_key) {
140        debug!("First subscription for this topic, sending subscribe command");
141
142        outbox_tx
143            .send(
144                serde_json::to_value(SubscriptionSendData {
145                    method: "subscribe",
146                    subscription: &serde_json::to_value(topic).unwrap(),
147                })
148                .unwrap(),
149            )
150            .await?;
151    }
152
153    subs.push(sub);
154
155    Ok(id)
156}
157
158async fn remove(
159    state: &State,
160    outbox_tx: mpsc::Sender<serde_json::Value>,
161    sub_id: SubId,
162) -> Result<()> {
163    // Locked for the duration of this function
164    let mut subs = state.subs.write().await;
165
166    let (topic, topic_key) = subs
167        .iter()
168        .find(|s| s.id == sub_id)
169        .map(|s| (s.topic.clone(), s.topic_key.clone()))
170        .unwrap();
171
172    debug!("Removing sub with id: {} ({})", sub_id, topic_key);
173
174    subs.retain(|s| s.id != sub_id);
175
176    // Send unsub if no subs have topic_key of token.topic_key anymore
177    if !subs.iter().any(|s| s.topic_key == topic_key) {
178        debug!(
179            "Last subscriber removed. Sending unsubscribe for topic: {}",
180            topic_key
181        );
182
183        outbox_tx
184            .send(
185                serde_json::to_value(Unsubscribe {
186                    method: "unsubscribe".to_string(),
187                    subscription: topic,
188                })
189                .unwrap(),
190            )
191            .await?;
192    }
193
194    Ok(())
195}
196
197pub struct Subs {
198    stream: Stream,
199    command_tx: mpsc::Sender<Command>,
200}
201
202pub struct Token {
203    id: SubId,
204    command_tx: mpsc::Sender<Command>,
205}
206
207impl Drop for Token {
208    fn drop(&mut self) {
209        let (id, command_tx) = (self.id, self.command_tx.clone());
210
211        trace!("Dropping Token with id: {}", self.id);
212
213        spawn(async move {
214            let _ = command_tx.send(Command::Unsubscribe(id)).await;
215        });
216    }
217}
218
219impl Subs {
220    pub fn start(base_url: &BaseUrl) -> (Self, JoinHandle<Result<()>>) {
221        let (inbox_tx, inbox_rx) = mpsc::channel(100);
222        let (command_tx, command_rx) = mpsc::channel(100);
223
224        let (stream, stream_handle) = Stream::connect(base_url, inbox_tx);
225
226        let run_handle = run(stream.outbox_tx.clone(), inbox_rx, command_rx);
227
228        let handle = spawn(async {
229            tokio::select! {
230                result = stream_handle => result.unwrap(),
231                result = run_handle => result,
232            }
233        });
234
235        (Self { stream, command_tx }, handle)
236    }
237
238    pub async fn add(&self, topic: Topic, tx: mpsc::UnboundedSender<Message>) -> Result<Token> {
239        let (reply_tx, reply_rx) = oneshot::channel();
240
241        self.command_tx
242            .send(Command::Subscribe {
243                subscription: topic,
244                tx,
245                reply_tx,
246            })
247            .await?;
248
249        let id = reply_rx.await.map_err(|e| anyhow::anyhow!(e))?;
250
251        Ok(Token {
252            id,
253            command_tx: self.command_tx.clone(),
254        })
255    }
256
257    pub async fn remove(&self, sub_id: SubId) -> Result<()> {
258        self.command_tx.send(Command::Unsubscribe(sub_id)).await?;
259
260        Ok(())
261    }
262
263    pub async fn cancel(&self) {
264        self.stream.cancel().await
265    }
266}