use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::Decimal;
use crate::types::error::{MMError, MMResult};
use super::connector::{
BookLevel, ExchangeConnector, Fill, MarketDataStream, OrderBookSnapshot, OrderId, OrderRequest,
OrderResponse, OrderStatus, Side,
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct MockConfig {
pub latency_ms: u64,
pub failure_rate: f64,
pub initial_balances: HashMap<String, Decimal>,
pub default_depth: usize,
pub base_price: Decimal,
pub spread: Decimal,
}
impl Default for MockConfig {
fn default() -> Self {
let mut balances = HashMap::new();
balances.insert("USD".to_string(), Decimal::from(100_000));
balances.insert("BTC".to_string(), Decimal::from(10));
Self {
latency_ms: 0,
failure_rate: 0.0,
initial_balances: balances,
default_depth: 10,
base_price: Decimal::from(50_000),
spread: Decimal::from_str_exact("0.001").unwrap(),
}
}
}
impl MockConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_latency_ms(mut self, latency_ms: u64) -> Self {
self.latency_ms = latency_ms;
self
}
#[must_use]
pub fn with_failure_rate(mut self, rate: f64) -> Self {
self.failure_rate = rate.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_base_price(mut self, price: Decimal) -> Self {
self.base_price = price;
self
}
#[must_use]
pub fn with_spread(mut self, spread: Decimal) -> Self {
self.spread = spread;
self
}
#[must_use]
pub fn with_balance(mut self, asset: impl Into<String>, balance: Decimal) -> Self {
self.initial_balances.insert(asset.into(), balance);
self
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct OrderState {
request: OrderRequest,
status: OrderStatus,
order_id: OrderId,
timestamp: u64,
}
#[derive(Debug)]
pub struct MockExchangeConnector {
config: MockConfig,
orders: RwLock<HashMap<String, OrderState>>,
balances: RwLock<HashMap<String, Decimal>>,
order_counter: AtomicU64,
current_time: AtomicU64,
}
impl MockExchangeConnector {
#[must_use]
pub fn new(config: MockConfig) -> Self {
let balances = RwLock::new(config.initial_balances.clone());
Self {
config,
orders: RwLock::new(HashMap::new()),
balances,
order_counter: AtomicU64::new(1),
current_time: AtomicU64::new(1_000_000),
}
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(MockConfig::default())
}
#[must_use]
pub fn config(&self) -> &MockConfig {
&self.config
}
pub fn set_time(&self, time: u64) {
self.current_time.store(time, Ordering::SeqCst);
}
pub fn advance_time(&self, delta: u64) {
self.current_time.fetch_add(delta, Ordering::SeqCst);
}
#[must_use]
pub fn current_time(&self) -> u64 {
self.current_time.load(Ordering::SeqCst)
}
pub fn set_balance(&self, asset: &str, balance: Decimal) {
let mut balances = self.balances.write().unwrap();
balances.insert(asset.to_string(), balance);
}
fn next_order_id(&self) -> OrderId {
let id = self.order_counter.fetch_add(1, Ordering::SeqCst);
OrderId::new(format!("mock-{}", id))
}
async fn simulate_latency(&self) {
let _ = self.config.latency_ms;
}
fn should_fail(&self) -> bool {
if self.config.failure_rate > 0.0 {
let counter = self.order_counter.load(Ordering::SeqCst);
let threshold = (self.config.failure_rate * 100.0) as u64;
counter % 100 < threshold
} else {
false
}
}
fn generate_orderbook(&self, symbol: &str, depth: usize) -> OrderBookSnapshot {
let base_price = self.config.base_price;
let half_spread = base_price * self.config.spread / Decimal::from(2);
let mut bids = Vec::with_capacity(depth);
let mut asks = Vec::with_capacity(depth);
let tick_size = base_price * Decimal::from_str_exact("0.0001").unwrap();
for i in 0..depth {
let offset = tick_size * Decimal::from(i);
let bid_price = base_price - half_spread - offset;
let bid_qty =
Decimal::from(1) + Decimal::from(i) * Decimal::from_str_exact("0.5").unwrap();
bids.push(BookLevel::new(bid_price, bid_qty));
let ask_price = base_price + half_spread + offset;
let ask_qty =
Decimal::from(1) + Decimal::from(i) * Decimal::from_str_exact("0.5").unwrap();
asks.push(BookLevel::new(ask_price, ask_qty));
}
OrderBookSnapshot {
symbol: symbol.to_string(),
bids,
asks,
timestamp: self.current_time(),
}
}
fn simulate_market_fill(&self, request: &OrderRequest) -> (OrderStatus, Decimal) {
let orderbook = self.generate_orderbook(&request.symbol, 1);
let fill_price = match request.side {
Side::Buy => orderbook.best_ask().unwrap_or(self.config.base_price),
Side::Sell => orderbook.best_bid().unwrap_or(self.config.base_price),
};
let status = OrderStatus::Filled {
filled_qty: request.quantity,
avg_price: fill_price,
};
(status, fill_price)
}
#[must_use]
pub fn open_order_count(&self) -> usize {
let orders = self.orders.read().unwrap();
orders.values().filter(|o| o.status.is_open()).count()
}
#[must_use]
pub fn all_order_ids(&self) -> Vec<OrderId> {
let orders = self.orders.read().unwrap();
orders.values().map(|o| o.order_id.clone()).collect()
}
}
#[async_trait]
impl ExchangeConnector for MockExchangeConnector {
async fn submit_order(&self, request: OrderRequest) -> MMResult<OrderResponse> {
self.simulate_latency().await;
if self.should_fail() {
return Err(MMError::InvalidMarketState(
"simulated exchange failure".to_string(),
));
}
let order_id = self.next_order_id();
let timestamp = self.current_time();
let status = if request.order_type.is_market() {
let (status, _) = self.simulate_market_fill(&request);
status
} else {
OrderStatus::Open {
filled_qty: Decimal::ZERO,
}
};
let state = OrderState {
request: request.clone(),
status: status.clone(),
order_id: order_id.clone(),
timestamp,
};
{
let mut orders = self.orders.write().unwrap();
orders.insert(order_id.as_str().to_string(), state);
}
Ok(OrderResponse {
order_id,
client_order_id: request.client_order_id,
status,
timestamp,
})
}
async fn cancel_order(&self, order_id: &OrderId) -> MMResult<OrderResponse> {
self.simulate_latency().await;
if self.should_fail() {
return Err(MMError::InvalidMarketState(
"simulated exchange failure".to_string(),
));
}
let mut orders = self.orders.write().unwrap();
let state = orders
.get_mut(order_id.as_str())
.ok_or_else(|| MMError::InvalidMarketState(format!("order not found: {}", order_id)))?;
if state.status.is_terminal() {
return Err(MMError::InvalidMarketState(format!(
"order already terminal: {}",
order_id
)));
}
let filled_qty = state.status.filled_qty();
state.status = OrderStatus::Cancelled { filled_qty };
Ok(OrderResponse {
order_id: order_id.clone(),
client_order_id: state.request.client_order_id.clone(),
status: state.status.clone(),
timestamp: self.current_time(),
})
}
async fn modify_order(
&self,
order_id: &OrderId,
new_price: Option<Decimal>,
new_quantity: Option<Decimal>,
) -> MMResult<OrderResponse> {
self.simulate_latency().await;
let cancelled = self.cancel_order(order_id).await?;
let original_request = {
let orders = self.orders.read().unwrap();
let state = orders.get(order_id.as_str()).ok_or_else(|| {
MMError::InvalidMarketState(format!("order not found: {}", order_id))
})?;
state.request.clone()
};
let new_request = OrderRequest {
price: new_price.or(original_request.price),
quantity: new_quantity.unwrap_or(original_request.quantity),
..original_request
};
let mut response = self.submit_order(new_request).await?;
response.client_order_id = cancelled.client_order_id;
Ok(response)
}
async fn get_order_status(&self, order_id: &OrderId) -> MMResult<OrderResponse> {
self.simulate_latency().await;
let orders = self.orders.read().unwrap();
let state = orders
.get(order_id.as_str())
.ok_or_else(|| MMError::InvalidMarketState(format!("order not found: {}", order_id)))?;
Ok(OrderResponse {
order_id: order_id.clone(),
client_order_id: state.request.client_order_id.clone(),
status: state.status.clone(),
timestamp: self.current_time(),
})
}
async fn get_open_orders(&self, symbol: &str) -> MMResult<Vec<OrderResponse>> {
self.simulate_latency().await;
let orders = self.orders.read().unwrap();
let open_orders: Vec<OrderResponse> = orders
.values()
.filter(|state| state.request.symbol == symbol && state.status.is_open())
.map(|state| OrderResponse {
order_id: state.order_id.clone(),
client_order_id: state.request.client_order_id.clone(),
status: state.status.clone(),
timestamp: self.current_time(),
})
.collect();
Ok(open_orders)
}
async fn cancel_all_orders(&self, symbol: &str) -> MMResult<Vec<OrderResponse>> {
self.simulate_latency().await;
let order_ids: Vec<OrderId> = {
let orders = self.orders.read().unwrap();
orders
.values()
.filter(|state| state.request.symbol == symbol && state.status.is_open())
.map(|state| state.order_id.clone())
.collect()
};
let mut responses = Vec::new();
for order_id in order_ids {
if let Ok(response) = self.cancel_order(&order_id).await {
responses.push(response);
}
}
Ok(responses)
}
async fn get_orderbook(&self, symbol: &str, depth: usize) -> MMResult<OrderBookSnapshot> {
self.simulate_latency().await;
if self.should_fail() {
return Err(MMError::InvalidMarketState(
"simulated exchange failure".to_string(),
));
}
Ok(self.generate_orderbook(symbol, depth))
}
async fn get_balance(&self, asset: &str) -> MMResult<Decimal> {
self.simulate_latency().await;
let balances = self.balances.read().unwrap();
Ok(*balances.get(asset).unwrap_or(&Decimal::ZERO))
}
}
#[async_trait]
impl MarketDataStream for MockExchangeConnector {
async fn subscribe_orderbook(&self, _symbol: &str) -> MMResult<()> {
Ok(())
}
async fn subscribe_trades(&self, _symbol: &str) -> MMResult<()> {
Ok(())
}
async fn next_orderbook_update(&self) -> MMResult<OrderBookSnapshot> {
Ok(self.generate_orderbook("BTC-USD", self.config.default_depth))
}
async fn next_trade(&self) -> MMResult<Fill> {
Ok(Fill {
order_id: OrderId::new("mock-trade"),
trade_id: format!("trade-{}", self.current_time()),
price: self.config.base_price,
quantity: Decimal::from_str_exact("0.1").unwrap(),
side: Side::Buy,
timestamp: self.current_time(),
fee: Decimal::from_str_exact("0.001").unwrap(),
fee_currency: "USD".to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dec;
#[tokio::test]
async fn test_submit_limit_order() {
let connector = MockExchangeConnector::with_defaults();
let request = OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1));
let response = connector.submit_order(request).await.unwrap();
assert!(response.status.is_open());
assert!(!response.order_id.as_str().is_empty());
}
#[tokio::test]
async fn test_submit_market_order() {
let connector = MockExchangeConnector::with_defaults();
let request = OrderRequest::market_buy("BTC-USD", dec!(0.1));
let response = connector.submit_order(request).await.unwrap();
assert!(response.status.is_terminal());
match response.status {
OrderStatus::Filled { filled_qty, .. } => {
assert_eq!(filled_qty, dec!(0.1));
}
_ => panic!("expected filled status"),
}
}
#[tokio::test]
async fn test_cancel_order() {
let connector = MockExchangeConnector::with_defaults();
let request = OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1));
let response = connector.submit_order(request).await.unwrap();
let cancelled = connector.cancel_order(&response.order_id).await.unwrap();
assert!(cancelled.status.is_terminal());
match cancelled.status {
OrderStatus::Cancelled { filled_qty } => {
assert_eq!(filled_qty, Decimal::ZERO);
}
_ => panic!("expected cancelled status"),
}
}
#[tokio::test]
async fn test_cancel_nonexistent_order() {
let connector = MockExchangeConnector::with_defaults();
let result = connector.cancel_order(&OrderId::new("nonexistent")).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_get_order_status() {
let connector = MockExchangeConnector::with_defaults();
let request = OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1));
let response = connector.submit_order(request).await.unwrap();
let status = connector
.get_order_status(&response.order_id)
.await
.unwrap();
assert!(status.status.is_open());
}
#[tokio::test]
async fn test_get_open_orders() {
let connector = MockExchangeConnector::with_defaults();
connector
.submit_order(OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1)))
.await
.unwrap();
connector
.submit_order(OrderRequest::limit_sell(
"BTC-USD",
dec!(51000.0),
dec!(0.1),
))
.await
.unwrap();
connector
.submit_order(OrderRequest::limit_buy("ETH-USD", dec!(3000.0), dec!(1.0)))
.await
.unwrap();
let btc_orders = connector.get_open_orders("BTC-USD").await.unwrap();
assert_eq!(btc_orders.len(), 2);
let eth_orders = connector.get_open_orders("ETH-USD").await.unwrap();
assert_eq!(eth_orders.len(), 1);
}
#[tokio::test]
async fn test_cancel_all_orders() {
let connector = MockExchangeConnector::with_defaults();
connector
.submit_order(OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1)))
.await
.unwrap();
connector
.submit_order(OrderRequest::limit_sell(
"BTC-USD",
dec!(51000.0),
dec!(0.1),
))
.await
.unwrap();
let cancelled = connector.cancel_all_orders("BTC-USD").await.unwrap();
assert_eq!(cancelled.len(), 2);
let open = connector.get_open_orders("BTC-USD").await.unwrap();
assert!(open.is_empty());
}
#[tokio::test]
async fn test_modify_order() {
let connector = MockExchangeConnector::with_defaults();
let request = OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1));
let response = connector.submit_order(request).await.unwrap();
let modified = connector
.modify_order(&response.order_id, Some(dec!(49000.0)), None)
.await
.unwrap();
assert!(modified.status.is_open());
assert_ne!(modified.order_id, response.order_id); }
#[tokio::test]
async fn test_get_orderbook() {
let connector = MockExchangeConnector::with_defaults();
let orderbook = connector.get_orderbook("BTC-USD", 5).await.unwrap();
assert_eq!(orderbook.symbol, "BTC-USD");
assert_eq!(orderbook.bids.len(), 5);
assert_eq!(orderbook.asks.len(), 5);
assert!(orderbook.best_bid().is_some());
assert!(orderbook.best_ask().is_some());
assert!(orderbook.spread().unwrap() > Decimal::ZERO);
}
#[tokio::test]
async fn test_get_balance() {
let connector = MockExchangeConnector::with_defaults();
let usd_balance = connector.get_balance("USD").await.unwrap();
assert_eq!(usd_balance, Decimal::from(100_000));
let btc_balance = connector.get_balance("BTC").await.unwrap();
assert_eq!(btc_balance, Decimal::from(10));
let unknown = connector.get_balance("UNKNOWN").await.unwrap();
assert_eq!(unknown, Decimal::ZERO);
}
#[tokio::test]
async fn test_set_balance() {
let connector = MockExchangeConnector::with_defaults();
connector.set_balance("ETH", dec!(100.0));
let balance = connector.get_balance("ETH").await.unwrap();
assert_eq!(balance, dec!(100.0));
}
#[tokio::test]
async fn test_time_management() {
let connector = MockExchangeConnector::with_defaults();
let initial_time = connector.current_time();
connector.advance_time(1000);
assert_eq!(connector.current_time(), initial_time + 1000);
connector.set_time(5000);
assert_eq!(connector.current_time(), 5000);
}
#[tokio::test]
async fn test_client_order_id() {
let connector = MockExchangeConnector::with_defaults();
let request = OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1))
.with_client_order_id("my-order-1");
let response = connector.submit_order(request).await.unwrap();
assert_eq!(response.client_order_id, Some("my-order-1".to_string()));
}
#[tokio::test]
async fn test_market_data_stream() {
let connector = MockExchangeConnector::with_defaults();
connector.subscribe_orderbook("BTC-USD").await.unwrap();
connector.subscribe_trades("BTC-USD").await.unwrap();
let orderbook = connector.next_orderbook_update().await.unwrap();
assert!(!orderbook.bids.is_empty());
let trade = connector.next_trade().await.unwrap();
assert!(trade.quantity > Decimal::ZERO);
}
#[test]
fn test_mock_config_builder() {
let config = MockConfig::new()
.with_latency_ms(50)
.with_failure_rate(0.1)
.with_base_price(dec!(60000.0))
.with_spread(dec!(0.002))
.with_balance("ETH", dec!(50.0));
assert_eq!(config.latency_ms, 50);
assert_eq!(config.failure_rate, 0.1);
assert_eq!(config.base_price, dec!(60000.0));
assert_eq!(config.spread, dec!(0.002));
assert_eq!(config.initial_balances.get("ETH"), Some(&dec!(50.0)));
}
#[tokio::test]
async fn test_open_order_count() {
let connector = MockExchangeConnector::with_defaults();
assert_eq!(connector.open_order_count(), 0);
connector
.submit_order(OrderRequest::limit_buy("BTC-USD", dec!(50000.0), dec!(0.1)))
.await
.unwrap();
assert_eq!(connector.open_order_count(), 1);
connector
.submit_order(OrderRequest::market_buy("BTC-USD", dec!(0.1)))
.await
.unwrap();
assert_eq!(connector.open_order_count(), 1);
}
}