1use std::collections::{HashMap, VecDeque};
14
15use serde::{Deserialize, Serialize};
16
17use super::detector::RegimeDetector;
18use super::hmm::{HMMConfig, HMMRegimeDetector};
19use super::types::{MarketRegime, RegimeConfidence, RegimeConfig};
20
21use crate::error::IndicatorError;
22use crate::indicator::{Indicator, IndicatorOutput};
23use crate::registry::param_usize;
24use crate::types::Candle;
25
26#[derive(Debug, Clone)]
32pub struct EnsembleIndicator {
33 pub ensemble_cfg: EnsembleConfig,
34 pub indicator_cfg: RegimeConfig,
35}
36
37impl EnsembleIndicator {
38 pub fn new(ensemble_cfg: EnsembleConfig, indicator_cfg: RegimeConfig) -> Self {
39 Self {
40 ensemble_cfg,
41 indicator_cfg,
42 }
43 }
44 pub fn with_defaults() -> Self {
45 Self::new(EnsembleConfig::default(), RegimeConfig::default())
46 }
47}
48
49fn regime_id_from(r: MarketRegime) -> f64 {
50 use super::types::TrendDirection;
51 match r {
52 MarketRegime::MeanReverting => 1.0,
53 MarketRegime::Volatile => 2.0,
54 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
55 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
56 MarketRegime::Uncertain => 0.0,
57 }
58}
59
60impl Indicator for EnsembleIndicator {
61 fn name(&self) -> &'static str {
62 "EnsembleRegime"
63 }
64 fn required_len(&self) -> usize {
65 let adx_warmup =
68 self.indicator_cfg.adx_period * 2 + self.indicator_cfg.regime_stability_bars;
69 let ema_warmup = self.indicator_cfg.ema_long_period;
70 let bb_warmup = self.indicator_cfg.bb_period;
71 let hmm_warmup = HMMConfig::default().min_observations + 1;
74 adx_warmup.max(ema_warmup).max(bb_warmup).max(hmm_warmup)
75 }
76 fn required_columns(&self) -> &[&'static str] {
77 &["high", "low", "close"]
78 }
79
80 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
81 self.check_len(candles)?;
82 let mut det =
83 EnsembleRegimeDetector::new(self.ensemble_cfg.clone(), self.indicator_cfg.clone());
84 let n = candles.len();
85 let mut conf = vec![f64::NAN; n];
86 let mut agree = vec![f64::NAN; n];
87 let mut regime = vec![f64::NAN; n];
88 for (i, c) in candles.iter().enumerate() {
89 let res = det.update(c.high, c.low, c.close);
90 conf[i] = res.confidence;
91 agree[i] = if res.methods_agree { 1.0 } else { 0.0 };
92 regime[i] = regime_id_from(res.regime);
93 }
94 Ok(IndicatorOutput::from_pairs([
95 ("ensemble_conf", conf),
96 ("ensemble_agree", agree),
97 ("ensemble_regime", regime),
98 ]))
99 }
100}
101
102pub fn factory<S: ::std::hash::BuildHasher>(
105 params: &HashMap<String, String, S>,
106) -> Result<Box<dyn Indicator>, IndicatorError> {
107 let adx_period = param_usize(params, "adx_period", 14)?;
108 let bb_period = param_usize(params, "bb_period", 20)?;
109 let indicator_cfg = RegimeConfig {
110 adx_period,
111 bb_period,
112 ..RegimeConfig::default()
113 };
114 Ok(Box::new(EnsembleIndicator::new(
115 EnsembleConfig::default(),
116 indicator_cfg,
117 )))
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct EnsembleConfig {
123 pub indicator_weight: f64,
125 pub hmm_weight: f64,
127 pub agreement_threshold: f64,
129 pub require_hmm_warmup: bool,
131 pub agreement_confidence_boost: f64,
133 pub disagreement_confidence_penalty: f64,
135}
136
137impl Default for EnsembleConfig {
138 fn default() -> Self {
139 Self {
140 indicator_weight: 0.6, hmm_weight: 0.4,
142 agreement_threshold: 0.5,
143 require_hmm_warmup: true,
144 agreement_confidence_boost: 0.15,
145 disagreement_confidence_penalty: 0.2,
146 }
147 }
148}
149
150impl EnsembleConfig {
151 pub fn balanced() -> Self {
153 Self {
154 indicator_weight: 0.5,
155 hmm_weight: 0.5,
156 ..Default::default()
157 }
158 }
159
160 pub fn hmm_focused() -> Self {
162 Self {
163 indicator_weight: 0.3,
164 hmm_weight: 0.7,
165 agreement_threshold: 0.6,
166 ..Default::default()
167 }
168 }
169
170 pub fn indicator_focused() -> Self {
172 Self {
173 indicator_weight: 0.7,
174 hmm_weight: 0.3,
175 agreement_threshold: 0.4,
176 ..Default::default()
177 }
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct EnsembleResult {
184 pub regime: MarketRegime,
186 pub confidence: f64,
188 pub methods_agree: bool,
190 pub indicator_result: RegimeConfidence,
192 pub hmm_result: RegimeConfidence,
194 pub indicator_regime: MarketRegime,
196 pub hmm_regime: MarketRegime,
197}
198
199impl EnsembleResult {
200 pub fn to_regime_confidence(&self) -> RegimeConfidence {
202 RegimeConfidence::new(self.regime, self.confidence)
203 }
204}
205
206impl std::fmt::Display for EnsembleResult {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(
209 f,
210 "Ensemble: {} (conf: {:.0}%, agree: {})",
211 self.regime,
212 self.confidence * 100.0,
213 if self.methods_agree { "✓" } else { "✗" }
214 )
215 }
216}
217
218#[derive(Debug)]
240pub struct EnsembleRegimeDetector {
241 config: EnsembleConfig,
242
243 indicator_detector: RegimeDetector,
245
246 hmm_detector: HMMRegimeDetector,
248
249 current_regime: MarketRegime,
251
252 agreement_history: VecDeque<bool>,
254}
255
256impl EnsembleRegimeDetector {
257 pub fn new(ensemble_config: EnsembleConfig, indicator_config: RegimeConfig) -> Self {
259 Self {
260 config: ensemble_config,
261 indicator_detector: RegimeDetector::new(indicator_config),
262 hmm_detector: HMMRegimeDetector::crypto_optimized(),
263 current_regime: MarketRegime::Uncertain,
264 agreement_history: VecDeque::with_capacity(100),
265 }
266 }
267
268 pub fn default_config() -> Self {
270 Self::new(EnsembleConfig::default(), RegimeConfig::crypto_optimized())
271 }
272
273 pub fn balanced() -> Self {
275 Self::new(EnsembleConfig::balanced(), RegimeConfig::crypto_optimized())
276 }
277
278 pub fn indicator_focused() -> Self {
280 Self::new(
281 EnsembleConfig::indicator_focused(),
282 RegimeConfig::crypto_optimized(),
283 )
284 }
285
286 pub fn hmm_focused() -> Self {
288 Self::new(
289 EnsembleConfig::hmm_focused(),
290 RegimeConfig::crypto_optimized(),
291 )
292 }
293
294 pub fn update(&mut self, high: f64, low: f64, close: f64) -> EnsembleResult {
299 let indicator_result = self.indicator_detector.update(high, low, close);
301 let hmm_result = self.hmm_detector.update_ohlc(high, low, close);
302
303 let indicator_regime = indicator_result.regime;
305 let hmm_regime = hmm_result.regime;
306
307 let hmm_ready = self.hmm_detector.is_ready();
309
310 let methods_agree = Self::regimes_agree(indicator_regime, hmm_regime);
312
313 self.agreement_history.push_back(methods_agree);
315 if self.agreement_history.len() > 100 {
316 self.agreement_history.pop_front();
317 }
318
319 let (regime, confidence) = if self.config.require_hmm_warmup && !hmm_ready {
321 (indicator_regime, indicator_result.confidence)
323 } else {
324 self.combine_results(
325 indicator_regime,
326 indicator_result.confidence,
327 hmm_regime,
328 hmm_result.confidence,
329 methods_agree,
330 )
331 };
332
333 self.current_regime = regime;
334
335 EnsembleResult {
336 regime,
337 confidence,
338 methods_agree,
339 indicator_result,
340 hmm_result,
341 indicator_regime,
342 hmm_regime,
343 }
344 }
345
346 fn regimes_agree(r1: MarketRegime, r2: MarketRegime) -> bool {
348 matches!(
349 (r1, r2),
350 (MarketRegime::Trending(_), MarketRegime::Trending(_))
351 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
352 | (MarketRegime::Volatile, MarketRegime::Volatile)
353 | (MarketRegime::Uncertain, MarketRegime::Uncertain)
354 )
355 }
356
357 fn regimes_agree_direction(r1: MarketRegime, r2: MarketRegime) -> bool {
359 match (r1, r2) {
360 (MarketRegime::Trending(d1), MarketRegime::Trending(d2)) => d1 == d2,
361 (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
362 | (MarketRegime::Volatile, MarketRegime::Volatile)
363 | (MarketRegime::Uncertain, MarketRegime::Uncertain) => true,
364 _ => false,
365 }
366 }
367
368 fn combine_results(
370 &self,
371 indicator_regime: MarketRegime,
372 indicator_conf: f64,
373 hmm_regime: MarketRegime,
374 hmm_conf: f64,
375 agree: bool,
376 ) -> (MarketRegime, f64) {
377 let w_ind = self.config.indicator_weight;
378 let w_hmm = self.config.hmm_weight;
379
380 let mut combined_conf = w_ind * indicator_conf + w_hmm * hmm_conf;
382
383 if agree {
385 combined_conf += self.config.agreement_confidence_boost;
387
388 if Self::regimes_agree_direction(indicator_regime, hmm_regime) {
390 combined_conf += 0.05;
391 }
392 } else {
393 combined_conf -= self.config.disagreement_confidence_penalty;
395 }
396
397 combined_conf = combined_conf.clamp(0.0, 1.0);
398
399 let regime = if agree {
401 indicator_regime
403 } else if combined_conf < self.config.agreement_threshold {
404 MarketRegime::Uncertain
406 } else {
407 if w_ind >= w_hmm {
409 indicator_regime
410 } else {
411 hmm_regime
412 }
413 };
414
415 (regime, combined_conf)
416 }
417
418 pub fn current_regime(&self) -> MarketRegime {
424 self.current_regime
425 }
426
427 pub fn agreement_rate(&self) -> f64 {
429 if self.agreement_history.is_empty() {
430 return 0.0;
431 }
432 let agrees = self.agreement_history.iter().filter(|&&a| a).count();
433 agrees as f64 / self.agreement_history.len() as f64
434 }
435
436 pub fn is_ready(&self) -> bool {
441 self.indicator_detector.is_ready()
442 && (!self.config.require_hmm_warmup || self.hmm_detector.is_ready())
443 }
444
445 pub fn indicator_ready(&self) -> bool {
447 self.indicator_detector.is_ready()
448 }
449
450 pub fn hmm_ready(&self) -> bool {
452 self.hmm_detector.is_ready()
453 }
454
455 pub fn hmm_state_probabilities(&self) -> &[f64] {
457 self.hmm_detector.state_probabilities()
458 }
459
460 pub fn expected_regime_duration(&self) -> f64 {
462 self.hmm_detector
463 .expected_regime_duration(self.hmm_detector.current_state_index())
464 }
465
466 pub fn status(&self) -> EnsembleStatus {
468 EnsembleStatus {
469 current_regime: self.current_regime,
470 indicator_ready: self.indicator_detector.is_ready(),
471 hmm_ready: self.hmm_detector.is_ready(),
472 agreement_rate: self.agreement_rate(),
473 hmm_state_probs: self.hmm_detector.state_probabilities().to_vec(),
474 expected_duration: self.expected_regime_duration(),
475 }
476 }
477
478 pub fn indicator_detector(&self) -> &RegimeDetector {
480 &self.indicator_detector
481 }
482
483 pub fn hmm_detector(&self) -> &HMMRegimeDetector {
485 &self.hmm_detector
486 }
487
488 pub fn config(&self) -> &EnsembleConfig {
490 &self.config
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct EnsembleStatus {
497 pub current_regime: MarketRegime,
498 pub indicator_ready: bool,
499 pub hmm_ready: bool,
500 pub agreement_rate: f64,
501 pub hmm_state_probs: Vec<f64>,
502 pub expected_duration: f64,
503}
504
505impl std::fmt::Display for EnsembleStatus {
506 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507 write!(
508 f,
509 "Regime: {} | Agreement: {:.1}% | HMM Ready: {} | Expected Duration: {:.1} bars",
510 self.current_regime,
511 self.agreement_rate * 100.0,
512 self.hmm_ready,
513 self.expected_duration
514 )
515 }
516}
517
518#[cfg(test)]
523mod tests {
524 use super::*;
525 use crate::types::TrendDirection;
526
527 #[test]
528 fn test_ensemble_creation() {
529 let ensemble = EnsembleRegimeDetector::default_config();
530 assert!(!ensemble.is_ready());
531 assert_eq!(ensemble.current_regime(), MarketRegime::Uncertain);
532 }
533
534 #[test]
535 fn test_balanced_creation() {
536 let ensemble = EnsembleRegimeDetector::balanced();
537 assert!(!ensemble.is_ready());
538 assert_eq!(ensemble.config().indicator_weight, 0.5);
539 assert_eq!(ensemble.config().hmm_weight, 0.5);
540 }
541
542 #[test]
543 fn test_indicator_focused_creation() {
544 let ensemble = EnsembleRegimeDetector::indicator_focused();
545 assert!(ensemble.config().indicator_weight > ensemble.config().hmm_weight);
546 }
547
548 #[test]
549 fn test_hmm_focused_creation() {
550 let ensemble = EnsembleRegimeDetector::hmm_focused();
551 assert!(ensemble.config().hmm_weight > ensemble.config().indicator_weight);
552 }
553
554 #[test]
555 fn test_regimes_agree_same_category() {
556 assert!(EnsembleRegimeDetector::regimes_agree(
558 MarketRegime::Trending(TrendDirection::Bullish),
559 MarketRegime::Trending(TrendDirection::Bearish)
560 ));
561
562 assert!(EnsembleRegimeDetector::regimes_agree(
563 MarketRegime::MeanReverting,
564 MarketRegime::MeanReverting
565 ));
566
567 assert!(EnsembleRegimeDetector::regimes_agree(
568 MarketRegime::Volatile,
569 MarketRegime::Volatile
570 ));
571
572 assert!(EnsembleRegimeDetector::regimes_agree(
573 MarketRegime::Uncertain,
574 MarketRegime::Uncertain
575 ));
576 }
577
578 #[test]
579 fn test_regimes_disagree_different_category() {
580 assert!(!EnsembleRegimeDetector::regimes_agree(
581 MarketRegime::Trending(TrendDirection::Bullish),
582 MarketRegime::MeanReverting
583 ));
584
585 assert!(!EnsembleRegimeDetector::regimes_agree(
586 MarketRegime::Volatile,
587 MarketRegime::Trending(TrendDirection::Bearish)
588 ));
589
590 assert!(!EnsembleRegimeDetector::regimes_agree(
591 MarketRegime::Uncertain,
592 MarketRegime::MeanReverting
593 ));
594 }
595
596 #[test]
597 fn test_regimes_agree_direction() {
598 assert!(EnsembleRegimeDetector::regimes_agree_direction(
599 MarketRegime::Trending(TrendDirection::Bullish),
600 MarketRegime::Trending(TrendDirection::Bullish)
601 ));
602
603 assert!(!EnsembleRegimeDetector::regimes_agree_direction(
604 MarketRegime::Trending(TrendDirection::Bullish),
605 MarketRegime::Trending(TrendDirection::Bearish)
606 ));
607
608 assert!(EnsembleRegimeDetector::regimes_agree_direction(
609 MarketRegime::MeanReverting,
610 MarketRegime::MeanReverting
611 ));
612
613 assert!(!EnsembleRegimeDetector::regimes_agree_direction(
614 MarketRegime::Trending(TrendDirection::Bullish),
615 MarketRegime::MeanReverting
616 ));
617 }
618
619 #[test]
620 fn test_agreement_rate_empty() {
621 let ensemble = EnsembleRegimeDetector::default_config();
622 assert_eq!(ensemble.agreement_rate(), 0.0);
623 }
624
625 #[test]
626 fn test_agreement_rate_tracked() {
627 let mut ensemble = EnsembleRegimeDetector::default_config();
628
629 let mut price = 100.0;
631 for i in 0..50 {
632 price *= if i % 2 == 0 { 1.01 } else { 0.99 };
633 ensemble.update(price * 1.01, price * 0.99, price);
634 }
635
636 let rate = ensemble.agreement_rate();
638 assert!(
639 (0.0..=1.0).contains(&rate),
640 "Agreement rate should be in [0, 1]: {rate}"
641 );
642 }
643
644 #[test]
645 fn test_bull_market_agreement() {
646 let mut ensemble = EnsembleRegimeDetector::default_config();
647
648 let mut price = 100.0;
650 for _ in 0..300 {
651 price *= 1.005; let high = price * 1.002;
653 let low = price * 0.998;
654 ensemble.update(high, low, price);
655 }
656
657 let result = ensemble.update(price * 1.002, price * 0.998, price);
658
659 assert!(
661 ensemble.agreement_rate() > 0.2,
662 "Agreement rate should be > 0.2 in consistent bull market: {}",
663 ensemble.agreement_rate()
664 );
665
666 assert!(
668 (0.0..=1.0).contains(&result.confidence),
669 "Confidence should be in [0, 1]: {}",
670 result.confidence
671 );
672 }
673
674 #[test]
675 fn test_ensemble_result_display() {
676 let result = EnsembleResult {
677 regime: MarketRegime::Trending(TrendDirection::Bullish),
678 confidence: 0.85,
679 methods_agree: true,
680 indicator_result: RegimeConfidence::new(
681 MarketRegime::Trending(TrendDirection::Bullish),
682 0.8,
683 ),
684 hmm_result: RegimeConfidence::new(MarketRegime::Trending(TrendDirection::Bullish), 0.9),
685 indicator_regime: MarketRegime::Trending(TrendDirection::Bullish),
686 hmm_regime: MarketRegime::Trending(TrendDirection::Bullish),
687 };
688
689 let display = format!("{result}");
690 assert!(display.contains("Trending (Bullish)"));
691 assert!(display.contains("85%"));
692 assert!(display.contains("✓"));
693 }
694
695 #[test]
696 fn test_ensemble_result_disagreement_display() {
697 let result = EnsembleResult {
698 regime: MarketRegime::Uncertain,
699 confidence: 0.3,
700 methods_agree: false,
701 indicator_result: RegimeConfidence::new(
702 MarketRegime::Trending(TrendDirection::Bullish),
703 0.6,
704 ),
705 hmm_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.5),
706 indicator_regime: MarketRegime::Trending(TrendDirection::Bullish),
707 hmm_regime: MarketRegime::MeanReverting,
708 };
709
710 let display = format!("{result}");
711 assert!(display.contains("✗"));
712 }
713
714 #[test]
715 fn test_ensemble_to_regime_confidence() {
716 let result = EnsembleResult {
717 regime: MarketRegime::MeanReverting,
718 confidence: 0.72,
719 methods_agree: true,
720 indicator_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.7),
721 hmm_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.75),
722 indicator_regime: MarketRegime::MeanReverting,
723 hmm_regime: MarketRegime::MeanReverting,
724 };
725
726 let rc = result.to_regime_confidence();
727 assert_eq!(rc.regime, MarketRegime::MeanReverting);
728 assert!((rc.confidence - 0.72).abs() < f64::EPSILON);
729 }
730
731 #[test]
732 fn test_status_display() {
733 let status = EnsembleStatus {
734 current_regime: MarketRegime::Volatile,
735 indicator_ready: true,
736 hmm_ready: false,
737 agreement_rate: 0.65,
738 hmm_state_probs: vec![0.3, 0.3, 0.4],
739 expected_duration: 8.5,
740 };
741
742 let display = format!("{status}");
743 assert!(display.contains("Volatile"));
744 assert!(display.contains("65.0%"));
745 assert!(display.contains("false"));
746 }
747
748 #[test]
749 fn test_ready_state() {
750 let mut ensemble = EnsembleRegimeDetector::default_config();
751
752 assert!(!ensemble.is_ready());
754 assert!(!ensemble.indicator_ready());
755 assert!(!ensemble.hmm_ready());
756
757 let mut price = 100.0;
759 for _ in 0..300 {
760 price *= 1.001;
761 ensemble.update(price * 1.01, price * 0.99, price);
762 }
763
764 assert!(ensemble.indicator_ready());
766 }
768
769 #[test]
770 fn test_hmm_state_probabilities_accessible() {
771 let mut ensemble = EnsembleRegimeDetector::default_config();
772
773 let mut price = 100.0;
774 for _ in 0..100 {
775 price *= 1.001;
776 ensemble.update(price * 1.01, price * 0.99, price);
777 }
778
779 let probs = ensemble.hmm_state_probabilities();
780 assert_eq!(probs.len(), 3, "Should have 3 HMM states");
781
782 let sum: f64 = probs.iter().sum();
783 assert!(
784 (sum - 1.0).abs() < 1e-6,
785 "HMM state probs should sum to 1.0: {sum}"
786 );
787 }
788
789 #[test]
790 fn test_expected_regime_duration() {
791 let ensemble = EnsembleRegimeDetector::default_config();
792 let duration = ensemble.expected_regime_duration();
793 assert!(duration > 0.0, "Duration should be > 0: {duration}");
794 }
795
796 #[test]
797 fn test_detector_accessors() {
798 let ensemble = EnsembleRegimeDetector::default_config();
799
800 assert!(!ensemble.indicator_detector().is_ready());
802 assert!(!ensemble.hmm_detector().is_ready());
803 }
804
805 #[test]
806 fn test_combine_results_agreement_boosts_confidence() {
807 let ensemble = EnsembleRegimeDetector::default_config();
808
809 let (_, conf_agree) = ensemble.combine_results(
810 MarketRegime::Trending(TrendDirection::Bullish),
811 0.7,
812 MarketRegime::Trending(TrendDirection::Bullish),
813 0.7,
814 true,
815 );
816
817 let (_, conf_disagree) = ensemble.combine_results(
818 MarketRegime::Trending(TrendDirection::Bullish),
819 0.7,
820 MarketRegime::MeanReverting,
821 0.7,
822 false,
823 );
824
825 assert!(
826 conf_agree > conf_disagree,
827 "Agreement should boost confidence: agree={conf_agree} vs disagree={conf_disagree}"
828 );
829 }
830
831 #[test]
832 fn test_combine_results_disagreement_returns_uncertain_at_low_conf() {
833 let config = EnsembleConfig {
834 agreement_threshold: 0.8,
835 disagreement_confidence_penalty: 0.5,
836 ..Default::default()
837 };
838 let ensemble = EnsembleRegimeDetector::new(config, RegimeConfig::default());
839
840 let (regime, _) = ensemble.combine_results(
841 MarketRegime::Trending(TrendDirection::Bullish),
842 0.4,
843 MarketRegime::MeanReverting,
844 0.4,
845 false,
846 );
847
848 assert_eq!(
849 regime,
850 MarketRegime::Uncertain,
851 "Low confidence + disagreement should produce Uncertain"
852 );
853 }
854}