1use std::sync::Arc;
8use std::time::Duration;
9
10use rustrade_core::{
11 Brain, CandleSource, Capability, Error, ExchangeClient, FillSource, MarketDataBus,
12 MarketSource, MetricsSink, NoopSink, Position, Result, SignalBus, StateStore, Symbol,
13};
14use rustrade_risk::{CircuitBreakerConfig, SessionPnlConfig, SizingConfig};
15use rustrade_supervisor::{Supervisor, SupervisorConfig};
16use tokio_util::sync::CancellationToken;
17
18use crate::execution::{ExecutionContext, ExecutionService};
19use crate::handle::BotHandle;
20use crate::order_tracker::{OrderReaperService, OrderTracker};
21use crate::risk_state::{
22 PositionCache, RiskPersister, RiskStateMap, build_position_cache, build_risk_state,
23};
24use crate::services::{CandlePollerService, FillRoutingService, MarketFeedService};
25
26const DEFAULT_MARKET_BUS_CAPACITY: usize = 1024;
27const DEFAULT_SIGNAL_BUS_CAPACITY: usize = 256;
28
29#[derive(Debug, Clone, Default)]
34pub struct RiskConfig {
35 pub session_pnl: SessionPnlConfig,
37 pub circuit_breaker: CircuitBreakerConfig,
39 pub sizing: SizingConfig,
41}
42
43#[derive(Debug, Clone)]
68pub struct BotConfig {
69 pub name: String,
72 pub symbols: Vec<Symbol>,
77 pub shutdown_timeout: Duration,
80 pub install_signal_handler: bool,
83 pub market_bus_capacity: usize,
90 pub signal_bus_capacity: usize,
95 pub close_positions_on_shutdown: bool,
99 pub risk: RiskConfig,
101}
102
103impl BotConfig {
104 pub fn builder() -> BotConfigBuilder {
106 BotConfigBuilder::default()
107 }
108}
109
110#[derive(Debug, Clone, Default)]
112pub struct BotConfigBuilder {
113 name: Option<String>,
114 symbols: Vec<Symbol>,
115 shutdown_timeout: Option<Duration>,
116 install_signal_handler: Option<bool>,
117 market_bus_capacity: Option<usize>,
118 signal_bus_capacity: Option<usize>,
119 close_positions_on_shutdown: Option<bool>,
120 risk: RiskConfig,
121}
122
123impl BotConfigBuilder {
124 pub fn name(mut self, name: impl Into<String>) -> Self {
126 self.name = Some(name.into());
127 self
128 }
129
130 pub fn symbol(mut self, sym: impl Into<Symbol>) -> Self {
132 self.symbols.push(sym.into());
133 self
134 }
135
136 pub fn symbols<I, S>(mut self, syms: I) -> Self
138 where
139 I: IntoIterator<Item = S>,
140 S: Into<Symbol>,
141 {
142 self.symbols.extend(syms.into_iter().map(Into::into));
143 self
144 }
145
146 pub fn shutdown_timeout(mut self, dur: Duration) -> Self {
148 self.shutdown_timeout = Some(dur);
149 self
150 }
151
152 pub fn without_signal_handler(mut self) -> Self {
154 self.install_signal_handler = Some(false);
155 self
156 }
157
158 pub fn market_bus_capacity(mut self, cap: usize) -> Self {
160 self.market_bus_capacity = Some(cap);
161 self
162 }
163
164 pub fn signal_bus_capacity(mut self, cap: usize) -> Self {
166 self.signal_bus_capacity = Some(cap);
167 self
168 }
169
170 pub fn close_positions_on_shutdown(mut self, b: bool) -> Self {
173 self.close_positions_on_shutdown = Some(b);
174 self
175 }
176
177 pub fn session_pnl_config(mut self, cfg: SessionPnlConfig) -> Self {
179 self.risk.session_pnl = cfg;
180 self
181 }
182
183 pub fn circuit_breaker_config(mut self, cfg: CircuitBreakerConfig) -> Self {
185 self.risk.circuit_breaker = cfg;
186 self
187 }
188
189 pub fn sizing_config(mut self, cfg: SizingConfig) -> Self {
191 self.risk.sizing = cfg;
192 self
193 }
194
195 pub fn build(self) -> Result<BotConfig> {
198 let name = self
199 .name
200 .filter(|n| !n.trim().is_empty())
201 .ok_or_else(|| Error::config("BotConfig.name is required and must not be empty"))?;
202
203 if self.symbols.is_empty() {
204 return Err(Error::config(
205 "BotConfig.symbols must contain at least one Symbol — \
206 the position cache and risk-state map are pre-seeded per symbol",
207 ));
208 }
209
210 let market_bus_capacity = self
211 .market_bus_capacity
212 .unwrap_or(DEFAULT_MARKET_BUS_CAPACITY);
213 if market_bus_capacity == 0 {
214 return Err(Error::config(
215 "BotConfig.market_bus_capacity must be > 0 (broadcast channel cannot have 0 slots)",
216 ));
217 }
218
219 let signal_bus_capacity = self
220 .signal_bus_capacity
221 .unwrap_or(DEFAULT_SIGNAL_BUS_CAPACITY);
222 if signal_bus_capacity == 0 {
223 return Err(Error::config(
224 "BotConfig.signal_bus_capacity must be > 0 (broadcast channel cannot have 0 slots)",
225 ));
226 }
227
228 let shutdown_timeout = self.shutdown_timeout.unwrap_or(Duration::from_secs(30));
229 if shutdown_timeout.is_zero() {
230 return Err(Error::config(
231 "BotConfig.shutdown_timeout must be > 0 — drain needs a non-zero deadline",
232 ));
233 }
234
235 if self.risk.session_pnl.loss_limit.is_nan() {
236 return Err(Error::config(
237 "BotConfig.risk.session_pnl.loss_limit must not be NaN",
238 ));
239 }
240 if !self.risk.sizing.margin_per_trade.is_finite() || self.risk.sizing.margin_per_trade < 0.0
241 {
242 return Err(Error::config(
243 "BotConfig.risk.sizing.margin_per_trade must be a finite non-negative number",
244 ));
245 }
246
247 Ok(BotConfig {
248 name,
249 symbols: self.symbols,
250 shutdown_timeout,
251 install_signal_handler: self.install_signal_handler.unwrap_or(true),
252 market_bus_capacity,
253 signal_bus_capacity,
254 close_positions_on_shutdown: self.close_positions_on_shutdown.unwrap_or(false),
255 risk: self.risk,
256 })
257 }
258}
259
260pub struct Bot {
299 config: BotConfig,
300 supervisor: Arc<Supervisor>,
301 exchange: Arc<dyn ExchangeClient>,
302 brains: Arc<Vec<Arc<dyn Brain>>>,
303 market_bus: MarketDataBus,
304 signal_bus: SignalBus,
305 positions: PositionCache,
306 risk: RiskStateMap,
307 metrics: Arc<dyn MetricsSink>,
308 state_store: Option<Arc<dyn StateStore>>,
309 persister_slot: crate::handle::PersisterSlot,
310 handle: BotHandle,
311 external_cancel: Option<CancellationToken>,
312 market_source: Option<Arc<dyn MarketSource>>,
313 fill_source: Option<Arc<dyn FillSource>>,
314 candle_pollers: Vec<CandlePollerSpec>,
315 order_tracker: OrderTracker,
316 order_tracking: Option<OrderTrackingSpec>,
317}
318
319struct OrderTrackingSpec {
321 ttl: Duration,
322 poll_cadence: Duration,
323}
324
325struct CandlePollerSpec {
327 source: Arc<dyn CandleSource>,
328 symbol: Symbol,
329 interval: Duration,
330 poll_cadence: Duration,
331 limit: usize,
332}
333
334impl Bot {
335 pub fn new(
340 config: BotConfig,
341 exchange: Arc<dyn ExchangeClient>,
342 brains: Vec<Arc<dyn Brain>>,
343 ) -> Result<Self> {
344 if brains.is_empty() {
345 return Err(Error::config(
346 "Bot::new requires at least one Brain — empty brain list",
347 ));
348 }
349
350 let supervisor = Arc::new(Supervisor::new(
351 SupervisorConfig::default()
352 .with_shutdown_timeout(config.shutdown_timeout)
353 .with_default_backoff(Default::default())
354 .pipe(|c| {
355 if config.install_signal_handler {
356 c
357 } else {
358 c.without_signal_handler()
359 }
360 }),
361 ));
362
363 let market_bus = MarketDataBus::with_capacity(config.market_bus_capacity);
364 let signal_bus = SignalBus::with_capacity(config.signal_bus_capacity);
365 let positions = build_position_cache(&config.symbols);
366 let risk = build_risk_state(
367 &config.symbols,
368 &config.risk.session_pnl,
369 &config.risk.circuit_breaker,
370 );
371
372 let brains = Arc::new(brains);
373 let persister_slot: crate::handle::PersisterSlot = Arc::new(std::sync::OnceLock::new());
374 let order_tracker = OrderTracker::new();
375 let handle = BotHandle::new(
376 supervisor.clone(),
377 brains.clone(),
378 risk.clone(),
379 positions.clone(),
380 signal_bus.clone(),
381 persister_slot.clone(),
382 order_tracker.clone(),
383 );
384
385 Ok(Self {
386 config,
387 supervisor,
388 exchange,
389 brains,
390 market_bus,
391 signal_bus,
392 positions,
393 risk,
394 metrics: Arc::new(NoopSink),
395 state_store: None,
396 persister_slot,
397 order_tracker,
398 handle,
399 external_cancel: None,
400 market_source: None,
401 fill_source: None,
402 candle_pollers: Vec::new(),
403 order_tracking: None,
404 })
405 }
406
407 pub fn with_metrics(mut self, sink: Arc<dyn MetricsSink>) -> Self {
411 self.metrics = sink;
412 self
413 }
414
415 pub fn with_state_store(mut self, store: Arc<dyn StateStore>) -> Self {
436 self.state_store = Some(store);
437 self
438 }
439
440 pub fn with_order_tracking(mut self, ttl: Duration, poll_cadence: Duration) -> Self {
457 self.order_tracking = Some(OrderTrackingSpec { ttl, poll_cadence });
458 self
459 }
460
461 pub fn with_candle_poller(
466 mut self,
467 source: Arc<dyn CandleSource>,
468 symbol: impl Into<Symbol>,
469 interval: Duration,
470 poll_cadence: Duration,
471 limit: usize,
472 ) -> Self {
473 self.candle_pollers.push(CandlePollerSpec {
474 source,
475 symbol: symbol.into(),
476 interval,
477 poll_cadence,
478 limit,
479 });
480 self
481 }
482
483 pub fn with_external_cancel(mut self, token: CancellationToken) -> Self {
492 self.external_cancel = Some(token);
493 self
494 }
495
496 pub fn with_market_source(mut self, source: Arc<dyn MarketSource>) -> Self {
501 self.market_source = Some(source);
502 self
503 }
504
505 pub fn with_fill_source(mut self, source: Arc<dyn FillSource>) -> Self {
510 self.fill_source = Some(source);
511 self
512 }
513
514 pub fn handle(&self) -> BotHandle {
518 self.handle.clone()
519 }
520
521 pub fn config(&self) -> &BotConfig {
523 &self.config
524 }
525
526 pub fn market_data_bus(&self) -> &MarketDataBus {
529 &self.market_bus
530 }
531
532 pub fn signal_bus(&self) -> &SignalBus {
537 &self.signal_bus
538 }
539
540 pub async fn run_until_shutdown(self) -> anyhow::Result<()> {
580 tracing::info!(
581 bot = %self.config.name,
582 brains = self.brains.len(),
583 symbols = self.config.symbols.len(),
584 exchange = %self.exchange.name(),
585 "rustrade Bot starting"
586 );
587
588 self.prefetch_positions().await;
590
591 let persister = self
595 .state_store
596 .clone()
597 .map(|store| RiskPersister::new(store, self.config.name.clone()));
598 if let Some(p) = &persister {
599 p.restore_into(&self.risk).await;
600 let _ = self.persister_slot.set(p.clone());
602 }
603
604 let order_tracking_active =
608 self.order_tracking.is_some() && self.exchange.supports(Capability::OrderTracking);
609 if self.order_tracking.is_some() && !order_tracking_active {
610 tracing::warn!(
611 exchange = %self.exchange.name(),
612 "order tracking requested but adapter lacks Capability::OrderTracking — \
613 resting orders will NOT be tracked or aged out"
614 );
615 }
616
617 let sizing = Arc::new(self.config.risk.sizing.clone());
618 let ctx = ExecutionContext {
619 exchange: self.exchange.clone(),
620 bus: self.market_bus.clone(),
621 signals: self.signal_bus.clone(),
622 positions: self.positions.clone(),
623 risk: self.risk.clone(),
624 sizing,
625 order_tracker: order_tracking_active.then(|| self.order_tracker.clone()),
626 };
627
628 for brain in self.brains.iter() {
629 let svc = ExecutionService::new(brain.clone(), ctx.clone());
630 self.supervisor.spawn_service(Box::new(svc));
631 }
632
633 if order_tracking_active {
634 let spec = self.order_tracking.as_ref().unwrap();
636 self.supervisor
637 .spawn_service(Box::new(OrderReaperService::new(
638 self.exchange.clone(),
639 self.order_tracker.clone(),
640 self.config.symbols.clone(),
641 spec.ttl,
642 spec.poll_cadence,
643 self.metrics.clone(),
644 )));
645 }
646
647 if let Some(source) = self.market_source.clone() {
648 self.supervisor
649 .spawn_service(Box::new(MarketFeedService::new(source)));
650 }
651
652 if let Some(source) = self.fill_source.clone() {
653 self.supervisor
654 .spawn_service(Box::new(FillRoutingService::new(
655 source,
656 self.brains.clone(),
657 self.exchange.clone(),
658 self.positions.clone(),
659 self.risk.clone(),
660 self.metrics.clone(),
661 persister.clone(),
662 )));
663 }
664
665 for spec in &self.candle_pollers {
666 self.supervisor
667 .spawn_service(Box::new(CandlePollerService::new(
668 spec.source.clone(),
669 spec.symbol.clone(),
670 spec.interval,
671 spec.poll_cadence,
672 spec.limit,
673 self.market_bus.clone(),
674 self.metrics.clone(),
675 )));
676 }
677
678 if let Some(external) = self.external_cancel.clone() {
681 let supervisor = self.supervisor.clone();
682 tokio::spawn(async move {
683 external.cancelled().await;
684 tracing::info!("external cancellation received; triggering bot shutdown");
685 supervisor.trigger_shutdown();
686 });
687 }
688
689 let run_result = self.supervisor.run_until_shutdown().await;
690
691 if self.config.close_positions_on_shutdown {
692 self.close_open_positions().await;
693 }
694
695 if let Some(p) = &persister {
698 p.persist_all(&self.risk).await;
699 }
700
701 for brain in self.brains.iter() {
702 let health = brain.health().await;
703 tracing::info!(
704 brain = %brain.name(),
705 healthy = health.healthy,
706 events = health.events_processed,
707 non_hold = health.non_hold_decisions,
708 "final brain health"
709 );
710 }
711
712 tracing::info!(bot = %self.config.name, "rustrade Bot exited");
713 run_result
714 }
715
716 async fn prefetch_positions(&self) {
717 for symbol in &self.config.symbols {
718 match self.exchange.get_position(symbol).await {
719 Ok(pos) => {
720 self.positions.write().await.insert(symbol.clone(), pos);
721 tracing::debug!(
722 symbol = %symbol,
723 qty = pos.qty,
724 "prefetched position from exchange"
725 );
726 }
727 Err(e) => {
728 tracing::warn!(
729 symbol = %symbol,
730 error = %e,
731 "failed to prefetch position; cache defaults to FLAT"
732 );
733 }
734 }
735 }
736 }
737
738 async fn close_open_positions(&self) {
739 let snapshot: Vec<(Symbol, Position)> = {
740 let map = self.positions.read().await;
741 map.iter()
742 .filter(|(_, p)| !p.is_flat())
743 .map(|(s, p)| (s.clone(), *p))
744 .collect()
745 };
746
747 if snapshot.is_empty() {
748 tracing::info!("close_positions_on_shutdown: no open positions");
749 return;
750 }
751
752 for (symbol, position) in snapshot {
753 match self.exchange.close_position(&symbol, &position).await {
754 Ok(order_id) => tracing::info!(
755 symbol = %symbol,
756 qty = position.qty,
757 order_id = %order_id,
758 "close_positions_on_shutdown: closed"
759 ),
760 Err(e) => tracing::error!(
761 symbol = %symbol,
762 qty = position.qty,
763 error = %e,
764 "close_positions_on_shutdown: failed (best-effort)"
765 ),
766 }
767 }
768 }
769}
770
771trait Pipe: Sized {
774 fn pipe<F: FnOnce(Self) -> Self>(self, f: F) -> Self {
775 f(self)
776 }
777}
778impl<T> Pipe for T {}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783 use async_trait::async_trait;
784 use rustrade_core::{Fill, MarketDataEvent, Order, Position};
785
786 struct NoopBrain;
787 #[async_trait]
788 impl Brain for NoopBrain {
789 fn name(&self) -> &str {
790 "noop"
791 }
792 async fn on_event(
793 &self,
794 _e: &MarketDataEvent,
795 _p: &Position,
796 ) -> Result<rustrade_core::Decision> {
797 Ok(rustrade_core::Decision::hold())
798 }
799 }
800
801 struct NoopExchange;
802 #[async_trait]
803 impl ExchangeClient for NoopExchange {
804 fn name(&self) -> &str {
805 "noop"
806 }
807 async fn place_order(&self, _o: &Order) -> Result<String> {
808 Ok("noop-1".into())
809 }
810 async fn cancel_all(&self, _s: &Symbol) -> Result<usize> {
811 Ok(0)
812 }
813 async fn close_position(&self, _s: &Symbol, _p: &Position) -> Result<String> {
814 Ok("noop-close".into())
815 }
816 async fn get_position(&self, _s: &Symbol) -> Result<Position> {
817 Ok(Position::FLAT)
818 }
819 async fn get_balance(&self, _c: &str) -> Result<f64> {
820 Ok(0.0)
821 }
822 }
823
824 fn cfg() -> BotConfig {
825 BotConfig::builder()
826 .name("test")
827 .symbol("BTCUSDT")
828 .without_signal_handler()
829 .build()
830 .unwrap()
831 }
832
833 #[test]
834 fn builder_requires_name() {
835 let err = BotConfig::builder().build().unwrap_err();
836 assert!(matches!(err, Error::Config(_)), "got {err:?}");
837 }
838
839 #[test]
840 fn builder_rejects_blank_name() {
841 let err = BotConfig::builder().name(" ").build().unwrap_err();
842 assert!(matches!(err, Error::Config(_)), "got {err:?}");
843 }
844
845 #[test]
846 fn builder_rejects_zero_market_bus_capacity() {
847 let err = BotConfig::builder()
848 .name("x")
849 .symbol("BTCUSDT")
850 .market_bus_capacity(0)
851 .build()
852 .unwrap_err();
853 assert!(matches!(err, Error::Config(_)));
854 }
855
856 #[test]
857 fn builder_rejects_zero_signal_bus_capacity() {
858 let err = BotConfig::builder()
859 .name("x")
860 .symbol("BTCUSDT")
861 .signal_bus_capacity(0)
862 .build()
863 .unwrap_err();
864 assert!(matches!(err, Error::Config(_)));
865 }
866
867 #[test]
868 fn builder_rejects_empty_symbol_list() {
869 let err = BotConfig::builder().name("x").build().unwrap_err();
870 assert!(matches!(err, Error::Config(_)));
871 }
872
873 #[test]
874 fn builder_rejects_zero_shutdown_timeout() {
875 let err = BotConfig::builder()
876 .name("x")
877 .symbol("BTCUSDT")
878 .shutdown_timeout(Duration::ZERO)
879 .build()
880 .unwrap_err();
881 assert!(matches!(err, Error::Config(_)));
882 }
883
884 #[test]
885 fn builder_rejects_nan_loss_limit() {
886 let err = BotConfig::builder()
887 .name("x")
888 .symbol("BTCUSDT")
889 .session_pnl_config(SessionPnlConfig {
890 loss_limit: f64::NAN,
891 })
892 .build()
893 .unwrap_err();
894 assert!(matches!(err, Error::Config(_)));
895 }
896
897 #[test]
898 fn builder_rejects_non_finite_margin() {
899 let err = BotConfig::builder()
900 .name("x")
901 .symbol("BTCUSDT")
902 .sizing_config(SizingConfig {
903 margin_per_trade: f64::INFINITY,
904 leverage: 1,
905 max_contracts: 1,
906 })
907 .build()
908 .unwrap_err();
909 assert!(matches!(err, Error::Config(_)));
910 }
911
912 #[test]
913 fn builder_accumulates_symbols() {
914 let c = BotConfig::builder()
915 .name("x")
916 .symbol("A")
917 .symbols(["B", "C"])
918 .build()
919 .unwrap();
920 assert_eq!(c.symbols.len(), 3);
921 assert_eq!(c.symbols[0], Symbol::new("A"));
922 assert_eq!(c.symbols[2], Symbol::new("C"));
923 }
924
925 #[test]
926 fn builder_accepts_risk_overrides() {
927 let c = BotConfig::builder()
928 .name("x")
929 .symbol("BTCUSDT")
930 .session_pnl_config(SessionPnlConfig { loss_limit: -123.0 })
931 .sizing_config(SizingConfig {
932 margin_per_trade: 250.0,
933 leverage: 10,
934 max_contracts: 5,
935 })
936 .build()
937 .unwrap();
938 assert_eq!(c.risk.session_pnl.loss_limit, -123.0);
939 assert_eq!(c.risk.sizing.leverage, 10);
940 }
941
942 #[test]
943 fn builder_has_separate_default_bus_capacities() {
944 let c = BotConfig::builder()
945 .name("x")
946 .symbol("BTCUSDT")
947 .build()
948 .unwrap();
949 assert_eq!(c.market_bus_capacity, 1024);
950 assert_eq!(c.signal_bus_capacity, 256);
951 }
952
953 #[tokio::test]
954 async fn bot_requires_at_least_one_brain() {
955 match Bot::new(cfg(), Arc::new(NoopExchange), vec![]) {
956 Err(Error::Config(_)) => {}
957 other => panic!(
958 "expected Error::Config for empty brain list, got {:?}",
959 other.map(|_| "Ok(Bot)").map_err(|e| format!("Err({e})"))
960 ),
961 }
962 }
963
964 #[tokio::test]
965 async fn bot_constructs_and_exposes_handle() {
966 let bot = Bot::new(cfg(), Arc::new(NoopExchange), vec![Arc::new(NoopBrain)]).unwrap();
967 let handle = bot.handle();
968 assert!(!handle.is_shutting_down());
969 assert_eq!(bot.config().name, "test");
970 let h2 = handle.clone();
971 assert!(!h2.is_shutting_down());
972 }
973
974 #[allow(dead_code)]
975 fn _noop_fill_compiles(_: &Fill) {}
976}