use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use crate::error::{Error, Result};
use crate::types::orderbook::{ObMessage, OrderbookUpdate, RawObMessage};
use crate::ws::codec::decode_frame;
#[derive(Debug, Clone)]
pub struct ObStreamOptions {
pub compress: bool,
pub auto_reconnect: bool,
pub max_reconnect_attempts: Option<u32>,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl Default for ObStreamOptions {
fn default() -> Self {
Self {
compress: true,
auto_reconnect: true,
max_reconnect_attempts: None,
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(30),
}
}
}
enum Command {
Subscribe(Vec<String>),
Unsubscribe,
Close,
}
pub struct ObStream {
rx: mpsc::Receiver<Result<ObMessage>>,
cmd_tx: mpsc::Sender<Command>,
_handle: tokio::task::JoinHandle<()>,
}
impl ObStream {
pub(crate) async fn connect(
api_key: &str,
ob_url: &str,
options: ObStreamOptions,
) -> Result<Self> {
let mut url = format!("{}?key={}", ob_url, api_key);
if options.compress {
url.push_str("&compress=zlib");
}
let (msg_tx, msg_rx) = mpsc::channel(4096);
let (cmd_tx, cmd_rx) = mpsc::channel(64);
let handle = tokio::spawn(ob_task(url, options, msg_tx, cmd_rx));
Ok(Self {
rx: msg_rx,
cmd_tx,
_handle: handle,
})
}
pub async fn next(&mut self) -> Option<Result<ObMessage>> {
self.rx.recv().await
}
pub async fn subscribe(&self, token_ids: Vec<String>) -> Result<()> {
self.cmd_tx
.send(Command::Subscribe(token_ids))
.await
.map_err(|_| Error::Disconnected)
}
pub async fn unsubscribe(&self) -> Result<()> {
self.cmd_tx
.send(Command::Unsubscribe)
.await
.map_err(|_| Error::Disconnected)
}
pub async fn close(self) -> Result<()> {
let _ = self.cmd_tx.send(Command::Close).await;
Ok(())
}
}
async fn ob_task(
url: String,
options: ObStreamOptions,
msg_tx: mpsc::Sender<Result<ObMessage>>,
mut cmd_rx: mpsc::Receiver<Command>,
) {
let mut last_token_ids: Vec<String> = Vec::new();
let mut reconnect_attempts: u32 = 0;
'outer: loop {
let ws_stream = match tokio_tungstenite::connect_async(&url).await {
Ok((stream, _)) => {
reconnect_attempts = 0;
stream
}
Err(e) => {
let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
if !should_reconnect(&options, reconnect_attempts) {
break;
}
let delay = backoff_delay(&options, reconnect_attempts);
reconnect_attempts += 1;
tokio::time::sleep(delay).await;
continue;
}
};
let (mut write, mut read) = ws_stream.split();
if !last_token_ids.is_empty() {
let msg = serde_json::json!({
"action": "subscribe",
"markets": last_token_ids
});
let msg_text = serde_json::to_string(&msg).unwrap();
if write.send(Message::Text(msg_text.into())).await.is_err() {
continue 'outer;
}
}
loop {
tokio::select! {
frame = read.next() => {
match frame {
Some(Ok(msg)) => {
match decode_frame(msg) {
Ok(Some(text)) => {
let messages = parse_ob_message(&text);
for m in messages {
if msg_tx.send(Ok(m)).await.is_err() {
break 'outer;
}
}
}
Ok(None) => {}
Err(Error::ConnectionClosed) => break,
Err(e) => {
let _ = msg_tx.send(Err(e)).await;
}
}
}
Some(Err(e)) => {
let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
break;
}
None => break,
}
}
cmd = cmd_rx.recv() => {
match cmd {
Some(Command::Subscribe(ids)) => {
last_token_ids = ids.clone();
let msg = serde_json::json!({
"action": "subscribe",
"markets": ids
});
let msg_text = serde_json::to_string(&msg).unwrap();
if write.send(Message::Text(msg_text.into())).await.is_err() {
break;
}
}
Some(Command::Unsubscribe) => {
last_token_ids.clear();
let msg = serde_json::json!({"action": "unsubscribe"});
let msg_text = serde_json::to_string(&msg).unwrap();
if write.send(Message::Text(msg_text.into())).await.is_err() {
break;
}
}
Some(Command::Close) | None => {
let _ = write.send(Message::Close(None)).await;
break 'outer;
}
}
}
}
}
if !should_reconnect(&options, reconnect_attempts) {
break;
}
let delay = backoff_delay(&options, reconnect_attempts);
reconnect_attempts += 1;
tokio::time::sleep(delay).await;
}
}
fn should_reconnect(options: &ObStreamOptions, attempts: u32) -> bool {
if !options.auto_reconnect {
return false;
}
match options.max_reconnect_attempts {
Some(max) => attempts < max,
None => true,
}
}
fn backoff_delay(options: &ObStreamOptions, attempts: u32) -> Duration {
let base = options.initial_backoff.as_millis() as u64;
let max = options.max_backoff.as_millis() as u64;
let delay = std::cmp::min(base * 2u64.pow(attempts), max);
let jitter = delay / 2 + (rand_simple() % (delay / 2 + 1));
Duration::from_millis(jitter)
}
fn rand_simple() -> u64 {
use std::time::SystemTime;
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as u64
}
fn parse_ob_message(text: &str) -> Vec<ObMessage> {
let raw: RawObMessage = match serde_json::from_str(text) {
Ok(r) => r,
Err(_) => return vec![],
};
if let Some(error) = raw.error {
return vec![ObMessage::Error {
error,
message: raw.message.unwrap_or_default(),
}];
}
let msg_type = match raw.msg_type {
Some(ref t) => t.as_str(),
None => return vec![],
};
match msg_type {
"subscribed" => vec![ObMessage::Subscribed {
markets: raw.markets.unwrap_or(0),
}],
"unsubscribed" => vec![ObMessage::Unsubscribed],
"snapshots_done" => vec![ObMessage::SnapshotsDone {
total: raw.total.unwrap_or(0),
}],
"snapshot_batch" => {
let mut out = Vec::new();
if let Some(snapshots) = raw.snapshots {
for val in snapshots {
if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
out.push(ObMessage::Update(update));
}
}
}
out
}
"batch" => {
let mut out = Vec::new();
if let Some(updates) = raw.updates {
for val in updates {
if let Ok(update) = serde_json::from_value::<OrderbookUpdate>(val) {
out.push(ObMessage::Update(update));
}
}
}
out
}
"pong" => vec![],
_ => vec![],
}
}