use std::collections::HashSet;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use chrono::Utc;
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt, stream::{SplitSink, SplitStream}};
use serde_json::{json, Value};
use tokio::sync::{broadcast, Mutex};
use std::sync::Mutex as StdMutex;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials, AccountType,
ExchangeResult, ExchangeError,
ConnectionStatus, StreamEvent, StreamType, SubscriptionRequest,
};
use crate::core::types::{WebSocketResult, WebSocketError};
use crate::core::traits::WebSocketConnector;
use super::endpoints::DeribitUrls;
use super::auth::DeribitAuth;
use super::parser::DeribitParser;
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = SplitSink<WsStream, Message>;
type WsReader = SplitStream<WsStream>;
pub struct DeribitWebSocket {
auth: Option<DeribitAuth>,
urls: DeribitUrls,
_testnet: bool,
_account_type: AccountType,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
event_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_writer: Arc<Mutex<Option<WsSink>>>,
request_id: Arc<Mutex<u64>>,
access_token: Arc<Mutex<Option<String>>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl DeribitWebSocket {
pub async fn new(
credentials: Option<Credentials>,
testnet: bool,
account_type: AccountType,
) -> ExchangeResult<Self> {
let urls = if testnet {
DeribitUrls::TESTNET
} else {
DeribitUrls::MAINNET
};
let auth = credentials
.as_ref()
.map(DeribitAuth::new)
.transpose()?;
Ok(Self {
auth,
urls,
_testnet: testnet,
_account_type: account_type,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
event_tx: Arc::new(StdMutex::new(None)),
ws_writer: Arc::new(Mutex::new(None)),
request_id: Arc::new(Mutex::new(1)),
access_token: Arc::new(Mutex::new(None)),
last_ping: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
async fn next_id(&self) -> u64 {
let mut id = self.request_id.lock().await;
let current = *id;
*id += 1;
current
}
fn build_request(&self, id: u64, method: &str, params: Value) -> Value {
json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params
})
}
async fn send_request(&self, method: &str, params: Value) -> ExchangeResult<u64> {
let id = self.next_id().await;
let request = self.build_request(id, method, params);
let mut writer_guard = self.ws_writer.lock().await;
let writer = writer_guard.as_mut()
.ok_or_else(|| ExchangeError::Network("WebSocket not connected".to_string()))?;
let msg_text = serde_json::to_string(&request)
.map_err(|e| ExchangeError::Parse(format!("Failed to serialize request: {}", e)))?;
writer.send(Message::Text(msg_text)).await
.map_err(|e| ExchangeError::Network(format!("Failed to send message: {}", e)))?;
Ok(id)
}
async fn authenticate(&self) -> ExchangeResult<()> {
let auth = self.auth.as_ref()
.ok_or_else(|| ExchangeError::Auth("No credentials provided".to_string()))?;
let params = auth.client_credentials_params();
let params_json = serde_json::to_value(params)
.map_err(|e| ExchangeError::Parse(format!("Failed to serialize auth params: {}", e)))?;
let _id = self.send_request("public/auth", params_json).await?;
Ok(())
}
async fn subscribe_channels(&self, channels: Vec<String>, is_private: bool) -> ExchangeResult<()> {
let method = if is_private {
"private/subscribe"
} else {
"public/subscribe"
};
let params = json!({
"channels": channels
});
self.send_request(method, params).await?;
Ok(())
}
async fn unsubscribe_channels(&self, channels: Vec<String>, is_private: bool) -> ExchangeResult<()> {
let method = if is_private {
"private/unsubscribe"
} else {
"public/unsubscribe"
};
let params = json!({
"channels": channels
});
self.send_request(method, params).await?;
Ok(())
}
fn build_channel_name(&self, request: &SubscriptionRequest) -> String {
let instrument = if request.symbol.base.is_empty() {
String::new()
} else {
format!("{}-PERPETUAL", request.symbol.base.to_uppercase())
};
match &request.stream_type {
StreamType::Ticker => format!("ticker.{}.100ms", instrument),
StreamType::Trade => format!("trades.{}.100ms", instrument),
StreamType::Orderbook => format!("book.{}.100ms", instrument),
StreamType::OrderbookDelta => format!("book.{}.100ms", instrument),
StreamType::Kline { interval } => {
format!("chart.trades.{}.{}", instrument, interval)
},
StreamType::OrderUpdate => "user.orders.any.any.raw".to_string(),
StreamType::BalanceUpdate => {
"user.portfolio.BTC,user.portfolio.ETH,user.portfolio.USDC,user.portfolio.USDT,user.portfolio.SOL".to_string()
}
StreamType::PositionUpdate => "user.changes.any.any.raw".to_string(),
_ => String::new(),
}
}
fn is_private_subscription(&self, request: &SubscriptionRequest) -> bool {
matches!(
request.stream_type,
StreamType::OrderUpdate | StreamType::BalanceUpdate | StreamType::PositionUpdate
)
}
fn start_message_loop(
mut reader: WsReader,
ws_writer: Arc<Mutex<Option<WsSink>>>,
event_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
status: Arc<Mutex<ConnectionStatus>>,
access_token: Arc<Mutex<Option<String>>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
) {
tokio::spawn(async move {
while let Some(msg) = reader.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
if let Some(result) = parsed.get("result") {
if let Some(token) = result.get("access_token") {
if let Some(token_str) = token.as_str() {
let mut token_guard = access_token.lock().await;
*token_guard = Some(token_str.to_string());
}
}
}
if let Some(method) = parsed.get("method") {
if method == "heartbeat" {
let is_test_request = parsed
.get("params")
.and_then(|p| p.get("type"))
.and_then(|t| t.as_str())
== Some("test_request");
if is_test_request {
let original_id = parsed
.get("id")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let response = json!({
"jsonrpc": "2.0",
"id": original_id,
"method": "public/test"
});
if let Ok(response_text) = serde_json::to_string(&response) {
let mut writer_guard = ws_writer.lock().await;
if let Some(ref mut writer) = *writer_guard {
let _ = writer.send(Message::Text(response_text)).await;
}
}
}
} else if method == "subscription" {
if let Some(event) = Self::parse_event(&parsed) {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Ok(event));
}
}
}
}
}
}
Ok(Message::Pong(_)) => {
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
}
Ok(Message::Close(_)) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Err(e) => {
let tx_guard = event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(Err(WebSocketError::ConnectionError(format!("WebSocket error: {}", e))));
}
break;
}
_ => {}
}
}
let _ = event_tx.lock().unwrap().take();
*status.lock().await = ConnectionStatus::Disconnected;
});
}
fn parse_event(msg: &Value) -> Option<StreamEvent> {
let params = msg.get("params")?;
let channel = params.get("channel")?.as_str()?;
let data = params.get("data")?;
if channel.starts_with("ticker.") {
DeribitParser::parse_ws_ticker(data).ok().map(StreamEvent::Ticker)
} else if channel.starts_with("book.") {
DeribitParser::parse_ws_orderbook(data).ok()
} else if channel.starts_with("trades.") {
DeribitParser::parse_ws_trade(data).ok().map(StreamEvent::Trade)
} else if channel.starts_with("user.orders.") {
DeribitParser::parse_ws_order_update(data).ok().map(StreamEvent::OrderUpdate)
} else if channel.starts_with("user.portfolio.") {
let currency = channel.strip_prefix("user.portfolio.").unwrap_or("");
let balance_val = |key: &str| -> f64 {
data.get(key)
.and_then(|v| v.as_f64())
.unwrap_or(0.0)
};
let total = balance_val("equity");
let available = balance_val("available_funds");
let event = crate::core::BalanceUpdateEvent {
asset: currency.to_string(),
free: available,
locked: (total - available).max(0.0),
total,
delta: None,
reason: Some(crate::core::BalanceChangeReason::Other),
timestamp: Utc::now().timestamp_millis(),
};
Some(StreamEvent::BalanceUpdate(event))
} else {
None
}
}
fn start_ws_ping_task(
ws_writer: Arc<Mutex<Option<WsSink>>>,
last_ping: Arc<Mutex<Instant>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
interval.tick().await;
loop {
interval.tick().await;
let mut writer_guard = ws_writer.lock().await;
if let Some(ref mut writer) = *writer_guard {
*last_ping.lock().await = Instant::now();
if writer.send(Message::Ping(vec![])).await.is_err() {
break;
}
} else {
break;
}
}
});
}
fn start_heartbeat_task(
ws_writer: Arc<Mutex<Option<WsSink>>>,
request_id: Arc<Mutex<u64>>,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
interval.tick().await;
loop {
interval.tick().await;
let id = {
let mut id_guard = request_id.lock().await;
let current = *id_guard;
*id_guard += 1;
current
};
let test_msg = json!({
"jsonrpc": "2.0",
"id": id,
"method": "public/test"
});
if let Ok(msg_text) = serde_json::to_string(&test_msg) {
let mut writer_guard = ws_writer.lock().await;
if let Some(ref mut writer) = *writer_guard {
if writer.send(Message::Text(msg_text)).await.is_err() {
break;
}
} else {
break;
}
}
}
});
}
}
#[async_trait]
impl WebSocketConnector for DeribitWebSocket {
async fn connect(&mut self, _account_type: AccountType) -> WebSocketResult<()> {
*self.status.lock().await = ConnectionStatus::Connecting;
let ws_url = self.urls.ws_url();
let (ws_stream, _) = connect_async(ws_url).await
.map_err(|e| WebSocketError::ConnectionError(format!("Failed to connect: {}", e)))?;
let (write, read) = ws_stream.split();
*self.ws_writer.lock().await = Some(write);
let (tx, _) = broadcast::channel(1000);
*self.event_tx.lock().unwrap() = Some(tx);
if self.auth.is_some() {
self.authenticate().await
.map_err(|e| WebSocketError::Auth(format!("Authentication failed: {}", e)))?;
}
Self::start_message_loop(
read,
self.ws_writer.clone(),
self.event_tx.clone(),
self.status.clone(),
self.access_token.clone(),
self.last_ping.clone(),
self.ws_ping_rtt_ms.clone(),
);
Self::start_ws_ping_task(
self.ws_writer.clone(),
self.last_ping.clone(),
);
Self::start_heartbeat_task(
self.ws_writer.clone(),
self.request_id.clone(),
);
*self.status.lock().await = ConnectionStatus::Connected;
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
if let Some(mut writer) = self.ws_writer.lock().await.take() {
let _ = writer.close().await;
}
*self.status.lock().await = ConnectionStatus::Disconnected;
self.subscriptions.lock().await.clear();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
match self.status.try_lock() {
Ok(guard) => *guard,
Err(_) => ConnectionStatus::Disconnected,
}
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
if self.connection_status() != ConnectionStatus::Connected {
return Err(WebSocketError::NotConnected);
}
let channel = self.build_channel_name(&request);
if channel.is_empty() {
return Err(WebSocketError::Subscription("Unsupported stream type".to_string()));
}
let channels: Vec<String> = channel.split(',').map(|s| s.trim().to_string()).collect();
let is_private = self.is_private_subscription(&request);
self.subscribe_channels(channels, is_private).await
.map_err(|e| WebSocketError::Subscription(format!("Subscribe failed: {}", e)))?;
self.subscriptions.lock().await.insert(request);
Ok(())
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channel = self.build_channel_name(&request);
if channel.is_empty() {
return Ok(());
}
let channels: Vec<String> = channel.split(',').map(|s| s.trim().to_string()).collect();
let is_private = self.is_private_subscription(&request);
self.unsubscribe_channels(channels, is_private).await
.map_err(|e| WebSocketError::Subscription(format!("Unsubscribe failed: {}", e)))?;
self.subscriptions.lock().await.remove(&request);
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let tx_guard = self.event_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let rx = tx.subscribe();
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).map(|r| {
r.map_err(|e| WebSocketError::ConnectionError(format!("Broadcast error: {}", e)))
.and_then(|x| x)
}))
} else {
Box::pin(futures_util::stream::empty())
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.subscriptions.try_lock() {
Ok(guard) => guard.iter().cloned().collect(),
Err(_) => Vec::new(),
}
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
}