use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicU8, Ordering},
Arc,
};
use std::time::Duration;
use async_trait::async_trait;
use futures_util::{SinkExt, Stream, StreamExt};
use serde_json::Value;
use tokio::sync::{broadcast, mpsc, RwLock as TokioRwLock};
use tokio::time::timeout;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{debug, trace, warn};
use crate::core::traits::Credentials;
use crate::core::types::{
AccountType, ConnectionStatus, StreamEvent, SubscriptionRequest, WebSocketError,
WebSocketResult,
};
use super::{
capability_provider::CapabilityProvider,
protocol::WsProtocol,
reconnect::{BackoffState, ReconnectConfig},
stream_kind::StreamKind,
stream_spec::StreamSpec,
support_level::SupportLevel,
};
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum TransportState {
Disconnected = 0,
Connecting = 1,
Connected = 2,
Reconnecting = 3,
}
impl TransportState {
fn from_u8(v: u8) -> Self {
match v {
0 => Self::Disconnected,
1 => Self::Connecting,
2 => Self::Connected,
3 => Self::Reconnecting,
_ => Self::Disconnected,
}
}
}
pub(super) enum TransportCmd {
Subscribe(StreamSpec),
Unsubscribe(StreamSpec),
Shutdown,
}
pub struct UniversalWsTransport<P: WsProtocol> {
protocol: Arc<P>,
account_type: AccountType,
testnet: bool,
credentials: Option<Credentials>,
reconnect_cfg: ReconnectConfig,
state: Arc<AtomicU8>,
active_subs: Arc<TokioRwLock<HashSet<StreamSpec>>>,
event_tx: broadcast::Sender<WebSocketResult<StreamEvent>>,
cmd_tx: mpsc::UnboundedSender<TransportCmd>,
}
impl<P: WsProtocol> Clone for UniversalWsTransport<P> {
fn clone(&self) -> Self {
Self {
protocol: Arc::clone(&self.protocol),
account_type: self.account_type,
testnet: self.testnet,
credentials: self.credentials.clone(),
reconnect_cfg: self.reconnect_cfg.clone(),
state: Arc::clone(&self.state),
active_subs: Arc::clone(&self.active_subs),
event_tx: self.event_tx.clone(),
cmd_tx: self.cmd_tx.clone(),
}
}
}
impl<P: WsProtocol> UniversalWsTransport<P> {
pub fn new(
protocol: P,
account_type: AccountType,
testnet: bool,
credentials: Option<Credentials>,
) -> Self {
Self::with_reconnect(protocol, account_type, testnet, credentials, ReconnectConfig::default())
}
pub fn with_reconnect(
protocol: P,
account_type: AccountType,
testnet: bool,
credentials: Option<Credentials>,
reconnect_cfg: ReconnectConfig,
) -> Self {
let (event_tx, _) = broadcast::channel(4096);
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let state = Arc::new(AtomicU8::new(TransportState::Disconnected as u8));
let active_subs = Arc::new(TokioRwLock::new(HashSet::new()));
let transport = Self {
protocol: Arc::new(protocol),
account_type,
testnet,
credentials,
reconnect_cfg,
state: Arc::clone(&state),
active_subs: Arc::clone(&active_subs),
event_tx,
cmd_tx,
};
let driver = DriverTask {
protocol: Arc::clone(&transport.protocol),
account_type,
testnet,
credentials: transport.credentials.clone(),
reconnect_cfg: transport.reconnect_cfg.clone(),
state: Arc::clone(&state),
active_subs: Arc::clone(&active_subs),
event_tx: transport.event_tx.clone(),
cmd_rx,
http: reqwest::Client::new(),
};
tokio::spawn(driver.run());
transport
}
pub async fn connect(&self) -> WebSocketResult<()> {
self.cmd_tx
.send(TransportCmd::Subscribe(StreamSpec {
kind: StreamKind::Ticker, symbol: String::new(),
account_type: self.account_type,
depth: None,
speed_ms: None,
}))
.ok();
let deadline = tokio::time::Instant::now()
+ Duration::from_millis(self.reconnect_cfg.connection_timeout_ms + 2_000);
loop {
let s = TransportState::from_u8(self.state.load(Ordering::Acquire));
if s == TransportState::Connected {
return Ok(());
}
if tokio::time::Instant::now() > deadline {
return Err(WebSocketError::Timeout);
}
tokio::time::sleep(Duration::from_millis(50)).await; }
}
pub async fn disconnect(&self) -> WebSocketResult<()> {
self.cmd_tx.send(TransportCmd::Shutdown).ok();
Ok(())
}
pub async fn subscribe(&self, spec: StreamSpec) -> WebSocketResult<()> {
self.cmd_tx
.send(TransportCmd::Subscribe(spec))
.map_err(|_| WebSocketError::ProtocolError("transport shut down".into()))
}
pub async fn unsubscribe(&self, spec: StreamSpec) -> WebSocketResult<()> {
self.cmd_tx
.send(TransportCmd::Unsubscribe(spec))
.map_err(|_| WebSocketError::ProtocolError("transport shut down".into()))
}
pub fn event_stream(&self) -> impl Stream<Item = WebSocketResult<StreamEvent>> + Send {
let rx = self.event_tx.subscribe();
tokio_stream::wrappers::BroadcastStream::new(rx).map(|r| match r {
Ok(v) => v,
Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(n)) => {
Err(WebSocketError::ProtocolError(format!("receiver lagged by {n} events")))
}
})
}
pub fn connection_status(&self) -> ConnectionStatus {
match TransportState::from_u8(self.state.load(Ordering::Acquire)) {
TransportState::Disconnected => ConnectionStatus::Disconnected,
TransportState::Connecting => ConnectionStatus::Connecting,
TransportState::Connected => ConnectionStatus::Connected,
TransportState::Reconnecting => ConnectionStatus::Reconnecting,
}
}
pub fn active_subscriptions(&self) -> Vec<StreamSpec> {
match self.active_subs.try_read() {
Ok(guard) => guard.iter().cloned().collect(),
Err(_) => Vec::new(),
}
}
}
impl<P: WsProtocol> CapabilityProvider for UniversalWsTransport<P> {
fn supports(&self, kind: &StreamKind, account: AccountType) -> SupportLevel {
let registry = self.protocol.topic_registry(account);
if registry.supports(kind, account) {
return SupportLevel::Native;
}
if self.protocol.requires_auth_kinds(account).contains(kind) {
return SupportLevel::RequiresAuth;
}
if self.protocol.unsupported_by_exchange(account).contains(kind) {
return SupportLevel::UnsupportedByExchange;
}
SupportLevel::NotImplemented
}
}
#[async_trait]
impl<P: WsProtocol> crate::core::traits::WebSocketConnector for UniversalWsTransport<P> {
async fn connect(&self, account_type: AccountType) -> WebSocketResult<()> {
let _ = account_type; UniversalWsTransport::connect(self).await
}
async fn disconnect(&self) -> WebSocketResult<()> {
UniversalWsTransport::disconnect(self).await
}
fn connection_status(&self) -> ConnectionStatus {
UniversalWsTransport::connection_status(self)
}
async fn subscribe(&self, request: SubscriptionRequest) -> WebSocketResult<()> {
let spec = StreamSpec::try_from(request)?;
UniversalWsTransport::subscribe(self, spec).await
}
async fn unsubscribe(&self, request: SubscriptionRequest) -> WebSocketResult<()> {
let spec = StreamSpec::try_from(request)?;
UniversalWsTransport::unsubscribe(self, spec).await
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
Box::pin(UniversalWsTransport::event_stream(self))
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
UniversalWsTransport::active_subscriptions(self)
.into_iter()
.map(SubscriptionRequest::from)
.collect()
}
}
struct DriverTask<P: WsProtocol> {
protocol: Arc<P>,
account_type: AccountType,
testnet: bool,
credentials: Option<Credentials>,
reconnect_cfg: ReconnectConfig,
state: Arc<AtomicU8>,
active_subs: Arc<TokioRwLock<HashSet<StreamSpec>>>,
event_tx: broadcast::Sender<WebSocketResult<StreamEvent>>,
cmd_rx: mpsc::UnboundedReceiver<TransportCmd>,
http: reqwest::Client,
}
impl<P: WsProtocol> DriverTask<P> {
async fn run(mut self) {
let mut backoff = BackoffState::new(self.reconnect_cfg.clone());
let exchange = self.protocol.name();
loop {
let is_reconnect = backoff.attempt > 0;
self.state.store(
if is_reconnect {
TransportState::Reconnecting
} else {
TransportState::Connecting
} as u8,
Ordering::Release,
);
let url = match self
.protocol
.pre_connect_hook(&self.http, self.account_type, self.testnet)
.await
{
Ok(Some(dynamic_url)) => dynamic_url,
Ok(None) => self.protocol.endpoint(self.account_type, self.testnet),
Err(e) => {
warn!(target: "dig3::ws::connect", exchange, error = %e, "pre_connect_hook failed");
self.state
.store(TransportState::Reconnecting as u8, Ordering::Release);
let delay = backoff.next_delay();
tokio::time::sleep(delay).await;
continue;
}
};
debug!(target: "dig3::ws::connect", exchange, url = %url, "connecting");
let conn_timeout = backoff.connection_timeout();
let ws_result = timeout(conn_timeout, connect_async(url.as_str())).await;
let ws_stream = match ws_result {
Ok(Ok((stream, _))) => stream,
Ok(Err(e)) => {
warn!(target: "dig3::ws::connect", exchange, error = %e, "connection failed");
let _ = self
.event_tx
.send(Err(WebSocketError::ConnectionError(e.to_string())));
let delay = backoff.next_delay();
tokio::time::sleep(delay).await;
continue;
}
Err(_elapsed) => {
warn!(target: "dig3::ws::connect", exchange, "connection timed out");
let _ = self.event_tx.send(Err(WebSocketError::Timeout));
let delay = backoff.next_delay();
tokio::time::sleep(delay).await;
continue;
}
};
let (mut write_half, mut read_half) = ws_stream.split();
if let Some(creds) = &self.credentials {
if let Some(auth_result) = self.protocol.auth_frame(creds) {
match auth_result {
Err(e) => {
warn!(target: "dig3::ws::auth", exchange, error = %e, "auth frame build failed");
let delay = backoff.auth_failure_delay();
tokio::time::sleep(delay).await;
continue;
}
Ok(auth_msg) => {
if let Err(e) = write_half.send(auth_msg).await {
warn!(target: "dig3::ws::auth", exchange, error = %e, "auth frame send failed");
let delay = backoff.auth_failure_delay();
tokio::time::sleep(delay).await;
continue;
}
let ack_timeout = self.protocol.auth_ack_timeout();
let ack_ok = wait_for_auth_ack(
&mut read_half,
&*self.protocol,
ack_timeout,
exchange,
)
.await;
if !ack_ok {
warn!(target: "dig3::ws::auth", exchange, "auth ack not received");
let delay = backoff.auth_failure_delay();
tokio::time::sleep(delay).await;
continue;
}
debug!(target: "dig3::ws::auth", exchange, "auth ack received");
}
}
}
}
{
let subs = self.active_subs.read().await;
for spec in subs.iter() {
match self.protocol.subscribe_frame(spec) {
Ok(msg) => {
if let Err(e) = write_half.send(msg).await {
warn!(target: "dig3::ws::replay", exchange, error = %e, "replay send failed");
}
}
Err(e) => {
warn!(target: "dig3::ws::replay", exchange, error = %e, "subscribe_frame failed");
}
}
}
}
self.state
.store(TransportState::Connected as u8, Ordering::Release);
backoff.reset();
debug!(target: "dig3::ws::connect", exchange, "connected");
let mut ping_interval =
tokio::time::interval(self.protocol.ping_interval());
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let exit = loop {
tokio::select! {
frame = read_half.next() => {
match frame {
Some(Ok(msg)) => {
match self.dispatch_message(msg, exchange).await {
Ok(true) => {} Ok(false) => break LoopExit::Shutdown, Err(e) => {
warn!(target: "dig3::ws::frame", exchange, error = %e, "frame error");
break LoopExit::Error;
}
}
}
Some(Err(e)) => {
warn!(target: "dig3::ws::frame", exchange, error = %e, "ws error");
break LoopExit::Error;
}
None => {
debug!(target: "dig3::ws::connect", exchange, "stream closed");
break LoopExit::Closed;
}
}
}
cmd = self.cmd_rx.recv() => {
match cmd {
Some(TransportCmd::Subscribe(spec)) => {
self.active_subs.write().await.insert(spec.clone());
match self.protocol.subscribe_frame(&spec) {
Ok(msg) => {
if let Err(e) = write_half.send(msg).await {
warn!(target: "dig3::ws", exchange, error = %e, "subscribe send failed");
}
}
Err(e) => {
warn!(target: "dig3::ws", exchange, error = %e, "subscribe_frame failed");
}
}
}
Some(TransportCmd::Unsubscribe(spec)) => {
self.active_subs.write().await.remove(&spec);
match self.protocol.unsubscribe_frame(&spec) {
Ok(msg) => {
if let Err(e) = write_half.send(msg).await {
warn!(target: "dig3::ws", exchange, error = %e, "unsubscribe send failed");
}
}
Err(e) => {
warn!(target: "dig3::ws", exchange, error = %e, "unsubscribe_frame failed");
}
}
}
Some(TransportCmd::Shutdown) => {
let _ = write_half.close().await;
self.state.store(TransportState::Disconnected as u8, Ordering::Release);
return;
}
None => {
break LoopExit::Closed;
}
}
}
_ = ping_interval.tick() => {
let msg = match self.protocol.ping_frame() {
Some(m) => m,
None => Message::Ping(vec![]),
};
if let Err(e) = write_half.send(msg).await {
warn!(target: "dig3::ws::ping", exchange, error = %e, "ping send failed");
break LoopExit::Error;
}
}
}
};
match exit {
LoopExit::Shutdown => {
self.state
.store(TransportState::Disconnected as u8, Ordering::Release);
return;
}
LoopExit::Closed | LoopExit::Error => {
if backoff.max_attempts() > 0 && backoff.attempt >= backoff.max_attempts() {
warn!(target: "dig3::ws::connect", exchange, "max reconnect attempts reached");
self.state
.store(TransportState::Disconnected as u8, Ordering::Release);
return;
}
let delay = backoff.next_delay();
tokio::time::sleep(delay).await;
}
}
}
}
async fn dispatch_message(
&self,
msg: Message,
exchange: &str,
) -> WebSocketResult<bool> {
let raw: Value = match msg {
Message::Text(text) => match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
warn!(target: "dig3::ws::frame", exchange, error = %e, "JSON parse failed");
return Ok(true);
}
},
Message::Binary(bytes) => match self.protocol.decode_binary(&bytes) {
Ok(v) => v,
Err(e) => {
warn!(target: "dig3::ws::frame", exchange, error = %e, "binary decode failed");
return Ok(true);
}
},
Message::Ping(data) => {
trace!(target: "dig3::ws::frame", exchange, kind = "Ping", len = data.len());
return Ok(true);
}
Message::Pong(_) => {
trace!(target: "dig3::ws::frame", exchange, kind = "Pong");
return Ok(true);
}
Message::Close(_) => {
debug!(target: "dig3::ws::connect", exchange, "received Close frame");
return Ok(true); }
Message::Frame(_) => {
return Ok(true);
}
};
trace!(
target: "dig3::ws::frame",
exchange,
payload_len = raw.to_string().len(),
"frame received"
);
if self.protocol.is_pong(&raw) {
return Ok(true);
}
if self.protocol.is_subscribe_ack(&raw) {
return Ok(true);
}
if self.credentials.is_some() && self.protocol.is_auth_ack(&raw) {
return Ok(true);
}
let topic_key = match self.protocol.extract_topic(&raw) {
None => return Ok(true), Some(k) => k,
};
let topic_str = topic_key.to_string();
let registry = self.protocol.topic_registry(self.account_type);
match registry.dispatch(&topic_key) {
Some(parser) => match parser(&raw) {
Ok(event) => {
let n_receivers = self.event_tx.receiver_count();
if n_receivers > 0 {
let _ = self.event_tx.send(Ok(event));
}
}
Err(e) => {
warn!(
target: "dig3::ws::parse",
exchange,
topic = %topic_str,
error = %e,
"parser failed"
);
let _ = self.event_tx.send(Err(e));
}
},
None => {
warn!(
target: "dig3::ws::unmatched",
exchange,
topic = %topic_str,
"no registered parser"
);
}
}
Ok(true)
}
}
async fn wait_for_auth_ack<P: WsProtocol, S>(
read_half: &mut S,
protocol: &P,
ack_timeout: Duration,
exchange: &str,
) -> bool
where
S: StreamExt<Item = Result<Message, tokio_tungstenite::tungstenite::Error>> + Unpin,
{
let result = timeout(ack_timeout, async {
while let Some(msg) = read_half.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Ok(v) = serde_json::from_str::<Value>(&text) {
if protocol.is_auth_ack(&v) {
return true;
}
}
}
Ok(_) => continue,
Err(e) => {
warn!(target: "dig3::ws::auth", exchange, error = %e, "error during auth ack wait");
return false;
}
}
}
false
})
.await;
result.unwrap_or(false)
}
enum LoopExit {
Shutdown,
Closed,
Error,
}
pub fn decode_binary_default(bytes: &[u8]) -> WebSocketResult<Value> {
use flate2::read::{DeflateDecoder, GzDecoder, ZlibDecoder};
use std::io::Read;
if bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b {
let mut decoder = GzDecoder::new(bytes);
let mut decompressed = String::new();
if decoder.read_to_string(&mut decompressed).is_ok() {
return serde_json::from_str(&decompressed)
.map_err(|e| WebSocketError::Parse(e.to_string()));
}
}
if !bytes.is_empty() && bytes[0] == 0x78 {
let mut decoder = ZlibDecoder::new(bytes);
let mut decompressed = String::new();
if decoder.read_to_string(&mut decompressed).is_ok() {
return serde_json::from_str(&decompressed)
.map_err(|e| WebSocketError::Parse(e.to_string()));
}
}
{
let mut decoder = DeflateDecoder::new(bytes);
let mut decompressed = String::new();
if decoder.read_to_string(&mut decompressed).is_ok() {
if let Ok(v) = serde_json::from_str(&decompressed) {
return Ok(v);
}
}
}
let text = std::str::from_utf8(bytes)
.map_err(|e| WebSocketError::Parse(format!("binary not valid UTF-8: {e}")))?;
serde_json::from_str(text).map_err(|e| WebSocketError::Parse(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transport_state_roundtrip() {
let states = [
TransportState::Disconnected,
TransportState::Connecting,
TransportState::Connected,
TransportState::Reconnecting,
];
for s in states {
assert_eq!(TransportState::from_u8(s as u8), s);
}
}
#[test]
fn decode_binary_plain_json() {
let json = br#"{"type":"trade","symbol":"BTCUSDT"}"#;
let v = decode_binary_default(json).unwrap();
assert_eq!(v["type"], "trade");
}
}