1use super::primitives::{ADX, ATR, BollingerBands, BollingerBandsValues, EMA};
11use super::types::{
12 MarketRegime, RecommendedStrategy, RegimeConfidence, RegimeConfig, TrendDirection,
13};
14use std::collections::VecDeque;
15
16#[derive(Debug)]
40pub struct RegimeDetector {
41 config: RegimeConfig,
42
43 adx: ADX,
45 atr: ATR,
46 atr_avg: EMA, bb: BollingerBands,
48 ema_short: EMA,
49 ema_long: EMA,
50
51 current_regime: MarketRegime,
53 regime_history: VecDeque<MarketRegime>,
54 bars_in_regime: usize,
55
56 last_close: Option<f64>,
58}
59
60impl RegimeDetector {
61 pub fn new(config: RegimeConfig) -> Self {
63 Self {
64 adx: ADX::new(config.adx_period),
65 atr: ATR::new(config.atr_period),
66 atr_avg: EMA::new(50), bb: BollingerBands::new(config.bb_period, config.bb_std_dev),
68 ema_short: EMA::new(config.ema_short_period),
69 ema_long: EMA::new(config.ema_long_period),
70 current_regime: MarketRegime::Uncertain,
71 regime_history: VecDeque::with_capacity(20),
72 bars_in_regime: 0,
73 last_close: None,
74 config,
75 }
76 }
77
78 pub fn default_config() -> Self {
80 Self::new(RegimeConfig::default())
81 }
82
83 pub fn crypto_optimized() -> Self {
85 Self::new(RegimeConfig::crypto_optimized())
86 }
87
88 pub fn conservative() -> Self {
90 Self::new(RegimeConfig::conservative())
91 }
92
93 pub fn update(&mut self, high: f64, low: f64, close: f64) -> RegimeConfidence {
102 let adx_value = self.adx.update(high, low, close);
104 let atr_value = self.atr.update(high, low, close);
105 let bb_values = self.bb.update(close);
106 let ema_short = self.ema_short.update(close);
107 let ema_long = self.ema_long.update(close);
108
109 if let Some(atr) = atr_value {
111 self.atr_avg.update(atr);
112 }
113
114 self.last_close = Some(close);
115
116 if !self.is_ready() {
118 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
119 }
120
121 let (new_regime, confidence) = self.classify_regime(
123 adx_value.unwrap(),
124 atr_value.unwrap(),
125 bb_values.as_ref().unwrap(),
126 ema_short.unwrap(),
127 ema_long.unwrap(),
128 close,
129 );
130
131 let stable_regime = self.apply_stability_filter(new_regime, confidence);
133
134 if stable_regime != self.current_regime {
136 self.regime_history.push_back(self.current_regime);
137 if self.regime_history.len() > 20 {
138 self.regime_history.pop_front();
139 }
140 self.current_regime = stable_regime;
141 self.bars_in_regime = 0;
142 } else {
143 self.bars_in_regime += 1;
144 }
145
146 RegimeConfidence::with_metrics(
147 stable_regime,
148 confidence,
149 adx_value.unwrap(),
150 bb_values.as_ref().map_or(50.0, |b| b.width_percentile),
151 Self::calculate_trend_strength(ema_short.unwrap(), ema_long.unwrap(), close),
152 )
153 }
154
155 fn classify_regime(
164 &self,
165 adx: f64,
166 atr: f64,
167 bb: &BollingerBandsValues,
168 ema_short: f64,
169 ema_long: f64,
170 close: f64,
171 ) -> (MarketRegime, f64) {
172 let atr_expansion = if let Some(avg_atr) = self.atr_avg.value() {
174 atr / avg_atr
175 } else {
176 1.0
177 };
178
179 let mut trending_score: f64 = 0.0;
181 let mut ranging_score: f64 = 0.0;
182 let mut volatile_score: f64 = 0.0;
183
184 if adx >= self.config.adx_trending_threshold {
186 trending_score += 0.4;
187 } else if adx <= self.config.adx_ranging_threshold {
188 ranging_score += 0.3;
189 }
190
191 if bb.is_high_volatility(self.config.bb_width_volatility_threshold) {
193 volatile_score += 0.3;
194 }
195 if bb.is_squeeze(25.0) {
196 ranging_score += 0.2; }
198
199 if atr_expansion >= self.config.atr_expansion_threshold {
201 volatile_score += 0.3;
202 } else if atr_expansion < 0.8 {
203 ranging_score += 0.2; }
205
206 let ema_diff_pct = ((ema_short - ema_long) / ema_long).abs() * 100.0;
208 if ema_diff_pct > 2.0 {
209 trending_score += 0.3;
210 } else if ema_diff_pct < 1.0 {
211 ranging_score += 0.2;
212 }
213
214 let price_above_both = close > ema_short && close > ema_long;
216 let price_below_both = close < ema_short && close < ema_long;
217 if price_above_both || price_below_both {
218 trending_score += 0.2;
219 } else {
220 ranging_score += 0.2; }
222
223 let max_score = trending_score.max(ranging_score).max(volatile_score);
225 let confidence = max_score / 1.2; let regime = if volatile_score >= 0.5 && volatile_score >= trending_score {
228 MarketRegime::Volatile
229 } else if trending_score > ranging_score && trending_score > 0.3 {
230 let direction = if ema_short > ema_long && close > ema_long {
232 TrendDirection::Bullish
233 } else if ema_short < ema_long && close < ema_long {
234 TrendDirection::Bearish
235 } else if let Some(dir) = self.adx.trend_direction() {
236 dir
237 } else {
238 TrendDirection::Bullish };
240 MarketRegime::Trending(direction)
241 } else if ranging_score > 0.3 {
242 MarketRegime::MeanReverting
243 } else {
244 MarketRegime::Uncertain
245 };
246
247 (regime, confidence.min(1.0))
248 }
249
250 fn apply_stability_filter(&self, new_regime: MarketRegime, confidence: f64) -> MarketRegime {
257 if confidence < 0.4 {
259 return self.current_regime;
260 }
261
262 if self.bars_in_regime < self.config.min_regime_duration
264 && new_regime != self.current_regime
265 {
266 if confidence < 0.7 {
268 return self.current_regime;
269 }
270 }
271
272 let recent_count = self
274 .regime_history
275 .iter()
276 .rev()
277 .take(self.config.regime_stability_bars)
278 .filter(|&&r| {
279 matches!(
280 (&r, &new_regime),
281 (MarketRegime::Trending(_), MarketRegime::Trending(_))
282 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
283 | (MarketRegime::Volatile, MarketRegime::Volatile)
284 )
285 })
286 .count();
287
288 if recent_count < self.config.regime_stability_bars / 2 && confidence < 0.6 {
290 return self.current_regime;
291 }
292
293 new_regime
294 }
295
296 fn calculate_trend_strength(ema_short: f64, ema_long: f64, close: f64) -> f64 {
298 let ema_alignment = (ema_short - ema_long).abs() / ema_long * 100.0;
299 let price_position = if close > ema_short && close > ema_long {
300 1.0
301 } else if close < ema_short && close < ema_long {
302 0.7
303 } else {
304 0.5
305 };
306
307 (ema_alignment * price_position / 5.0).min(1.0) }
309
310 pub fn is_ready(&self) -> bool {
319 self.adx.is_ready()
320 && self.atr.is_ready()
321 && self.bb.is_ready()
322 && self.ema_short.is_ready()
323 && self.ema_long.is_ready()
324 }
325
326 pub fn current_regime(&self) -> MarketRegime {
328 self.current_regime
329 }
330
331 pub fn recommended_strategy(&self) -> RecommendedStrategy {
333 RecommendedStrategy::from(&self.current_regime)
334 }
335
336 pub fn bars_in_current_regime(&self) -> usize {
338 self.bars_in_regime
339 }
340
341 pub fn adx_value(&self) -> Option<f64> {
343 self.adx.value()
344 }
345
346 pub fn atr_value(&self) -> Option<f64> {
348 self.atr.value()
349 }
350
351 pub fn config(&self) -> &RegimeConfig {
353 &self.config
354 }
355
356 pub fn set_config(&mut self, config: RegimeConfig) {
358 *self = Self::new(config);
359 }
360
361 pub fn regime_history(&self) -> &VecDeque<MarketRegime> {
363 &self.regime_history
364 }
365
366 pub fn last_close(&self) -> Option<f64> {
368 self.last_close
369 }
370}
371
372#[cfg(test)]
377mod tests {
378 use super::*;
379
380 fn generate_trending_data(
382 bars: usize,
383 start_price: f64,
384 trend_strength: f64,
385 ) -> Vec<(f64, f64, f64)> {
386 let mut data = Vec::new();
387 let mut price = start_price;
388
389 for _ in 0..bars {
390 let change = trend_strength * (1.0 + (rand::random::<f64>() - 0.5) * 0.2);
391 price += change;
392
393 let high = price + price * 0.005;
394 let low = price - price * 0.005;
395 let close = price;
396
397 data.push((high, low, close));
398 }
399
400 data
401 }
402
403 fn generate_ranging_data(
405 bars: usize,
406 center_price: f64,
407 range_pct: f64,
408 ) -> Vec<(f64, f64, f64)> {
409 let mut data = Vec::new();
410
411 for i in 0..bars {
412 let offset = (i as f64 * 0.5).sin() * center_price * range_pct / 100.0;
413 let price = center_price + offset;
414
415 let high = price + price * 0.002;
416 let low = price - price * 0.002;
417 let close = price;
418
419 data.push((high, low, close));
420 }
421
422 data
423 }
424
425 #[allow(dead_code)]
427 fn generate_volatile_data(bars: usize, center_price: f64) -> Vec<(f64, f64, f64)> {
428 let mut data = Vec::new();
429
430 for i in 0..bars {
431 let swing = if i % 2 == 0 { 1.05 } else { 0.95 };
432 let price = center_price * swing;
433
434 let high = price * 1.03;
435 let low = price * 0.97;
436 let close = price;
437
438 data.push((high, low, close));
439 }
440
441 data
442 }
443
444 #[test]
445 fn test_detector_creation() {
446 let detector = RegimeDetector::default_config();
447 assert!(!detector.is_ready());
448 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
449 assert_eq!(detector.bars_in_current_regime(), 0);
450 }
451
452 #[test]
453 fn test_crypto_optimized_creation() {
454 let detector = RegimeDetector::crypto_optimized();
455 assert!(!detector.is_ready());
456 assert_eq!(detector.config().adx_trending_threshold, 20.0);
457 assert_eq!(detector.config().ema_short_period, 21);
458 }
459
460 #[test]
461 fn test_conservative_creation() {
462 let detector = RegimeDetector::conservative();
463 assert_eq!(detector.config().adx_trending_threshold, 30.0);
464 assert_eq!(detector.config().min_regime_duration, 10);
465 }
466
467 #[test]
468 fn test_warmup_returns_uncertain() {
469 let mut detector = RegimeDetector::default_config();
470
471 for i in 0..10 {
473 let price = 100.0 + i as f64;
474 let result = detector.update(price + 1.0, price - 1.0, price);
475 assert_eq!(result.regime, MarketRegime::Uncertain);
476 assert_eq!(result.confidence, 0.0);
477 }
478
479 assert!(!detector.is_ready());
480 }
481
482 #[test]
483 fn test_trending_detection() {
484 let mut detector = RegimeDetector::default_config();
485
486 let data = generate_trending_data(300, 100.0, 0.5);
488
489 let mut last_regime = MarketRegime::Uncertain;
490 for (high, low, close) in data {
491 let result = detector.update(high, low, close);
492 if detector.is_ready() {
493 last_regime = result.regime;
494 }
495 }
496
497 assert!(
498 matches!(last_regime, MarketRegime::Trending(_)),
499 "Expected Trending regime, got: {last_regime:?}"
500 );
501 }
502
503 #[test]
504 fn test_trending_bullish_direction() {
505 let mut detector = RegimeDetector::default_config();
506
507 let data = generate_trending_data(300, 100.0, 0.5);
509
510 let mut last_regime = MarketRegime::Uncertain;
511 for (high, low, close) in data {
512 let result = detector.update(high, low, close);
513 if detector.is_ready() {
514 last_regime = result.regime;
515 }
516 }
517
518 assert!(
519 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
520 "Expected Bullish trend, got: {last_regime:?}"
521 );
522 }
523
524 #[test]
525 fn test_trending_bearish_direction() {
526 let mut detector = RegimeDetector::default_config();
527
528 let data = generate_trending_data(300, 200.0, -0.5);
530
531 let mut last_regime = MarketRegime::Uncertain;
532 for (high, low, close) in data {
533 let result = detector.update(high, low, close);
534 if detector.is_ready() {
535 last_regime = result.regime;
536 }
537 }
538
539 if matches!(last_regime, MarketRegime::Trending(_)) {
541 assert!(
542 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bearish)),
543 "Expected Bearish trend, got: {last_regime:?}"
544 );
545 }
546 }
547
548 #[test]
549 fn test_ranging_detection() {
550 let mut detector = RegimeDetector::default_config();
551
552 let data = generate_ranging_data(300, 100.0, 2.0);
554
555 let mut last_regime = MarketRegime::Uncertain;
556 for (high, low, close) in data {
557 let result = detector.update(high, low, close);
558 if detector.is_ready() {
559 last_regime = result.regime;
560 }
561 }
562
563 assert!(
565 !matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
566 "Ranging data shouldn't produce strong bullish trend, got: {last_regime:?}"
567 );
568 }
569
570 #[test]
571 fn test_confidence_range() {
572 let mut detector = RegimeDetector::default_config();
573
574 let data = generate_trending_data(300, 100.0, 0.5);
575
576 for (high, low, close) in data {
577 let result = detector.update(high, low, close);
578 assert!(
579 (0.0..=1.0).contains(&result.confidence),
580 "Confidence should be in [0, 1]: {}",
581 result.confidence
582 );
583 }
584 }
585
586 #[test]
587 fn test_regime_history_tracking() {
588 let mut detector = RegimeDetector::default_config();
589
590 let trend_data = generate_trending_data(200, 100.0, 0.5);
592 for (high, low, close) in trend_data {
593 detector.update(high, low, close);
594 }
595
596 let range_data = generate_ranging_data(200, 200.0, 1.0);
597 for (high, low, close) in range_data {
598 detector.update(high, low, close);
599 }
600
601 assert!(
604 detector.regime_history().len() <= 20,
605 "History should be bounded"
606 );
607 }
608
609 #[test]
610 fn test_recommended_strategy() {
611 let mut detector = RegimeDetector::default_config();
612
613 let data = generate_trending_data(300, 100.0, 0.5);
615 for (high, low, close) in data {
616 detector.update(high, low, close);
617 }
618
619 if matches!(detector.current_regime(), MarketRegime::Trending(_)) {
620 assert_eq!(
621 detector.recommended_strategy(),
622 RecommendedStrategy::TrendFollowing
623 );
624 }
625 }
626
627 #[test]
628 fn test_adx_atr_accessors() {
629 let mut detector = RegimeDetector::default_config();
630
631 assert!(detector.adx_value().is_none());
633 assert!(detector.atr_value().is_none());
634
635 let data = generate_trending_data(300, 100.0, 0.5);
637 for (high, low, close) in data {
638 detector.update(high, low, close);
639 }
640
641 assert!(detector.adx_value().is_some());
643 assert!(detector.atr_value().is_some());
644 }
645
646 #[test]
647 fn test_set_config_resets_state() {
648 let mut detector = RegimeDetector::default_config();
649
650 let data = generate_trending_data(300, 100.0, 0.5);
652 for (high, low, close) in data {
653 detector.update(high, low, close);
654 }
655 assert!(detector.is_ready());
656
657 detector.set_config(RegimeConfig::crypto_optimized());
659 assert!(!detector.is_ready());
660 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
661 assert_eq!(detector.bars_in_current_regime(), 0);
662 }
663
664 #[test]
665 fn test_last_close_tracking() {
666 let mut detector = RegimeDetector::default_config();
667
668 assert!(detector.last_close().is_none());
669
670 detector.update(101.0, 99.0, 100.0);
671 assert_eq!(detector.last_close(), Some(100.0));
672
673 detector.update(106.0, 104.0, 105.0);
674 assert_eq!(detector.last_close(), Some(105.0));
675 }
676
677 #[test]
678 fn test_bars_in_regime_increments() {
679 let mut detector = RegimeDetector::default_config();
680
681 for i in 0..300 {
683 let price = 100.0 + i as f64 * 0.3;
684 detector.update(price + 1.0, price - 1.0, price);
685 }
686
687 assert!(
689 detector.bars_in_current_regime() > 0,
690 "Should have been in current regime for multiple bars (regime: {:?})",
691 detector.current_regime()
692 );
693 }
694
695 #[test]
696 fn test_stability_filter_prevents_whipsaw() {
697 let mut detector = RegimeDetector::new(RegimeConfig {
698 min_regime_duration: 10,
699 regime_stability_bars: 5,
700 ..RegimeConfig::default()
701 });
702
703 let trend_data = generate_trending_data(300, 100.0, 0.5);
705 for (high, low, close) in trend_data {
706 detector.update(high, low, close);
707 }
708
709 let regime_before = detector.current_regime();
710
711 for (high, low, close) in generate_ranging_data(3, 250.0, 1.0) {
713 detector.update(high, low, close);
714 }
715
716 let regime_after = detector.current_regime();
717
718 assert!(
722 matches!(
723 regime_after,
724 MarketRegime::Trending(_)
725 | MarketRegime::MeanReverting
726 | MarketRegime::Volatile
727 | MarketRegime::Uncertain
728 ),
729 "Regime should be a valid variant: {regime_after:?}"
730 );
731 let _ = regime_before; }
733
734 #[test]
735 fn test_metrics_populated_after_warmup() {
736 let mut detector = RegimeDetector::default_config();
737
738 let data = generate_trending_data(300, 100.0, 0.5);
739 let mut last_result = RegimeConfidence::default();
740
741 for (high, low, close) in data {
742 last_result = detector.update(high, low, close);
743 }
744
745 assert!(last_result.adx_value > 0.0, "ADX should be > 0");
747 assert!(
748 last_result.bb_width_percentile >= 0.0 && last_result.bb_width_percentile <= 100.0,
749 "BB width percentile should be in [0, 100]"
750 );
751 assert!(
752 last_result.trend_strength >= 0.0 && last_result.trend_strength <= 1.0,
753 "Trend strength should be in [0, 1]"
754 );
755 }
756}