use rust_decimal::Decimal;
use std::collections::VecDeque;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub mod controller;
pub mod detector;
pub mod statistics;
pub mod volatility;
pub use controller::{RegimeController, RegimeSchedule, RegimeSegment, ScheduleInfo, TransitionState};
pub use detector::RegimeDetector;
pub use statistics::RollingStatistics;
pub use volatility::VolatilityRegimeDetector;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum MarketRegime {
Bull,
Bear,
Sideways,
Normal {
mean: f64,
std_dev: f64,
bias: Option<f64>,
}
}
impl MarketRegime {
pub fn description(&self) -> &'static str {
match self {
MarketRegime::Bull => "Bull market - upward trend",
MarketRegime::Bear => "Bear market - downward trend",
MarketRegime::Sideways => "Sideways market - ranging",
MarketRegime::Normal { .. } => "Normal market - volatility-based regime",
}
}
pub fn volatility_factor(&self) -> Decimal {
match self {
MarketRegime::Bull => Decimal::new(8, 1), MarketRegime::Bear => Decimal::new(12, 1), MarketRegime::Sideways => Decimal::ONE, MarketRegime::Normal { std_dev, .. } => {
Decimal::try_from(*std_dev).unwrap_or(Decimal::ONE)
}
}
}
pub fn drift_factor(&self) -> Decimal {
match self {
MarketRegime::Bull => Decimal::new(5, 3), MarketRegime::Bear => Decimal::new(-5, 3), MarketRegime::Sideways => Decimal::ZERO, MarketRegime::Normal { bias, .. } => {
let drift = bias.unwrap_or(0.0);
Decimal::try_from(drift).unwrap_or(Decimal::ZERO)
}
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RegimeState {
pub current_regime: MarketRegime,
pub confidence: Decimal,
pub duration: usize,
pub start_timestamp: i64,
pub start_price: Decimal,
}
impl RegimeState {
pub fn new(regime: MarketRegime, confidence: Decimal, timestamp: i64, price: Decimal) -> Self {
Self {
current_regime: regime,
confidence,
duration: 1,
start_timestamp: timestamp,
start_price: price,
}
}
pub fn increment_duration(&mut self) {
self.duration += 1;
}
pub fn should_transition(&self, new_regime: MarketRegime, new_confidence: Decimal) -> bool {
new_regime != self.current_regime && new_confidence > Decimal::new(6, 1) || self.confidence < Decimal::new(3, 1) }
pub fn transition(&mut self, new_regime: MarketRegime, confidence: Decimal, timestamp: i64, price: Decimal) {
self.current_regime = new_regime;
self.confidence = confidence;
self.duration = 1;
self.start_timestamp = timestamp;
self.start_price = price;
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RegimeConfig {
pub lookback_period: usize,
pub bull_threshold: Decimal,
pub bear_threshold: Decimal,
pub min_confidence: Decimal,
pub use_volatility: bool,
pub window_size: usize,
}
impl Default for RegimeConfig {
fn default() -> Self {
Self {
lookback_period: 20,
bull_threshold: Decimal::new(2, 2), bear_threshold: Decimal::new(-2, 2), min_confidence: Decimal::new(6, 1), use_volatility: true,
window_size: 20,
}
}
}
#[derive(Debug, Clone)]
pub struct RegimeTracker {
pub history: VecDeque<RegimeState>,
max_history: usize,
pub transitions: usize,
}
impl RegimeTracker {
pub fn new(max_history: usize) -> Self {
Self {
history: VecDeque::with_capacity(max_history),
max_history,
transitions: 0,
}
}
pub fn record(&mut self, state: RegimeState) {
if let Some(last) = self.history.back() {
if last.current_regime != state.current_regime {
self.transitions += 1;
}
}
self.history.push_back(state);
while self.history.len() > self.max_history {
self.history.pop_front();
}
}
pub fn current(&self) -> Option<&RegimeState> {
self.history.back()
}
pub fn average_duration(&self) -> Decimal {
if self.history.is_empty() {
return Decimal::ZERO;
}
let total_duration: usize = self.history.iter().map(|s| s.duration).sum();
Decimal::from(total_duration) / Decimal::from(self.history.len())
}
pub fn regime_distribution(&self) -> (Decimal, Decimal, Decimal) {
if self.history.is_empty() {
return (Decimal::ZERO, Decimal::ZERO, Decimal::ZERO);
}
let mut bull_periods = 0;
let mut bear_periods = 0;
let mut sideways_periods = 0;
for state in &self.history {
match state.current_regime {
MarketRegime::Bull => bull_periods += state.duration,
MarketRegime::Bear => bear_periods += state.duration,
MarketRegime::Sideways => sideways_periods += state.duration,
MarketRegime::Normal { .. } => sideways_periods += state.duration, }
}
let total = bull_periods + bear_periods + sideways_periods;
if total == 0 {
return (Decimal::ZERO, Decimal::ZERO, Decimal::ZERO);
}
let total_dec = Decimal::from(total);
(
Decimal::from(bull_periods) / total_dec,
Decimal::from(bear_periods) / total_dec,
Decimal::from(sideways_periods) / total_dec,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_market_regime_properties() {
assert_eq!(MarketRegime::Bull.volatility_factor(), Decimal::new(8, 1));
assert_eq!(MarketRegime::Bear.volatility_factor(), Decimal::new(12, 1));
assert_eq!(MarketRegime::Sideways.volatility_factor(), Decimal::ONE);
assert_eq!(MarketRegime::Bull.drift_factor(), Decimal::new(5, 3));
assert_eq!(MarketRegime::Bear.drift_factor(), Decimal::new(-5, 3));
assert_eq!(MarketRegime::Sideways.drift_factor(), Decimal::ZERO);
}
#[test]
fn test_regime_state() {
let mut state = RegimeState::new(
MarketRegime::Bull,
Decimal::new(8, 1),
1000,
Decimal::new(100, 0),
);
assert_eq!(state.current_regime, MarketRegime::Bull);
assert_eq!(state.duration, 1);
state.increment_duration();
assert_eq!(state.duration, 2);
assert!(state.should_transition(MarketRegime::Bear, Decimal::new(7, 1)));
assert!(!state.should_transition(MarketRegime::Bear, Decimal::new(5, 1)));
}
#[test]
fn test_regime_tracker() {
let mut tracker = RegimeTracker::new(10);
let state1 = RegimeState::new(MarketRegime::Bull, Decimal::new(8, 1), 1000, Decimal::new(100, 0));
let state2 = RegimeState::new(MarketRegime::Bear, Decimal::new(7, 1), 2000, Decimal::new(95, 0));
tracker.record(state1.clone());
assert_eq!(tracker.transitions, 0);
tracker.record(state2);
assert_eq!(tracker.transitions, 1);
tracker.record(state1);
assert_eq!(tracker.transitions, 2);
}
#[test]
fn test_regime_config_default() {
let config = RegimeConfig::default();
assert_eq!(config.lookback_period, 20);
assert_eq!(config.bull_threshold, Decimal::new(2, 2));
assert_eq!(config.bear_threshold, Decimal::new(-2, 2));
assert_eq!(config.min_confidence, Decimal::new(6, 1));
assert!(config.use_volatility);
}
}