use rust_decimal::Decimal;
use rust_decimal::prelude::FromPrimitive;
use crate::config::{GeneratorConfig, TrendDirection};
use crate::regimes::MarketRegime;
use std::collections::VecDeque;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RegimeSegment {
pub regime: MarketRegime,
pub duration: usize,
pub config: GeneratorConfig,
pub transition_duration: Option<usize>,
}
impl RegimeSegment {
pub fn new(regime: MarketRegime, duration: usize) -> Self {
let config = Self::default_config_for_regime(regime);
Self {
regime,
duration,
config,
transition_duration: None,
}
}
pub fn with_config(regime: MarketRegime, duration: usize, config: GeneratorConfig) -> Self {
Self {
regime,
duration,
config,
transition_duration: None,
}
}
pub fn with_transition(mut self, transition_duration: usize) -> Self {
self.transition_duration = Some(transition_duration);
self
}
fn default_config_for_regime(regime: MarketRegime) -> GeneratorConfig {
let mut config = GeneratorConfig::default();
match regime {
MarketRegime::Bull => {
config.trend_direction = TrendDirection::Bullish;
config.trend_strength = Decimal::new(5, 3); config.volatility = Decimal::new(15, 3); },
MarketRegime::Bear => {
config.trend_direction = TrendDirection::Bearish;
config.trend_strength = Decimal::new(7, 3); config.volatility = Decimal::new(25, 3); },
MarketRegime::Sideways => {
config.trend_direction = TrendDirection::Sideways;
config.trend_strength = Decimal::ZERO;
config.volatility = Decimal::new(10, 3); },
MarketRegime::Normal { std_dev, bias, .. } => {
config.trend_direction = TrendDirection::Sideways; config.trend_strength = Decimal::try_from(bias.unwrap_or(0.0)).unwrap_or(Decimal::ZERO);
config.volatility = Decimal::try_from(std_dev).unwrap_or(Decimal::new(15, 3));
},
}
config
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RegimeSchedule {
segments: VecDeque<RegimeSegment>,
current_segment_progress: usize,
total_progress: usize,
repeat: bool,
original_segments: Vec<RegimeSegment>,
}
impl RegimeSchedule {
pub fn new(segments: Vec<RegimeSegment>) -> Self {
let original_segments = segments.clone();
Self {
segments: VecDeque::from(segments),
current_segment_progress: 0,
total_progress: 0,
repeat: false,
original_segments,
}
}
pub fn repeating(segments: Vec<RegimeSegment>) -> Self {
let mut schedule = Self::new(segments);
schedule.repeat = true;
schedule
}
pub fn current_segment(&self) -> Option<&RegimeSegment> {
self.segments.front()
}
pub fn advance(&mut self) -> Option<&RegimeSegment> {
self.current_segment_progress += 1;
self.total_progress += 1;
if let Some(current) = self.segments.front() {
if self.current_segment_progress >= current.duration {
self.segments.pop_front();
self.current_segment_progress = 0;
if self.segments.is_empty() && self.repeat {
self.segments = VecDeque::from(self.original_segments.clone());
}
}
}
self.segments.front()
}
pub fn current_segment_progress(&self) -> f64 {
if let Some(current) = self.segments.front() {
if current.duration == 0 {
return 1.0;
}
self.current_segment_progress as f64 / current.duration as f64
} else {
1.0
}
}
pub fn total_progress(&self) -> usize {
self.total_progress
}
pub fn is_complete(&self) -> bool {
self.segments.is_empty() && !self.repeat
}
pub fn reset(&mut self) {
self.segments = VecDeque::from(self.original_segments.clone());
self.current_segment_progress = 0;
self.total_progress = 0;
}
pub fn add_segment(&mut self, segment: RegimeSegment) {
self.segments.push_back(segment.clone());
self.original_segments.push(segment);
}
pub fn remaining_segments(&self) -> Vec<&RegimeSegment> {
self.segments.iter().collect()
}
pub fn total_duration(&self) -> usize {
self.original_segments.iter().map(|s| s.duration).sum()
}
}
#[derive(Debug, Clone)]
pub struct TransitionState {
pub from_config: GeneratorConfig,
pub to_config: GeneratorConfig,
pub progress: f64,
pub duration: usize,
pub current_step: usize,
}
impl TransitionState {
pub fn new(from_config: GeneratorConfig, to_config: GeneratorConfig, duration: usize) -> Self {
Self {
from_config,
to_config,
progress: 0.0,
duration,
current_step: 0,
}
}
pub fn advance(&mut self) -> bool {
if self.current_step < self.duration {
self.current_step += 1;
self.progress = self.current_step as f64 / self.duration as f64;
true
} else {
false
}
}
pub fn interpolated_config(&self) -> GeneratorConfig {
let mut config = self.from_config.clone();
let from_strength = self.from_config.trend_strength;
let to_strength = self.to_config.trend_strength;
config.trend_strength = from_strength + (to_strength - from_strength) * Decimal::from_f64(self.progress).unwrap_or(Decimal::ZERO);
let from_vol = self.from_config.volatility;
let to_vol = self.to_config.volatility;
config.volatility = from_vol + (to_vol - from_vol) * Decimal::from_f64(self.progress).unwrap_or(Decimal::ZERO);
if self.progress >= 0.5 {
config.trend_direction = self.to_config.trend_direction;
}
config
}
pub fn is_complete(&self) -> bool {
self.current_step >= self.duration
}
}
#[derive(Debug)]
pub struct RegimeController {
schedule: RegimeSchedule,
current_config: GeneratorConfig,
transition: Option<TransitionState>,
base_config: GeneratorConfig,
}
impl RegimeController {
pub fn new(schedule: RegimeSchedule, base_config: GeneratorConfig) -> Self {
let current_config = if let Some(segment) = schedule.current_segment() {
Self::merge_configs(&base_config, &segment.config)
} else {
base_config.clone()
};
Self {
schedule,
current_config,
transition: None,
base_config,
}
}
pub fn current_config(&self) -> &GeneratorConfig {
&self.current_config
}
pub fn current_regime(&self) -> Option<MarketRegime> {
self.schedule.current_segment().map(|s| s.regime)
}
pub fn advance(&mut self) -> bool {
if let Some(ref mut transition) = self.transition {
if transition.advance() {
self.current_config = transition.interpolated_config();
if transition.is_complete() {
self.transition = None;
}
}
}
let previous_regime = self.current_regime();
let current_segment = self.schedule.advance();
let new_regime = current_segment.map(|s| s.regime);
if previous_regime != new_regime {
if let Some(segment) = current_segment {
let new_config = Self::merge_configs(&self.base_config, &segment.config);
if let Some(transition_duration) = segment.transition_duration {
if transition_duration > 0 {
self.transition = Some(TransitionState::new(
self.current_config.clone(),
new_config,
transition_duration,
));
} else {
self.current_config = new_config;
}
} else {
self.current_config = new_config;
}
}
}
!self.schedule.is_complete()
}
fn merge_configs(base: &GeneratorConfig, regime_specific: &GeneratorConfig) -> GeneratorConfig {
let mut config = base.clone();
config.trend_direction = regime_specific.trend_direction;
config.trend_strength = regime_specific.trend_strength;
config.volatility = regime_specific.volatility;
config
}
pub fn schedule_info(&self) -> ScheduleInfo {
ScheduleInfo {
current_regime: self.current_regime(),
current_segment_progress: self.schedule.current_segment_progress(),
total_progress: self.schedule.total_progress(),
is_complete: self.schedule.is_complete(),
remaining_segments: self.schedule.remaining_segments().len(),
in_transition: self.transition.is_some(),
}
}
pub fn set_schedule(&mut self, schedule: RegimeSchedule) {
self.schedule = schedule;
if let Some(segment) = self.schedule.current_segment() {
self.current_config = Self::merge_configs(&self.base_config, &segment.config);
}
self.transition = None;
}
pub fn add_segment(&mut self, segment: RegimeSegment) {
self.schedule.add_segment(segment);
}
pub fn reset(&mut self) {
self.schedule.reset();
self.transition = None;
if let Some(segment) = self.schedule.current_segment() {
self.current_config = Self::merge_configs(&self.base_config, &segment.config);
}
}
pub fn force_regime(&mut self, regime: MarketRegime, duration: usize, transition_duration: Option<usize>) {
let segment = RegimeSegment::new(regime, duration)
.with_transition(transition_duration.unwrap_or(0));
let new_schedule = RegimeSchedule::new(vec![segment]);
self.set_schedule(new_schedule);
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ScheduleInfo {
pub current_regime: Option<MarketRegime>,
pub current_segment_progress: f64,
pub total_progress: usize,
pub is_complete: bool,
pub remaining_segments: usize,
pub in_transition: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_regime_segment_creation() {
let segment = RegimeSegment::new(MarketRegime::Bull, 100);
assert_eq!(segment.regime, MarketRegime::Bull);
assert_eq!(segment.duration, 100);
assert_eq!(segment.config.trend_direction, TrendDirection::Bullish);
}
#[test]
fn test_regime_schedule_basic() {
let segments = vec![
RegimeSegment::new(MarketRegime::Bull, 50),
RegimeSegment::new(MarketRegime::Bear, 30),
];
let mut schedule = RegimeSchedule::new(segments);
assert_eq!(schedule.current_segment().unwrap().regime, MarketRegime::Bull);
for _ in 0..49 {
schedule.advance();
}
assert_eq!(schedule.current_segment().unwrap().regime, MarketRegime::Bull);
schedule.advance();
assert_eq!(schedule.current_segment().unwrap().regime, MarketRegime::Bear);
}
#[test]
fn test_regime_schedule_completion() {
let segments = vec![
RegimeSegment::new(MarketRegime::Bull, 2),
];
let mut schedule = RegimeSchedule::new(segments);
assert!(!schedule.is_complete());
schedule.advance(); assert!(!schedule.is_complete());
schedule.advance(); assert!(schedule.is_complete());
}
#[test]
fn test_regime_schedule_repeating() {
let segments = vec![
RegimeSegment::new(MarketRegime::Bull, 1),
RegimeSegment::new(MarketRegime::Bear, 1),
];
let mut schedule = RegimeSchedule::repeating(segments);
assert_eq!(schedule.current_segment().unwrap().regime, MarketRegime::Bull);
schedule.advance(); assert_eq!(schedule.current_segment().unwrap().regime, MarketRegime::Bear);
schedule.advance(); assert_eq!(schedule.current_segment().unwrap().regime, MarketRegime::Bull);
assert!(!schedule.is_complete()); }
#[test]
fn test_transition_state() {
let from_config = GeneratorConfig::default();
let mut to_config = GeneratorConfig::default();
to_config.volatility = Decimal::new(2, 1);
let mut transition = TransitionState::new(from_config, to_config, 4);
assert_eq!(transition.progress, 0.0);
transition.advance();
assert_eq!(transition.progress, 0.25);
let interpolated = transition.interpolated_config();
let expected = Decimal::new(2, 2) + (Decimal::new(2, 1) - Decimal::new(2, 2)) * Decimal::new(25, 2);
assert_eq!(interpolated.volatility, expected);
}
#[test]
fn test_regime_controller_basic() {
let segments = vec![
RegimeSegment::new(MarketRegime::Bull, 3),
RegimeSegment::new(MarketRegime::Bear, 2),
];
let schedule = RegimeSchedule::new(segments);
let base_config = GeneratorConfig::default();
let mut controller = RegimeController::new(schedule, base_config);
assert_eq!(controller.current_regime(), Some(MarketRegime::Bull));
assert_eq!(controller.current_config().trend_direction, TrendDirection::Bullish);
for _ in 0..3 {
controller.advance();
}
assert_eq!(controller.current_regime(), Some(MarketRegime::Bear));
assert_eq!(controller.current_config().trend_direction, TrendDirection::Bearish);
}
}