use crate::rest::errors::{BybitError, BybitResult};
use crate::ws::messages::{WsMessage, WsRequest};
use futures_util::stream::SplitSink;
use futures_util::{SinkExt, Stream, StreamExt};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, sleep};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
const MAX_RECONNECT_ATTEMPTS: u32 = 10;
const RECONNECT_BASE_DELAY_MS: u64 = 500;
const RECONNECT_MAX_DELAY_MS: u64 = 30_000;
const PING_INTERVAL_SECS: u64 = 20;
#[derive(Clone)]
struct AuthParams {
api_key: String,
expires: u64,
signature: String,
}
pub struct WsClient {
command_tx: mpsc::UnboundedSender<Command>,
message_rx: mpsc::UnboundedReceiver<WsMessage>,
_handle: Option<tokio::task::JoinHandle<()>>,
url: String,
subscribed_topics: Arc<Mutex<Vec<String>>>,
}
enum Command {
Subscribe(Vec<String>),
Unsubscribe(Vec<String>),
Authenticate {
api_key: String,
expires: u64,
signature: String,
},
}
impl WsClient {
pub async fn connect(url: &str) -> BybitResult<Self> {
let (command_tx, command_rx) = mpsc::unbounded_channel();
let (message_tx, message_rx) = mpsc::unbounded_channel();
let subscribed_topics = Arc::new(Mutex::new(Vec::new()));
let topics = subscribed_topics.clone();
let url_owned = url.to_string();
let handle = tokio::spawn(async move {
run_connection_loop(&url_owned, command_rx, message_tx, topics).await;
});
Ok(WsClient {
command_tx,
message_rx,
_handle: Some(handle),
url: url.to_string(),
subscribed_topics,
})
}
pub async fn subscribe(&self, topics: Vec<String>) -> BybitResult<()> {
{
let mut stored = self.subscribed_topics.lock().await;
for t in &topics {
if !stored.contains(t) {
stored.push(t.clone());
}
}
}
self.command_tx
.send(Command::Subscribe(topics))
.map_err(|e| BybitError::Internal(format!("Subscribe channel closed: {}", e)))?;
Ok(())
}
pub async fn unsubscribe(&self, topics: Vec<String>) -> BybitResult<()> {
{
let mut stored = self.subscribed_topics.lock().await;
stored.retain(|t| !topics.contains(t));
}
self.command_tx
.send(Command::Unsubscribe(topics))
.map_err(|e| BybitError::Internal(format!("Unsubscribe channel closed: {}", e)))?;
Ok(())
}
pub async fn authenticate(
&self,
api_key: &str,
expires: u64,
signature: &str,
) -> BybitResult<()> {
self.command_tx
.send(Command::Authenticate {
api_key: api_key.to_string(),
expires,
signature: signature.to_string(),
})
.map_err(|e| BybitError::Internal(format!("Auth channel closed: {}", e)))?;
Ok(())
}
pub fn url(&self) -> &str {
&self.url
}
pub fn close(&mut self) {
self.command_tx = mpsc::unbounded_channel().0;
self._handle = None;
}
}
impl Drop for WsClient {
fn drop(&mut self) {
}
}
impl Stream for WsClient {
type Item = WsMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.message_rx.poll_recv(cx)
}
}
async fn run_connection_loop(
url: &str,
mut command_rx: mpsc::UnboundedReceiver<Command>,
message_tx: mpsc::UnboundedSender<WsMessage>,
subscribed_topics: Arc<Mutex<Vec<String>>>,
) {
let mut auth_params: Option<AuthParams> = None;
let mut attempt = 0;
loop {
if attempt > 0 {
let delay_ms =
(RECONNECT_BASE_DELAY_MS * 2_u64.pow(attempt.min(6))).min(RECONNECT_MAX_DELAY_MS);
log::warn!(
"Reconnecting in {}ms (attempt {}/{})...",
delay_ms,
attempt,
MAX_RECONNECT_ATTEMPTS
);
sleep(Duration::from_millis(delay_ms)).await;
}
if attempt >= MAX_RECONNECT_ATTEMPTS {
log::error!("Max reconnect attempts reached. Giving up.");
break;
}
match connect_async(url).await {
Ok((ws_stream, _)) => {
log::info!("WebSocket connected to {}", url);
attempt = 0;
let (ws_write, ws_read) = ws_stream.split();
let ws_write = Arc::new(Mutex::new(ws_write));
if let Some(ref auth) = auth_params {
let req = WsRequest::auth(&auth.api_key, auth.expires, &auth.signature);
send_command(&ws_write, &req).await;
}
{
let topics = subscribed_topics.lock().await;
if !topics.is_empty() {
let req = WsRequest::subscribe(topics.clone());
send_command(&ws_write, &req).await;
}
}
run_connection(
ws_read,
ws_write,
&mut command_rx,
&message_tx,
&mut auth_params,
)
.await;
}
Err(e) => {
log::error!("Connection failed: {}", e);
}
}
attempt += 1;
}
}
async fn send_command(writer: &Arc<Mutex<SplitSink<WsStream, Message>>>, req: &WsRequest) {
if let Ok(json) = serde_json::to_string(req) {
if let Ok(mut w) = writer.try_lock() {
let _ = w.send(Message::Text(json.into())).await;
}
}
}
async fn run_connection(
mut ws_read: futures_util::stream::SplitStream<WsStream>,
ws_write: Arc<Mutex<SplitSink<WsStream, Message>>>,
command_rx: &mut mpsc::UnboundedReceiver<Command>,
message_tx: &mpsc::UnboundedSender<WsMessage>,
auth_params: &mut Option<AuthParams>,
) {
let mut ping_interval = interval(Duration::from_secs(PING_INTERVAL_SECS));
loop {
tokio::select! {
cmd = command_rx.recv() => {
match cmd {
Some(Command::Subscribe(topics)) => {
let req = WsRequest::subscribe(topics);
send_command(&ws_write, &req).await;
}
Some(Command::Unsubscribe(topics)) => {
let req = WsRequest::unsubscribe(topics);
send_command(&ws_write, &req).await;
}
Some(Command::Authenticate { api_key, expires, signature }) => {
*auth_params = Some(AuthParams {
api_key: api_key.clone(),
expires,
signature: signature.clone(),
});
let req = WsRequest::auth(&api_key, expires, &signature);
send_command(&ws_write, &req).await;
}
None => {
break;
}
}
}
_ = ping_interval.tick() => {
let ping = WsRequest::ping();
send_command(&ws_write, &ping).await;
}
msg = ws_read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<WsMessage>(&text) {
Ok(parsed) => {
if message_tx.send(parsed).is_err() {
break; }
}
Err(e) => {
log::warn!("Failed to parse WS message: {} -- raw: {}", e, text);
}
}
}
Some(Ok(Message::Ping(data))) => {
if let Ok(mut writer) = ws_write.try_lock() {
let _ = writer.send(Message::Pong(data)).await;
}
}
Some(Ok(Message::Close(frame))) => {
log::info!(
"WebSocket closed by server: {:?}",
frame.map(|f| f.reason.to_string())
);
break;
}
Some(Err(e)) => {
log::error!("WebSocket error: {}", e);
break;
}
None => {
log::info!("WebSocket stream ended");
break;
}
_ => {} }
}
}
}
log::info!("WebSocket connection handler exited");
}