use futures_util::{SinkExt, StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, RwLock};
use tokio::time::{interval, Duration};
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
use tracing::{debug, error, info, warn};
use crate::auth::{generate_ws_signature, get_timestamp};
use crate::config::WsConfig;
use crate::error::{BybitError, Result};
use crate::websocket::models::*;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type Callback = Arc<dyn Fn(WsMessage) + Send + Sync>;
pub struct BybitWebSocket {
config: WsConfig,
subscriptions: Arc<RwLock<Vec<String>>>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
tx: Option<mpsc::Sender<Message>>,
is_connected: Arc<RwLock<bool>>,
}
impl BybitWebSocket {
pub fn public(url: &str) -> Self {
Self {
config: WsConfig::public(url),
subscriptions: Arc::new(RwLock::new(Vec::new())),
callbacks: Arc::new(RwLock::new(HashMap::new())),
tx: None,
is_connected: Arc::new(RwLock::new(false)),
}
}
pub fn private(api_key: &str, api_secret: &str, url: &str) -> Self {
Self {
config: WsConfig::private(api_key, api_secret).with_url(url),
subscriptions: Arc::new(RwLock::new(Vec::new())),
callbacks: Arc::new(RwLock::new(HashMap::new())),
tx: None,
is_connected: Arc::new(RwLock::new(false)),
}
}
pub async fn connect(&mut self) -> Result<()> {
let url = &self.config.url;
info!(url = %url, "Connecting to WebSocket");
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| BybitError::WebSocket(Box::new(e)))?;
let (write, read) = ws_stream.split();
let (tx, mut rx) = mpsc::channel::<Message>(100);
self.tx = Some(tx.clone());
*self.is_connected.write().await = true;
let write = Arc::new(tokio::sync::Mutex::new(write));
let write_clone = write.clone();
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let mut w = write_clone.lock().await;
if let Err(e) = w.send(msg).await {
error!("Failed to send message: {}", e);
break;
}
}
});
if self.config.api_key.is_some() {
self.authenticate().await?;
}
let tx_ping = tx.clone();
let ping_interval = self.config.ping_interval;
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(ping_interval));
loop {
interval.tick().await;
let ping = WsPing::new();
let msg = serde_json::to_string(&ping).unwrap_or_default();
if tx_ping.send(Message::Text(msg)).await.is_err() {
break;
}
debug!("Ping sent");
}
});
let callbacks = self.callbacks.clone();
let is_connected = self.is_connected.clone();
let subscriptions = self.subscriptions.clone();
let config = self.config.clone();
let tx_reconnect = tx.clone();
tokio::spawn(async move {
Self::handle_messages(
read,
callbacks,
is_connected,
subscriptions,
config,
tx_reconnect,
)
.await;
});
info!("WebSocket connected");
Ok(())
}
async fn handle_messages(
mut read: futures_util::stream::SplitStream<WsStream>,
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
is_connected: Arc<RwLock<bool>>,
_subscriptions: Arc<RwLock<Vec<String>>>,
_config: WsConfig,
_tx: mpsc::Sender<Message>,
) {
while let Some(msg_result) = read.next().await {
match msg_result {
Ok(Message::Text(text)) => {
let json: serde_json::Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
warn!(
"Failed to parse message: {}, text: {}",
e,
&text[..text.len().min(200)]
);
continue; }
};
if is_pong(&json) {
debug!("Pong received");
continue;
}
if is_auth_response(&json) {
if json
.get("success")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
info!("Authentication successful");
} else {
error!("Authentication failed: {:?}", json);
}
continue;
}
if is_subscription_response(&json) {
if json
.get("success")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
debug!("Subscription successful");
} else {
warn!("Subscription failed: {:?}", json);
}
continue;
}
if is_data_message(&json) {
if let Ok(ws_msg) = serde_json::from_value::<WsMessage>(json) {
let cbs = callbacks.read().await;
if let Some(callback) = cbs.get(&ws_msg.topic) {
callback(ws_msg.clone());
} else {
for (topic, callback) in cbs.iter() {
if ws_msg
.topic
.starts_with(topic.split('.').next().unwrap_or(""))
{
callback(ws_msg.clone());
break;
}
}
}
}
}
}
Ok(Message::Ping(_)) => {
debug!("Received ping frame");
}
Ok(Message::Close(_)) => {
info!("WebSocket closed");
*is_connected.write().await = false;
break;
}
Err(e) => {
error!("WebSocket error: {}", e);
*is_connected.write().await = false;
break;
}
_ => {}
}
}
}
async fn authenticate(&self) -> Result<()> {
let api_key = self
.config
.api_key
.as_ref()
.ok_or_else(|| BybitError::Auth("API key not set".into()))?;
let api_secret = self
.config
.api_secret
.as_ref()
.ok_or_else(|| BybitError::Auth("API secret not set".into()))?;
let expires = get_timestamp() + 10000;
let signature = generate_ws_signature(api_secret, expires);
let auth_msg = WsAuthRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "auth".to_string(),
args: vec![
serde_json::Value::String(api_key.clone()),
serde_json::Value::Number(expires.into()),
serde_json::Value::String(signature),
],
};
let msg = serde_json::to_string(&auth_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
self.send(msg).await?;
info!("Authentication request sent");
Ok(())
}
pub async fn subscribe<F>(&mut self, topics: Vec<String>, callback: F) -> Result<()>
where
F: Fn(WsMessage) + Send + Sync + 'static,
{
let callback = Arc::new(callback) as Callback;
{
let mut cbs = self.callbacks.write().await;
for topic in &topics {
cbs.insert(topic.clone(), callback.clone());
}
}
{
let mut subs = self.subscriptions.write().await;
subs.extend(topics.clone());
}
let sub_msg = WsRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "subscribe".to_string(),
args: topics,
};
let msg = serde_json::to_string(&sub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
self.send(msg).await
}
pub async fn unsubscribe(&mut self, topics: Vec<String>) -> Result<()> {
{
let mut cbs = self.callbacks.write().await;
for topic in &topics {
cbs.remove(topic);
}
}
{
let mut subs = self.subscriptions.write().await;
subs.retain(|t| !topics.contains(t));
}
let unsub_msg = WsRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "unsubscribe".to_string(),
args: topics,
};
let msg =
serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
self.send(msg).await
}
async fn send(&self, msg: String) -> Result<()> {
if let Some(tx) = &self.tx {
tx.send(Message::Text(msg)).await.map_err(|_| {
BybitError::WebSocket(Box::new(
tokio_tungstenite::tungstenite::Error::AlreadyClosed,
))
})?;
}
Ok(())
}
pub async fn is_connected(&self) -> bool {
*self.is_connected.read().await
}
pub async fn disconnect(&mut self) -> Result<()> {
*self.is_connected.write().await = false;
self.tx = None;
info!("WebSocket disconnected");
Ok(())
}
}