use std::sync::Arc;
use std::time::Duration;
use rustrade_core::{
Brain, CandleSource, Capability, Error, ExchangeClient, FillSource, MarketDataBus,
MarketSource, MetricsSink, NoopSink, Position, Result, SignalBus, StateStore, Symbol,
};
use rustrade_risk::{CircuitBreakerConfig, SessionPnlConfig, SizingConfig};
use rustrade_supervisor::{Supervisor, SupervisorConfig};
use tokio_util::sync::CancellationToken;
use crate::execution::{ExecutionContext, ExecutionService};
use crate::handle::BotHandle;
use crate::order_tracker::{OrderReaperService, OrderTracker};
use crate::risk_state::{
PositionCache, RiskPersister, RiskStateMap, build_position_cache, build_risk_state,
};
use crate::services::{CandlePollerService, FillRoutingService, MarketFeedService};
const DEFAULT_MARKET_BUS_CAPACITY: usize = 1024;
const DEFAULT_SIGNAL_BUS_CAPACITY: usize = 256;
#[derive(Debug, Clone, Default)]
pub struct RiskConfig {
pub session_pnl: SessionPnlConfig,
pub circuit_breaker: CircuitBreakerConfig,
pub sizing: SizingConfig,
}
#[derive(Debug, Clone)]
pub struct BotConfig {
pub name: String,
pub symbols: Vec<Symbol>,
pub shutdown_timeout: Duration,
pub install_signal_handler: bool,
pub market_bus_capacity: usize,
pub signal_bus_capacity: usize,
pub close_positions_on_shutdown: bool,
pub risk: RiskConfig,
}
impl BotConfig {
pub fn builder() -> BotConfigBuilder {
BotConfigBuilder::default()
}
}
#[derive(Debug, Clone, Default)]
pub struct BotConfigBuilder {
name: Option<String>,
symbols: Vec<Symbol>,
shutdown_timeout: Option<Duration>,
install_signal_handler: Option<bool>,
market_bus_capacity: Option<usize>,
signal_bus_capacity: Option<usize>,
close_positions_on_shutdown: Option<bool>,
risk: RiskConfig,
}
impl BotConfigBuilder {
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn symbol(mut self, sym: impl Into<Symbol>) -> Self {
self.symbols.push(sym.into());
self
}
pub fn symbols<I, S>(mut self, syms: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<Symbol>,
{
self.symbols.extend(syms.into_iter().map(Into::into));
self
}
pub fn shutdown_timeout(mut self, dur: Duration) -> Self {
self.shutdown_timeout = Some(dur);
self
}
pub fn without_signal_handler(mut self) -> Self {
self.install_signal_handler = Some(false);
self
}
pub fn market_bus_capacity(mut self, cap: usize) -> Self {
self.market_bus_capacity = Some(cap);
self
}
pub fn signal_bus_capacity(mut self, cap: usize) -> Self {
self.signal_bus_capacity = Some(cap);
self
}
pub fn close_positions_on_shutdown(mut self, b: bool) -> Self {
self.close_positions_on_shutdown = Some(b);
self
}
pub fn session_pnl_config(mut self, cfg: SessionPnlConfig) -> Self {
self.risk.session_pnl = cfg;
self
}
pub fn circuit_breaker_config(mut self, cfg: CircuitBreakerConfig) -> Self {
self.risk.circuit_breaker = cfg;
self
}
pub fn sizing_config(mut self, cfg: SizingConfig) -> Self {
self.risk.sizing = cfg;
self
}
pub fn build(self) -> Result<BotConfig> {
let name = self
.name
.filter(|n| !n.trim().is_empty())
.ok_or_else(|| Error::config("BotConfig.name is required and must not be empty"))?;
if self.symbols.is_empty() {
return Err(Error::config(
"BotConfig.symbols must contain at least one Symbol — \
the position cache and risk-state map are pre-seeded per symbol",
));
}
let market_bus_capacity = self
.market_bus_capacity
.unwrap_or(DEFAULT_MARKET_BUS_CAPACITY);
if market_bus_capacity == 0 {
return Err(Error::config(
"BotConfig.market_bus_capacity must be > 0 (broadcast channel cannot have 0 slots)",
));
}
let signal_bus_capacity = self
.signal_bus_capacity
.unwrap_or(DEFAULT_SIGNAL_BUS_CAPACITY);
if signal_bus_capacity == 0 {
return Err(Error::config(
"BotConfig.signal_bus_capacity must be > 0 (broadcast channel cannot have 0 slots)",
));
}
let shutdown_timeout = self.shutdown_timeout.unwrap_or(Duration::from_secs(30));
if shutdown_timeout.is_zero() {
return Err(Error::config(
"BotConfig.shutdown_timeout must be > 0 — drain needs a non-zero deadline",
));
}
if self.risk.session_pnl.loss_limit.is_nan() {
return Err(Error::config(
"BotConfig.risk.session_pnl.loss_limit must not be NaN",
));
}
if !self.risk.sizing.margin_per_trade.is_finite() || self.risk.sizing.margin_per_trade < 0.0
{
return Err(Error::config(
"BotConfig.risk.sizing.margin_per_trade must be a finite non-negative number",
));
}
Ok(BotConfig {
name,
symbols: self.symbols,
shutdown_timeout,
install_signal_handler: self.install_signal_handler.unwrap_or(true),
market_bus_capacity,
signal_bus_capacity,
close_positions_on_shutdown: self.close_positions_on_shutdown.unwrap_or(false),
risk: self.risk,
})
}
}
pub struct Bot {
config: BotConfig,
supervisor: Arc<Supervisor>,
exchange: Arc<dyn ExchangeClient>,
brains: Arc<Vec<Arc<dyn Brain>>>,
market_bus: MarketDataBus,
signal_bus: SignalBus,
positions: PositionCache,
risk: RiskStateMap,
metrics: Arc<dyn MetricsSink>,
state_store: Option<Arc<dyn StateStore>>,
persister_slot: crate::handle::PersisterSlot,
handle: BotHandle,
external_cancel: Option<CancellationToken>,
market_source: Option<Arc<dyn MarketSource>>,
fill_source: Option<Arc<dyn FillSource>>,
candle_pollers: Vec<CandlePollerSpec>,
order_tracker: OrderTracker,
order_tracking: Option<OrderTrackingSpec>,
}
struct OrderTrackingSpec {
ttl: Duration,
poll_cadence: Duration,
}
struct CandlePollerSpec {
source: Arc<dyn CandleSource>,
symbol: Symbol,
interval: Duration,
poll_cadence: Duration,
limit: usize,
}
impl Bot {
pub fn new(
config: BotConfig,
exchange: Arc<dyn ExchangeClient>,
brains: Vec<Arc<dyn Brain>>,
) -> Result<Self> {
if brains.is_empty() {
return Err(Error::config(
"Bot::new requires at least one Brain — empty brain list",
));
}
let supervisor = Arc::new(Supervisor::new(
SupervisorConfig::default()
.with_shutdown_timeout(config.shutdown_timeout)
.with_default_backoff(Default::default())
.pipe(|c| {
if config.install_signal_handler {
c
} else {
c.without_signal_handler()
}
}),
));
let market_bus = MarketDataBus::with_capacity(config.market_bus_capacity);
let signal_bus = SignalBus::with_capacity(config.signal_bus_capacity);
let positions = build_position_cache(&config.symbols);
let risk = build_risk_state(
&config.symbols,
&config.risk.session_pnl,
&config.risk.circuit_breaker,
);
let brains = Arc::new(brains);
let persister_slot: crate::handle::PersisterSlot = Arc::new(std::sync::OnceLock::new());
let order_tracker = OrderTracker::new();
let handle = BotHandle::new(
supervisor.clone(),
brains.clone(),
risk.clone(),
positions.clone(),
signal_bus.clone(),
persister_slot.clone(),
order_tracker.clone(),
);
Ok(Self {
config,
supervisor,
exchange,
brains,
market_bus,
signal_bus,
positions,
risk,
metrics: Arc::new(NoopSink),
state_store: None,
persister_slot,
order_tracker,
handle,
external_cancel: None,
market_source: None,
fill_source: None,
candle_pollers: Vec::new(),
order_tracking: None,
})
}
pub fn with_metrics(mut self, sink: Arc<dyn MetricsSink>) -> Self {
self.metrics = sink;
self
}
pub fn with_state_store(mut self, store: Arc<dyn StateStore>) -> Self {
self.state_store = Some(store);
self
}
pub fn with_order_tracking(mut self, ttl: Duration, poll_cadence: Duration) -> Self {
self.order_tracking = Some(OrderTrackingSpec { ttl, poll_cadence });
self
}
pub fn with_candle_poller(
mut self,
source: Arc<dyn CandleSource>,
symbol: impl Into<Symbol>,
interval: Duration,
poll_cadence: Duration,
limit: usize,
) -> Self {
self.candle_pollers.push(CandlePollerSpec {
source,
symbol: symbol.into(),
interval,
poll_cadence,
limit,
});
self
}
pub fn with_external_cancel(mut self, token: CancellationToken) -> Self {
self.external_cancel = Some(token);
self
}
pub fn with_market_source(mut self, source: Arc<dyn MarketSource>) -> Self {
self.market_source = Some(source);
self
}
pub fn with_fill_source(mut self, source: Arc<dyn FillSource>) -> Self {
self.fill_source = Some(source);
self
}
pub fn handle(&self) -> BotHandle {
self.handle.clone()
}
pub fn config(&self) -> &BotConfig {
&self.config
}
pub fn market_data_bus(&self) -> &MarketDataBus {
&self.market_bus
}
pub fn signal_bus(&self) -> &SignalBus {
&self.signal_bus
}
pub async fn run_until_shutdown(self) -> anyhow::Result<()> {
tracing::info!(
bot = %self.config.name,
brains = self.brains.len(),
symbols = self.config.symbols.len(),
exchange = %self.exchange.name(),
"rustrade Bot starting"
);
self.prefetch_positions().await;
let persister = self
.state_store
.clone()
.map(|store| RiskPersister::new(store, self.config.name.clone()));
if let Some(p) = &persister {
p.restore_into(&self.risk).await;
let _ = self.persister_slot.set(p.clone());
}
let order_tracking_active =
self.order_tracking.is_some() && self.exchange.supports(Capability::OrderTracking);
if self.order_tracking.is_some() && !order_tracking_active {
tracing::warn!(
exchange = %self.exchange.name(),
"order tracking requested but adapter lacks Capability::OrderTracking — \
resting orders will NOT be tracked or aged out"
);
}
let sizing = Arc::new(self.config.risk.sizing.clone());
let ctx = ExecutionContext {
exchange: self.exchange.clone(),
bus: self.market_bus.clone(),
signals: self.signal_bus.clone(),
positions: self.positions.clone(),
risk: self.risk.clone(),
sizing,
order_tracker: order_tracking_active.then(|| self.order_tracker.clone()),
};
for brain in self.brains.iter() {
let svc = ExecutionService::new(brain.clone(), ctx.clone());
self.supervisor.spawn_service(Box::new(svc));
}
if order_tracking_active {
let spec = self.order_tracking.as_ref().unwrap();
self.supervisor
.spawn_service(Box::new(OrderReaperService::new(
self.exchange.clone(),
self.order_tracker.clone(),
self.config.symbols.clone(),
spec.ttl,
spec.poll_cadence,
self.metrics.clone(),
)));
}
if let Some(source) = self.market_source.clone() {
self.supervisor
.spawn_service(Box::new(MarketFeedService::new(source)));
}
if let Some(source) = self.fill_source.clone() {
self.supervisor
.spawn_service(Box::new(FillRoutingService::new(
source,
self.brains.clone(),
self.exchange.clone(),
self.positions.clone(),
self.risk.clone(),
self.metrics.clone(),
persister.clone(),
)));
}
for spec in &self.candle_pollers {
self.supervisor
.spawn_service(Box::new(CandlePollerService::new(
spec.source.clone(),
spec.symbol.clone(),
spec.interval,
spec.poll_cadence,
spec.limit,
self.market_bus.clone(),
self.metrics.clone(),
)));
}
if let Some(external) = self.external_cancel.clone() {
let supervisor = self.supervisor.clone();
tokio::spawn(async move {
external.cancelled().await;
tracing::info!("external cancellation received; triggering bot shutdown");
supervisor.trigger_shutdown();
});
}
let run_result = self.supervisor.run_until_shutdown().await;
if self.config.close_positions_on_shutdown {
self.close_open_positions().await;
}
if let Some(p) = &persister {
p.persist_all(&self.risk).await;
}
for brain in self.brains.iter() {
let health = brain.health().await;
tracing::info!(
brain = %brain.name(),
healthy = health.healthy,
events = health.events_processed,
non_hold = health.non_hold_decisions,
"final brain health"
);
}
tracing::info!(bot = %self.config.name, "rustrade Bot exited");
run_result
}
async fn prefetch_positions(&self) {
for symbol in &self.config.symbols {
match self.exchange.get_position(symbol).await {
Ok(pos) => {
self.positions.write().await.insert(symbol.clone(), pos);
tracing::debug!(
symbol = %symbol,
qty = pos.qty,
"prefetched position from exchange"
);
}
Err(e) => {
tracing::warn!(
symbol = %symbol,
error = %e,
"failed to prefetch position; cache defaults to FLAT"
);
}
}
}
}
async fn close_open_positions(&self) {
let snapshot: Vec<(Symbol, Position)> = {
let map = self.positions.read().await;
map.iter()
.filter(|(_, p)| !p.is_flat())
.map(|(s, p)| (s.clone(), *p))
.collect()
};
if snapshot.is_empty() {
tracing::info!("close_positions_on_shutdown: no open positions");
return;
}
for (symbol, position) in snapshot {
match self.exchange.close_position(&symbol, &position).await {
Ok(order_id) => tracing::info!(
symbol = %symbol,
qty = position.qty,
order_id = %order_id,
"close_positions_on_shutdown: closed"
),
Err(e) => tracing::error!(
symbol = %symbol,
qty = position.qty,
error = %e,
"close_positions_on_shutdown: failed (best-effort)"
),
}
}
}
}
trait Pipe: Sized {
fn pipe<F: FnOnce(Self) -> Self>(self, f: F) -> Self {
f(self)
}
}
impl<T> Pipe for T {}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use rustrade_core::{Fill, MarketDataEvent, Order, Position};
struct NoopBrain;
#[async_trait]
impl Brain for NoopBrain {
fn name(&self) -> &str {
"noop"
}
async fn on_event(
&self,
_e: &MarketDataEvent,
_p: &Position,
) -> Result<rustrade_core::Decision> {
Ok(rustrade_core::Decision::hold())
}
}
struct NoopExchange;
#[async_trait]
impl ExchangeClient for NoopExchange {
fn name(&self) -> &str {
"noop"
}
async fn place_order(&self, _o: &Order) -> Result<String> {
Ok("noop-1".into())
}
async fn cancel_all(&self, _s: &Symbol) -> Result<usize> {
Ok(0)
}
async fn close_position(&self, _s: &Symbol, _p: &Position) -> Result<String> {
Ok("noop-close".into())
}
async fn get_position(&self, _s: &Symbol) -> Result<Position> {
Ok(Position::FLAT)
}
async fn get_balance(&self, _c: &str) -> Result<f64> {
Ok(0.0)
}
}
fn cfg() -> BotConfig {
BotConfig::builder()
.name("test")
.symbol("BTCUSDT")
.without_signal_handler()
.build()
.unwrap()
}
#[test]
fn builder_requires_name() {
let err = BotConfig::builder().build().unwrap_err();
assert!(matches!(err, Error::Config(_)), "got {err:?}");
}
#[test]
fn builder_rejects_blank_name() {
let err = BotConfig::builder().name(" ").build().unwrap_err();
assert!(matches!(err, Error::Config(_)), "got {err:?}");
}
#[test]
fn builder_rejects_zero_market_bus_capacity() {
let err = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.market_bus_capacity(0)
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_zero_signal_bus_capacity() {
let err = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.signal_bus_capacity(0)
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_empty_symbol_list() {
let err = BotConfig::builder().name("x").build().unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_zero_shutdown_timeout() {
let err = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.shutdown_timeout(Duration::ZERO)
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_nan_loss_limit() {
let err = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.session_pnl_config(SessionPnlConfig {
loss_limit: f64::NAN,
})
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_rejects_non_finite_margin() {
let err = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.sizing_config(SizingConfig {
margin_per_trade: f64::INFINITY,
leverage: 1,
max_contracts: 1,
})
.build()
.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
#[test]
fn builder_accumulates_symbols() {
let c = BotConfig::builder()
.name("x")
.symbol("A")
.symbols(["B", "C"])
.build()
.unwrap();
assert_eq!(c.symbols.len(), 3);
assert_eq!(c.symbols[0], Symbol::new("A"));
assert_eq!(c.symbols[2], Symbol::new("C"));
}
#[test]
fn builder_accepts_risk_overrides() {
let c = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.session_pnl_config(SessionPnlConfig { loss_limit: -123.0 })
.sizing_config(SizingConfig {
margin_per_trade: 250.0,
leverage: 10,
max_contracts: 5,
})
.build()
.unwrap();
assert_eq!(c.risk.session_pnl.loss_limit, -123.0);
assert_eq!(c.risk.sizing.leverage, 10);
}
#[test]
fn builder_has_separate_default_bus_capacities() {
let c = BotConfig::builder()
.name("x")
.symbol("BTCUSDT")
.build()
.unwrap();
assert_eq!(c.market_bus_capacity, 1024);
assert_eq!(c.signal_bus_capacity, 256);
}
#[tokio::test]
async fn bot_requires_at_least_one_brain() {
match Bot::new(cfg(), Arc::new(NoopExchange), vec![]) {
Err(Error::Config(_)) => {}
other => panic!(
"expected Error::Config for empty brain list, got {:?}",
other.map(|_| "Ok(Bot)").map_err(|e| format!("Err({e})"))
),
}
}
#[tokio::test]
async fn bot_constructs_and_exposes_handle() {
let bot = Bot::new(cfg(), Arc::new(NoopExchange), vec![Arc::new(NoopBrain)]).unwrap();
let handle = bot.handle();
assert!(!handle.is_shutting_down());
assert_eq!(bot.config().name, "test");
let h2 = handle.clone();
assert!(!h2.is_shutting_down());
}
#[allow(dead_code)]
fn _noop_fill_compiles(_: &Fill) {}
}