use std::collections::HashSet;
use std::time::Duration;
use bon::bon;
use futures::channel::{mpsc, oneshot};
use futures::future::{self, Either, pending};
use futures::{FutureExt, StreamExt};
use futures_timer::Delay;
use thiserror::Error;
use tracing::{debug, info, warn};
use web_time::Instant;
use super::client::{WebsocketConfig, WebsocketHandle};
use super::models::ServerMessage;
use super::topics::Topic;
use crate::Client;
use crate::errors::WSErrors;
use crate::types::{ClientMessage, OrderParams, RequestId};
#[derive(Debug, Error)]
pub enum ManagedWsError {
#[error("managed websocket is stopped")]
Stopped,
#[error("managed websocket command channel is full")]
Busy,
}
#[derive(Debug, Error)]
enum ReconnectError {
#[error("managed websocket handle dropped")]
HandleDropped,
#[error("exhausted {0} reconnect attempts")]
RetriesExhausted(u32),
#[error("subscription replay failed: {0}")]
ReplayFailed(#[source] WSErrors),
}
#[derive(Debug)]
pub enum WsEvent {
Message(Box<ServerMessage>),
Reconnecting,
Disconnected(String),
}
const MIN_BACKOFF: Duration = Duration::from_millis(10);
const CMD_CHANNEL_CAPACITY: usize = 256;
#[derive(bon::Builder, Clone, Debug)]
pub struct ManagedWsConfig {
#[builder(default = Duration::from_secs(1))]
pub initial_backoff: Duration,
#[builder(default = Duration::from_secs(30))]
pub max_backoff: Duration,
pub max_retries: Option<u32>,
#[builder(default = 10_000)]
pub channel_capacity: usize,
pub ws_config: Option<WebsocketConfig>,
#[builder(default = Duration::from_secs(60))]
pub idle_timeout: Duration,
#[builder(default = Duration::from_secs(30))]
pub backoff_reset_after: Duration,
}
impl Default for ManagedWsConfig {
fn default() -> Self {
Self::builder().build()
}
}
enum WsCommand {
Subscribe(Vec<String>, Option<RequestId>),
Unsubscribe(Vec<String>, Option<RequestId>),
Send(ClientMessage),
}
pub struct ManagedWebsocket {
event_rx: mpsc::Receiver<WsEvent>,
cmd_tx: mpsc::Sender<WsCommand>,
_shutdown_tx: oneshot::Sender<()>,
}
impl ManagedWebsocket {
#[cfg_attr(not(target_arch = "wasm32"), doc = "Uses [`tokio::spawn`] on native targets.")]
#[cfg_attr(
target_arch = "wasm32",
doc = "Uses [`wasm_bindgen_futures::spawn_local`] on wasm targets."
)]
pub async fn connect(client: &Client) -> Result<ManagedWebsocket, WSErrors> {
Self::connect_with(client, ManagedWsConfig::default()).await
}
pub async fn connect_with(
client: &Client,
config: ManagedWsConfig,
) -> Result<ManagedWebsocket, WSErrors> {
let ws = client.connect_ws().maybe_config(config.ws_config.clone()).call().await?;
let (event_tx, event_rx) = mpsc::channel(config.channel_capacity);
let (cmd_tx, cmd_rx) = mpsc::channel(CMD_CHANNEL_CAPACITY);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let inner = ManagedWsClient::from_client(client);
spawn(async move {
run_managed_ws(inner, ws, config, event_tx, cmd_rx, shutdown_rx).await;
});
Ok(ManagedWebsocket { event_rx, cmd_tx, _shutdown_tx: shutdown_tx })
}
pub async fn recv(&mut self) -> Option<WsEvent> {
self.event_rx.next().await
}
pub fn subscribe(
&self,
topics: impl IntoIterator<Item = Topic>,
id: Option<RequestId>,
) -> Result<(), ManagedWsError> {
let params: Vec<String> = topics.into_iter().map(|t| t.to_string()).collect();
self.try_send_cmd(WsCommand::Subscribe(params, id))
}
pub fn subscribe_raw(
&self,
topics: impl IntoIterator<Item = String>,
id: Option<RequestId>,
) -> Result<(), ManagedWsError> {
let params: Vec<String> = topics.into_iter().collect();
self.try_send_cmd(WsCommand::Subscribe(params, id))
}
pub fn unsubscribe(
&self,
topics: impl IntoIterator<Item = Topic>,
id: Option<RequestId>,
) -> Result<(), ManagedWsError> {
let params: Vec<String> = topics.into_iter().map(|t| t.to_string()).collect();
self.try_send_cmd(WsCommand::Unsubscribe(params, id))
}
pub fn unsubscribe_raw(
&self,
topics: impl IntoIterator<Item = String>,
id: Option<RequestId>,
) -> Result<(), ManagedWsError> {
let params: Vec<String> = topics.into_iter().collect();
self.try_send_cmd(WsCommand::Unsubscribe(params, id))
}
pub fn order_place(
&self,
tx: impl Into<String>,
id: Option<RequestId>,
) -> Result<(), ManagedWsError> {
self.try_send_cmd(WsCommand::Send(ClientMessage::OrderPlace {
id,
params: OrderParams { tx: tx.into() },
}))
}
pub fn order_cancel(
&self,
tx: impl Into<String>,
id: Option<RequestId>,
) -> Result<(), ManagedWsError> {
self.try_send_cmd(WsCommand::Send(ClientMessage::OrderCancel {
id,
params: OrderParams { tx: tx.into() },
}))
}
pub fn place_order(
&self,
signed: &bullet_exchange_interface::transaction::Transaction,
id: Option<RequestId>,
) -> Result<(), WSErrors> {
let base64 =
crate::Transaction::to_base64(signed).map_err(|e| WSErrors::WsError(e.to_string()))?;
self.order_place(base64, id).map_err(|e| WSErrors::WsError(e.to_string()))
}
pub fn cancel_order(
&self,
signed: &bullet_exchange_interface::transaction::Transaction,
id: Option<RequestId>,
) -> Result<(), WSErrors> {
let base64 =
crate::Transaction::to_base64(signed).map_err(|e| WSErrors::WsError(e.to_string()))?;
self.order_cancel(base64, id).map_err(|e| WSErrors::WsError(e.to_string()))
}
pub fn stop(self) {
}
fn try_send_cmd(&self, cmd: WsCommand) -> Result<(), ManagedWsError> {
let mut tx = self.cmd_tx.clone();
tx.try_send(cmd)
.map_err(|e| if e.is_full() { ManagedWsError::Busy } else { ManagedWsError::Stopped })
}
}
#[bon]
impl Client {
#[builder]
pub async fn connect_ws_managed(
&self,
config: Option<ManagedWsConfig>,
) -> Result<ManagedWebsocket, WSErrors> {
match config {
Some(c) => ManagedWebsocket::connect_with(self, c).await,
None => ManagedWebsocket::connect(self).await,
}
}
}
struct ManagedWsClient {
ws_client: reqwest::Client,
ws_url: String,
}
impl ManagedWsClient {
fn from_client(client: &Client) -> Self {
Self { ws_client: client.ws_client.clone(), ws_url: client.ws_url().to_string() }
}
async fn connect(
&self,
ws_config: &Option<WebsocketConfig>,
) -> Result<WebsocketHandle, WSErrors> {
let timeout = ws_config
.as_ref()
.map(|c| c.connection_timeout)
.unwrap_or(web_time::Duration::from_secs(10));
WebsocketHandle::connect(&self.ws_client, &self.ws_url, timeout).await
}
}
#[cfg(not(target_arch = "wasm32"))]
fn spawn<F>(fut: F)
where
F: std::future::Future<Output = ()> + Send + 'static,
{
tokio::spawn(fut);
}
#[cfg(target_arch = "wasm32")]
fn spawn<F>(fut: F)
where
F: std::future::Future<Output = ()> + 'static,
{
wasm_bindgen_futures::spawn_local(fut);
}
struct ReconnectState {
backoff: Duration,
connected_since: Option<Instant>,
}
async fn run_managed_ws(
client: ManagedWsClient,
mut ws: WebsocketHandle,
config: ManagedWsConfig,
mut event_tx: mpsc::Sender<WsEvent>,
mut cmd_rx: mpsc::Receiver<WsCommand>,
mut shutdown_rx: oneshot::Receiver<()>,
) {
let mut active_topics: HashSet<String> = HashSet::new();
let mut last_msg = Instant::now();
let mut state = ReconnectState {
backoff: config.initial_backoff.max(MIN_BACKOFF),
connected_since: Some(Instant::now()),
};
enum Branch {
Shutdown,
Recv(Result<Box<ServerMessage>, WSErrors>),
Cmd(Option<WsCommand>),
Idle,
}
loop {
let idle_remaining = if config.idle_timeout.is_zero() {
None
} else {
Some(config.idle_timeout.saturating_sub(last_msg.elapsed()))
};
let branch = {
let recv_fut = ws.recv().fuse();
let cmd_fut = cmd_rx.next().fuse();
let idle_fut = match idle_remaining {
Some(d) => Either::Left(Delay::new(d)),
None => Either::Right(pending::<()>()),
}
.fuse();
futures::pin_mut!(recv_fut, cmd_fut, idle_fut);
futures::select! {
_ = (&mut shutdown_rx).fuse() => Branch::Shutdown,
r = recv_fut => Branch::Recv(r.map(Box::new)),
c = cmd_fut => Branch::Cmd(c),
_ = idle_fut => Branch::Idle,
}
};
match branch {
Branch::Shutdown => {
debug!("shutdown signaled, stopping managed ws");
return;
}
Branch::Recv(Ok(msg)) => {
last_msg = Instant::now();
match event_tx.try_send(WsEvent::Message(msg)) {
Ok(()) => {}
Err(e) if e.is_full() => {
warn!("event channel full, dropping message — consumer too slow");
}
Err(_) => {
debug!("event receiver dropped, stopping managed ws");
return;
}
}
}
Branch::Recv(Err(e)) => {
match &e {
WSErrors::WsClosed { code, reason } => {
warn!(?code, %reason, "WebSocket disconnected, reconnecting");
}
WSErrors::WsStreamEnded => {
warn!("WebSocket stream ended, reconnecting");
}
_ => {
warn!(?e, "WebSocket error, reconnecting");
}
}
if do_reconnect(
&client,
&config,
&active_topics,
&mut event_tx,
&mut ws,
&mut shutdown_rx,
&mut state,
)
.await
{
return;
}
last_msg = Instant::now();
}
Branch::Idle => {
let elapsed = last_msg.elapsed();
warn!(?elapsed, "no server messages within idle timeout, forcing reconnect");
if do_reconnect(
&client,
&config,
&active_topics,
&mut event_tx,
&mut ws,
&mut shutdown_rx,
&mut state,
)
.await
{
return;
}
last_msg = Instant::now();
}
Branch::Cmd(Some(WsCommand::Subscribe(params, id))) => {
let new_params: Vec<String> =
params.into_iter().filter(|p| active_topics.insert(p.clone())).collect();
if new_params.is_empty() {
debug!("subscribe: all topics already active, skipping wire send");
} else if let Err(e) =
ws.send(ClientMessage::Subscribe { id, params: new_params }).await
{
debug!(?e, "subscribe send failed, will replay after reconnect");
}
}
Branch::Cmd(Some(WsCommand::Unsubscribe(params, id))) => {
let to_send: Vec<String> =
params.into_iter().filter(|p| active_topics.remove(p)).collect();
if to_send.is_empty() {
debug!("unsubscribe: no matching active topics, skipping wire send");
} else if let Err(e) =
ws.send(ClientMessage::Unsubscribe { id, params: to_send }).await
{
debug!(?e, "unsubscribe send failed");
}
}
Branch::Cmd(Some(WsCommand::Send(msg))) => {
if let Err(e) = ws.send(msg.clone()).await {
warn!(?e, "failed to send order message, reconnecting");
if do_reconnect(
&client,
&config,
&active_topics,
&mut event_tx,
&mut ws,
&mut shutdown_rx,
&mut state,
)
.await
{
return;
}
if let Err(e) = ws.send(msg).await {
warn!(?e, "retry after reconnect also failed");
}
last_msg = Instant::now();
}
}
Branch::Cmd(None) => {
debug!("command channel closed, stopping managed ws");
return;
}
}
}
}
async fn do_reconnect(
client: &ManagedWsClient,
config: &ManagedWsConfig,
active_topics: &HashSet<String>,
event_tx: &mut mpsc::Sender<WsEvent>,
ws: &mut WebsocketHandle,
shutdown_rx: &mut oneshot::Receiver<()>,
state: &mut ReconnectState,
) -> bool {
match event_tx.try_send(WsEvent::Reconnecting) {
Ok(()) => {}
Err(e) if e.is_full() => {}
Err(_) => return true,
}
if let Some(t) = state.connected_since
&& t.elapsed() >= config.backoff_reset_after
{
debug!(
uptime = ?t.elapsed(),
"previous connection was stable; resetting backoff"
);
state.backoff = config.initial_backoff.max(MIN_BACKOFF);
}
state.connected_since = None;
match reconnect(client, config, active_topics, event_tx, shutdown_rx, state).await {
Ok(new_ws) => {
*ws = new_ws;
state.connected_since = Some(Instant::now());
info!("reconnected successfully");
false
}
Err(ReconnectError::HandleDropped) => true,
Err(err) => {
let _ = event_tx.try_send(WsEvent::Disconnected(err.to_string()));
true
}
}
}
async fn reconnect(
client: &ManagedWsClient,
config: &ManagedWsConfig,
active_topics: &HashSet<String>,
event_tx: &mpsc::Sender<WsEvent>,
shutdown_rx: &mut oneshot::Receiver<()>,
state: &mut ReconnectState,
) -> Result<WebsocketHandle, ReconnectError> {
let max_backoff = config.max_backoff.max(MIN_BACKOFF);
let mut attempts = 0u32;
loop {
if shutdown_observed(shutdown_rx) || event_tx.is_closed() {
return Err(ReconnectError::HandleDropped);
}
attempts += 1;
if let Some(max) = config.max_retries
&& attempts > max
{
return Err(ReconnectError::RetriesExhausted(max));
}
let jitter_ms = rand::random::<u64>() % (state.backoff.as_millis() as u64 / 2 + 1);
let delay = state.backoff + Duration::from_millis(jitter_ms);
info!(attempt = attempts, delay = ?delay, backoff = ?state.backoff, "attempting reconnect");
match future::select(Delay::new(delay), &mut *shutdown_rx).await {
Either::Left(_) => {}
Either::Right(_) => return Err(ReconnectError::HandleDropped),
}
let connect_fut = client.connect(&config.ws_config);
let connect_result = match future::select(Box::pin(connect_fut), &mut *shutdown_rx).await {
Either::Left((r, _)) => r,
Either::Right(_) => return Err(ReconnectError::HandleDropped),
};
match connect_result {
Ok(mut ws) => {
if !active_topics.is_empty() {
let params: Vec<String> = active_topics.iter().cloned().collect();
debug!(count = params.len(), "replaying subscriptions");
if let Err(e) = ws.send(ClientMessage::Subscribe { id: None, params }).await {
if matches!(&e, WSErrors::WsStreamEnded | WSErrors::WsClosed { .. }) {
warn!(?e, "replay send lost connection, retrying");
state.backoff = (state.backoff * 2).min(max_backoff);
continue;
}
return Err(ReconnectError::ReplayFailed(e));
}
}
return Ok(ws);
}
Err(e) => {
warn!(?e, attempt = attempts, "reconnect failed");
state.backoff = (state.backoff * 2).min(max_backoff);
}
}
}
}
fn shutdown_observed(rx: &mut oneshot::Receiver<()>) -> bool {
!matches!(rx.try_recv(), Ok(None))
}