use std::sync::Arc;
use crate::connector_manager::ConnectorPool;
use crate::core::traits::MarketData;
use crate::core::types::{
AccountType, ExchangeError, ExchangeId, ExchangeResult, Kline, OrderBook, Price, Symbol,
Ticker,
};
#[derive(Clone)]
pub struct ConnectorAggregator {
pool: Arc<ConnectorPool>,
}
impl ConnectorAggregator {
pub fn new(pool: ConnectorPool) -> Self {
Self {
pool: Arc::new(pool),
}
}
pub fn pool(&self) -> &ConnectorPool {
&self.pool
}
pub fn available_exchanges(&self) -> Vec<ExchangeId> {
self.pool.ids()
}
pub async fn get_price(
&self,
id: ExchangeId,
symbol: Symbol,
account_type: AccountType,
) -> ExchangeResult<Price> {
let connector = self
.pool
.get(&id)
.ok_or_else(|| ExchangeError::NotFound(format!("Exchange {:?} not in pool", id)))?;
connector.get_price(symbol, account_type).await
}
pub async fn get_ticker(
&self,
id: ExchangeId,
symbol: Symbol,
account_type: AccountType,
) -> ExchangeResult<Ticker> {
let connector = self
.pool
.get(&id)
.ok_or_else(|| ExchangeError::NotFound(format!("Exchange {:?} not in pool", id)))?;
connector.get_ticker(symbol, account_type).await
}
pub async fn get_orderbook(
&self,
id: ExchangeId,
symbol: Symbol,
account_type: AccountType,
depth: Option<u16>,
) -> ExchangeResult<OrderBook> {
let connector = self
.pool
.get(&id)
.ok_or_else(|| ExchangeError::NotFound(format!("Exchange {:?} not in pool", id)))?;
connector.get_orderbook(symbol, depth, account_type).await
}
pub async fn get_klines(
&self,
id: ExchangeId,
symbol: Symbol,
interval: &str,
account_type: AccountType,
limit: Option<u16>,
end_time: Option<i64>,
) -> ExchangeResult<Vec<Kline>> {
let connector = self
.pool
.get(&id)
.ok_or_else(|| ExchangeError::NotFound(format!("Exchange {:?} not in pool", id)))?;
connector
.get_klines(symbol, interval, limit, account_type, end_time)
.await
}
pub async fn get_prices_multi(
&self,
ids: &[ExchangeId],
symbol: Symbol,
account_type: AccountType,
) -> ExchangeResult<std::collections::HashMap<ExchangeId, Price>> {
use futures_util::future::join_all;
let connectors: Vec<_> = ids
.iter()
.filter_map(|id| self.pool.get(id).map(|c| (*id, c)))
.collect();
if connectors.is_empty() {
return Err(ExchangeError::NotFound(
"No specified exchanges found in pool".to_string(),
));
}
let futures = connectors.into_iter().map(|(id, connector)| {
let sym = symbol.clone();
let acc_type = account_type;
async move {
connector
.get_price(sym, acc_type)
.await
.ok()
.map(|price| (id, price))
}
});
let results: Vec<Option<(ExchangeId, Price)>> = join_all(futures).await;
let prices: std::collections::HashMap<_, _> =
results.into_iter().flatten().collect();
if prices.is_empty() {
return Err(ExchangeError::NotFound(
"All exchange queries failed".to_string(),
));
}
Ok(prices)
}
pub async fn get_best_bid_ask(
&self,
ids: &[ExchangeId],
symbol: Symbol,
account_type: AccountType,
) -> ExchangeResult<BestBidAsk> {
use futures_util::future::join_all;
let connectors: Vec<_> = ids
.iter()
.filter_map(|id| self.pool.get(id).map(|c| (*id, c)))
.collect();
if connectors.is_empty() {
return Err(ExchangeError::NotFound(
"No specified exchanges found in pool".to_string(),
));
}
let futures = connectors.into_iter().map(|(id, connector)| {
let sym = symbol.clone();
let acc_type = account_type;
async move {
connector
.get_orderbook(sym, Some(1), acc_type)
.await
.ok()
.map(|ob| (id, ob))
}
});
let results: Vec<Option<(ExchangeId, OrderBook)>> = join_all(futures).await;
let orderbooks: Vec<_> = results.into_iter().flatten().collect();
if orderbooks.is_empty() {
return Err(ExchangeError::NotFound(
"All orderbook queries failed".to_string(),
));
}
let mut best_bid: Option<(f64, ExchangeId)> = None;
let mut best_ask: Option<(f64, ExchangeId)> = None;
for (id, ob) in orderbooks {
if let Some((bid_price, _)) = ob.bids.first() {
match best_bid {
None => best_bid = Some((*bid_price, id)),
Some((current_best, _)) if *bid_price > current_best => {
best_bid = Some((*bid_price, id))
}
_ => {}
}
}
if let Some((ask_price, _)) = ob.asks.first() {
match best_ask {
None => best_ask = Some((*ask_price, id)),
Some((current_best, _)) if *ask_price < current_best => {
best_ask = Some((*ask_price, id))
}
_ => {}
}
}
}
let (bid, bid_exchange) = best_bid.ok_or_else(|| {
ExchangeError::NotFound("No valid bids found in orderbooks".to_string())
})?;
let (ask, ask_exchange) = best_ask.ok_or_else(|| {
ExchangeError::NotFound("No valid asks found in orderbooks".to_string())
})?;
Ok(BestBidAsk {
bid,
bid_exchange,
ask,
ask_exchange,
spread: ask - bid,
spread_percent: ((ask - bid) / bid) * 100.0,
})
}
}
pub struct ConnectorAggregatorBuilder {
pool: ConnectorPool,
}
impl ConnectorAggregatorBuilder {
pub fn new() -> Self {
Self {
pool: ConnectorPool::new(),
}
}
pub fn with_pool(pool: ConnectorPool) -> Self {
Self { pool }
}
pub fn build(self) -> ConnectorAggregator {
ConnectorAggregator::new(self.pool)
}
}
impl Default for ConnectorAggregatorBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BestBidAsk {
pub bid: f64,
pub bid_exchange: ExchangeId,
pub ask: f64,
pub ask_exchange: ExchangeId,
pub spread: f64,
pub spread_percent: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connector_manager::AnyConnector;
use crate::crypto::cex::okx::OkxConnector;
async fn create_mock_connector() -> Arc<AnyConnector> {
let connector = OkxConnector::public(true).await.unwrap();
Arc::new(AnyConnector::OKX(Arc::new(connector)))
}
async fn create_test_pool() -> ConnectorPool {
let pool = ConnectorPool::new();
pool.insert(ExchangeId::Binance, create_mock_connector().await);
pool.insert(ExchangeId::KuCoin, create_mock_connector().await);
pool
}
#[tokio::test]
async fn test_new_aggregator() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
assert!(aggregator.available_exchanges().is_empty());
}
#[tokio::test]
async fn test_aggregator_with_pool() {
let pool = create_test_pool().await;
let aggregator = ConnectorAggregator::new(pool);
assert_eq!(aggregator.available_exchanges().len(), 2);
}
#[tokio::test]
async fn test_builder_new() {
let aggregator = ConnectorAggregatorBuilder::new().build();
assert!(aggregator.available_exchanges().is_empty());
}
#[tokio::test]
async fn test_builder_with_pool() {
let pool = create_test_pool().await;
let aggregator = ConnectorAggregatorBuilder::with_pool(pool).build();
assert_eq!(aggregator.available_exchanges().len(), 2);
}
#[tokio::test]
async fn test_builder_default() {
let aggregator = ConnectorAggregatorBuilder::default().build();
assert!(aggregator.available_exchanges().is_empty());
}
#[tokio::test]
async fn test_pool_access() {
let pool = create_test_pool().await;
let aggregator = ConnectorAggregator::new(pool);
let pool_ref = aggregator.pool();
assert_eq!(pool_ref.len(), 2);
}
#[tokio::test]
async fn test_available_exchanges() {
let pool = create_test_pool().await;
let aggregator = ConnectorAggregator::new(pool);
let exchanges = aggregator.available_exchanges();
assert_eq!(exchanges.len(), 2);
assert!(exchanges.contains(&ExchangeId::Binance));
assert!(exchanges.contains(&ExchangeId::KuCoin));
}
#[tokio::test]
async fn test_get_price_exchange_not_in_pool() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_price(
ExchangeId::Binance,
Symbol::new("BTC", "USDT"),
AccountType::Spot,
)
.await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ExchangeError::NotFound(_)));
}
#[tokio::test]
async fn test_get_ticker_exchange_not_in_pool() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_ticker(
ExchangeId::Binance,
Symbol::new("BTC", "USDT"),
AccountType::Spot,
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_orderbook_exchange_not_in_pool() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_orderbook(
ExchangeId::Binance,
Symbol::new("BTC", "USDT"),
AccountType::Spot,
Some(20),
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_klines_exchange_not_in_pool() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_klines(
ExchangeId::Binance,
Symbol::new("BTC", "USDT"),
"1h",
AccountType::Spot,
Some(100),
None,
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_prices_multi_no_exchanges() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_prices_multi(
&[ExchangeId::Binance, ExchangeId::KuCoin],
Symbol::new("BTC", "USDT"),
AccountType::Spot,
)
.await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ExchangeError::NotFound(_)));
}
#[tokio::test]
async fn test_get_prices_multi_empty_list() {
let pool = create_test_pool().await;
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_prices_multi(&[], Symbol::new("BTC", "USDT"), AccountType::Spot)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_best_bid_ask_no_exchanges() {
let pool = ConnectorPool::new();
let aggregator = ConnectorAggregator::new(pool);
let result = aggregator
.get_best_bid_ask(
&[ExchangeId::Binance],
Symbol::new("BTC", "USDT"),
AccountType::Spot,
)
.await;
assert!(result.is_err());
}
#[test]
fn test_best_bid_ask_spread_calculation() {
let best = BestBidAsk {
bid: 50000.0,
bid_exchange: ExchangeId::Binance,
ask: 50100.0,
ask_exchange: ExchangeId::KuCoin,
spread: 100.0,
spread_percent: 0.2,
};
assert_eq!(best.spread, 100.0);
assert_eq!(best.spread_percent, 0.2);
}
#[tokio::test]
async fn test_aggregator_clone() {
let pool = create_test_pool().await;
let aggregator1 = ConnectorAggregator::new(pool);
let aggregator2 = aggregator1.clone();
assert_eq!(aggregator1.available_exchanges().len(), 2);
assert_eq!(aggregator2.available_exchanges().len(), 2);
}
}