polynode/orderbook/
stream.rs1use std::time::Duration;
2use futures_util::{SinkExt, StreamExt};
3use tokio::sync::mpsc;
4use tokio_tungstenite::tungstenite::Message;
5
6use crate::error::{Error, Result};
7use crate::types::orderbook::{ObMessage, OrderbookUpdate, RawObMessage};
8use crate::ws::codec::decode_frame;
9
10#[derive(Debug, Clone)]
12pub struct ObStreamOptions {
13 pub compress: bool,
14 pub auto_reconnect: bool,
15 pub max_reconnect_attempts: Option<u32>,
16 pub initial_backoff: Duration,
17 pub max_backoff: Duration,
18}
19
20impl Default for ObStreamOptions {
21 fn default() -> Self {
22 Self {
23 compress: true,
24 auto_reconnect: true,
25 max_reconnect_attempts: None,
26 initial_backoff: Duration::from_secs(1),
27 max_backoff: Duration::from_secs(30),
28 }
29 }
30}
31
32enum Command {
33 Subscribe(Vec<String>),
34 Unsubscribe,
35 Close,
36}
37
38pub struct ObStream {
40 rx: mpsc::Receiver<Result<ObMessage>>,
41 cmd_tx: mpsc::Sender<Command>,
42 _handle: tokio::task::JoinHandle<()>,
43}
44
45impl ObStream {
46 pub(crate) async fn connect(
47 api_key: &str,
48 ob_url: &str,
49 options: ObStreamOptions,
50 ) -> Result<Self> {
51 let mut url = format!("{}?key={}", ob_url, api_key);
52 if options.compress {
53 url.push_str("&compress=zlib");
54 }
55
56 let (msg_tx, msg_rx) = mpsc::channel(4096);
57 let (cmd_tx, cmd_rx) = mpsc::channel(64);
58
59 let handle = tokio::spawn(ob_task(url, options, msg_tx, cmd_rx));
60
61 Ok(Self {
62 rx: msg_rx,
63 cmd_tx,
64 _handle: handle,
65 })
66 }
67
68 pub async fn next(&mut self) -> Option<Result<ObMessage>> {
70 self.rx.recv().await
71 }
72
73 pub async fn subscribe(&self, token_ids: Vec<String>) -> Result<()> {
75 self.cmd_tx.send(Command::Subscribe(token_ids)).await
76 .map_err(|_| Error::Disconnected)
77 }
78
79 pub async fn unsubscribe(&self) -> Result<()> {
81 self.cmd_tx.send(Command::Unsubscribe).await
82 .map_err(|_| Error::Disconnected)
83 }
84
85 pub async fn close(self) -> Result<()> {
87 let _ = self.cmd_tx.send(Command::Close).await;
88 Ok(())
89 }
90}
91
92async fn ob_task(
93 url: String,
94 options: ObStreamOptions,
95 msg_tx: mpsc::Sender<Result<ObMessage>>,
96 mut cmd_rx: mpsc::Receiver<Command>,
97) {
98 let mut last_token_ids: Vec<String> = Vec::new();
99 let mut reconnect_attempts: u32 = 0;
100
101 'outer: loop {
102 let ws_stream = match tokio_tungstenite::connect_async(&url).await {
103 Ok((stream, _)) => {
104 reconnect_attempts = 0;
105 stream
106 }
107 Err(e) => {
108 let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
109 if !should_reconnect(&options, reconnect_attempts) {
110 break;
111 }
112 let delay = backoff_delay(&options, reconnect_attempts);
113 reconnect_attempts += 1;
114 tokio::time::sleep(delay).await;
115 continue;
116 }
117 };
118
119 let (mut write, mut read) = ws_stream.split();
120
121 if !last_token_ids.is_empty() {
123 let msg = serde_json::json!({
124 "action": "subscribe",
125 "markets": last_token_ids
126 });
127 let msg_text = serde_json::to_string(&msg).unwrap();
128 if write.send(Message::Text(msg_text.into())).await.is_err() {
129 continue 'outer;
130 }
131 }
132
133 loop {
134 tokio::select! {
135 frame = read.next() => {
136 match frame {
137 Some(Ok(msg)) => {
138 match decode_frame(msg) {
139 Ok(Some(text)) => {
140 let messages = parse_ob_message(&text);
141 for m in messages {
142 if msg_tx.send(Ok(m)).await.is_err() {
143 break 'outer;
144 }
145 }
146 }
147 Ok(None) => {}
148 Err(Error::ConnectionClosed) => break,
149 Err(e) => {
150 let _ = msg_tx.send(Err(e)).await;
151 }
152 }
153 }
154 Some(Err(e)) => {
155 let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
156 break;
157 }
158 None => break,
159 }
160 }
161 cmd = cmd_rx.recv() => {
162 match cmd {
163 Some(Command::Subscribe(ids)) => {
164 last_token_ids = ids.clone();
165 let msg = serde_json::json!({
166 "action": "subscribe",
167 "markets": ids
168 });
169 let msg_text = serde_json::to_string(&msg).unwrap();
170 if write.send(Message::Text(msg_text.into())).await.is_err() {
171 break;
172 }
173 }
174 Some(Command::Unsubscribe) => {
175 last_token_ids.clear();
176 let msg = serde_json::json!({"action": "unsubscribe"});
177 let msg_text = serde_json::to_string(&msg).unwrap();
178 if write.send(Message::Text(msg_text.into())).await.is_err() {
179 break;
180 }
181 }
182 Some(Command::Close) | None => {
183 let _ = write.send(Message::Close(None)).await;
184 break 'outer;
185 }
186 }
187 }
188 }
189 }
190
191 if !should_reconnect(&options, reconnect_attempts) {
192 break;
193 }
194 let delay = backoff_delay(&options, reconnect_attempts);
195 reconnect_attempts += 1;
196 tokio::time::sleep(delay).await;
197 }
198}
199
200fn should_reconnect(options: &ObStreamOptions, attempts: u32) -> bool {
201 if !options.auto_reconnect {
202 return false;
203 }
204 match options.max_reconnect_attempts {
205 Some(max) => attempts < max,
206 None => true,
207 }
208}
209
210fn backoff_delay(options: &ObStreamOptions, attempts: u32) -> Duration {
211 let base = options.initial_backoff.as_millis() as u64;
212 let max = options.max_backoff.as_millis() as u64;
213 let delay = std::cmp::min(base * 2u64.pow(attempts), max);
214 let jitter = delay / 2 + (rand_simple() % (delay / 2 + 1));
215 Duration::from_millis(jitter)
216}
217
218fn rand_simple() -> u64 {
219 use std::time::SystemTime;
220 SystemTime::now()
221 .duration_since(SystemTime::UNIX_EPOCH)
222 .unwrap_or_default()
223 .subsec_nanos() as u64
224}
225
226fn parse_ob_message(text: &str) -> Vec<ObMessage> {
229 let raw: RawObMessage = match serde_json::from_str(text) {
230 Ok(r) => r,
231 Err(_) => return vec![],
232 };
233
234 if let Some(error) = raw.error {
236 return vec![ObMessage::Error {
237 error,
238 message: raw.message.unwrap_or_default(),
239 }];
240 }
241
242 let msg_type = match raw.msg_type {
243 Some(ref t) => t.as_str(),
244 None => return vec![],
245 };
246
247 match msg_type {
248 "subscribed" => vec![ObMessage::Subscribed {
249 markets: raw.markets.unwrap_or(0),
250 }],
251 "unsubscribed" => vec![ObMessage::Unsubscribed],
252 "snapshots_done" => vec![ObMessage::SnapshotsDone {
253 total: raw.total.unwrap_or(0),
254 }],
255 "snapshot_batch" => {
256 let mut out = Vec::new();
257 if let Some(snapshots) = raw.snapshots {
258 for val in snapshots {
259 if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
260 out.push(ObMessage::Update(update));
261 }
262 }
263 }
264 out
265 }
266 "batch" => {
267 let mut out = Vec::new();
268 if let Some(updates) = raw.updates {
269 for val in updates {
270 if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
271 out.push(ObMessage::Update(update));
272 }
273 }
274 }
275 out
276 }
277 "pong" => vec![],
278 _ => vec![],
279 }
280}