Skip to main content

polynode/orderbook/
stream.rs

1use std::time::Duration;
2use futures_util::{SinkExt, StreamExt};
3use tokio::sync::mpsc;
4use tokio_tungstenite::tungstenite::Message;
5
6use crate::error::{Error, Result};
7use crate::types::orderbook::{ObMessage, OrderbookUpdate, RawObMessage};
8use crate::ws::codec::decode_frame;
9
10/// Options for the orderbook stream.
11#[derive(Debug, Clone)]
12pub struct ObStreamOptions {
13    pub compress: bool,
14    pub auto_reconnect: bool,
15    pub max_reconnect_attempts: Option<u32>,
16    pub initial_backoff: Duration,
17    pub max_backoff: Duration,
18}
19
20impl Default for ObStreamOptions {
21    fn default() -> Self {
22        Self {
23            compress: true,
24            auto_reconnect: true,
25            max_reconnect_attempts: None,
26            initial_backoff: Duration::from_secs(1),
27            max_backoff: Duration::from_secs(30),
28        }
29    }
30}
31
32enum Command {
33    Subscribe(Vec<String>),
34    Unsubscribe,
35    Close,
36}
37
38/// A real-time orderbook stream from ob.polynode.dev.
39pub struct ObStream {
40    rx: mpsc::Receiver<Result<ObMessage>>,
41    cmd_tx: mpsc::Sender<Command>,
42    _handle: tokio::task::JoinHandle<()>,
43}
44
45impl ObStream {
46    pub(crate) async fn connect(
47        api_key: &str,
48        ob_url: &str,
49        options: ObStreamOptions,
50    ) -> Result<Self> {
51        let mut url = format!("{}?key={}", ob_url, api_key);
52        if options.compress {
53            url.push_str("&compress=zlib");
54        }
55
56        let (msg_tx, msg_rx) = mpsc::channel(4096);
57        let (cmd_tx, cmd_rx) = mpsc::channel(64);
58
59        let handle = tokio::spawn(ob_task(url, options, msg_tx, cmd_rx));
60
61        Ok(Self {
62            rx: msg_rx,
63            cmd_tx,
64            _handle: handle,
65        })
66    }
67
68    /// Receive the next message. Returns None when the stream is closed.
69    pub async fn next(&mut self) -> Option<Result<ObMessage>> {
70        self.rx.recv().await
71    }
72
73    /// Subscribe to orderbook updates for the given token IDs.
74    pub async fn subscribe(&self, token_ids: Vec<String>) -> Result<()> {
75        self.cmd_tx.send(Command::Subscribe(token_ids)).await
76            .map_err(|_| Error::Disconnected)
77    }
78
79    /// Unsubscribe from all markets.
80    pub async fn unsubscribe(&self) -> Result<()> {
81        self.cmd_tx.send(Command::Unsubscribe).await
82            .map_err(|_| Error::Disconnected)
83    }
84
85    /// Close the connection.
86    pub async fn close(self) -> Result<()> {
87        let _ = self.cmd_tx.send(Command::Close).await;
88        Ok(())
89    }
90}
91
92async fn ob_task(
93    url: String,
94    options: ObStreamOptions,
95    msg_tx: mpsc::Sender<Result<ObMessage>>,
96    mut cmd_rx: mpsc::Receiver<Command>,
97) {
98    let mut last_token_ids: Vec<String> = Vec::new();
99    let mut reconnect_attempts: u32 = 0;
100
101    'outer: loop {
102        let ws_stream = match tokio_tungstenite::connect_async(&url).await {
103            Ok((stream, _)) => {
104                reconnect_attempts = 0;
105                stream
106            }
107            Err(e) => {
108                let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
109                if !should_reconnect(&options, reconnect_attempts) {
110                    break;
111                }
112                let delay = backoff_delay(&options, reconnect_attempts);
113                reconnect_attempts += 1;
114                tokio::time::sleep(delay).await;
115                continue;
116            }
117        };
118
119        let (mut write, mut read) = ws_stream.split();
120
121        // Re-subscribe after reconnect
122        if !last_token_ids.is_empty() {
123            let msg = serde_json::json!({
124                "action": "subscribe",
125                "markets": last_token_ids
126            });
127            let msg_text = serde_json::to_string(&msg).unwrap();
128            if write.send(Message::Text(msg_text.into())).await.is_err() {
129                continue 'outer;
130            }
131        }
132
133        loop {
134            tokio::select! {
135                frame = read.next() => {
136                    match frame {
137                        Some(Ok(msg)) => {
138                            match decode_frame(msg) {
139                                Ok(Some(text)) => {
140                                    let messages = parse_ob_message(&text);
141                                    for m in messages {
142                                        if msg_tx.send(Ok(m)).await.is_err() {
143                                            break 'outer;
144                                        }
145                                    }
146                                }
147                                Ok(None) => {}
148                                Err(Error::ConnectionClosed) => break,
149                                Err(e) => {
150                                    let _ = msg_tx.send(Err(e)).await;
151                                }
152                            }
153                        }
154                        Some(Err(e)) => {
155                            let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
156                            break;
157                        }
158                        None => break,
159                    }
160                }
161                cmd = cmd_rx.recv() => {
162                    match cmd {
163                        Some(Command::Subscribe(ids)) => {
164                            last_token_ids = ids.clone();
165                            let msg = serde_json::json!({
166                                "action": "subscribe",
167                                "markets": ids
168                            });
169                            let msg_text = serde_json::to_string(&msg).unwrap();
170                            if write.send(Message::Text(msg_text.into())).await.is_err() {
171                                break;
172                            }
173                        }
174                        Some(Command::Unsubscribe) => {
175                            last_token_ids.clear();
176                            let msg = serde_json::json!({"action": "unsubscribe"});
177                            let msg_text = serde_json::to_string(&msg).unwrap();
178                            if write.send(Message::Text(msg_text.into())).await.is_err() {
179                                break;
180                            }
181                        }
182                        Some(Command::Close) | None => {
183                            let _ = write.send(Message::Close(None)).await;
184                            break 'outer;
185                        }
186                    }
187                }
188            }
189        }
190
191        if !should_reconnect(&options, reconnect_attempts) {
192            break;
193        }
194        let delay = backoff_delay(&options, reconnect_attempts);
195        reconnect_attempts += 1;
196        tokio::time::sleep(delay).await;
197    }
198}
199
200fn should_reconnect(options: &ObStreamOptions, attempts: u32) -> bool {
201    if !options.auto_reconnect {
202        return false;
203    }
204    match options.max_reconnect_attempts {
205        Some(max) => attempts < max,
206        None => true,
207    }
208}
209
210fn backoff_delay(options: &ObStreamOptions, attempts: u32) -> Duration {
211    let base = options.initial_backoff.as_millis() as u64;
212    let max = options.max_backoff.as_millis() as u64;
213    let delay = std::cmp::min(base * 2u64.pow(attempts), max);
214    let jitter = delay / 2 + (rand_simple() % (delay / 2 + 1));
215    Duration::from_millis(jitter)
216}
217
218fn rand_simple() -> u64 {
219    use std::time::SystemTime;
220    SystemTime::now()
221        .duration_since(SystemTime::UNIX_EPOCH)
222        .unwrap_or_default()
223        .subsec_nanos() as u64
224}
225
226/// Parse a raw JSON message into zero or more ObMessages.
227/// Batches and snapshot_batches are flattened into individual updates.
228fn parse_ob_message(text: &str) -> Vec<ObMessage> {
229    let raw: RawObMessage = match serde_json::from_str(text) {
230        Ok(r) => r,
231        Err(_) => return vec![],
232    };
233
234    // Error messages have "error" field instead of "type"
235    if let Some(error) = raw.error {
236        return vec![ObMessage::Error {
237            error,
238            message: raw.message.unwrap_or_default(),
239        }];
240    }
241
242    let msg_type = match raw.msg_type {
243        Some(ref t) => t.as_str(),
244        None => return vec![],
245    };
246
247    match msg_type {
248        "subscribed" => vec![ObMessage::Subscribed {
249            markets: raw.markets.unwrap_or(0),
250        }],
251        "unsubscribed" => vec![ObMessage::Unsubscribed],
252        "snapshots_done" => vec![ObMessage::SnapshotsDone {
253            total: raw.total.unwrap_or(0),
254        }],
255        "snapshot_batch" => {
256            let mut out = Vec::new();
257            if let Some(snapshots) = raw.snapshots {
258                for val in snapshots {
259                    if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
260                        out.push(ObMessage::Update(update));
261                    }
262                }
263            }
264            out
265        }
266        "batch" => {
267            let mut out = Vec::new();
268            if let Some(updates) = raw.updates {
269                for val in updates {
270                    if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
271                        out.push(ObMessage::Update(update));
272                    }
273                }
274            }
275            out
276        }
277        "pong" => vec![],
278        _ => vec![],
279    }
280}