use std::{
time::{Duration, SystemTime},
vec,
};
use ahash::AHashSet;
use eyre::{Result, bail};
use futures_util::{SinkExt as _, StreamExt as _};
use jiff::Timestamp;
use reqwest::Url;
use tokio::net::TcpStream;
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream,
tungstenite::{self, Bytes},
};
use tracing::instrument;
use crate::{ConstructAuthError, RetryConfig, UrlError, retry::ExponentialBackoff};
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub trait WsHandler: std::fmt::Debug {
fn config(&self) -> Result<WsConfig, UrlError> {
Ok(WsConfig::default())
}
#[allow(unused_variables)]
fn handle_auth(&mut self) -> Result<Vec<tungstenite::Message>, WsError> {
Ok(vec![])
}
#[allow(unused_variables)]
fn handle_subscribe(&mut self, topics: AHashSet<Topic>) -> Result<Vec<tungstenite::Message>, WsError>;
#[allow(unused_variables)]
fn handle_jrpc(&mut self, jrpc: serde_json::Value) -> Result<ResponseOrContent, WsError>;
}
#[derive(Clone, Debug)]
pub enum ResponseOrContent {
Response(Vec<tungstenite::Message>),
Content(ContentEvent),
}
#[derive(Clone, Debug)]
pub struct ContentEvent {
pub data: serde_json::Value,
pub topic: String,
pub time: Timestamp,
pub event_type: String,
}
#[derive(Clone, Debug, Eq)]
pub struct TopicInterpreter<T> {
pub event_name: String,
pub interpret: fn(&serde_json::Value) -> Result<T, WsError>,
}
impl<T> std::hash::Hash for TopicInterpreter<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.event_name.hash(state);
}
}
impl<T> PartialEq for TopicInterpreter<T> {
fn eq(&self, other: &Self) -> bool {
self.event_name == other.event_name
}
}
#[derive(Debug)]
pub struct WsConnection<H: WsHandler> {
url: Url,
config: WsConfig,
handler: H,
stream: Option<WsConnectionStream>,
backoff: ExponentialBackoff,
}
impl<H: WsHandler> WsConnection<H> {
#[allow(missing_docs)]
pub fn try_new(url_suffix: &str, handler: H) -> Result<Self, WsError> {
let config = handler.config()?;
let url = match &config.base_url {
Some(base_url) => base_url.join(url_suffix).map_err(UrlError::Parse)?,
None => Url::parse(url_suffix).map_err(UrlError::Parse)?,
};
let backoff = ExponentialBackoff::try_from(&config.reconnect).map_err(|e| WsError::Other(eyre::eyre!("Invalid reconnect backoff configuration: {e}")))?;
Ok(Self {
url,
config,
handler,
stream: None,
backoff,
})
}
pub async fn next(&mut self) -> Result<ContentEvent, WsError> {
if let Some(inner) = &self.stream
&& inner.connected_since + self.config.refresh_after < SystemTime::now()
{
tracing::info!("Refreshing connection, as `refresh_after` specified in WsConfig has elapsed ({:?})", self.config.refresh_after);
self.reconnect().await?;
}
if self.stream.is_none() {
self.connect().await?;
}
let json_rpc_value = loop {
let resp = {
let timeout = match self.stream.as_ref() {
Some(stream) => match stream.last_unanswered_communication {
Some(last_unanswered) => {
let now = SystemTime::now();
match last_unanswered + self.config.response_timeout > now {
true => self.config.response_timeout,
false => {
tracing::error!(
"Timeout for last unanswered communication ended before `.next()` was called. This likely indicates an implementation error on the clientside."
);
self.reconnect().await?;
continue;
}
}
}
None => self.config.message_timeout,
},
None => {
tracing::error!(
"UNEXPECTED: Stream is None at ws.rs:172 despite guard at line 163. \
Possible causes: (1) system hibernation/sleep caused stale state, \
(2) memory corruption, (3) logic bug in reconnection flow, \
(4) async cancellation. \
Backoff current delay: {:?}. Attempting to reconnect...",
self.backoff.current_delay()
);
self.connect().await?;
continue;
}
};
let timeout_handle = tokio::time::timeout(timeout, {
let stream = self.stream.as_mut().unwrap();
stream.next()
});
match timeout_handle.await {
Ok(Some(resp)) => {
self.stream.as_mut().unwrap().last_unanswered_communication = None;
resp
}
Ok(None) => {
tracing::warn!("tungstenite couldn't read from the stream. Restarting.");
self.reconnect().await?;
continue;
}
Err(timeout_error) => {
tracing::warn!("Message reception timed out after {timeout:?} seconds. // {timeout_error}");
{
let stream = self.stream.as_mut().unwrap();
match stream.last_unanswered_communication.is_some() {
true => self.reconnect().await?,
false => {
self.send(tungstenite::Message::Ping(Bytes::default())).await?;
continue;
}
}
}
continue;
}
}
};
match resp {
Ok(succ_resp) => match succ_resp {
tungstenite::Message::Text(text) => {
let value: serde_json::Value =
serde_json::from_str(&text).expect("API sent invalid JSON, which is completely unexpected. Disappointment is immeasurable and the day is ruined.");
tracing::trace!("{value:#?}"); break match { self.handler.handle_jrpc(value)? } {
ResponseOrContent::Response(messages) => {
self.send_all(messages).await?;
continue; }
ResponseOrContent::Content(content) => content,
};
}
tungstenite::Message::Binary(_) => {
panic!("Received binary. But exchanges are not smart enough to send this, what is happening");
}
tungstenite::Message::Ping(bytes) => {
self.send(tungstenite::Message::Pong(bytes)).await?; tracing::debug!("ponged");
continue;
}
tungstenite::Message::Pong(_) => {
tracing::info!("Received pong");
continue;
}
tungstenite::Message::Close(maybe_reason) => {
match maybe_reason {
Some(close_frame) => {
tracing::info!("Server closed connection; reason: {close_frame:?}");
}
None => {
tracing::info!("Server closed connection; no reason specified.");
}
}
self.stream = None;
self.reconnect().await?;
continue;
}
tungstenite::Message::Frame(_) => {
unreachable!("Can't get from reading");
}
},
Err(err) => match err {
tungstenite::Error::ConnectionClosed => {
tracing::error!("received `tungstenite::Error::ConnectionClosed` on polling. Will reconnect");
self.stream = None;
continue;
}
tungstenite::Error::AlreadyClosed => {
tracing::error!("received `tungstenite::Error::AlreadyClosed` from polling. Will reconnect");
self.stream = None;
continue;
}
tungstenite::Error::Io(e) => {
tracing::error!("received `tungstenite::Error::Io` from polling: {e:?}. Atm don't know valid cases of this happening given intact application state...");
self.stream = None;
continue;
}
tungstenite::Error::Tls(_tls_error) => todo!(),
tungstenite::Error::Capacity(capacity_error) => {
tracing::warn!("received `tungstenite::Error::Capacity` from polling: {capacity_error:?}. Skipping.");
continue;
}
tungstenite::Error::Protocol(protocol_error) => {
tracing::warn!("received `tungstenite::Error::Protocol` from polling: {protocol_error:?}. Will reconnect");
self.stream = None;
continue;
}
tungstenite::Error::WriteBufferFull(_) => unreachable!("can only get from writing"),
tungstenite::Error::Utf8(e) => panic!("received `tungstenite::Error::Utf8` from polling: {e:?}. Exchange is going crazy, aborting"),
tungstenite::Error::AttackAttempt => {
tracing::warn!("received `tungstenite::Error::AttackAttempt` from polling. Don't have a reason to trust detection 100%, so just reconnecting.");
self.stream = None;
continue;
}
tungstenite::Error::Url(_url_error) => todo!(),
tungstenite::Error::Http(_response) => todo!(),
tungstenite::Error::HttpFormat(_error) => todo!(),
},
}
};
Ok(json_rpc_value)
}
#[instrument(skip_all)]
async fn send_all(&mut self, messages: Vec<tungstenite::Message>) -> Result<(), tungstenite::Error> {
if let Some(inner) = &mut self.stream {
match messages.len() {
0 => return Ok(()),
1 => {
tracing::debug!("sending to server: {:#?}", &messages[0]);
inner.send(messages.into_iter().next().unwrap()).await?;
inner.last_unanswered_communication = Some(SystemTime::now());
}
_ => {
tracing::debug!("sending to server: {messages:#?}");
let mut message_stream = futures_util::stream::iter(messages).map(Ok);
inner.send_all(&mut message_stream).await?;
inner.last_unanswered_communication = Some(SystemTime::now());
}
};
Ok(())
} else {
Err(tungstenite::Error::ConnectionClosed)
}
}
async fn send(&mut self, message: tungstenite::Message) -> Result<(), tungstenite::Error> {
self.send_all(vec![message]).await }
async fn connect(&mut self) -> Result<(), WsError> {
tracing::info!("Connecting to {}...", self.url);
let delay = self.backoff.next_duration();
if !delay.is_zero() {
tracing::warn!(delay_ms = delay.as_millis(), "Reconnect backoff active. Likely indicative of a bad connection.");
tokio::time::sleep(delay).await;
}
let (stream, http_resp) = tokio_tungstenite::connect_async(self.url.as_str()).await?;
tracing::debug!("Ws handshake with server: {http_resp:#?}");
let now = SystemTime::now();
self.stream = Some(WsConnectionStream::new(stream, now));
let auth_messages = self.handler.handle_auth()?;
self.send_all(auth_messages).await?;
self.backoff.reset();
Ok(())
}
pub async fn reconnect(&mut self) -> Result<(), WsError> {
if let Some(stream) = self.stream.as_mut() {
tracing::info!("Dropping old connection before reconnecting...");
if let Err(e) = stream.send(tungstenite::Message::Close(None)).await {
tracing::debug!("Failed to send Close frame (connection likely already dead): {e}");
}
self.stream = None;
}
self.connect().await
}
}
#[derive(Clone, Debug)]
pub struct WsConfig {
pub auth: bool,
pub base_url: Option<Url>,
pub reconnect: RetryConfig,
refresh_after: Duration,
message_timeout: Duration,
response_timeout: Duration,
pub topics: AHashSet<String>,
}
impl WsConfig {
pub fn set_reconnect(&mut self, reconnect: RetryConfig) {
self.reconnect = reconnect;
}
pub fn set_refresh_after(&mut self, refresh_after: Duration) -> Result<()> {
if refresh_after.is_zero() {
bail!("refresh_after must be greater than 0");
}
self.refresh_after = refresh_after;
Ok(())
}
pub fn set_message_timeout(&mut self, message_timeout: Duration) -> Result<()> {
if message_timeout.is_zero() {
bail!("message_timeout must be greater than 0");
}
self.message_timeout = message_timeout;
Ok(())
}
pub fn set_response_timout(&mut self, response_timeout: Duration) -> Result<()> {
if response_timeout.is_zero() {
bail!("response_timeout must be greater than 0");
}
self.response_timeout = response_timeout;
Ok(())
}
}
#[derive(Debug, miette::Diagnostic, derive_more::Display, thiserror::Error, derive_more::From)]
pub enum WsError {
#[diagnostic(transparent)]
Definition(WsDefinitionError),
#[diagnostic(code(v_exchanges::ws::tungstenite), help("WebSocket protocol error. The connection may need to be reestablished."))]
Tungstenite(tungstenite::Error),
#[diagnostic(transparent)]
Auth(ConstructAuthError),
#[diagnostic(code(v_exchanges::ws::parse), help("Failed to parse WebSocket message. Check if the exchange API has changed."))]
Parse(serde_json::Error),
#[diagnostic(code(v_exchanges::ws::subscription))]
Subscription(String),
#[diagnostic(code(v_exchanges::ws::network), help("Network connection failed. Check your internet connection."))]
NetworkConnection,
#[diagnostic(transparent)]
Url(UrlError),
#[diagnostic(code(v_exchanges::ws::unexpected_event), help("Received an unexpected event from the WebSocket. This may indicate an API change."))]
UnexpectedEvent(serde_json::Value),
#[error(transparent)]
Other(eyre::Report),
}
#[derive(Debug, miette::Diagnostic, derive_more::Display, thiserror::Error)]
pub enum WsDefinitionError {
#[diagnostic(code(v_exchanges::ws::definition::missing_url), help("WebSocket base URL must be configured in WsConfig."))]
MissingUrl,
}
#[derive(Clone, Debug, derive_more::Display, Eq, Hash, PartialEq, serde::Serialize)]
pub enum Topic {
String(String),
Order(serde_json::Value),
}
#[derive(Debug, derive_more::Deref, derive_more::DerefMut, derive_new::new)]
struct WsConnectionStream {
#[deref_mut]
#[deref]
stream: WsStream,
connected_since: SystemTime,
#[new(default)]
last_unanswered_communication: Option<SystemTime>,
}
impl Default for WsConfig {
fn default() -> Self {
Self {
auth: false,
base_url: None,
reconnect: RetryConfig {
max_retries: u32::MAX,
initial_delay_ms: 1_000,
max_delay_ms: 30_000,
backoff_factor: 2.0,
jitter_ms: 500,
immediate_first: false,
max_elapsed_ms: None,
},
refresh_after: Duration::from_hours(12),
message_timeout: Duration::from_mins(16),
response_timeout: Duration::from_mins(2),
topics: AHashSet::new(),
}
}
}