use futures_util::{SinkExt, StreamExt};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use super::codec::decode_frame;
use super::subscription::Subscription;
use crate::error::{Error, Result};
use crate::types::events::{PolyNodeEvent, PriceFeedEvent};
use crate::types::ws_messages::{RawWsMessage, WsMessage};
#[derive(Debug, Clone)]
pub struct StreamOptions {
pub compress: bool,
pub auto_reconnect: bool,
pub max_reconnect_attempts: Option<u32>,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl Default for StreamOptions {
fn default() -> Self {
Self {
compress: true,
auto_reconnect: true,
max_reconnect_attempts: None,
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(30),
}
}
}
enum Command {
Subscribe(Subscription),
Unsubscribe(Option<String>),
Close,
}
pub struct WsStream {
rx: mpsc::Receiver<Result<WsMessage>>,
cmd_tx: mpsc::Sender<Command>,
_handle: tokio::task::JoinHandle<()>,
}
impl WsStream {
pub(crate) async fn connect(
api_key: &str,
ws_url: &str,
options: StreamOptions,
) -> Result<Self> {
let mut url = format!("{}?key={}", ws_url, api_key);
if options.compress {
url.push_str("&compress=zlib");
}
let (msg_tx, msg_rx) = mpsc::channel(1024);
let (cmd_tx, cmd_rx) = mpsc::channel(64);
let handle = tokio::spawn(ws_task(url, options, msg_tx, cmd_rx));
Ok(Self {
rx: msg_rx,
cmd_tx,
_handle: handle,
})
}
pub async fn next(&mut self) -> Option<Result<WsMessage>> {
self.rx.recv().await
}
pub async fn subscribe(&self, sub: Subscription) -> Result<()> {
self.cmd_tx
.send(Command::Subscribe(sub))
.await
.map_err(|_| Error::Disconnected)
}
pub async fn unsubscribe(&self, subscription_id: Option<String>) -> Result<()> {
self.cmd_tx
.send(Command::Unsubscribe(subscription_id))
.await
.map_err(|_| Error::Disconnected)
}
pub async fn close(self) -> Result<()> {
let _ = self.cmd_tx.send(Command::Close).await;
Ok(())
}
}
async fn ws_task(
url: String,
options: StreamOptions,
msg_tx: mpsc::Sender<Result<WsMessage>>,
mut cmd_rx: mpsc::Receiver<Command>,
) {
let mut active_subs: Vec<Subscription> = Vec::new();
let mut reconnect_attempts: u32 = 0;
'outer: loop {
let ws_stream = match tokio_tungstenite::connect_async(&url).await {
Ok((stream, _)) => {
reconnect_attempts = 0;
stream
}
Err(e) => {
let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
if !should_reconnect(&options, reconnect_attempts) {
break;
}
let delay = backoff_delay(&options, reconnect_attempts);
reconnect_attempts += 1;
tokio::time::sleep(delay).await;
continue;
}
};
let (mut write, mut read) = ws_stream.split();
for sub in &active_subs {
let msg_text = serde_json::to_string(&sub.to_message()).unwrap();
if write.send(Message::Text(msg_text.into())).await.is_err() {
continue 'outer;
}
}
loop {
tokio::select! {
frame = read.next() => {
match frame {
Some(Ok(msg)) => {
match decode_frame(msg) {
Ok(Some(text)) => {
if let Some(ws_msg) = parse_message(&text) {
if msg_tx.send(Ok(ws_msg)).await.is_err() {
break 'outer;
}
}
}
Ok(None) => {} Err(Error::ConnectionClosed) => break,
Err(e) => {
let _ = msg_tx.send(Err(e)).await;
}
}
}
Some(Err(e)) => {
let _ = msg_tx.send(Err(Error::WebSocket(e))).await;
break;
}
None => break, }
}
cmd = cmd_rx.recv() => {
match cmd {
Some(Command::Subscribe(sub)) => {
let msg_text = serde_json::to_string(&sub.to_message()).unwrap();
active_subs.push(sub);
if write.send(Message::Text(msg_text.into())).await.is_err() {
break;
}
}
Some(Command::Unsubscribe(id)) => {
let msg = if let Some(ref sid) = id {
serde_json::json!({"action": "unsubscribe", "subscription_id": sid})
} else {
active_subs.clear();
serde_json::json!({"action": "unsubscribe"})
};
let msg_text = serde_json::to_string(&msg).unwrap();
if write.send(Message::Text(msg_text.into())).await.is_err() {
break;
}
}
Some(Command::Close) | None => {
let _ = write.send(Message::Close(None)).await;
break 'outer;
}
}
}
}
}
if !should_reconnect(&options, reconnect_attempts) {
break;
}
let delay = backoff_delay(&options, reconnect_attempts);
reconnect_attempts += 1;
tokio::time::sleep(delay).await;
}
}
fn should_reconnect(options: &StreamOptions, attempts: u32) -> bool {
if !options.auto_reconnect {
return false;
}
match options.max_reconnect_attempts {
Some(max) => attempts < max,
None => true,
}
}
fn backoff_delay(options: &StreamOptions, attempts: u32) -> Duration {
let base = options.initial_backoff.as_millis() as u64;
let max = options.max_backoff.as_millis() as u64;
let delay = std::cmp::min(base * 2u64.pow(attempts), max);
let jitter = delay / 2 + (rand_simple() % (delay / 2 + 1));
Duration::from_millis(jitter)
}
fn rand_simple() -> u64 {
use std::time::SystemTime;
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as u64
}
fn parse_message(text: &str) -> Option<WsMessage> {
let raw: RawWsMessage = serde_json::from_str(text).ok()?;
match raw.msg_type.as_str() {
"subscribed" => Some(WsMessage::Subscribed {
subscription_id: raw.subscription_id.unwrap_or_default(),
subscription_type: raw.subscription_type.unwrap_or_default(),
warnings: raw.warnings.unwrap_or_default(),
}),
"unsubscribed" => Some(WsMessage::Unsubscribed {
subscriber_id: raw.subscriber_id.unwrap_or_default(),
}),
"heartbeat" => Some(WsMessage::Heartbeat {
ts: raw.ts.unwrap_or(0),
}),
"pong" => None, "error" => Some(WsMessage::Error {
code: raw.code,
message: raw.message.unwrap_or_default(),
}),
"snapshot" => Some(WsMessage::Snapshot(raw.events.unwrap_or_default())),
"price_feed" => {
if let Some(data) = raw.data {
if let Ok(pf) = serde_json::from_value::<PriceFeedEvent>(data) {
return Some(WsMessage::PriceFeed(pf));
}
}
None
}
_ => {
if let Some(data) = raw.data {
if let Ok(event) = serde_json::from_value::<PolyNodeEvent>(data) {
return Some(WsMessage::Event(event));
}
}
None
}
}
}