use crate::Decimal;
use crate::types::error::{MMError, MMResult};
use std::collections::VecDeque;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum TriggerReason {
MaxDailyLoss,
VolatilitySpike,
ConsecutiveLosses,
RapidDrawdown,
Manual,
}
impl std::fmt::Display for TriggerReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MaxDailyLoss => write!(f, "max daily loss exceeded"),
Self::VolatilitySpike => write!(f, "volatility spike detected"),
Self::ConsecutiveLosses => write!(f, "consecutive losses limit reached"),
Self::RapidDrawdown => write!(f, "rapid drawdown detected"),
Self::Manual => write!(f, "manually triggered"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CircuitBreakerState {
Active,
Triggered {
reason: TriggerReason,
triggered_at: u64,
},
Cooldown {
resume_at: u64,
},
}
impl CircuitBreakerState {
#[must_use]
pub fn is_active(&self) -> bool {
matches!(self, Self::Active)
}
#[must_use]
pub fn is_triggered(&self) -> bool {
matches!(self, Self::Triggered { .. })
}
#[must_use]
pub fn is_cooldown(&self) -> bool {
matches!(self, Self::Cooldown { .. })
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CircuitBreakerConfig {
pub max_daily_loss: Decimal,
pub max_volatility: Decimal,
pub max_consecutive_losses: u32,
pub rapid_drawdown_threshold: Decimal,
pub rapid_drawdown_window_ms: u64,
pub cooldown_duration_ms: u64,
}
impl CircuitBreakerConfig {
pub fn new(
max_daily_loss: Decimal,
max_volatility: Decimal,
max_consecutive_losses: u32,
rapid_drawdown_threshold: Decimal,
rapid_drawdown_window_ms: u64,
cooldown_duration_ms: u64,
) -> MMResult<Self> {
if max_daily_loss <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"max_daily_loss must be positive".to_string(),
));
}
if max_volatility <= Decimal::ZERO {
return Err(MMError::InvalidConfiguration(
"max_volatility must be positive".to_string(),
));
}
if max_consecutive_losses == 0 {
return Err(MMError::InvalidConfiguration(
"max_consecutive_losses must be greater than 0".to_string(),
));
}
if rapid_drawdown_threshold <= Decimal::ZERO || rapid_drawdown_threshold > Decimal::ONE {
return Err(MMError::InvalidConfiguration(
"rapid_drawdown_threshold must be between 0 and 1 (exclusive of 0)".to_string(),
));
}
if rapid_drawdown_window_ms == 0 {
return Err(MMError::InvalidConfiguration(
"rapid_drawdown_window_ms must be greater than 0".to_string(),
));
}
Ok(Self {
max_daily_loss,
max_volatility,
max_consecutive_losses,
rapid_drawdown_threshold,
rapid_drawdown_window_ms,
cooldown_duration_ms,
})
}
}
#[derive(Debug, Clone)]
struct EquitySnapshot {
equity: Decimal,
timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
state: CircuitBreakerState,
daily_loss: Decimal,
consecutive_losses: u32,
current_volatility: Decimal,
equity_history: VecDeque<EquitySnapshot>,
peak_equity: Decimal,
current_equity: Decimal,
}
impl CircuitBreaker {
#[must_use]
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
config,
state: CircuitBreakerState::Active,
daily_loss: Decimal::ZERO,
consecutive_losses: 0,
current_volatility: Decimal::ZERO,
equity_history: VecDeque::new(),
peak_equity: Decimal::ZERO,
current_equity: Decimal::ZERO,
}
}
#[must_use]
pub fn with_initial_equity(
config: CircuitBreakerConfig,
initial_equity: Decimal,
timestamp: u64,
) -> Self {
let mut breaker = Self::new(config);
breaker.current_equity = initial_equity;
breaker.peak_equity = initial_equity;
breaker.equity_history.push_back(EquitySnapshot {
equity: initial_equity,
timestamp,
});
breaker
}
#[must_use]
pub fn state(&self) -> &CircuitBreakerState {
&self.state
}
#[must_use]
pub fn daily_loss(&self) -> Decimal {
self.daily_loss
}
#[must_use]
pub fn consecutive_losses(&self) -> u32 {
self.consecutive_losses
}
#[must_use]
pub fn is_trading_allowed(&self) -> bool {
self.state.is_active()
}
pub fn record_trade(&mut self, pnl: Decimal, timestamp: u64) -> CircuitBreakerState {
if !self.state.is_active() {
return self.state.clone();
}
if pnl < Decimal::ZERO {
self.daily_loss += pnl.abs();
self.consecutive_losses += 1;
} else {
self.consecutive_losses = 0;
}
self.current_equity += pnl;
if self.current_equity > self.peak_equity {
self.peak_equity = self.current_equity;
}
self.equity_history.push_back(EquitySnapshot {
equity: self.current_equity,
timestamp,
});
self.prune_equity_history(timestamp);
self.check_triggers(timestamp);
self.state.clone()
}
pub fn update_volatility(
&mut self,
volatility: Decimal,
timestamp: u64,
) -> CircuitBreakerState {
self.current_volatility = volatility;
if self.state.is_active() {
self.check_triggers(timestamp);
}
self.state.clone()
}
pub fn trigger_manual(&mut self, timestamp: u64) {
self.state = CircuitBreakerState::Triggered {
reason: TriggerReason::Manual,
triggered_at: timestamp,
};
}
pub fn start_cooldown(&mut self, timestamp: u64) {
if self.state.is_triggered() {
self.state = CircuitBreakerState::Cooldown {
resume_at: timestamp + self.config.cooldown_duration_ms,
};
}
}
pub fn check_cooldown(&mut self, current_time: u64) -> CircuitBreakerState {
if let CircuitBreakerState::Cooldown { resume_at } = self.state
&& current_time >= resume_at
{
self.state = CircuitBreakerState::Active;
}
self.state.clone()
}
pub fn reset(&mut self, initial_equity: Decimal, timestamp: u64) {
self.state = CircuitBreakerState::Active;
self.daily_loss = Decimal::ZERO;
self.consecutive_losses = 0;
self.current_volatility = Decimal::ZERO;
self.equity_history.clear();
self.peak_equity = initial_equity;
self.current_equity = initial_equity;
self.equity_history.push_back(EquitySnapshot {
equity: initial_equity,
timestamp,
});
}
#[must_use]
pub fn config(&self) -> &CircuitBreakerConfig {
&self.config
}
fn check_triggers(&mut self, timestamp: u64) {
if self.daily_loss >= self.config.max_daily_loss {
self.state = CircuitBreakerState::Triggered {
reason: TriggerReason::MaxDailyLoss,
triggered_at: timestamp,
};
return;
}
if self.current_volatility > self.config.max_volatility {
self.state = CircuitBreakerState::Triggered {
reason: TriggerReason::VolatilitySpike,
triggered_at: timestamp,
};
return;
}
if self.consecutive_losses >= self.config.max_consecutive_losses {
self.state = CircuitBreakerState::Triggered {
reason: TriggerReason::ConsecutiveLosses,
triggered_at: timestamp,
};
return;
}
if self.check_rapid_drawdown() {
self.state = CircuitBreakerState::Triggered {
reason: TriggerReason::RapidDrawdown,
triggered_at: timestamp,
};
}
}
fn check_rapid_drawdown(&self) -> bool {
if self.equity_history.is_empty() {
return false;
}
let window_peak = self
.equity_history
.iter()
.map(|s| s.equity)
.max()
.unwrap_or(self.current_equity);
if window_peak <= Decimal::ZERO {
return false;
}
let drawdown = (window_peak - self.current_equity) / window_peak;
drawdown >= self.config.rapid_drawdown_threshold
}
fn prune_equity_history(&mut self, current_time: u64) {
let cutoff = current_time.saturating_sub(self.config.rapid_drawdown_window_ms);
while let Some(front) = self.equity_history.front() {
if front.timestamp < cutoff {
self.equity_history.pop_front();
} else {
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dec;
fn default_config() -> CircuitBreakerConfig {
CircuitBreakerConfig::new(
dec!(1000.0), dec!(0.05), 5, dec!(0.10), 300_000, 60_000, )
.unwrap()
}
#[test]
fn test_config_valid() {
let config = default_config();
assert_eq!(config.max_daily_loss, dec!(1000.0));
assert_eq!(config.max_consecutive_losses, 5);
}
#[test]
fn test_config_invalid_max_daily_loss() {
let result =
CircuitBreakerConfig::new(dec!(0.0), dec!(0.05), 5, dec!(0.10), 300_000, 60_000);
assert!(result.is_err());
let result =
CircuitBreakerConfig::new(dec!(-100.0), dec!(0.05), 5, dec!(0.10), 300_000, 60_000);
assert!(result.is_err());
}
#[test]
fn test_config_invalid_max_volatility() {
let result =
CircuitBreakerConfig::new(dec!(1000.0), dec!(0.0), 5, dec!(0.10), 300_000, 60_000);
assert!(result.is_err());
}
#[test]
fn test_config_invalid_max_consecutive_losses() {
let result =
CircuitBreakerConfig::new(dec!(1000.0), dec!(0.05), 0, dec!(0.10), 300_000, 60_000);
assert!(result.is_err());
}
#[test]
fn test_config_invalid_rapid_drawdown_threshold() {
let result =
CircuitBreakerConfig::new(dec!(1000.0), dec!(0.05), 5, dec!(0.0), 300_000, 60_000);
assert!(result.is_err());
let result =
CircuitBreakerConfig::new(dec!(1000.0), dec!(0.05), 5, dec!(1.1), 300_000, 60_000);
assert!(result.is_err());
}
#[test]
fn test_config_invalid_rapid_drawdown_window() {
let result = CircuitBreakerConfig::new(dec!(1000.0), dec!(0.05), 5, dec!(0.10), 0, 60_000);
assert!(result.is_err());
}
#[test]
fn test_new_breaker_is_active() {
let breaker = CircuitBreaker::new(default_config());
assert!(breaker.is_trading_allowed());
assert!(breaker.state().is_active());
}
#[test]
fn test_with_initial_equity() {
let breaker = CircuitBreaker::with_initial_equity(default_config(), dec!(10000.0), 0);
assert!(breaker.is_trading_allowed());
assert_eq!(breaker.current_equity, dec!(10000.0));
assert_eq!(breaker.peak_equity, dec!(10000.0));
}
#[test]
fn test_record_winning_trade() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.record_trade(dec!(100.0), 1000);
assert!(breaker.is_trading_allowed());
assert_eq!(breaker.daily_loss(), dec!(0.0));
assert_eq!(breaker.consecutive_losses(), 0);
}
#[test]
fn test_record_losing_trade() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.record_trade(dec!(-100.0), 1000);
assert!(breaker.is_trading_allowed());
assert_eq!(breaker.daily_loss(), dec!(100.0));
assert_eq!(breaker.consecutive_losses(), 1);
}
#[test]
fn test_consecutive_losses_reset_on_win() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.record_trade(dec!(-100.0), 1000);
breaker.record_trade(dec!(-100.0), 2000);
assert_eq!(breaker.consecutive_losses(), 2);
breaker.record_trade(dec!(50.0), 3000);
assert_eq!(breaker.consecutive_losses(), 0);
}
#[test]
fn test_trigger_max_daily_loss() {
let config =
CircuitBreakerConfig::new(dec!(100.0), dec!(0.05), 10, dec!(0.10), 300_000, 60_000)
.unwrap();
let mut breaker = CircuitBreaker::new(config);
let state = breaker.record_trade(dec!(-100.0), 1000);
assert!(state.is_triggered());
assert!(!breaker.is_trading_allowed());
if let CircuitBreakerState::Triggered { reason, .. } = state {
assert_eq!(reason, TriggerReason::MaxDailyLoss);
} else {
panic!("Expected Triggered state");
}
}
#[test]
fn test_trigger_consecutive_losses() {
let config =
CircuitBreakerConfig::new(dec!(10000.0), dec!(0.05), 3, dec!(0.10), 300_000, 60_000)
.unwrap();
let mut breaker = CircuitBreaker::new(config);
breaker.record_trade(dec!(-10.0), 1000);
breaker.record_trade(dec!(-10.0), 2000);
assert!(breaker.is_trading_allowed());
let state = breaker.record_trade(dec!(-10.0), 3000);
assert!(state.is_triggered());
if let CircuitBreakerState::Triggered { reason, .. } = state {
assert_eq!(reason, TriggerReason::ConsecutiveLosses);
}
}
#[test]
fn test_trigger_volatility_spike() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.update_volatility(dec!(0.04), 1000);
assert!(breaker.is_trading_allowed());
let state = breaker.update_volatility(dec!(0.06), 2000);
assert!(state.is_triggered());
if let CircuitBreakerState::Triggered { reason, .. } = state {
assert_eq!(reason, TriggerReason::VolatilitySpike);
}
}
#[test]
fn test_trigger_rapid_drawdown() {
let config =
CircuitBreakerConfig::new(dec!(10000.0), dec!(1.0), 100, dec!(0.10), 300_000, 60_000)
.unwrap();
let mut breaker = CircuitBreaker::with_initial_equity(config, dec!(1000.0), 0);
let state = breaker.record_trade(dec!(-100.0), 1000);
assert!(state.is_triggered());
if let CircuitBreakerState::Triggered { reason, .. } = state {
assert_eq!(reason, TriggerReason::RapidDrawdown);
}
}
#[test]
fn test_manual_trigger() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.trigger_manual(1000);
assert!(!breaker.is_trading_allowed());
if let CircuitBreakerState::Triggered {
reason,
triggered_at,
} = breaker.state()
{
assert_eq!(*reason, TriggerReason::Manual);
assert_eq!(*triggered_at, 1000);
} else {
panic!("Expected Triggered state");
}
}
#[test]
fn test_cooldown_transition() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.trigger_manual(1000);
assert!(breaker.state().is_triggered());
breaker.start_cooldown(2000);
assert!(breaker.state().is_cooldown());
if let CircuitBreakerState::Cooldown { resume_at } = breaker.state() {
assert_eq!(*resume_at, 2000 + 60_000); }
}
#[test]
fn test_cooldown_expiry() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.trigger_manual(1000);
breaker.start_cooldown(2000);
breaker.check_cooldown(50_000);
assert!(breaker.state().is_cooldown());
breaker.check_cooldown(70_000);
assert!(breaker.state().is_active());
assert!(breaker.is_trading_allowed());
}
#[test]
fn test_reset() {
let mut breaker = CircuitBreaker::new(default_config());
breaker.record_trade(dec!(-500.0), 1000);
breaker.record_trade(dec!(-500.0), 2000);
breaker.trigger_manual(3000);
assert!(!breaker.is_trading_allowed());
assert_eq!(breaker.daily_loss(), dec!(1000.0));
breaker.reset(dec!(10000.0), 100_000);
assert!(breaker.is_trading_allowed());
assert_eq!(breaker.daily_loss(), dec!(0.0));
assert_eq!(breaker.consecutive_losses(), 0);
}
#[test]
fn test_no_recording_when_triggered() {
let config =
CircuitBreakerConfig::new(dec!(100.0), dec!(0.05), 10, dec!(0.10), 300_000, 60_000)
.unwrap();
let mut breaker = CircuitBreaker::new(config);
breaker.record_trade(dec!(-100.0), 1000); assert!(!breaker.is_trading_allowed());
let loss_before = breaker.daily_loss();
breaker.record_trade(dec!(-100.0), 2000);
assert_eq!(breaker.daily_loss(), loss_before);
}
#[test]
fn test_state_helper_methods() {
let active = CircuitBreakerState::Active;
assert!(active.is_active());
assert!(!active.is_triggered());
assert!(!active.is_cooldown());
let triggered = CircuitBreakerState::Triggered {
reason: TriggerReason::Manual,
triggered_at: 0,
};
assert!(!triggered.is_active());
assert!(triggered.is_triggered());
assert!(!triggered.is_cooldown());
let cooldown = CircuitBreakerState::Cooldown { resume_at: 1000 };
assert!(!cooldown.is_active());
assert!(!cooldown.is_triggered());
assert!(cooldown.is_cooldown());
}
#[test]
fn test_trigger_reason_display() {
assert_eq!(
TriggerReason::MaxDailyLoss.to_string(),
"max daily loss exceeded"
);
assert_eq!(
TriggerReason::VolatilitySpike.to_string(),
"volatility spike detected"
);
assert_eq!(
TriggerReason::ConsecutiveLosses.to_string(),
"consecutive losses limit reached"
);
assert_eq!(
TriggerReason::RapidDrawdown.to_string(),
"rapid drawdown detected"
);
assert_eq!(TriggerReason::Manual.to_string(), "manually triggered");
}
#[test]
fn test_equity_history_pruning() {
let config = CircuitBreakerConfig::new(
dec!(10000.0),
dec!(1.0),
100,
dec!(0.50),
100,
60_000, )
.unwrap();
let mut breaker = CircuitBreaker::with_initial_equity(config, dec!(1000.0), 0);
breaker.record_trade(dec!(-10.0), 50);
breaker.record_trade(dec!(-10.0), 100);
breaker.record_trade(dec!(-10.0), 200);
assert!(breaker.equity_history.len() <= 3);
}
#[cfg(feature = "serde")]
#[test]
fn test_serialization() {
let config = default_config();
let json = serde_json::to_string(&config).unwrap();
let deserialized: CircuitBreakerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config, deserialized);
let state = CircuitBreakerState::Triggered {
reason: TriggerReason::MaxDailyLoss,
triggered_at: 1000,
};
let json = serde_json::to_string(&state).unwrap();
let deserialized: CircuitBreakerState = serde_json::from_str(&json).unwrap();
assert_eq!(state, deserialized);
}
}