hyperliquid_rust_sdk_abrkn/ws/robust/
stream.rs

1use crate::{ws::ws_manager::Ping, BaseUrl, Message};
2use anyhow::{anyhow, Context, Result};
3use futures_util::{
4    stream::{SplitSink, SplitStream},
5    SinkExt, StreamExt,
6};
7use log::{debug, trace};
8use serde::Serialize;
9use std::{sync::Arc, time::Duration};
10use tokio::{
11    net::TcpStream,
12    spawn,
13    sync::{mpsc, Mutex},
14    task::JoinHandle,
15    time::{interval, interval_at, Instant},
16};
17use tokio_tungstenite::{connect_async, tungstenite::protocol, MaybeTlsStream, WebSocketStream};
18
19type Socket = WebSocketStream<MaybeTlsStream<TcpStream>>;
20type Writer = SplitSink<Socket, protocol::Message>;
21type Reader = SplitStream<Socket>;
22
23const PING_INTERVAL: Duration = Duration::from_secs(50);
24const PONG_TIMEOUT: Duration = Duration::from_secs(60);
25
26pub async fn connect(base_url: &BaseUrl) -> Result<Socket> {
27    let url = format!("ws{}/ws", &BaseUrl::get_url(base_url)[4..]);
28
29    let (socket, _response) = connect_async(url).await.context("Failed to connect")?;
30
31    Ok(socket)
32}
33
34pub async fn send<C: Serialize>(writer: &mut Writer, command: C) -> Result<()> {
35    let serialized = serde_json::to_string(&command).context("Failed to serialize command")?;
36
37    trace!("--> {:?}", &serialized);
38
39    writer
40        .send(protocol::Message::Text(serialized))
41        .await
42        .context("Failed to send command")?;
43
44    Ok(())
45}
46
47// NOTE: Unknown message types are returned as None
48fn parse_message(message: protocol::Message) -> Result<Option<Message>> {
49    match message {
50        protocol::Message::Text(text) => {
51            trace!("<-- {:?}", &text);
52
53            let message = serde_json::from_str::<serde_json::Value>(&text)?;
54
55            match serde_json::from_value::<Message>(message) {
56                Ok(message) => Ok(Some(message)),
57                Err(e) => {
58                    debug!("Unhandled message: {}", e);
59
60                    Ok(None)
61                }
62            }
63        }
64        _ => Err(anyhow!("Unhandled message type: {:?}", message)),
65    }
66}
67
68pub async fn stream(
69    mut reader: Reader,
70    writer: Arc<Mutex<Writer>>,
71    tx: mpsc::Sender<Message>,
72    mut cancel_rx: mpsc::Receiver<()>,
73) -> Result<()> {
74    let mut ping_interval = interval(PING_INTERVAL);
75
76    let mut pong_interval = interval_at(Instant::now() + PONG_TIMEOUT, PONG_TIMEOUT);
77
78    loop {
79        tokio::select! {
80            message = reader.next() => match message {
81                None => {
82                    trace!("Reader stream ended");
83                    break Ok(());
84                },
85                Some(message) => match message {
86                    Err(e) => break Err(e.into()),
87                    Ok(message) => {
88                        let message = parse_message(message)?;
89
90                        if let Some(message) = message {
91                            if let Message::Pong = message {
92                                trace!("Pong received. Interval reset");
93
94                                pong_interval = interval_at(
95                                    Instant::now() + PONG_TIMEOUT,
96                                    PONG_TIMEOUT,
97                                );
98                            }
99
100                            tx.send(message).await.context("Failed to send message")?;
101                        }
102                    }
103                }
104            },
105            _ = ping_interval.tick() => {
106                send(&mut *writer.lock().await, Ping { method: "ping" }).await?;
107            },
108            // Handle pong timeout
109            _ = pong_interval.tick() => {
110                return Err(anyhow!("Pong timeout"));
111            },
112            _ = cancel_rx.recv() => {
113                trace!("Received cancel signal");
114                break Ok(());
115            }
116        }
117    }
118}
119
120pub async fn connect_and_stream(
121    base_url: &BaseUrl,
122    inbox_tx: mpsc::Sender<Message>,
123    mut outbox_rx: mpsc::Receiver<serde_json::Value>,
124    cancel_rx: mpsc::Receiver<()>,
125) -> Result<()> {
126    let socket = connect(base_url).await?;
127
128    let (writer, reader) = socket.split();
129    let writer = Arc::new(Mutex::new(writer));
130
131    tokio::select! {
132        result = stream(reader, writer.clone(), inbox_tx, cancel_rx) => result,
133        result = async {
134            while let Some(message) = outbox_rx.recv().await {
135                send(&mut *writer.lock().await, message).await?;
136            }
137
138            Ok(())
139        } =>
140            result
141
142    }
143}
144
145pub struct Stream {
146    pub outbox_tx: mpsc::Sender<serde_json::Value>,
147    cancel_tx: mpsc::Sender<()>,
148}
149
150impl Drop for Stream {
151    fn drop(&mut self) {
152        let cancel_tx = self.cancel_tx.clone();
153
154        spawn(async move {
155            let _ = cancel_tx.send(()).await;
156        });
157    }
158}
159
160impl Stream {
161    pub fn connect(
162        base_url: &BaseUrl,
163        inbox_tx: mpsc::Sender<Message>,
164    ) -> (Self, JoinHandle<Result<()>>) {
165        let (outbox_tx, outbox_rx) = mpsc::channel(100);
166        let (cancel_tx, cancel_rx) = mpsc::channel(1);
167
168        let handle = spawn({
169            let base_url = *base_url;
170
171            async move { connect_and_stream(&base_url, inbox_tx, outbox_rx, cancel_rx).await }
172        });
173
174        (
175            Self {
176                outbox_tx,
177                cancel_tx,
178            },
179            handle,
180        )
181    }
182
183    pub async fn send(&self, message: serde_json::Value) -> Result<()> {
184        self.outbox_tx
185            .send(message)
186            .await
187            .context("Failed to send message")
188    }
189
190    pub async fn cancel(&self) {
191        let _ = self.cancel_tx.send(()).await;
192    }
193}