1use std::collections::{HashMap, VecDeque};
11
12use super::primitives::{ADX, ATR, BollingerBands, BollingerBandsValues, EMA};
13use super::types::{
14 MarketRegime, RecommendedStrategy, RegimeConfidence, RegimeConfig, TrendDirection,
15};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
30pub struct DetectorIndicator {
31 pub config: RegimeConfig,
32}
33
34impl DetectorIndicator {
35 pub fn new(config: RegimeConfig) -> Self {
36 Self { config }
37 }
38 pub fn with_defaults() -> Self {
39 Self::new(RegimeConfig::default())
40 }
41}
42
43fn regime_id(r: MarketRegime) -> f64 {
44 use super::types::TrendDirection;
45 match r {
46 MarketRegime::MeanReverting => 1.0,
47 MarketRegime::Volatile => 2.0,
48 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
49 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
50 MarketRegime::Uncertain => 0.0,
51 }
52}
53
54impl Indicator for DetectorIndicator {
55 fn name(&self) -> &'static str {
56 "RegimeDetector"
57 }
58 fn required_len(&self) -> usize {
59 let adx_warmup = self.config.adx_period * 2 + self.config.regime_stability_bars;
62 let ema_warmup = self.config.ema_long_period;
63 let bb_warmup = self.config.bb_period;
64 adx_warmup.max(ema_warmup).max(bb_warmup)
65 }
66 fn required_columns(&self) -> &[&'static str] {
67 &["high", "low", "close"]
68 }
69
70 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
71 self.check_len(candles)?;
72 let mut det = RegimeDetector::new(self.config.clone());
73 let n = candles.len();
74 let mut conf = vec![f64::NAN; n];
75 let mut regime = vec![f64::NAN; n];
76 for (i, c) in candles.iter().enumerate() {
77 let rc = det.update(c.high, c.low, c.close);
78 conf[i] = rc.confidence;
79 regime[i] = regime_id(rc.regime);
80 }
81 Ok(IndicatorOutput::from_pairs([
82 ("regime_conf", conf),
83 ("regime_id", regime),
84 ]))
85 }
86}
87
88pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
91 let adx_period = param_usize(params, "adx_period", 14)?;
92 let bb_period = param_usize(params, "bb_period", 20)?;
93 let config = RegimeConfig {
94 adx_period,
95 bb_period,
96 ..RegimeConfig::default()
97 };
98 Ok(Box::new(DetectorIndicator::new(config)))
99}
100
101#[derive(Debug)]
125pub struct RegimeDetector {
126 config: RegimeConfig,
127
128 adx: ADX,
130 atr: ATR,
131 atr_avg: EMA, bb: BollingerBands,
133 ema_short: EMA,
134 ema_long: EMA,
135
136 current_regime: MarketRegime,
138 regime_history: VecDeque<MarketRegime>,
139 bars_in_regime: usize,
140
141 last_close: Option<f64>,
143}
144
145impl RegimeDetector {
146 pub fn new(config: RegimeConfig) -> Self {
148 Self {
149 adx: ADX::new(config.adx_period),
150 atr: ATR::new(config.atr_period),
151 atr_avg: EMA::new(50), bb: BollingerBands::new(config.bb_period, config.bb_std_dev),
153 ema_short: EMA::new(config.ema_short_period),
154 ema_long: EMA::new(config.ema_long_period),
155 current_regime: MarketRegime::Uncertain,
156 regime_history: VecDeque::with_capacity(20),
157 bars_in_regime: 0,
158 last_close: None,
159 config,
160 }
161 }
162
163 pub fn default_config() -> Self {
165 Self::new(RegimeConfig::default())
166 }
167
168 pub fn crypto_optimized() -> Self {
170 Self::new(RegimeConfig::crypto_optimized())
171 }
172
173 pub fn conservative() -> Self {
175 Self::new(RegimeConfig::conservative())
176 }
177
178 pub fn update(&mut self, high: f64, low: f64, close: f64) -> RegimeConfidence {
187 let adx_value = self.adx.update(high, low, close);
189 let atr_value = self.atr.update(high, low, close);
190 let bb_values = self.bb.update(close);
191 let ema_short = self.ema_short.update(close);
192 let ema_long = self.ema_long.update(close);
193
194 if let Some(atr) = atr_value {
196 self.atr_avg.update(atr);
197 }
198
199 self.last_close = Some(close);
200
201 if !self.is_ready() {
203 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
204 }
205
206 let (new_regime, confidence) = self.classify_regime(
208 adx_value.unwrap(),
209 atr_value.unwrap(),
210 bb_values.as_ref().unwrap(),
211 ema_short.unwrap(),
212 ema_long.unwrap(),
213 close,
214 );
215
216 let stable_regime = self.apply_stability_filter(new_regime, confidence);
218
219 if stable_regime != self.current_regime {
221 self.regime_history.push_back(self.current_regime);
222 if self.regime_history.len() > 20 {
223 self.regime_history.pop_front();
224 }
225 self.current_regime = stable_regime;
226 self.bars_in_regime = 0;
227 } else {
228 self.bars_in_regime += 1;
229 }
230
231 RegimeConfidence::with_metrics(
232 stable_regime,
233 confidence,
234 adx_value.unwrap(),
235 bb_values.as_ref().map_or(50.0, |b| b.width_percentile),
236 Self::calculate_trend_strength(ema_short.unwrap(), ema_long.unwrap(), close),
237 )
238 }
239
240 fn classify_regime(
249 &self,
250 adx: f64,
251 atr: f64,
252 bb: &BollingerBandsValues,
253 ema_short: f64,
254 ema_long: f64,
255 close: f64,
256 ) -> (MarketRegime, f64) {
257 let atr_expansion = if let Some(avg_atr) = self.atr_avg.value() {
259 atr / avg_atr
260 } else {
261 1.0
262 };
263
264 let mut trending_score: f64 = 0.0;
266 let mut ranging_score: f64 = 0.0;
267 let mut volatile_score: f64 = 0.0;
268
269 if adx >= self.config.adx_trending_threshold {
271 trending_score += 0.4;
272 } else if adx <= self.config.adx_ranging_threshold {
273 ranging_score += 0.3;
274 }
275
276 if bb.is_high_volatility(self.config.bb_width_volatility_threshold) {
278 volatile_score += 0.3;
279 }
280 if bb.is_squeeze(25.0) {
281 ranging_score += 0.2; }
283
284 if atr_expansion >= self.config.atr_expansion_threshold {
286 volatile_score += 0.3;
287 } else if atr_expansion < 0.8 {
288 ranging_score += 0.2; }
290
291 let ema_diff_pct = ((ema_short - ema_long) / ema_long).abs() * 100.0;
293 if ema_diff_pct > 2.0 {
294 trending_score += 0.3;
295 } else if ema_diff_pct < 1.0 {
296 ranging_score += 0.2;
297 }
298
299 let price_above_both = close > ema_short && close > ema_long;
301 let price_below_both = close < ema_short && close < ema_long;
302 if price_above_both || price_below_both {
303 trending_score += 0.2;
304 } else {
305 ranging_score += 0.2; }
307
308 let max_score = trending_score.max(ranging_score).max(volatile_score);
310 let confidence = max_score / 1.2; let regime = if volatile_score >= 0.5 && volatile_score >= trending_score {
313 MarketRegime::Volatile
314 } else if trending_score > ranging_score && trending_score > 0.3 {
315 let direction = if ema_short > ema_long && close > ema_long {
317 TrendDirection::Bullish
318 } else if ema_short < ema_long && close < ema_long {
319 TrendDirection::Bearish
320 } else if let Some(dir) = self.adx.trend_direction() {
321 dir
322 } else {
323 TrendDirection::Bullish };
325 MarketRegime::Trending(direction)
326 } else if ranging_score > 0.3 {
327 MarketRegime::MeanReverting
328 } else {
329 MarketRegime::Uncertain
330 };
331
332 (regime, confidence.min(1.0))
333 }
334
335 fn apply_stability_filter(&self, new_regime: MarketRegime, confidence: f64) -> MarketRegime {
342 if confidence < 0.4 {
344 return self.current_regime;
345 }
346
347 if self.bars_in_regime < self.config.min_regime_duration
349 && new_regime != self.current_regime
350 {
351 if confidence < 0.7 {
353 return self.current_regime;
354 }
355 }
356
357 let recent_count = self
359 .regime_history
360 .iter()
361 .rev()
362 .take(self.config.regime_stability_bars)
363 .filter(|&&r| {
364 matches!(
365 (&r, &new_regime),
366 (MarketRegime::Trending(_), MarketRegime::Trending(_))
367 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
368 | (MarketRegime::Volatile, MarketRegime::Volatile)
369 )
370 })
371 .count();
372
373 if recent_count < self.config.regime_stability_bars / 2 && confidence < 0.6 {
375 return self.current_regime;
376 }
377
378 new_regime
379 }
380
381 fn calculate_trend_strength(ema_short: f64, ema_long: f64, close: f64) -> f64 {
383 let ema_alignment = (ema_short - ema_long).abs() / ema_long * 100.0;
384 let price_position = if close > ema_short && close > ema_long {
385 1.0
386 } else if close < ema_short && close < ema_long {
387 0.7
388 } else {
389 0.5
390 };
391
392 (ema_alignment * price_position / 5.0).min(1.0) }
394
395 pub fn is_ready(&self) -> bool {
404 self.adx.is_ready()
405 && self.atr.is_ready()
406 && self.bb.is_ready()
407 && self.ema_short.is_ready()
408 && self.ema_long.is_ready()
409 }
410
411 pub fn current_regime(&self) -> MarketRegime {
413 self.current_regime
414 }
415
416 pub fn recommended_strategy(&self) -> RecommendedStrategy {
418 RecommendedStrategy::from(&self.current_regime)
419 }
420
421 pub fn bars_in_current_regime(&self) -> usize {
423 self.bars_in_regime
424 }
425
426 pub fn adx_value(&self) -> Option<f64> {
428 self.adx.value()
429 }
430
431 pub fn atr_value(&self) -> Option<f64> {
433 self.atr.value()
434 }
435
436 pub fn config(&self) -> &RegimeConfig {
438 &self.config
439 }
440
441 pub fn set_config(&mut self, config: RegimeConfig) {
443 *self = Self::new(config);
444 }
445
446 pub fn regime_history(&self) -> &VecDeque<MarketRegime> {
448 &self.regime_history
449 }
450
451 pub fn last_close(&self) -> Option<f64> {
453 self.last_close
454 }
455}
456
457#[cfg(test)]
462mod tests {
463 use super::*;
464
465 fn generate_trending_data(
467 bars: usize,
468 start_price: f64,
469 trend_strength: f64,
470 ) -> Vec<(f64, f64, f64)> {
471 let mut data = Vec::new();
472 let mut price = start_price;
473
474 for _ in 0..bars {
475 let change = trend_strength * (1.0 + (rand::random::<f64>() - 0.5) * 0.2);
476 price += change;
477
478 let high = price + price * 0.005;
479 let low = price - price * 0.005;
480 let close = price;
481
482 data.push((high, low, close));
483 }
484
485 data
486 }
487
488 fn generate_ranging_data(
490 bars: usize,
491 center_price: f64,
492 range_pct: f64,
493 ) -> Vec<(f64, f64, f64)> {
494 let mut data = Vec::new();
495
496 for i in 0..bars {
497 let offset = (i as f64 * 0.5).sin() * center_price * range_pct / 100.0;
498 let price = center_price + offset;
499
500 let high = price + price * 0.002;
501 let low = price - price * 0.002;
502 let close = price;
503
504 data.push((high, low, close));
505 }
506
507 data
508 }
509
510 fn generate_volatile_data(bars: usize, center_price: f64) -> Vec<(f64, f64, f64)> {
512 let mut data = Vec::new();
513
514 for i in 0..bars {
515 let swing = if i % 2 == 0 { 1.05 } else { 0.95 };
516 let price = center_price * swing;
517
518 let high = price * 1.03;
519 let low = price * 0.97;
520 let close = price;
521
522 data.push((high, low, close));
523 }
524
525 data
526 }
527
528 #[test]
529 fn test_volatile_regime_detection() {
530 let mut detector = RegimeDetector::default_config();
531 for (high, low, close) in generate_volatile_data(200, 100.0) {
533 detector.update(high, low, close);
534 }
535 assert!(detector.is_ready());
537 let regime = detector.current_regime();
538 assert!(
539 matches!(
540 regime,
541 MarketRegime::Volatile
542 | MarketRegime::Trending(_)
543 | MarketRegime::MeanReverting
544 | MarketRegime::Uncertain
545 ),
546 "Expected a valid regime variant, got: {regime:?}"
547 );
548 }
549
550 #[test]
551 fn test_detector_creation() {
552 let detector = RegimeDetector::default_config();
553 assert!(!detector.is_ready());
554 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
555 assert_eq!(detector.bars_in_current_regime(), 0);
556 }
557
558 #[test]
559 fn test_crypto_optimized_creation() {
560 let detector = RegimeDetector::crypto_optimized();
561 assert!(!detector.is_ready());
562 assert_eq!(detector.config().adx_trending_threshold, 20.0);
563 assert_eq!(detector.config().ema_short_period, 21);
564 }
565
566 #[test]
567 fn test_conservative_creation() {
568 let detector = RegimeDetector::conservative();
569 assert_eq!(detector.config().adx_trending_threshold, 30.0);
570 assert_eq!(detector.config().min_regime_duration, 10);
571 }
572
573 #[test]
574 fn test_warmup_returns_uncertain() {
575 let mut detector = RegimeDetector::default_config();
576
577 for i in 0..10 {
579 let price = 100.0 + i as f64;
580 let result = detector.update(price + 1.0, price - 1.0, price);
581 assert_eq!(result.regime, MarketRegime::Uncertain);
582 assert_eq!(result.confidence, 0.0);
583 }
584
585 assert!(!detector.is_ready());
586 }
587
588 #[test]
589 fn test_trending_detection() {
590 let mut detector = RegimeDetector::default_config();
591
592 let data = generate_trending_data(300, 100.0, 0.5);
594
595 let mut last_regime = MarketRegime::Uncertain;
596 for (high, low, close) in data {
597 let result = detector.update(high, low, close);
598 if detector.is_ready() {
599 last_regime = result.regime;
600 }
601 }
602
603 assert!(
604 matches!(last_regime, MarketRegime::Trending(_)),
605 "Expected Trending regime, got: {last_regime:?}"
606 );
607 }
608
609 #[test]
610 fn test_trending_bullish_direction() {
611 let mut detector = RegimeDetector::default_config();
612
613 let data = generate_trending_data(300, 100.0, 0.5);
615
616 let mut last_regime = MarketRegime::Uncertain;
617 for (high, low, close) in data {
618 let result = detector.update(high, low, close);
619 if detector.is_ready() {
620 last_regime = result.regime;
621 }
622 }
623
624 assert!(
625 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
626 "Expected Bullish trend, got: {last_regime:?}"
627 );
628 }
629
630 #[test]
631 fn test_trending_bearish_direction() {
632 let mut detector = RegimeDetector::default_config();
633
634 let data = generate_trending_data(300, 200.0, -0.5);
636
637 let mut last_regime = MarketRegime::Uncertain;
638 for (high, low, close) in data {
639 let result = detector.update(high, low, close);
640 if detector.is_ready() {
641 last_regime = result.regime;
642 }
643 }
644
645 if matches!(last_regime, MarketRegime::Trending(_)) {
647 assert!(
648 matches!(last_regime, MarketRegime::Trending(TrendDirection::Bearish)),
649 "Expected Bearish trend, got: {last_regime:?}"
650 );
651 }
652 }
653
654 #[test]
655 fn test_ranging_detection() {
656 let mut detector = RegimeDetector::default_config();
657
658 let data = generate_ranging_data(300, 100.0, 2.0);
660
661 let mut last_regime = MarketRegime::Uncertain;
662 for (high, low, close) in data {
663 let result = detector.update(high, low, close);
664 if detector.is_ready() {
665 last_regime = result.regime;
666 }
667 }
668
669 assert!(
671 !matches!(last_regime, MarketRegime::Trending(TrendDirection::Bullish)),
672 "Ranging data shouldn't produce strong bullish trend, got: {last_regime:?}"
673 );
674 }
675
676 #[test]
677 fn test_confidence_range() {
678 let mut detector = RegimeDetector::default_config();
679
680 let data = generate_trending_data(300, 100.0, 0.5);
681
682 for (high, low, close) in data {
683 let result = detector.update(high, low, close);
684 assert!(
685 (0.0..=1.0).contains(&result.confidence),
686 "Confidence should be in [0, 1]: {}",
687 result.confidence
688 );
689 }
690 }
691
692 #[test]
693 fn test_regime_history_tracking() {
694 let mut detector = RegimeDetector::default_config();
695
696 let trend_data = generate_trending_data(200, 100.0, 0.5);
698 for (high, low, close) in trend_data {
699 detector.update(high, low, close);
700 }
701
702 let range_data = generate_ranging_data(200, 200.0, 1.0);
703 for (high, low, close) in range_data {
704 detector.update(high, low, close);
705 }
706
707 assert!(
710 detector.regime_history().len() <= 20,
711 "History should be bounded"
712 );
713 }
714
715 #[test]
716 fn test_recommended_strategy() {
717 let mut detector = RegimeDetector::default_config();
718
719 let data = generate_trending_data(300, 100.0, 0.5);
721 for (high, low, close) in data {
722 detector.update(high, low, close);
723 }
724
725 if matches!(detector.current_regime(), MarketRegime::Trending(_)) {
726 assert_eq!(
727 detector.recommended_strategy(),
728 RecommendedStrategy::TrendFollowing
729 );
730 }
731 }
732
733 #[test]
734 fn test_adx_atr_accessors() {
735 let mut detector = RegimeDetector::default_config();
736
737 assert!(detector.adx_value().is_none());
739 assert!(detector.atr_value().is_none());
740
741 let data = generate_trending_data(300, 100.0, 0.5);
743 for (high, low, close) in data {
744 detector.update(high, low, close);
745 }
746
747 assert!(detector.adx_value().is_some());
749 assert!(detector.atr_value().is_some());
750 }
751
752 #[test]
753 fn test_set_config_resets_state() {
754 let mut detector = RegimeDetector::default_config();
755
756 let data = generate_trending_data(300, 100.0, 0.5);
758 for (high, low, close) in data {
759 detector.update(high, low, close);
760 }
761 assert!(detector.is_ready());
762
763 detector.set_config(RegimeConfig::crypto_optimized());
765 assert!(!detector.is_ready());
766 assert_eq!(detector.current_regime(), MarketRegime::Uncertain);
767 assert_eq!(detector.bars_in_current_regime(), 0);
768 }
769
770 #[test]
771 fn test_last_close_tracking() {
772 let mut detector = RegimeDetector::default_config();
773
774 assert!(detector.last_close().is_none());
775
776 detector.update(101.0, 99.0, 100.0);
777 assert_eq!(detector.last_close(), Some(100.0));
778
779 detector.update(106.0, 104.0, 105.0);
780 assert_eq!(detector.last_close(), Some(105.0));
781 }
782
783 #[test]
784 fn test_bars_in_regime_increments() {
785 let mut detector = RegimeDetector::default_config();
786
787 for i in 0..300 {
789 let price = 100.0 + i as f64 * 0.3;
790 detector.update(price + 1.0, price - 1.0, price);
791 }
792
793 assert!(
795 detector.bars_in_current_regime() > 0,
796 "Should have been in current regime for multiple bars (regime: {:?})",
797 detector.current_regime()
798 );
799 }
800
801 #[test]
802 fn test_stability_filter_prevents_whipsaw() {
803 let mut detector = RegimeDetector::new(RegimeConfig {
804 min_regime_duration: 10,
805 regime_stability_bars: 5,
806 ..RegimeConfig::default()
807 });
808
809 let trend_data = generate_trending_data(300, 100.0, 0.5);
811 for (high, low, close) in trend_data {
812 detector.update(high, low, close);
813 }
814
815 let regime_before = detector.current_regime();
816
817 for (high, low, close) in generate_ranging_data(3, 250.0, 1.0) {
819 detector.update(high, low, close);
820 }
821
822 let regime_after = detector.current_regime();
823
824 assert!(
828 matches!(
829 regime_before,
830 MarketRegime::Trending(_)
831 | MarketRegime::MeanReverting
832 | MarketRegime::Volatile
833 | MarketRegime::Uncertain
834 ),
835 "regime_before should be a valid variant: {regime_before:?}"
836 );
837 assert!(
838 matches!(
839 regime_after,
840 MarketRegime::Trending(_)
841 | MarketRegime::MeanReverting
842 | MarketRegime::Volatile
843 | MarketRegime::Uncertain
844 ),
845 "Regime should be a valid variant after ranging data: {regime_after:?}"
846 );
847 }
848
849 #[test]
850 fn test_metrics_populated_after_warmup() {
851 let mut detector = RegimeDetector::default_config();
852
853 let data = generate_trending_data(300, 100.0, 0.5);
854 let mut last_result = RegimeConfidence::default();
855
856 for (high, low, close) in data {
857 last_result = detector.update(high, low, close);
858 }
859
860 assert!(last_result.adx_value > 0.0, "ADX should be > 0");
862 assert!(
863 last_result.bb_width_percentile >= 0.0 && last_result.bb_width_percentile <= 100.0,
864 "BB width percentile should be in [0, 100]"
865 );
866 assert!(
867 last_result.trend_strength >= 0.0 && last_result.trend_strength <= 1.0,
868 "Trend strength should be in [0, 1]"
869 );
870 }
871}