use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex, OnceLock};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tokio::sync::{broadcast, Mutex};
use tokio::time::sleep;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials, AccountType, ExchangeResult,
ConnectionStatus, StreamEvent, StreamType, SubscriptionRequest,
};
use crate::core::types::{WebSocketResult, WebSocketError, OrderbookCapabilities};
use crate::core::traits::WebSocketConnector;
use crate::core::utils::SimpleRateLimiter;
use super::auth::UpbitAuth;
use super::endpoints::UpbitUrls;
use super::parser::UpbitParser;
static GLOBAL_WS_LIMITER: OnceLock<Arc<StdMutex<SimpleRateLimiter>>> = OnceLock::new();
fn get_global_ws_limiter() -> Arc<StdMutex<SimpleRateLimiter>> {
GLOBAL_WS_LIMITER.get_or_init(|| {
Arc::new(StdMutex::new(
SimpleRateLimiter::new(5, Duration::from_secs(10))
))
}).clone()
}
static SUBSCRIPTION_LIMITER: OnceLock<Arc<StdMutex<SimpleRateLimiter>>> = OnceLock::new();
fn get_subscription_limiter() -> Arc<StdMutex<SimpleRateLimiter>> {
SUBSCRIPTION_LIMITER.get_or_init(|| {
Arc::new(StdMutex::new(
SimpleRateLimiter::new(3, Duration::from_secs(1))
))
}).clone()
}
#[derive(Debug, Clone, Serialize)]
#[allow(dead_code)]
struct SubscriptionMessage {
ticket: String,
#[serde(rename = "type")]
msg_type: String,
codes: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct IncomingMessage {
#[serde(rename = "type")]
msg_type: Option<String>,
code: Option<String>,
status: Option<String>,
}
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
pub struct UpbitWebSocket {
auth: Option<UpbitAuth>,
urls: UpbitUrls,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<SubscriptionRequest>>>,
broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_stream: Arc<Mutex<Option<WsStream>>>,
last_ping: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl UpbitWebSocket {
pub async fn new(
credentials: Option<Credentials>,
region: &str,
) -> ExchangeResult<Self> {
let urls = match region {
"id" => UpbitUrls::INDONESIA,
"th" => UpbitUrls::THAILAND,
_ => UpbitUrls::SINGAPORE,
};
let auth = credentials
.as_ref()
.map(UpbitAuth::new)
.transpose()?;
Ok(Self {
auth,
urls,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
broadcast_tx: Arc::new(StdMutex::new(None)),
ws_stream: Arc::new(Mutex::new(None)),
last_ping: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
async fn subscription_rate_limit_wait() {
loop {
let can_subscribe = {
let limiter = get_subscription_limiter();
let mut guard = limiter.lock().expect("Mutex poisoned");
guard.try_acquire()
};
if can_subscribe {
return; }
let wait_time = {
let limiter = get_subscription_limiter();
let guard = limiter.lock().expect("Mutex poisoned");
guard.time_until_ready()
};
if wait_time > Duration::ZERO {
sleep(wait_time).await;
} else {
sleep(Duration::from_millis(100)).await;
}
}
}
async fn send_subscription(
&self,
msg_type: &str,
symbols: Vec<String>,
) -> WebSocketResult<()> {
Self::subscription_rate_limit_wait().await;
let mut ws_lock = self.ws_stream.lock().await;
let ws = ws_lock.as_mut()
.ok_or(WebSocketError::NotConnected)?;
let subscription = json!([
{"ticket": "upbit-connector"},
{
"type": msg_type,
"codes": symbols
},
{"format": "DEFAULT"}
]);
ws.send(Message::Text(subscription.to_string())).await
.map_err(|e| WebSocketError::SendError(e.to_string()))?;
Ok(())
}
async fn send_ping(&self) -> WebSocketResult<()> {
let mut ws_lock = self.ws_stream.lock().await;
let ws = ws_lock.as_mut()
.ok_or(WebSocketError::NotConnected)?;
ws.send(Message::Text("PING".to_string())).await
.map_err(|e| WebSocketError::SendError(e.to_string()))?;
*self.last_ping.lock().await = Instant::now();
ws.send(Message::Ping(vec![])).await
.map_err(|e| WebSocketError::SendError(e.to_string()))?;
Ok(())
}
async fn handle_message(&self, text: &str) -> Option<StreamEvent> {
let value: Value = match serde_json::from_str(text) {
Ok(v) => v,
Err(_) => return None,
};
if let Some(status) = value.get("status") {
if status == "UP" {
return None; }
}
let msg_type = value.get("type")
.or_else(|| value.get("ty"))
.and_then(|t| t.as_str())?;
match msg_type {
"ticker" => {
UpbitParser::parse_ws_ticker(&value)
.ok()
.map(StreamEvent::Ticker)
},
"trade" => {
UpbitParser::parse_ws_trade(&value)
.ok()
.map(StreamEvent::Trade)
},
"orderbook" => {
UpbitParser::parse_ws_orderbook(&value)
.ok()
.map(StreamEvent::OrderbookSnapshot)
},
_ => None,
}
}
async fn start_message_loop(&self) -> WebSocketResult<()> {
let ws_stream = self.ws_stream.clone();
let broadcast_tx = self.broadcast_tx.clone();
let status = self.status.clone();
let last_ping = self.last_ping.clone();
let ws_ping_rtt_ms = self.ws_ping_rtt_ms.clone();
let ws_clone = self.clone_for_loop();
tokio::spawn(async move {
loop {
{
let last = last_ping.lock().await;
if last.elapsed() > Duration::from_secs(30) {
drop(last);
if let Err(e) = ws_clone.send_ping().await {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Err(e));
}
break;
}
}
}
let msg = {
let mut ws_lock = ws_stream.lock().await;
match ws_lock.as_mut() {
Some(ws) => {
match tokio::time::timeout(Duration::from_millis(100), ws.next()).await {
Ok(Some(msg)) => msg,
Ok(None) => break,
Err(_) => continue, }
},
None => break,
}
};
match msg {
Ok(Message::Text(text)) => {
if let Some(event) = ws_clone.handle_message(&text).await {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Ok(event));
}
}
},
Ok(Message::Binary(data)) => {
if let Ok(text) = String::from_utf8(data) {
if let Some(event) = ws_clone.handle_message(&text).await {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Ok(event));
}
}
}
},
Ok(Message::Pong(_)) => {
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
},
Ok(Message::Close(_)) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
},
Err(e) => {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
let _ = tx.send(Err(WebSocketError::ReceiveError(e.to_string())));
}
break;
},
_ => {},
}
}
let _ = broadcast_tx.lock().unwrap().take();
});
Ok(())
}
fn clone_for_loop(&self) -> Self {
Self {
auth: self.auth.clone(),
urls: self.urls.clone(),
status: self.status.clone(),
subscriptions: self.subscriptions.clone(),
broadcast_tx: self.broadcast_tx.clone(),
ws_stream: self.ws_stream.clone(),
last_ping: self.last_ping.clone(),
ws_ping_rtt_ms: self.ws_ping_rtt_ms.clone(),
}
}
}
#[async_trait]
impl WebSocketConnector for UpbitWebSocket {
async fn connect(&mut self, _account_type: AccountType) -> WebSocketResult<()> {
let limiter = get_global_ws_limiter();
loop {
let can_connect = {
let mut limiter_guard = limiter.lock().expect("Mutex poisoned");
limiter_guard.try_acquire()
};
if can_connect {
break;
}
let wait_time = {
let limiter_guard = limiter.lock().expect("Mutex poisoned");
limiter_guard.time_until_ready()
};
if wait_time > Duration::ZERO {
sleep(wait_time).await;
} else {
sleep(Duration::from_millis(100)).await;
}
}
let ws_url = if self.auth.is_some() {
self.urls.ws_private_url()
} else {
self.urls.ws_url().to_string()
};
let (ws_stream, _) = connect_async(&ws_url).await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
*self.ws_stream.lock().await = Some(ws_stream);
*self.status.lock().await = ConnectionStatus::Connected;
*self.last_ping.lock().await = Instant::now();
let (tx, _) = broadcast::channel(1000);
*self.broadcast_tx.lock().unwrap() = Some(tx);
self.start_message_loop().await?;
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
let mut ws_lock = self.ws_stream.lock().await;
if let Some(ws) = ws_lock.as_mut() {
ws.close(None).await
.map_err(|e: tokio_tungstenite::tungstenite::Error| WebSocketError::ConnectionError(e.to_string()))?;
}
*ws_lock = None;
*self.status.lock().await = ConnectionStatus::Disconnected;
let _ = self.broadcast_tx.lock().unwrap().take();
Ok(())
}
fn connection_status(&self) -> ConnectionStatus {
self.status.try_lock()
.map(|s| *s)
.unwrap_or(ConnectionStatus::Disconnected)
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
self.subscriptions.lock().await.insert(request.clone());
let upbit_symbol = format!("{}-{}", request.symbol.quote.to_uppercase(), request.symbol.base.to_uppercase());
match request.stream_type {
StreamType::Ticker => {
self.send_subscription("ticker", vec![upbit_symbol]).await?;
},
StreamType::Trade => {
self.send_subscription("trade", vec![upbit_symbol]).await?;
},
StreamType::Orderbook => {
self.send_subscription("orderbook", vec![upbit_symbol]).await?;
},
_ => {
return Err(WebSocketError::UnsupportedOperation(format!("Unsupported stream type: {:?}", request.stream_type)));
}
}
Ok(())
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
self.subscriptions.lock().await.remove(&request);
Err(WebSocketError::UnsupportedOperation("Upbit doesn't support unsubscribe - reconnect required".to_string()))
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let tx_guard = self.broadcast_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let rx = tx.subscribe();
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).filter_map(|r| async move {
r.ok()
}))
} else {
Box::pin(futures_util::stream::empty())
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
self.subscriptions.try_lock()
.map(|subs| subs.iter().cloned().collect())
.unwrap_or_default()
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
OrderbookCapabilities {
ws_depths: &[1, 5, 15, 30],
ws_default_depth: Some(30),
rest_max_depth: Some(30),
rest_depth_values: &[1, 5, 15, 30],
supports_snapshot: true,
supports_delta: false,
update_speeds_ms: &[],
default_speed_ms: None,
ws_channels: &[],
checksum: None,
has_sequence: false,
has_prev_sequence: false,
supports_aggregation: true,
aggregation_levels: &[],
}
}
}