1use std::collections::{HashMap, VecDeque};
11
12use super::primitives::{ADX, ATR, BollingerBands, BollingerBandsValues, EMA};
13use super::types::{
14 MarketRegime, RecommendedStrategy, RegimeConfidence, RegimeConfig, TrendDirection,
15};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
30pub struct DetectorIndicator {
31 pub config: RegimeConfig,
32}
33
34impl DetectorIndicator {
35 pub fn new(config: RegimeConfig) -> Self {
36 Self { config }
37 }
38 pub fn with_defaults() -> Self {
39 Self::new(RegimeConfig::default())
40 }
41}
42
43fn regime_id(r: MarketRegime) -> f64 {
44 use super::types::TrendDirection;
45 match r {
46 MarketRegime::MeanReverting => 1.0,
47 MarketRegime::Volatile => 2.0,
48 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
49 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
50 MarketRegime::Uncertain => 0.0,
51 }
52}
53
54impl Indicator for DetectorIndicator {
55 fn name(&self) -> &'static str {
56 "RegimeDetector"
57 }
58 fn required_len(&self) -> usize {
59 let adx_warmup = self.config.adx_period * 2 + self.config.regime_stability_bars;
62 let ema_warmup = self.config.ema_long_period;
63 let bb_warmup = self.config.bb_period;
64 adx_warmup.max(ema_warmup).max(bb_warmup)
65 }
66 fn required_columns(&self) -> &[&'static str] {
67 &["high", "low", "close"]
68 }
69
70 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
71 self.check_len(candles)?;
72 let mut det = RegimeDetector::new(self.config.clone());
73 let n = candles.len();
74 let mut conf = vec![f64::NAN; n];
75 let mut regime = vec![f64::NAN; n];
76 for (i, c) in candles.iter().enumerate() {
77 let rc = det.update(c.high, c.low, c.close);
78 conf[i] = rc.confidence;
79 regime[i] = regime_id(rc.regime);
80 }
81 Ok(IndicatorOutput::from_pairs([
82 ("regime_conf", conf),
83 ("regime_id", regime),
84 ]))
85 }
86}
87
88pub fn factory<S: ::std::hash::BuildHasher>(
91 params: &HashMap<String, String, S>,
92) -> Result<Box<dyn Indicator>, IndicatorError> {
93 let adx_period = param_usize(params, "adx_period", 14)?;
94 let bb_period = param_usize(params, "bb_period", 20)?;
95 let config = RegimeConfig {
96 adx_period,
97 bb_period,
98 ..RegimeConfig::default()
99 };
100 Ok(Box::new(DetectorIndicator::new(config)))
101}
102
103#[derive(Debug)]
127pub struct RegimeDetector {
128 config: RegimeConfig,
129
130 adx: ADX,
132 atr: ATR,
133 atr_avg: EMA, bb: BollingerBands,
135 ema_short: EMA,
136 ema_long: EMA,
137
138 current_regime: MarketRegime,
140 regime_history: VecDeque<MarketRegime>,
141 bars_in_regime: usize,
142
143 last_close: Option<f64>,
145}
146
147impl RegimeDetector {
148 pub fn new(config: RegimeConfig) -> Self {
150 Self {
151 adx: ADX::new(config.adx_period),
152 atr: ATR::new(config.atr_period),
153 atr_avg: EMA::new(50), bb: BollingerBands::new(config.bb_period, config.bb_std_dev),
155 ema_short: EMA::new(config.ema_short_period),
156 ema_long: EMA::new(config.ema_long_period),
157 current_regime: MarketRegime::Uncertain,
158 regime_history: VecDeque::with_capacity(20),
159 bars_in_regime: 0,
160 last_close: None,
161 config,
162 }
163 }
164
165 pub fn default_config() -> Self {
167 Self::new(RegimeConfig::default())
168 }
169
170 pub fn crypto_optimized() -> Self {
172 Self::new(RegimeConfig::crypto_optimized())
173 }
174
175 pub fn conservative() -> Self {
177 Self::new(RegimeConfig::conservative())
178 }
179
180 pub fn update(&mut self, high: f64, low: f64, close: f64) -> RegimeConfidence {
189 let adx_value = self.adx.update(high, low, close);
191 let atr_value = self.atr.update(high, low, close);
192 let bb_values = self.bb.update(close);
193 let ema_short = self.ema_short.update(close);
194 let ema_long = self.ema_long.update(close);
195
196 if let Some(atr) = atr_value {
198 self.atr_avg.update(atr);
199 }
200
201 self.last_close = Some(close);
202
203 let (Some(adx_value), Some(atr_value), Some(bb_values), Some(ema_short), Some(ema_long)) = (
207 adx_value,
208 atr_value,
209 bb_values.as_ref(),
210 ema_short,
211 ema_long,
212 ) else {
213 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
214 };
215 if !self.is_ready() {
216 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
217 }
218
219 let (new_regime, confidence) =
221 self.classify_regime(adx_value, atr_value, bb_values, ema_short, ema_long, close);
222
223 let stable_regime = self.apply_stability_filter(new_regime, confidence);
225
226 if stable_regime != self.current_regime {
228 self.regime_history.push_back(self.current_regime);
229 if self.regime_history.len() > 20 {
230 self.regime_history.pop_front();
231 }
232 self.current_regime = stable_regime;
233 self.bars_in_regime = 0;
234 } else {
235 self.bars_in_regime += 1;
236 }
237
238 RegimeConfidence::with_metrics(
239 stable_regime,
240 confidence,
241 adx_value,
242 bb_values.width_percentile,
243 Self::calculate_trend_strength(ema_short, ema_long, close),
244 )
245 }
246
247 fn classify_regime(
256 &self,
257 adx: f64,
258 atr: f64,
259 bb: &BollingerBandsValues,
260 ema_short: f64,
261 ema_long: f64,
262 close: f64,
263 ) -> (MarketRegime, f64) {
264 let atr_expansion = if let Some(avg_atr) = self.atr_avg.value() {
266 atr / avg_atr
267 } else {
268 1.0
269 };
270
271 let mut trending_score: f64 = 0.0;
273 let mut ranging_score: f64 = 0.0;
274 let mut volatile_score: f64 = 0.0;
275
276 if adx >= self.config.adx_trending_threshold {
278 trending_score += 0.4;
279 } else if adx <= self.config.adx_ranging_threshold {
280 ranging_score += 0.3;
281 }
282
283 if bb.is_high_volatility(self.config.bb_width_volatility_threshold) {
285 volatile_score += 0.3;
286 }
287 if bb.is_squeeze(25.0) {
288 ranging_score += 0.2; }
290
291 if atr_expansion >= self.config.atr_expansion_threshold {
293 volatile_score += 0.3;
294 } else if atr_expansion < 0.8 {
295 ranging_score += 0.2; }
297
298 let ema_diff_pct = ((ema_short - ema_long) / ema_long).abs() * 100.0;
300 if ema_diff_pct > 2.0 {
301 trending_score += 0.3;
302 } else if ema_diff_pct < 1.0 {
303 ranging_score += 0.2;
304 }
305
306 let price_above_both = close > ema_short && close > ema_long;
308 let price_below_both = close < ema_short && close < ema_long;
309 if price_above_both || price_below_both {
310 trending_score += 0.2;
311 } else {
312 ranging_score += 0.2; }
314
315 let max_score = trending_score.max(ranging_score).max(volatile_score);
317 let confidence = max_score / 1.2; let regime = if volatile_score >= 0.5 && volatile_score >= trending_score {
320 MarketRegime::Volatile
321 } else if trending_score > ranging_score && trending_score > 0.3 {
322 let direction = if ema_short > ema_long && close > ema_long {
324 TrendDirection::Bullish
325 } else if ema_short < ema_long && close < ema_long {
326 TrendDirection::Bearish
327 } else if let Some(dir) = self.adx.trend_direction() {
328 dir
329 } else {
330 TrendDirection::Bullish };
332 MarketRegime::Trending(direction)
333 } else if ranging_score > 0.3 {
334 MarketRegime::MeanReverting
335 } else {
336 MarketRegime::Uncertain
337 };
338
339 (regime, confidence.min(1.0))
340 }
341
342 fn apply_stability_filter(&self, new_regime: MarketRegime, confidence: f64) -> MarketRegime {
349 if confidence < 0.4 {
351 return self.current_regime;
352 }
353
354 if self.bars_in_regime < self.config.min_regime_duration
356 && new_regime != self.current_regime
357 {
358 if confidence < 0.7 {
360 return self.current_regime;
361 }
362 }
363
364 let recent_count = self
366 .regime_history
367 .iter()
368 .rev()
369 .take(self.config.regime_stability_bars)
370 .filter(|&&r| {
371 matches!(
372 (&r, &new_regime),
373 (MarketRegime::Trending(_), MarketRegime::Trending(_))
374 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
375 | (MarketRegime::Volatile, MarketRegime::Volatile)
376 )
377 })
378 .count();
379
380 if recent_count < self.config.regime_stability_bars / 2 && confidence < 0.6 {
382 return self.current_regime;
383 }
384
385 new_regime
386 }
387
388 fn calculate_trend_strength(ema_short: f64, ema_long: f64, close: f64) -> f64 {
390 let ema_alignment = (ema_short - ema_long).abs() / ema_long * 100.0;
391 let price_position = if close > ema_short && close > ema_long {
392 1.0
393 } else if close < ema_short && close < ema_long {
394 0.7
395 } else {
396 0.5
397 };
398
399 (ema_alignment * price_position / 5.0).min(1.0) }
401
402 pub fn is_ready(&self) -> bool {
411 self.adx.is_ready()
412 && self.atr.is_ready()
413 && self.bb.is_ready()
414 && self.ema_short.is_ready()
415 && self.ema_long.is_ready()
416 }
417
418 pub fn current_regime(&self) -> MarketRegime {
420 self.current_regime
421 }
422
423 pub fn recommended_strategy(&self) -> RecommendedStrategy {
425 RecommendedStrategy::from(&self.current_regime)
426 }
427
428 pub fn bars_in_current_regime(&self) -> usize {
430 self.bars_in_regime
431 }
432
433 pub fn adx_value(&self) -> Option<f64> {
435 self.adx.value()
436 }
437
438 pub fn atr_value(&self) -> Option<f64> {
440 self.atr.value()
441 }
442
443 pub fn config(&self) -> &RegimeConfig {
445 &self.config
446 }
447
448 pub fn set_config(&mut self, config: RegimeConfig) {
450 *self = Self::new(config);
451 }
452
453 pub fn regime_history(&self) -> &VecDeque<MarketRegime> {
455 &self.regime_history
456 }
457
458 pub fn last_close(&self) -> Option<f64> {
460 self.last_close
461 }
462}
463
464#[cfg(test)]
469mod tests {
470 use super::*;
471
472 fn generate_trending_data(
474 bars: usize,
475 start_price: f64,
476 trend_strength: f64,
477 ) -> Vec<(f64, f64, f64)> {
478 let mut data = Vec::new();
479 let mut price = start_price;
480
481 for _ in 0..bars {
482 let change = trend_strength * (1.0 + (rand::random::<f64>() - 0.5) * 0.2);
483 price += change;
484
485 let high = price + price * 0.005;
486 let low = price - price * 0.005;
487 let close = price;
488
489 data.push((high, low, close));
490 }
491
492 data
493 }
494
495 fn generate_ranging_data(
497 bars: usize,
498 center_price: f64,
499 range_pct: f64,
500 ) -> Vec<(f64, f64, f64)> {
501 let mut data = Vec::new();
502
503 for i in 0..bars {
504 let offset = (i as f64 * 0.5).sin() * center_price * range_pct / 100.0;
505 let price = center_price + offset;
506
507 let high = price + price * 0.002;
508 let low = price - price * 0.002;
509 let close = price;
510
511 data.push((high, low, close));
512 }
513
514 data
515 }
516
517 fn generate_volatile_data(bars: usize, center_price: f64) -> Vec<(f64, f64, f64)> {
519 let mut data = Vec::new();
520
521 for i in 0..bars {
522 let swing = if i % 2 == 0 { 1.05 } else { 0.95 };
523 let price = center_price * swing;
524
525 let high = price * 1.03;
526 let low = price * 0.97;
527 let close = price;
528
529 data.push((high, low, close));
530 }
531
532 data
533 }
534
535 #[test]
536 fn test_volatile_regime_detection() {
537 let mut detector = RegimeDetector::default_config();
538 for (high, low, close) in generate_volatile_data(200, 100.0) {
540 detector.update(high, low, close);
541 }
542 assert!(detector.is_ready());
544 let regime = detector.current_regime();
545 assert!(
546 matches!(
547 regime,
548 MarketRegime::Volatile
549 | MarketRegime::Trending(_)
550 | MarketRegime::MeanReverting
551 | MarketRegime::Uncertain
552 ),
553 "Expected a valid regime variant, got: {regime:?}"
554 );
555 }
556
557 #[test]
558 fn test_detector_creation() {
559 let detector = RegimeDetector::default_config();
560 assert!(!detector.is_ready());
561 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
562 assert_eq!(detector.bars_in_current_regime(), 0);
563 }
564
565 #[test]
566 fn test_crypto_optimized_creation() {
567 let detector = RegimeDetector::crypto_optimized();
568 assert!(!detector.is_ready());
569 assert_eq!(detector.config().adx_trending_threshold, 20.0);
570 assert_eq!(detector.config().ema_short_period, 21);
571 }
572
573 #[test]
574 fn test_conservative_creation() {
575 let detector = RegimeDetector::conservative();
576 assert_eq!(detector.config().adx_trending_threshold, 30.0);
577 assert_eq!(detector.config().min_regime_duration, 10);
578 }
579
580 #[test]
581 fn test_warmup_returns_uncertain() {
582 let mut detector = RegimeDetector::default_config();
583
584 for i in 0..10 {
586 let price = 100.0 + i as f64;
587 let result = detector.update(price + 1.0, price - 1.0, price);
588 assert_eq!(result.regime, MarketRegime::Uncertain);
589 assert_eq!(result.confidence, 0.0);
590 }
591
592 assert!(!detector.is_ready());
593 }
594
595 #[test]
596 fn test_trending_detection() {
597 let mut detector = RegimeDetector::default_config();
598
599 let data = generate_trending_data(300, 100.0, 0.5);
601
602 let mut last_regime = MarketRegime::Uncertain;
603 for (high, low, close) in data {
604 let result = detector.update(high, low, close);
605 if detector.is_ready() {
606 last_regime = result.regime;
607 }
608 }
609
610 assert!(
611 matches!(last_regime, MarketRegime::Trending(_)),
612 "Expected Trending regime, got: {last_regime:?}"
613 );
614 }
615
616 #[test]
617 fn test_trending_bullish_direction() {
618 let mut detector = RegimeDetector::default_config();
619
620 let data = generate_trending_data(300, 100.0, 0.5);
622
623 let mut last_regime = MarketRegime::Uncertain;
624 for (high, low, close) in data {
625 let result = detector.update(high, low, close);
626 if detector.is_ready() {
627 last_regime = result.regime;
628 }
629 }
630
631 assert!(
632 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
633 "Expected Bullish trend, got: {last_regime:?}"
634 );
635 }
636
637 #[test]
638 fn test_trending_bearish_direction() {
639 let mut detector = RegimeDetector::default_config();
640
641 let data = generate_trending_data(300, 200.0, -0.5);
643
644 let mut last_regime = MarketRegime::Uncertain;
645 for (high, low, close) in data {
646 let result = detector.update(high, low, close);
647 if detector.is_ready() {
648 last_regime = result.regime;
649 }
650 }
651
652 if matches!(last_regime, MarketRegime::Trending(_)) {
654 assert!(
655 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bearish)),
656 "Expected Bearish trend, got: {last_regime:?}"
657 );
658 }
659 }
660
661 #[test]
662 fn test_ranging_detection() {
663 let mut detector = RegimeDetector::default_config();
664
665 let data = generate_ranging_data(300, 100.0, 2.0);
667
668 let mut last_regime = MarketRegime::Uncertain;
669 for (high, low, close) in data {
670 let result = detector.update(high, low, close);
671 if detector.is_ready() {
672 last_regime = result.regime;
673 }
674 }
675
676 assert!(
678 !matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
679 "Ranging data shouldn't produce strong bullish trend, got: {last_regime:?}"
680 );
681 }
682
683 #[test]
684 fn test_confidence_range() {
685 let mut detector = RegimeDetector::default_config();
686
687 let data = generate_trending_data(300, 100.0, 0.5);
688
689 for (high, low, close) in data {
690 let result = detector.update(high, low, close);
691 assert!(
692 (0.0..=1.0).contains(&result.confidence),
693 "Confidence should be in [0, 1]: {}",
694 result.confidence
695 );
696 }
697 }
698
699 #[test]
700 fn test_regime_history_tracking() {
701 let mut detector = RegimeDetector::default_config();
702
703 let trend_data = generate_trending_data(200, 100.0, 0.5);
705 for (high, low, close) in trend_data {
706 detector.update(high, low, close);
707 }
708
709 let range_data = generate_ranging_data(200, 200.0, 1.0);
710 for (high, low, close) in range_data {
711 detector.update(high, low, close);
712 }
713
714 assert!(
717 detector.regime_history().len() <= 20,
718 "History should be bounded"
719 );
720 }
721
722 #[test]
723 fn test_recommended_strategy() {
724 let mut detector = RegimeDetector::default_config();
725
726 let data = generate_trending_data(300, 100.0, 0.5);
728 for (high, low, close) in data {
729 detector.update(high, low, close);
730 }
731
732 if matches!(detector.current_regime(), MarketRegime::Trending(_)) {
733 assert_eq!(
734 detector.recommended_strategy(),
735 RecommendedStrategy::TrendFollowing
736 );
737 }
738 }
739
740 #[test]
741 fn test_adx_atr_accessors() {
742 let mut detector = RegimeDetector::default_config();
743
744 assert!(detector.adx_value().is_none());
746 assert!(detector.atr_value().is_none());
747
748 let data = generate_trending_data(300, 100.0, 0.5);
750 for (high, low, close) in data {
751 detector.update(high, low, close);
752 }
753
754 assert!(detector.adx_value().is_some());
756 assert!(detector.atr_value().is_some());
757 }
758
759 #[test]
760 fn test_set_config_resets_state() {
761 let mut detector = RegimeDetector::default_config();
762
763 let data = generate_trending_data(300, 100.0, 0.5);
765 for (high, low, close) in data {
766 detector.update(high, low, close);
767 }
768 assert!(detector.is_ready());
769
770 detector.set_config(RegimeConfig::crypto_optimized());
772 assert!(!detector.is_ready());
773 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
774 assert_eq!(detector.bars_in_current_regime(), 0);
775 }
776
777 #[test]
778 fn test_last_close_tracking() {
779 let mut detector = RegimeDetector::default_config();
780
781 assert!(detector.last_close().is_none());
782
783 detector.update(101.0, 99.0, 100.0);
784 assert_eq!(detector.last_close(), Some(100.0));
785
786 detector.update(106.0, 104.0, 105.0);
787 assert_eq!(detector.last_close(), Some(105.0));
788 }
789
790 #[test]
791 fn test_bars_in_regime_increments() {
792 let mut detector = RegimeDetector::default_config();
793
794 for i in 0..300 {
796 let price = 100.0 + i as f64 * 0.3;
797 detector.update(price + 1.0, price - 1.0, price);
798 }
799
800 assert!(
802 detector.bars_in_current_regime() > 0,
803 "Should have been in current regime for multiple bars (regime: {:?})",
804 detector.current_regime()
805 );
806 }
807
808 #[test]
809 fn test_stability_filter_prevents_whipsaw() {
810 let mut detector = RegimeDetector::new(RegimeConfig {
811 min_regime_duration: 10,
812 regime_stability_bars: 5,
813 ..RegimeConfig::default()
814 });
815
816 let trend_data = generate_trending_data(300, 100.0, 0.5);
818 for (high, low, close) in trend_data {
819 detector.update(high, low, close);
820 }
821
822 let regime_before = detector.current_regime();
823
824 for (high, low, close) in generate_ranging_data(3, 250.0, 1.0) {
826 detector.update(high, low, close);
827 }
828
829 let regime_after = detector.current_regime();
830
831 assert!(
835 matches!(
836 regime_before,
837 MarketRegime::Trending(_)
838 | MarketRegime::MeanReverting
839 | MarketRegime::Volatile
840 | MarketRegime::Uncertain
841 ),
842 "regime_before should be a valid variant: {regime_before:?}"
843 );
844 assert!(
845 matches!(
846 regime_after,
847 MarketRegime::Trending(_)
848 | MarketRegime::MeanReverting
849 | MarketRegime::Volatile
850 | MarketRegime::Uncertain
851 ),
852 "Regime should be a valid variant after ranging data: {regime_after:?}"
853 );
854 }
855
856 #[test]
857 fn test_metrics_populated_after_warmup() {
858 let mut detector = RegimeDetector::default_config();
859
860 let data = generate_trending_data(300, 100.0, 0.5);
861 let mut last_result = RegimeConfidence::default();
862
863 for (high, low, close) in data {
864 last_result = detector.update(high, low, close);
865 }
866
867 assert!(last_result.adx_value > 0.0, "ADX should be > 0");
869 assert!(
870 last_result.bb_width_percentile >= 0.0 && last_result.bb_width_percentile <= 100.0,
871 "BB width percentile should be in [0, 100]"
872 );
873 assert!(
874 last_result.trend_strength >= 0.0 && last_result.trend_strength <= 1.0,
875 "Trend strength should be in [0, 1]"
876 );
877 }
878}