use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use serde_json::{Value, json};
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::Message;
use crate::error::ClientError;
use crate::types::order::{CancelOrder, Order, OrderResponse};
use crate::wallet::{TypedTradingAction, TypedTradingDigest, Wallet};
use crate::ws::subscriptions::{Subscription, WsMessage};
#[derive(Clone, Debug)]
pub struct WsConfig {
pub ping_interval: Duration,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub channel_capacity: usize,
pub post_timeout: Duration,
}
impl Default for WsConfig {
fn default() -> Self {
Self {
ping_interval: Duration::from_secs(30),
initial_backoff: Duration::from_millis(250),
max_backoff: Duration::from_secs(30),
channel_capacity: 1024,
post_timeout: Duration::from_secs(10),
}
}
}
#[derive(Debug)]
enum Command {
Subscribe(Subscription),
Unsubscribe(Subscription),
Post {
id: u64,
frame: String,
reply: oneshot::Sender<Value>,
},
CancelPost {
id: u64,
},
Shutdown,
}
#[derive(Debug, Clone)]
pub struct WsClient {
inbound_tx: broadcast::Sender<WsMessage>,
cmd_tx: mpsc::UnboundedSender<Command>,
alive: Arc<AtomicBool>,
active: Arc<Mutex<Vec<Subscription>>>,
post_id: Arc<AtomicU64>,
post_timeout: Duration,
}
impl WsClient {
pub async fn connect(url: impl Into<String>) -> Result<Self, ClientError> {
Self::connect_with(url, WsConfig::default()).await
}
pub async fn connect_with(
url: impl Into<String>,
config: WsConfig,
) -> Result<Self, ClientError> {
let url = url.into();
let (inbound_tx, _) = broadcast::channel(config.channel_capacity);
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let alive = Arc::new(AtomicBool::new(true));
let active: Arc<Mutex<Vec<Subscription>>> = Arc::new(Mutex::new(Vec::new()));
let post_timeout = config.post_timeout;
let (probe, _) = tokio_tungstenite::connect_async(&url).await?;
drop(probe);
let task_state = TaskState {
url,
config,
inbound_tx: inbound_tx.clone(),
cmd_rx,
alive: alive.clone(),
active: active.clone(),
};
let _handle: JoinHandle<()> = tokio::spawn(run_background(task_state));
Ok(Self {
inbound_tx,
cmd_tx,
alive,
active,
post_id: Arc::new(AtomicU64::new(1)),
post_timeout,
})
}
pub async fn subscribe(&self, sub: Subscription) -> Result<(), ClientError> {
{
let mut g = self.active.lock().await;
if !g.contains(&sub) {
g.push(sub.clone());
}
}
self.cmd_tx
.send(Command::Subscribe(sub))
.map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
Ok(())
}
pub async fn unsubscribe(&self, sub: Subscription) -> Result<(), ClientError> {
{
let mut g = self.active.lock().await;
g.retain(|s| s != &sub);
}
self.cmd_tx
.send(Command::Unsubscribe(sub))
.map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
Ok(())
}
pub async fn subscribe_l2_book(
&self,
market: crate::types::MarketId,
) -> Result<(), ClientError> {
self.subscribe(Subscription::L2Book {
coin: market.0.to_string(),
})
.await
}
pub async fn subscribe_trades(
&self,
market: crate::types::MarketId,
) -> Result<(), ClientError> {
self.subscribe(Subscription::Trades {
coin: market.0.to_string(),
})
.await
}
pub async fn subscribe_bbo(&self, market: crate::types::MarketId) -> Result<(), ClientError> {
self.subscribe(Subscription::Bbo {
coin: market.0.to_string(),
})
.await
}
pub async fn subscribe_active_asset_ctx(
&self,
market: crate::types::MarketId,
) -> Result<(), ClientError> {
self.subscribe(Subscription::ActiveAssetCtx {
coin: market.0.to_string(),
})
.await
}
pub async fn subscribe_candles(
&self,
market: crate::types::MarketId,
interval: impl Into<String>,
) -> Result<(), ClientError> {
self.subscribe(Subscription::Candles {
coin: market.0.to_string(),
interval: interval.into(),
})
.await
}
pub async fn subscribe_all_mids(&self) -> Result<(), ClientError> {
self.subscribe(Subscription::AllMids).await
}
pub async fn subscribe_fills(&self, user: crate::wallet::Address) -> Result<(), ClientError> {
self.subscribe(Subscription::Fills { user }).await
}
pub async fn subscribe_order_updates(
&self,
user: crate::wallet::Address,
) -> Result<(), ClientError> {
self.subscribe(Subscription::OrderUpdates { user }).await
}
pub async fn subscribe_user_events(
&self,
user: crate::wallet::Address,
) -> Result<(), ClientError> {
self.subscribe(Subscription::UserEvents { user }).await
}
pub async fn subscribe_account_state(
&self,
user: crate::wallet::Address,
) -> Result<(), ClientError> {
self.subscribe(Subscription::AccountState { user }).await
}
#[must_use]
pub fn messages(&self) -> broadcast::Receiver<WsMessage> {
self.inbound_tx.subscribe()
}
pub async fn post_action(&self, wallet: &Wallet, action: Value) -> Result<Value, ClientError> {
let (nonce, signature) = crate::rest::exchange::sign_action(wallet, &action)?;
let payload = json!({ "signature": signature, "nonce": nonce, "action": action });
self.post_request("action", payload).await
}
async fn post_typed_trade(
&self,
wallet: &Wallet,
action: Value,
typed: TypedTradingAction<'_>,
) -> Result<Value, ClientError> {
let nonce = crate::rest::exchange::next_nonce();
let digest =
TypedTradingDigest::new(typed, crate::rest::exchange::MTF_CHAIN_ID, nonce).digest()?;
let signature = wallet.sign_digest(&digest)?.to_hex();
let payload = json!({
"signature": signature,
"nonce": nonce,
"action": action,
"sig_scheme": "typed",
});
self.post_request("action", payload).await
}
pub async fn post_info(&self, payload: Value) -> Result<Value, ClientError> {
self.post_request("info", payload).await
}
pub async fn submit_order(
&self,
wallet: &Wallet,
order: &Order,
) -> Result<OrderResponse, ClientError> {
if order.owner != wallet.address() {
return Err(ClientError::Validation(format!(
"order.owner {} != wallet address {}",
order.owner,
wallet.address()
)));
}
let action = json!({ "type": "submit_order", "order": order });
let payload = self
.post_typed_trade(wallet, action, TypedTradingAction::SubmitOrder(order))
.await?;
Ok(serde_json::from_value(payload)?)
}
pub async fn cancel_order(
&self,
wallet: &Wallet,
cancel: &CancelOrder,
) -> Result<Value, ClientError> {
if cancel.owner != wallet.address() {
return Err(ClientError::Validation(format!(
"cancel.owner {} != wallet address {}",
cancel.owner,
wallet.address()
)));
}
let action = json!({ "type": "cancel_order", "cancel": cancel });
self.post_typed_trade(wallet, action, TypedTradingAction::CancelOrder(cancel))
.await
}
async fn post_request(&self, request_type: &str, payload: Value) -> Result<Value, ClientError> {
let id = self.post_id.fetch_add(1, Ordering::Relaxed);
let frame = json!({
"method": "post",
"id": id,
"request": { "type": request_type, "payload": payload },
})
.to_string();
let (reply_tx, reply_rx) = oneshot::channel();
self.cmd_tx
.send(Command::Post {
id,
frame,
reply: reply_tx,
})
.map_err(|_| ClientError::WebSocket("ws task is dead".into()))?;
let response = match tokio::time::timeout(self.post_timeout, reply_rx).await {
Ok(Ok(resp)) => resp,
Ok(Err(_)) => {
return Err(ClientError::WebSocket(
"ws post: connection closed before response".into(),
));
}
Err(_) => {
let _ = self.cmd_tx.send(Command::CancelPost { id });
return Err(ClientError::WebSocket("ws post: timed out".into()));
}
};
if response.get("type").and_then(Value::as_str) == Some("error") {
let msg = response
.get("payload")
.and_then(Value::as_str)
.unwrap_or("unknown post error");
return Err(ClientError::WebSocket(format!("ws post error: {msg}")));
}
Ok(response.get("payload").cloned().unwrap_or(Value::Null))
}
#[must_use]
pub fn is_alive(&self) -> bool {
self.alive.load(Ordering::Acquire)
}
pub async fn shutdown(&self) {
let _ = self.cmd_tx.send(Command::Shutdown);
self.alive.store(false, Ordering::Release);
}
}
struct TaskState {
url: String,
config: WsConfig,
inbound_tx: broadcast::Sender<WsMessage>,
cmd_rx: mpsc::UnboundedReceiver<Command>,
alive: Arc<AtomicBool>,
active: Arc<Mutex<Vec<Subscription>>>,
}
async fn run_background(mut state: TaskState) {
let mut backoff = state.config.initial_backoff;
loop {
match run_connection(&mut state).await {
Ok(ConnectionExit::Shutdown) => break,
Ok(ConnectionExit::Recoverable) | Err(_) => {
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(state.config.max_backoff);
}
}
}
state.alive.store(false, Ordering::Release);
}
#[derive(Debug)]
enum ConnectionExit {
Shutdown,
Recoverable,
}
async fn run_connection(state: &mut TaskState) -> Result<ConnectionExit, ClientError> {
let (stream, _) = tokio_tungstenite::connect_async(&state.url).await?;
let (mut sink, mut stream) = stream.split();
{
let subs = state.active.lock().await.clone();
for sub in &subs {
let frame = json!({"method": "subscribe", "subscription": sub});
sink.send(Message::Text(frame.to_string())).await?;
}
}
let mut pending: HashMap<u64, oneshot::Sender<Value>> = HashMap::new();
let mut ping_tick = tokio::time::interval(state.config.ping_interval);
ping_tick.tick().await;
loop {
tokio::select! {
cmd = state.cmd_rx.recv() => {
match cmd {
Some(Command::Subscribe(sub)) => {
let frame = json!({"method": "subscribe", "subscription": sub});
sink.send(Message::Text(frame.to_string())).await?;
}
Some(Command::Unsubscribe(sub)) => {
let frame = json!({"method": "unsubscribe", "subscription": sub});
sink.send(Message::Text(frame.to_string())).await?;
}
Some(Command::Post { id, frame, reply }) => {
sink.send(Message::Text(frame)).await?;
pending.insert(id, reply);
}
Some(Command::CancelPost { id }) => {
pending.remove(&id);
}
Some(Command::Shutdown) | None => {
let _ = sink.send(Message::Close(None)).await;
return Ok(ConnectionExit::Shutdown);
}
}
}
_ = ping_tick.tick() => {
let ping = json!({"method": "ping"});
if sink.send(Message::Text(ping.to_string())).await.is_err() {
return Ok(ConnectionExit::Recoverable);
}
}
frame = stream.next() => {
let Some(frame) = frame else {
return Ok(ConnectionExit::Recoverable);
};
match frame {
Ok(Message::Text(text)) => {
match serde_json::from_str::<Value>(&text) {
Ok(v)
if v.get("channel").and_then(Value::as_str) == Some("post") =>
{
if let Some(id) =
v.pointer("/data/id").and_then(Value::as_u64)
{
if let Some(reply) = pending.remove(&id) {
let resp = v
.pointer("/data/response")
.cloned()
.unwrap_or(Value::Null);
let _ = reply.send(resp);
}
}
}
Ok(v) => {
let msg = serde_json::from_value::<WsMessage>(v)
.unwrap_or(WsMessage::Unknown);
let _ = state.inbound_tx.send(msg);
}
Err(_) => {}
}
}
Ok(Message::Binary(_) | Message::Pong(_) | Message::Ping(_)) => {
}
Ok(Message::Close(_)) => {
return Ok(ConnectionExit::Recoverable);
}
Ok(Message::Frame(_)) => {
}
Err(_) => return Ok(ConnectionExit::Recoverable),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_config_default_values() {
let c = WsConfig::default();
assert_eq!(c.ping_interval, Duration::from_secs(30));
assert_eq!(c.initial_backoff, Duration::from_millis(250));
assert_eq!(c.max_backoff, Duration::from_secs(30));
assert_eq!(c.channel_capacity, 1024);
}
}