1use std::collections::{HashMap, VecDeque};
11
12use super::primitives::{ADX, ATR, BollingerBands, BollingerBandsValues, EMA};
13use super::types::{
14 MarketRegime, RecommendedStrategy, RegimeConfidence, RegimeConfig, TrendDirection,
15};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
30pub struct DetectorIndicator {
31 pub config: RegimeConfig,
32}
33
34impl DetectorIndicator {
35 pub fn new(config: RegimeConfig) -> Self {
36 Self { config }
37 }
38 pub fn with_defaults() -> Self {
39 Self::new(RegimeConfig::default())
40 }
41}
42
43fn regime_id(r: MarketRegime) -> f64 {
44 use super::types::TrendDirection;
45 match r {
46 MarketRegime::MeanReverting => 1.0,
47 MarketRegime::Volatile => 2.0,
48 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
49 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
50 MarketRegime::Uncertain => 0.0,
51 }
52}
53
54impl Indicator for DetectorIndicator {
55 fn name(&self) -> &'static str {
56 "RegimeDetector"
57 }
58 fn required_len(&self) -> usize {
59 let adx_warmup = self.config.adx_period * 2 + self.config.regime_stability_bars;
62 let ema_warmup = self.config.ema_long_period;
63 let bb_warmup = self.config.bb_period;
64 adx_warmup.max(ema_warmup).max(bb_warmup)
65 }
66 fn required_columns(&self) -> &[&'static str] {
67 &["high", "low", "close"]
68 }
69
70 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
71 self.check_len(candles)?;
72 let mut det = RegimeDetector::new(self.config.clone());
73 let n = candles.len();
74 let mut conf = vec![f64::NAN; n];
75 let mut regime = vec![f64::NAN; n];
76 for (i, c) in candles.iter().enumerate() {
77 let rc = det.update(c.high, c.low, c.close);
78 conf[i] = rc.confidence;
79 regime[i] = regime_id(rc.regime);
80 }
81 Ok(IndicatorOutput::from_pairs([
82 ("regime_conf", conf),
83 ("regime_id", regime),
84 ]))
85 }
86}
87
88pub fn factory<S: ::std::hash::BuildHasher>(
91 params: &HashMap<String, String, S>,
92) -> Result<Box<dyn Indicator>, IndicatorError> {
93 let adx_period = param_usize(params, "adx_period", 14)?;
94 let bb_period = param_usize(params, "bb_period", 20)?;
95 let config = RegimeConfig {
96 adx_period,
97 bb_period,
98 ..RegimeConfig::default()
99 };
100 Ok(Box::new(DetectorIndicator::new(config)))
101}
102
103#[derive(Debug)]
127pub struct RegimeDetector {
128 config: RegimeConfig,
129
130 adx: ADX,
132 atr: ATR,
133 atr_avg: EMA, bb: BollingerBands,
135 ema_short: EMA,
136 ema_long: EMA,
137
138 current_regime: MarketRegime,
140 regime_history: VecDeque<MarketRegime>,
141 bars_in_regime: usize,
142
143 last_close: Option<f64>,
145}
146
147impl RegimeDetector {
148 pub fn new(config: RegimeConfig) -> Self {
150 Self {
151 adx: ADX::new(config.adx_period),
152 atr: ATR::new(config.atr_period),
153 atr_avg: EMA::new(50), bb: BollingerBands::new(config.bb_period, config.bb_std_dev),
155 ema_short: EMA::new(config.ema_short_period),
156 ema_long: EMA::new(config.ema_long_period),
157 current_regime: MarketRegime::Uncertain,
158 regime_history: VecDeque::with_capacity(20),
159 bars_in_regime: 0,
160 last_close: None,
161 config,
162 }
163 }
164
165 pub fn default_config() -> Self {
167 Self::new(RegimeConfig::default())
168 }
169
170 pub fn crypto_optimized() -> Self {
172 Self::new(RegimeConfig::crypto_optimized())
173 }
174
175 pub fn conservative() -> Self {
177 Self::new(RegimeConfig::conservative())
178 }
179
180 pub fn update(&mut self, high: f64, low: f64, close: f64) -> RegimeConfidence {
189 let adx_value = self.adx.update(high, low, close);
191 let atr_value = self.atr.update(high, low, close);
192 let bb_values = self.bb.update(close);
193 let ema_short = self.ema_short.update(close);
194 let ema_long = self.ema_long.update(close);
195
196 if let Some(atr) = atr_value {
198 self.atr_avg.update(atr);
199 }
200
201 self.last_close = Some(close);
202
203 if !self.is_ready() {
205 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
206 }
207
208 let (new_regime, confidence) = self.classify_regime(
210 adx_value.unwrap(),
211 atr_value.unwrap(),
212 bb_values.as_ref().unwrap(),
213 ema_short.unwrap(),
214 ema_long.unwrap(),
215 close,
216 );
217
218 let stable_regime = self.apply_stability_filter(new_regime, confidence);
220
221 if stable_regime != self.current_regime {
223 self.regime_history.push_back(self.current_regime);
224 if self.regime_history.len() > 20 {
225 self.regime_history.pop_front();
226 }
227 self.current_regime = stable_regime;
228 self.bars_in_regime = 0;
229 } else {
230 self.bars_in_regime += 1;
231 }
232
233 RegimeConfidence::with_metrics(
234 stable_regime,
235 confidence,
236 adx_value.unwrap(),
237 bb_values.as_ref().map_or(50.0, |b| b.width_percentile),
238 Self::calculate_trend_strength(ema_short.unwrap(), ema_long.unwrap(), close),
239 )
240 }
241
242 fn classify_regime(
251 &self,
252 adx: f64,
253 atr: f64,
254 bb: &BollingerBandsValues,
255 ema_short: f64,
256 ema_long: f64,
257 close: f64,
258 ) -> (MarketRegime, f64) {
259 let atr_expansion = if let Some(avg_atr) = self.atr_avg.value() {
261 atr / avg_atr
262 } else {
263 1.0
264 };
265
266 let mut trending_score: f64 = 0.0;
268 let mut ranging_score: f64 = 0.0;
269 let mut volatile_score: f64 = 0.0;
270
271 if adx >= self.config.adx_trending_threshold {
273 trending_score += 0.4;
274 } else if adx <= self.config.adx_ranging_threshold {
275 ranging_score += 0.3;
276 }
277
278 if bb.is_high_volatility(self.config.bb_width_volatility_threshold) {
280 volatile_score += 0.3;
281 }
282 if bb.is_squeeze(25.0) {
283 ranging_score += 0.2; }
285
286 if atr_expansion >= self.config.atr_expansion_threshold {
288 volatile_score += 0.3;
289 } else if atr_expansion < 0.8 {
290 ranging_score += 0.2; }
292
293 let ema_diff_pct = ((ema_short - ema_long) / ema_long).abs() * 100.0;
295 if ema_diff_pct > 2.0 {
296 trending_score += 0.3;
297 } else if ema_diff_pct < 1.0 {
298 ranging_score += 0.2;
299 }
300
301 let price_above_both = close > ema_short && close > ema_long;
303 let price_below_both = close < ema_short && close < ema_long;
304 if price_above_both || price_below_both {
305 trending_score += 0.2;
306 } else {
307 ranging_score += 0.2; }
309
310 let max_score = trending_score.max(ranging_score).max(volatile_score);
312 let confidence = max_score / 1.2; let regime = if volatile_score >= 0.5 && volatile_score >= trending_score {
315 MarketRegime::Volatile
316 } else if trending_score > ranging_score && trending_score > 0.3 {
317 let direction = if ema_short > ema_long && close > ema_long {
319 TrendDirection::Bullish
320 } else if ema_short < ema_long && close < ema_long {
321 TrendDirection::Bearish
322 } else if let Some(dir) = self.adx.trend_direction() {
323 dir
324 } else {
325 TrendDirection::Bullish };
327 MarketRegime::Trending(direction)
328 } else if ranging_score > 0.3 {
329 MarketRegime::MeanReverting
330 } else {
331 MarketRegime::Uncertain
332 };
333
334 (regime, confidence.min(1.0))
335 }
336
337 fn apply_stability_filter(&self, new_regime: MarketRegime, confidence: f64) -> MarketRegime {
344 if confidence < 0.4 {
346 return self.current_regime;
347 }
348
349 if self.bars_in_regime < self.config.min_regime_duration
351 && new_regime != self.current_regime
352 {
353 if confidence < 0.7 {
355 return self.current_regime;
356 }
357 }
358
359 let recent_count = self
361 .regime_history
362 .iter()
363 .rev()
364 .take(self.config.regime_stability_bars)
365 .filter(|&&r| {
366 matches!(
367 (&r, &new_regime),
368 (MarketRegime::Trending(_), MarketRegime::Trending(_))
369 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
370 | (MarketRegime::Volatile, MarketRegime::Volatile)
371 )
372 })
373 .count();
374
375 if recent_count < self.config.regime_stability_bars / 2 && confidence < 0.6 {
377 return self.current_regime;
378 }
379
380 new_regime
381 }
382
383 fn calculate_trend_strength(ema_short: f64, ema_long: f64, close: f64) -> f64 {
385 let ema_alignment = (ema_short - ema_long).abs() / ema_long * 100.0;
386 let price_position = if close > ema_short && close > ema_long {
387 1.0
388 } else if close < ema_short && close < ema_long {
389 0.7
390 } else {
391 0.5
392 };
393
394 (ema_alignment * price_position / 5.0).min(1.0) }
396
397 pub fn is_ready(&self) -> bool {
406 self.adx.is_ready()
407 && self.atr.is_ready()
408 && self.bb.is_ready()
409 && self.ema_short.is_ready()
410 && self.ema_long.is_ready()
411 }
412
413 pub fn current_regime(&self) -> MarketRegime {
415 self.current_regime
416 }
417
418 pub fn recommended_strategy(&self) -> RecommendedStrategy {
420 RecommendedStrategy::from(&self.current_regime)
421 }
422
423 pub fn bars_in_current_regime(&self) -> usize {
425 self.bars_in_regime
426 }
427
428 pub fn adx_value(&self) -> Option<f64> {
430 self.adx.value()
431 }
432
433 pub fn atr_value(&self) -> Option<f64> {
435 self.atr.value()
436 }
437
438 pub fn config(&self) -> &RegimeConfig {
440 &self.config
441 }
442
443 pub fn set_config(&mut self, config: RegimeConfig) {
445 *self = Self::new(config);
446 }
447
448 pub fn regime_history(&self) -> &VecDeque<MarketRegime> {
450 &self.regime_history
451 }
452
453 pub fn last_close(&self) -> Option<f64> {
455 self.last_close
456 }
457}
458
459#[cfg(test)]
464mod tests {
465 use super::*;
466
467 fn generate_trending_data(
469 bars: usize,
470 start_price: f64,
471 trend_strength: f64,
472 ) -> Vec<(f64, f64, f64)> {
473 let mut data = Vec::new();
474 let mut price = start_price;
475
476 for _ in 0..bars {
477 let change = trend_strength * (1.0 + (rand::random::<f64>() - 0.5) * 0.2);
478 price += change;
479
480 let high = price + price * 0.005;
481 let low = price - price * 0.005;
482 let close = price;
483
484 data.push((high, low, close));
485 }
486
487 data
488 }
489
490 fn generate_ranging_data(
492 bars: usize,
493 center_price: f64,
494 range_pct: f64,
495 ) -> Vec<(f64, f64, f64)> {
496 let mut data = Vec::new();
497
498 for i in 0..bars {
499 let offset = (i as f64 * 0.5).sin() * center_price * range_pct / 100.0;
500 let price = center_price + offset;
501
502 let high = price + price * 0.002;
503 let low = price - price * 0.002;
504 let close = price;
505
506 data.push((high, low, close));
507 }
508
509 data
510 }
511
512 fn generate_volatile_data(bars: usize, center_price: f64) -> Vec<(f64, f64, f64)> {
514 let mut data = Vec::new();
515
516 for i in 0..bars {
517 let swing = if i % 2 == 0 { 1.05 } else { 0.95 };
518 let price = center_price * swing;
519
520 let high = price * 1.03;
521 let low = price * 0.97;
522 let close = price;
523
524 data.push((high, low, close));
525 }
526
527 data
528 }
529
530 #[test]
531 fn test_volatile_regime_detection() {
532 let mut detector = RegimeDetector::default_config();
533 for (high, low, close) in generate_volatile_data(200, 100.0) {
535 detector.update(high, low, close);
536 }
537 assert!(detector.is_ready());
539 let regime = detector.current_regime();
540 assert!(
541 matches!(
542 regime,
543 MarketRegime::Volatile
544 | MarketRegime::Trending(_)
545 | MarketRegime::MeanReverting
546 | MarketRegime::Uncertain
547 ),
548 "Expected a valid regime variant, got: {regime:?}"
549 );
550 }
551
552 #[test]
553 fn test_detector_creation() {
554 let detector = RegimeDetector::default_config();
555 assert!(!detector.is_ready());
556 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
557 assert_eq!(detector.bars_in_current_regime(), 0);
558 }
559
560 #[test]
561 fn test_crypto_optimized_creation() {
562 let detector = RegimeDetector::crypto_optimized();
563 assert!(!detector.is_ready());
564 assert_eq!(detector.config().adx_trending_threshold, 20.0);
565 assert_eq!(detector.config().ema_short_period, 21);
566 }
567
568 #[test]
569 fn test_conservative_creation() {
570 let detector = RegimeDetector::conservative();
571 assert_eq!(detector.config().adx_trending_threshold, 30.0);
572 assert_eq!(detector.config().min_regime_duration, 10);
573 }
574
575 #[test]
576 fn test_warmup_returns_uncertain() {
577 let mut detector = RegimeDetector::default_config();
578
579 for i in 0..10 {
581 let price = 100.0 + i as f64;
582 let result = detector.update(price + 1.0, price - 1.0, price);
583 assert_eq!(result.regime, MarketRegime::Uncertain);
584 assert_eq!(result.confidence, 0.0);
585 }
586
587 assert!(!detector.is_ready());
588 }
589
590 #[test]
591 fn test_trending_detection() {
592 let mut detector = RegimeDetector::default_config();
593
594 let data = generate_trending_data(300, 100.0, 0.5);
596
597 let mut last_regime = MarketRegime::Uncertain;
598 for (high, low, close) in data {
599 let result = detector.update(high, low, close);
600 if detector.is_ready() {
601 last_regime = result.regime;
602 }
603 }
604
605 assert!(
606 matches!(last_regime, MarketRegime::Trending(_)),
607 "Expected Trending regime, got: {last_regime:?}"
608 );
609 }
610
611 #[test]
612 fn test_trending_bullish_direction() {
613 let mut detector = RegimeDetector::default_config();
614
615 let data = generate_trending_data(300, 100.0, 0.5);
617
618 let mut last_regime = MarketRegime::Uncertain;
619 for (high, low, close) in data {
620 let result = detector.update(high, low, close);
621 if detector.is_ready() {
622 last_regime = result.regime;
623 }
624 }
625
626 assert!(
627 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
628 "Expected Bullish trend, got: {last_regime:?}"
629 );
630 }
631
632 #[test]
633 fn test_trending_bearish_direction() {
634 let mut detector = RegimeDetector::default_config();
635
636 let data = generate_trending_data(300, 200.0, -0.5);
638
639 let mut last_regime = MarketRegime::Uncertain;
640 for (high, low, close) in data {
641 let result = detector.update(high, low, close);
642 if detector.is_ready() {
643 last_regime = result.regime;
644 }
645 }
646
647 if matches!(last_regime, MarketRegime::Trending(_)) {
649 assert!(
650 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bearish)),
651 "Expected Bearish trend, got: {last_regime:?}"
652 );
653 }
654 }
655
656 #[test]
657 fn test_ranging_detection() {
658 let mut detector = RegimeDetector::default_config();
659
660 let data = generate_ranging_data(300, 100.0, 2.0);
662
663 let mut last_regime = MarketRegime::Uncertain;
664 for (high, low, close) in data {
665 let result = detector.update(high, low, close);
666 if detector.is_ready() {
667 last_regime = result.regime;
668 }
669 }
670
671 assert!(
673 !matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
674 "Ranging data shouldn't produce strong bullish trend, got: {last_regime:?}"
675 );
676 }
677
678 #[test]
679 fn test_confidence_range() {
680 let mut detector = RegimeDetector::default_config();
681
682 let data = generate_trending_data(300, 100.0, 0.5);
683
684 for (high, low, close) in data {
685 let result = detector.update(high, low, close);
686 assert!(
687 (0.0..=1.0).contains(&result.confidence),
688 "Confidence should be in [0, 1]: {}",
689 result.confidence
690 );
691 }
692 }
693
694 #[test]
695 fn test_regime_history_tracking() {
696 let mut detector = RegimeDetector::default_config();
697
698 let trend_data = generate_trending_data(200, 100.0, 0.5);
700 for (high, low, close) in trend_data {
701 detector.update(high, low, close);
702 }
703
704 let range_data = generate_ranging_data(200, 200.0, 1.0);
705 for (high, low, close) in range_data {
706 detector.update(high, low, close);
707 }
708
709 assert!(
712 detector.regime_history().len() <= 20,
713 "History should be bounded"
714 );
715 }
716
717 #[test]
718 fn test_recommended_strategy() {
719 let mut detector = RegimeDetector::default_config();
720
721 let data = generate_trending_data(300, 100.0, 0.5);
723 for (high, low, close) in data {
724 detector.update(high, low, close);
725 }
726
727 if matches!(detector.current_regime(), MarketRegime::Trending(_)) {
728 assert_eq!(
729 detector.recommended_strategy(),
730 RecommendedStrategy::TrendFollowing
731 );
732 }
733 }
734
735 #[test]
736 fn test_adx_atr_accessors() {
737 let mut detector = RegimeDetector::default_config();
738
739 assert!(detector.adx_value().is_none());
741 assert!(detector.atr_value().is_none());
742
743 let data = generate_trending_data(300, 100.0, 0.5);
745 for (high, low, close) in data {
746 detector.update(high, low, close);
747 }
748
749 assert!(detector.adx_value().is_some());
751 assert!(detector.atr_value().is_some());
752 }
753
754 #[test]
755 fn test_set_config_resets_state() {
756 let mut detector = RegimeDetector::default_config();
757
758 let data = generate_trending_data(300, 100.0, 0.5);
760 for (high, low, close) in data {
761 detector.update(high, low, close);
762 }
763 assert!(detector.is_ready());
764
765 detector.set_config(RegimeConfig::crypto_optimized());
767 assert!(!detector.is_ready());
768 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
769 assert_eq!(detector.bars_in_current_regime(), 0);
770 }
771
772 #[test]
773 fn test_last_close_tracking() {
774 let mut detector = RegimeDetector::default_config();
775
776 assert!(detector.last_close().is_none());
777
778 detector.update(101.0, 99.0, 100.0);
779 assert_eq!(detector.last_close(), Some(100.0));
780
781 detector.update(106.0, 104.0, 105.0);
782 assert_eq!(detector.last_close(), Some(105.0));
783 }
784
785 #[test]
786 fn test_bars_in_regime_increments() {
787 let mut detector = RegimeDetector::default_config();
788
789 for i in 0..300 {
791 let price = 100.0 + i as f64 * 0.3;
792 detector.update(price + 1.0, price - 1.0, price);
793 }
794
795 assert!(
797 detector.bars_in_current_regime() > 0,
798 "Should have been in current regime for multiple bars (regime: {:?})",
799 detector.current_regime()
800 );
801 }
802
803 #[test]
804 fn test_stability_filter_prevents_whipsaw() {
805 let mut detector = RegimeDetector::new(RegimeConfig {
806 min_regime_duration: 10,
807 regime_stability_bars: 5,
808 ..RegimeConfig::default()
809 });
810
811 let trend_data = generate_trending_data(300, 100.0, 0.5);
813 for (high, low, close) in trend_data {
814 detector.update(high, low, close);
815 }
816
817 let regime_before = detector.current_regime();
818
819 for (high, low, close) in generate_ranging_data(3, 250.0, 1.0) {
821 detector.update(high, low, close);
822 }
823
824 let regime_after = detector.current_regime();
825
826 assert!(
830 matches!(
831 regime_before,
832 MarketRegime::Trending(_)
833 | MarketRegime::MeanReverting
834 | MarketRegime::Volatile
835 | MarketRegime::Uncertain
836 ),
837 "regime_before should be a valid variant: {regime_before:?}"
838 );
839 assert!(
840 matches!(
841 regime_after,
842 MarketRegime::Trending(_)
843 | MarketRegime::MeanReverting
844 | MarketRegime::Volatile
845 | MarketRegime::Uncertain
846 ),
847 "Regime should be a valid variant after ranging data: {regime_after:?}"
848 );
849 }
850
851 #[test]
852 fn test_metrics_populated_after_warmup() {
853 let mut detector = RegimeDetector::default_config();
854
855 let data = generate_trending_data(300, 100.0, 0.5);
856 let mut last_result = RegimeConfidence::default();
857
858 for (high, low, close) in data {
859 last_result = detector.update(high, low, close);
860 }
861
862 assert!(last_result.adx_value > 0.0, "ADX should be > 0");
864 assert!(
865 last_result.bb_width_percentile >= 0.0 && last_result.bb_width_percentile <= 100.0,
866 "BB width percentile should be in [0, 100]"
867 );
868 assert!(
869 last_result.trend_strength >= 0.0 && last_result.trend_strength <= 1.0,
870 "Trend strength should be in [0, 1]"
871 );
872 }
873}