use std::sync::Arc;
use dashmap::DashMap;
use crate::connector_manager::{ConnectorFactory, ConnectorPool, WebSocketPool};
use crate::core::traits::{CoreConnector, Credentials, WebSocketConnector};
use crate::core::types::{AccountType, ConnectorCapabilities, ExchangeError, ExchangeId, ExchangeResult};
#[derive(Clone)]
pub struct ExchangeHub {
rest: ConnectorPool,
ws: WebSocketPool,
rest_overrides: Arc<DashMap<ExchangeId, String>>,
}
impl Default for ExchangeHub {
fn default() -> Self {
Self {
rest: ConnectorPool::default(),
ws: WebSocketPool::default(),
rest_overrides: Arc::new(DashMap::new()),
}
}
}
impl ExchangeHub {
pub fn new() -> Self {
Self::default()
}
pub async fn connect_public(&self, id: ExchangeId, testnet: bool) -> ExchangeResult<()> {
let override_url = self.rest_overrides.get(&id).map(|v| v.clone());
let conn = ConnectorFactory::create_public(id, testnet, override_url).await?;
self.rest.insert(id, conn);
Ok(())
}
pub fn rest(&self, id: ExchangeId) -> Option<Arc<dyn CoreConnector>> {
self.rest.get(&id)
}
pub fn set_rest_base_override(&self, id: ExchangeId, url: String) {
if url.is_empty() {
self.rest_overrides.remove(&id);
} else {
self.rest_overrides.insert(id, url);
}
}
pub fn clear_rest_base_override(&self, id: ExchangeId) {
self.rest_overrides.remove(&id);
}
pub fn get_rest_base_override(&self, id: ExchangeId) -> Option<String> {
self.rest_overrides.get(&id).map(|v| v.clone())
}
pub async fn connect_websocket(
&self,
id: ExchangeId,
account_type: AccountType,
testnet: bool,
) -> ExchangeResult<()> {
let rest_override = self.rest_overrides.get(&id).map(|v| v.clone());
let ws = ConnectorFactory::create_websocket(id, account_type, testnet, rest_override).await?;
self.ws.insert(id, account_type, ws);
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect_websocket_with_credentials(
&self,
id: ExchangeId,
account_type: AccountType,
credentials: Credentials,
) -> ExchangeResult<()> {
let rest_override = self.rest_overrides.get(&id).map(|v| v.clone());
let ws = ConnectorFactory::create_websocket_authenticated(
id,
account_type,
credentials,
rest_override,
)
.await?;
self.ws.insert(id, account_type, ws);
Ok(())
}
pub fn ws(&self, id: ExchangeId, account_type: AccountType) -> Option<Arc<dyn WebSocketConnector>> {
self.ws.get(id, account_type)
}
pub async fn connect_full(
&self,
id: ExchangeId,
account_types: &[AccountType],
testnet: bool,
) -> ExchangeResult<()> {
let override_url = self.rest_overrides.get(&id).map(|v| v.clone());
let conn = ConnectorFactory::create_public(id, testnet, override_url.clone()).await?;
self.rest.insert(id, conn);
for &at in account_types {
if let Ok(ws) = ConnectorFactory::create_websocket(id, at, testnet, override_url.clone()).await {
self.ws.insert(id, at, ws);
}
}
Ok(())
}
pub async fn connect_full_validated(
&self,
id: ExchangeId,
account_types: &[AccountType],
testnet: bool,
) -> ExchangeResult<()> {
self.connect_full(id, account_types, testnet).await?;
let conn = self.rest.get(&id).ok_or_else(|| {
ExchangeError::NotValidated(format!(
"{:?} connected but rest() returned None — internal error",
id
))
})?;
if conn.validation_status().is_none() {
self.shutdown(id);
return Err(ExchangeError::NotValidated(format!(
"{:?} has no ValidationStamp — refusing connect_full_validated; \
use connect_full() to bypass",
id
)));
}
Ok(())
}
pub fn capabilities(&self, id: ExchangeId) -> Option<ConnectorCapabilities> {
self.rest.get(&id).map(|c| c.capabilities())
}
pub fn max_kline_limit(&self, id: ExchangeId, default: u16) -> u16 {
self.rest(id)
.and_then(|c| c.market_data_capabilities(AccountType::Spot).max_kline_limit)
.unwrap_or(default)
}
pub fn ids(&self) -> Vec<ExchangeId> {
self.rest.ids()
}
pub fn len_rest(&self) -> usize {
self.rest.len()
}
pub fn len_ws(&self) -> usize {
self.ws.len()
}
pub fn len(&self) -> usize {
self.rest.len() + self.ws.len()
}
pub fn is_empty(&self) -> bool {
self.rest.is_empty() && self.ws.is_empty()
}
pub fn shutdown(&self, id: ExchangeId) {
self.rest.remove(&id);
for at in [
AccountType::Spot,
AccountType::Margin,
AccountType::FuturesCross,
AccountType::FuturesIsolated,
AccountType::Earn,
AccountType::Lending,
AccountType::Options,
AccountType::Convert,
] {
self.ws.remove(id, at);
}
}
pub fn clear(&self) {
self.rest.clear();
self.ws.clear();
}
pub fn is_connected(&self, id: ExchangeId) -> bool {
self.rest.contains(&id)
}
pub fn list_connected(&self) -> Vec<ExchangeId> {
self.rest.ids()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max_kline_limit_returns_default_when_not_connected() {
let hub = ExchangeHub::new();
assert_eq!(hub.max_kline_limit(ExchangeId::OKX, 1000), 1000);
assert_eq!(hub.max_kline_limit(ExchangeId::Binance, 500), 500);
}
#[tokio::test]
#[ignore = "requires network access to OKX / KuCoin REST"]
async fn max_kline_limit_live_okx_kucoin() {
let hub = ExchangeHub::new();
hub.connect_public(ExchangeId::OKX, false).await.unwrap();
assert_eq!(hub.max_kline_limit(ExchangeId::OKX, 1000), 300);
hub.connect_public(ExchangeId::KuCoin, false).await.unwrap();
assert_eq!(hub.max_kline_limit(ExchangeId::KuCoin, 1000), 1500);
}
}