use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use chrono::Utc;
use rustrade::{
Bot, BotConfig, Brain, Candle, CircuitBreakerConfig, Decision, Exchange, ExchangeClient,
MarketDataEvent, Order, Position, Result, SessionPnlConfig, SignalType, SizingConfig, Symbol,
};
struct FixedSignalBrain {
signal: SignalType,
}
#[async_trait]
impl Brain for FixedSignalBrain {
fn name(&self) -> &str {
"fixed"
}
async fn on_event(&self, _e: &MarketDataEvent, _p: &Position) -> Result<Decision> {
Ok(match self.signal {
SignalType::Hold => Decision::hold(),
SignalType::Buy => Decision::buy(1.0),
SignalType::Sell => Decision::sell(1.0),
SignalType::Close => Decision::close(),
})
}
}
struct CountingExchange {
placed: Arc<AtomicU64>,
closed: Arc<AtomicU64>,
positions: Mutex<HashMap<Symbol, Position>>,
}
impl CountingExchange {
fn new() -> (Arc<Self>, Arc<AtomicU64>, Arc<AtomicU64>) {
let placed = Arc::new(AtomicU64::new(0));
let closed = Arc::new(AtomicU64::new(0));
let inst = Arc::new(Self {
placed: placed.clone(),
closed: closed.clone(),
positions: Mutex::new(HashMap::new()),
});
(inst, placed, closed)
}
fn set_position(&self, sym: Symbol, pos: Position) {
self.positions.lock().unwrap().insert(sym, pos);
}
}
#[async_trait]
impl ExchangeClient for CountingExchange {
fn name(&self) -> &str {
"counting"
}
async fn place_order(&self, _o: &Order) -> Result<String> {
let n = self.placed.fetch_add(1, Ordering::SeqCst) + 1;
Ok(format!("ord-{n}"))
}
async fn cancel_all(&self, _s: &Symbol) -> Result<usize> {
Ok(0)
}
async fn close_position(&self, _s: &Symbol, _p: &Position) -> Result<String> {
let n = self.closed.fetch_add(1, Ordering::SeqCst) + 1;
Ok(format!("close-{n}"))
}
async fn get_position(&self, s: &Symbol) -> Result<Position> {
Ok(self
.positions
.lock()
.unwrap()
.get(s)
.copied()
.unwrap_or(Position::FLAT))
}
async fn get_balance(&self, _c: &str) -> Result<f64> {
Ok(0.0)
}
}
fn candle_event(symbol: &str, close: f64) -> MarketDataEvent {
MarketDataEvent::Candle {
exchange: Exchange::from("test"),
symbol: Symbol::from(symbol),
candle: Candle {
time: Utc::now().timestamp_millis(),
open: close,
high: close,
low: close,
close,
volume: 1.0,
},
}
}
async fn wait_until<F>(mut f: F, timeout: Duration, msg: &str)
where
F: FnMut() -> bool,
{
let deadline = tokio::time::Instant::now() + timeout;
while !f() {
if tokio::time::Instant::now() > deadline {
panic!("timed out: {msg}");
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
}
#[tokio::test(start_paused = true)]
async fn happy_path_buy_places_order() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Buy,
});
let (exchange, placed, _) = CountingExchange::new();
let bot = Bot::new(
BotConfig::builder()
.name("happy")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.sizing_config(SizingConfig {
margin_per_trade: 100.0,
leverage: 1,
max_contracts: 10,
})
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let handle = bot.handle();
let bus = bot.market_data_bus().clone();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(50)).await;
bus.publish(candle_event("BTCUSDT", 100.0));
wait_until(
|| placed.load(Ordering::SeqCst) == 1,
Duration::from_secs(2),
"buy order never placed",
)
.await;
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
}
#[tokio::test(start_paused = true)]
async fn session_halt_blocks_buy_order() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Buy,
});
let (exchange, placed, _closed) = CountingExchange::new();
let bot = Bot::new(
BotConfig::builder()
.name("halt")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.session_pnl_config(SessionPnlConfig { loss_limit: -1.0 })
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let handle = bot.handle();
handle
.record_trade_outcome(&Symbol::from("BTCUSDT"), -10.0, 0.0)
.await;
let bus = bot.market_data_bus().clone();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(50)).await;
bus.publish(candle_event("BTCUSDT", 100.0));
bus.publish(candle_event("BTCUSDT", 100.0));
bus.publish(candle_event("BTCUSDT", 100.0));
tokio::time::sleep(Duration::from_millis(200)).await;
assert_eq!(
placed.load(Ordering::SeqCst),
0,
"session halt must block every buy order"
);
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
}
#[tokio::test(start_paused = true)]
async fn circuit_breaker_blocks_buy_order() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Buy,
});
let (exchange, placed, _closed) = CountingExchange::new();
let bot = Bot::new(
BotConfig::builder()
.name("breaker")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.circuit_breaker_config(CircuitBreakerConfig {
loss_limit: 1,
window_secs: 3600,
cooldown_secs: 3600,
})
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let handle = bot.handle();
handle
.record_trade_outcome(&Symbol::from("BTCUSDT"), -5.0, 0.0)
.await;
let bus = bot.market_data_bus().clone();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(50)).await;
bus.publish(candle_event("BTCUSDT", 100.0));
tokio::time::sleep(Duration::from_millis(150)).await;
assert_eq!(
placed.load(Ordering::SeqCst),
0,
"tripped circuit breaker must block buy orders"
);
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
}
#[tokio::test(start_paused = true)]
async fn sizer_zero_blocks_buy_order() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Buy,
});
let (exchange, placed, _closed) = CountingExchange::new();
let bot = Bot::new(
BotConfig::builder()
.name("sizer-zero")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.sizing_config(SizingConfig {
margin_per_trade: 0.01,
leverage: 1,
max_contracts: 10,
})
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let handle = bot.handle();
let bus = bot.market_data_bus().clone();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(50)).await;
bus.publish(candle_event("BTCUSDT", 50_000.0));
tokio::time::sleep(Duration::from_millis(150)).await;
assert_eq!(
placed.load(Ordering::SeqCst),
0,
"sizer returning 0 must block the order"
);
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
}
#[tokio::test(start_paused = true)]
async fn close_positions_on_shutdown_invokes_close() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Hold,
});
let (exchange, _placed, closed) = CountingExchange::new();
exchange.set_position(
Symbol::from("BTCUSDT"),
Position {
qty: 3.0,
entry_price: Some(100.0),
unrealised_pnl: 0.0,
},
);
let bot = Bot::new(
BotConfig::builder()
.name("close-on-shutdown")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.close_positions_on_shutdown(true)
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let handle = bot.handle();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(100)).await;
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
assert_eq!(
closed.load(Ordering::SeqCst),
1,
"exchange.close_position should fire once for the open position"
);
}
#[tokio::test(start_paused = true)]
async fn close_decision_emits_reduce_only_order_against_position() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Close,
});
let (exchange, placed, _closed) = CountingExchange::new();
exchange.set_position(
Symbol::from("BTCUSDT"),
Position {
qty: 5.0,
entry_price: Some(100.0),
unrealised_pnl: 0.0,
},
);
let bot = Bot::new(
BotConfig::builder()
.name("close-decision")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let handle = bot.handle();
let bus = bot.market_data_bus().clone();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(50)).await;
bus.publish(candle_event("BTCUSDT", 100.0));
wait_until(
|| placed.load(Ordering::SeqCst) == 1,
Duration::from_secs(2),
"close-decision order never placed",
)
.await;
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
}
#[tokio::test(start_paused = true)]
async fn close_decision_on_flat_position_is_silent_noop() {
let brain = Arc::new(FixedSignalBrain {
signal: SignalType::Close,
});
let (exchange, placed, _closed) = CountingExchange::new();
let bot = Bot::new(
BotConfig::builder()
.name("close-when-flat")
.symbol("BTCUSDT")
.without_signal_handler()
.shutdown_timeout(Duration::from_secs(2))
.build()
.unwrap(),
exchange,
vec![brain],
)
.unwrap();
let bus = bot.market_data_bus().clone();
let handle = bot.handle();
let task = tokio::spawn(async move { bot.run_until_shutdown().await });
tokio::time::sleep(Duration::from_millis(50)).await;
bus.publish(candle_event("BTCUSDT", 100.0));
tokio::time::sleep(Duration::from_millis(150)).await;
assert_eq!(
placed.load(Ordering::SeqCst),
0,
"Close against a flat position must not place an order"
);
handle.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(3), task).await;
}