use std::sync::{Arc, OnceLock};
use rustrade_core::{Brain, Position, Signal, SignalBus, Symbol};
use rustrade_supervisor::{ServiceLifecycleSnapshot, Supervisor};
use tokio::sync::broadcast;
use tokio_util::sync::CancellationToken;
use crate::risk_state::{PositionCache, RiskPersister, RiskStateMap};
pub(crate) type PersisterSlot = Arc<OnceLock<RiskPersister>>;
#[derive(Debug, Clone)]
pub struct BotHealth {
pub healthy: bool,
pub shutting_down: bool,
pub services: Vec<ServiceLifecycleSnapshot>,
pub brains: Vec<BrainHealthSnapshot>,
}
#[derive(Debug, Clone)]
pub struct BrainHealthSnapshot {
pub name: String,
pub healthy: bool,
pub events_processed: u64,
pub non_hold_decisions: u64,
pub details: serde_json::Value,
}
#[derive(Clone)]
pub struct BotHandle {
cancel: CancellationToken,
supervisor: Arc<Supervisor>,
brains: Arc<Vec<Arc<dyn Brain>>>,
risk: RiskStateMap,
positions: PositionCache,
signals: SignalBus,
persister: PersisterSlot,
order_tracker: crate::order_tracker::OrderTracker,
}
impl BotHandle {
pub(crate) fn new(
supervisor: Arc<Supervisor>,
brains: Arc<Vec<Arc<dyn Brain>>>,
risk: RiskStateMap,
positions: PositionCache,
signals: SignalBus,
persister: PersisterSlot,
order_tracker: crate::order_tracker::OrderTracker,
) -> Self {
Self {
cancel: supervisor.cancel_token().clone(),
supervisor,
brains,
risk,
positions,
signals,
persister,
order_tracker,
}
}
pub async fn tracked_orders(&self) -> Vec<crate::order_tracker::TrackedOrder> {
self.order_tracker.snapshot().await
}
pub fn shutdown(&self) {
self.cancel.cancel();
}
pub fn is_shutting_down(&self) -> bool {
self.cancel.is_cancelled()
}
pub async fn await_shutdown(&self) {
self.cancel.cancelled().await;
}
pub async fn record_trade_outcome(&self, symbol: &Symbol, gross_pnl: f64, fee: f64) {
if !gross_pnl.is_finite() || !fee.is_finite() {
tracing::error!(
symbol = %symbol,
gross_pnl,
fee,
"record_trade_outcome: non-finite PnL rejected — risk state unchanged"
);
return;
}
{
let mut map = self.risk.write().await;
let Some(risk) = map.get_mut(symbol) else {
tracing::warn!(
symbol = %symbol,
"record_trade_outcome: symbol not in risk state map (was it configured?)"
);
return;
};
risk.session_pnl.record_close(gross_pnl, fee);
let net = gross_pnl - fee;
if net > 0.0 {
risk.circuit_breaker.record_win();
} else if net < 0.0 {
risk.circuit_breaker.record_loss();
}
}
if let Some(persister) = self.persister.get() {
persister.persist_symbol(&self.risk, symbol).await;
}
}
pub async fn position(&self, symbol: &Symbol) -> Position {
self.positions
.read()
.await
.get(symbol)
.copied()
.unwrap_or(Position::FLAT)
}
pub async fn set_position(&self, symbol: &Symbol, position: Position) {
self.positions
.write()
.await
.insert(symbol.clone(), position);
}
pub fn subscribe_signals(&self) -> broadcast::Receiver<Signal> {
self.signals.subscribe()
}
pub fn signal_subscriber_count(&self) -> usize {
self.signals.subscriber_count()
}
pub async fn health(&self) -> BotHealth {
let services = self.supervisor.lifecycle_snapshots().await;
let mut brains = Vec::with_capacity(self.brains.len());
for brain in self.brains.iter() {
let h = brain.health().await;
brains.push(BrainHealthSnapshot {
name: brain.name().to_string(),
healthy: h.healthy,
events_processed: h.events_processed,
non_hold_decisions: h.non_hold_decisions,
details: h.details,
});
}
let all_services_alive = services
.iter()
.all(|s| !matches!(s.phase, rustrade_supervisor::ServicePhase::Terminated));
let all_brains_healthy = brains.iter().all(|b| b.healthy);
BotHealth {
healthy: all_services_alive && all_brains_healthy,
shutting_down: self.is_shutting_down(),
services,
brains,
}
}
}
impl std::fmt::Debug for BotHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BotHandle")
.field("shutting_down", &self.is_shutting_down())
.field("brain_count", &self.brains.len())
.finish_non_exhaustive()
}
}