1use std::collections::{HashMap, VecDeque};
14
15use serde::{Deserialize, Serialize};
16
17use super::detector::RegimeDetector;
18use super::hmm::{HMMConfig, HMMRegimeDetector};
19use super::types::{MarketRegime, RegimeConfidence, RegimeConfig};
20
21use crate::error::IndicatorError;
22use crate::indicator::{Indicator, IndicatorOutput};
23use crate::registry::param_usize;
24use crate::types::Candle;
25
26#[derive(Debug, Clone)]
32pub struct EnsembleIndicator {
33 pub ensemble_cfg: EnsembleConfig,
34 pub indicator_cfg: RegimeConfig,
35}
36
37impl EnsembleIndicator {
38 pub fn new(ensemble_cfg: EnsembleConfig, indicator_cfg: RegimeConfig) -> Self {
39 Self {
40 ensemble_cfg,
41 indicator_cfg,
42 }
43 }
44 pub fn with_defaults() -> Self {
45 Self::new(EnsembleConfig::default(), RegimeConfig::default())
46 }
47}
48
49fn regime_id_from(r: MarketRegime) -> f64 {
50 use super::types::TrendDirection;
51 match r {
52 MarketRegime::MeanReverting => 1.0,
53 MarketRegime::Volatile => 2.0,
54 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
55 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
56 MarketRegime::Uncertain => 0.0,
57 }
58}
59
60impl Indicator for EnsembleIndicator {
61 fn name(&self) -> &'static str {
62 "EnsembleRegime"
63 }
64 fn required_len(&self) -> usize {
65 let adx_warmup =
68 self.indicator_cfg.adx_period * 2 + self.indicator_cfg.regime_stability_bars;
69 let ema_warmup = self.indicator_cfg.ema_long_period;
70 let bb_warmup = self.indicator_cfg.bb_period;
71 let hmm_warmup = HMMConfig::default().min_observations + 1;
74 adx_warmup.max(ema_warmup).max(bb_warmup).max(hmm_warmup)
75 }
76 fn required_columns(&self) -> &[&'static str] {
77 &["high", "low", "close"]
78 }
79
80 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
81 self.check_len(candles)?;
82 let mut det =
83 EnsembleRegimeDetector::new(self.ensemble_cfg.clone(), self.indicator_cfg.clone());
84 let n = candles.len();
85 let mut conf = vec![f64::NAN; n];
86 let mut agree = vec![f64::NAN; n];
87 let mut regime = vec![f64::NAN; n];
88 for (i, c) in candles.iter().enumerate() {
89 let res = det.update(c.high, c.low, c.close);
90 conf[i] = res.confidence;
91 agree[i] = if res.methods_agree { 1.0 } else { 0.0 };
92 regime[i] = regime_id_from(res.regime);
93 }
94 Ok(IndicatorOutput::from_pairs([
95 ("ensemble_conf", conf),
96 ("ensemble_agree", agree),
97 ("ensemble_regime", regime),
98 ]))
99 }
100}
101
102pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
105 let adx_period = param_usize(params, "adx_period", 14)?;
106 let bb_period = param_usize(params, "bb_period", 20)?;
107 let indicator_cfg = RegimeConfig {
108 adx_period,
109 bb_period,
110 ..RegimeConfig::default()
111 };
112 Ok(Box::new(EnsembleIndicator::new(
113 EnsembleConfig::default(),
114 indicator_cfg,
115 )))
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct EnsembleConfig {
121 pub indicator_weight: f64,
123 pub hmm_weight: f64,
125 pub agreement_threshold: f64,
127 pub require_hmm_warmup: bool,
129 pub agreement_confidence_boost: f64,
131 pub disagreement_confidence_penalty: f64,
133}
134
135impl Default for EnsembleConfig {
136 fn default() -> Self {
137 Self {
138 indicator_weight: 0.6, hmm_weight: 0.4,
140 agreement_threshold: 0.5,
141 require_hmm_warmup: true,
142 agreement_confidence_boost: 0.15,
143 disagreement_confidence_penalty: 0.2,
144 }
145 }
146}
147
148impl EnsembleConfig {
149 pub fn balanced() -> Self {
151 Self {
152 indicator_weight: 0.5,
153 hmm_weight: 0.5,
154 ..Default::default()
155 }
156 }
157
158 pub fn hmm_focused() -> Self {
160 Self {
161 indicator_weight: 0.3,
162 hmm_weight: 0.7,
163 agreement_threshold: 0.6,
164 ..Default::default()
165 }
166 }
167
168 pub fn indicator_focused() -> Self {
170 Self {
171 indicator_weight: 0.7,
172 hmm_weight: 0.3,
173 agreement_threshold: 0.4,
174 ..Default::default()
175 }
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct EnsembleResult {
182 pub regime: MarketRegime,
184 pub confidence: f64,
186 pub methods_agree: bool,
188 pub indicator_result: RegimeConfidence,
190 pub hmm_result: RegimeConfidence,
192 pub indicator_regime: MarketRegime,
194 pub hmm_regime: MarketRegime,
195}
196
197impl EnsembleResult {
198 pub fn to_regime_confidence(&self) -> RegimeConfidence {
200 RegimeConfidence::new(self.regime, self.confidence)
201 }
202}
203
204impl std::fmt::Display for EnsembleResult {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 write!(
207 f,
208 "Ensemble: {} (conf: {:.0}%, agree: {})",
209 self.regime,
210 self.confidence * 100.0,
211 if self.methods_agree { "✓" } else { "✗" }
212 )
213 }
214}
215
216#[derive(Debug)]
238pub struct EnsembleRegimeDetector {
239 config: EnsembleConfig,
240
241 indicator_detector: RegimeDetector,
243
244 hmm_detector: HMMRegimeDetector,
246
247 current_regime: MarketRegime,
249
250 agreement_history: VecDeque<bool>,
252}
253
254impl EnsembleRegimeDetector {
255 pub fn new(ensemble_config: EnsembleConfig, indicator_config: RegimeConfig) -> Self {
257 Self {
258 config: ensemble_config,
259 indicator_detector: RegimeDetector::new(indicator_config),
260 hmm_detector: HMMRegimeDetector::crypto_optimized(),
261 current_regime: MarketRegime::Uncertain,
262 agreement_history: VecDeque::with_capacity(100),
263 }
264 }
265
266 pub fn default_config() -> Self {
268 Self::new(EnsembleConfig::default(), RegimeConfig::crypto_optimized())
269 }
270
271 pub fn balanced() -> Self {
273 Self::new(EnsembleConfig::balanced(), RegimeConfig::crypto_optimized())
274 }
275
276 pub fn indicator_focused() -> Self {
278 Self::new(
279 EnsembleConfig::indicator_focused(),
280 RegimeConfig::crypto_optimized(),
281 )
282 }
283
284 pub fn hmm_focused() -> Self {
286 Self::new(
287 EnsembleConfig::hmm_focused(),
288 RegimeConfig::crypto_optimized(),
289 )
290 }
291
292 pub fn update(&mut self, high: f64, low: f64, close: f64) -> EnsembleResult {
297 let indicator_result = self.indicator_detector.update(high, low, close);
299 let hmm_result = self.hmm_detector.update_ohlc(high, low, close);
300
301 let indicator_regime = indicator_result.regime;
303 let hmm_regime = hmm_result.regime;
304
305 let hmm_ready = self.hmm_detector.is_ready();
307
308 let methods_agree = Self::regimes_agree(indicator_regime, hmm_regime);
310
311 self.agreement_history.push_back(methods_agree);
313 if self.agreement_history.len() > 100 {
314 self.agreement_history.pop_front();
315 }
316
317 let (regime, confidence) = if self.config.require_hmm_warmup && !hmm_ready {
319 (indicator_regime, indicator_result.confidence)
321 } else {
322 self.combine_results(
323 indicator_regime,
324 indicator_result.confidence,
325 hmm_regime,
326 hmm_result.confidence,
327 methods_agree,
328 )
329 };
330
331 self.current_regime = regime;
332
333 EnsembleResult {
334 regime,
335 confidence,
336 methods_agree,
337 indicator_result,
338 hmm_result,
339 indicator_regime,
340 hmm_regime,
341 }
342 }
343
344 fn regimes_agree(r1: MarketRegime, r2: MarketRegime) -> bool {
346 matches!(
347 (r1, r2),
348 (MarketRegime::Trending(_), MarketRegime::Trending(_))
349 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
350 | (MarketRegime::Volatile, MarketRegime::Volatile)
351 | (MarketRegime::Uncertain, MarketRegime::Uncertain)
352 )
353 }
354
355 fn regimes_agree_direction(r1: MarketRegime, r2: MarketRegime) -> bool {
357 match (r1, r2) {
358 (MarketRegime::Trending(d1), MarketRegime::Trending(d2)) => d1 == d2,
359 (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
360 | (MarketRegime::Volatile, MarketRegime::Volatile)
361 | (MarketRegime::Uncertain, MarketRegime::Uncertain) => true,
362 _ => false,
363 }
364 }
365
366 fn combine_results(
368 &self,
369 indicator_regime: MarketRegime,
370 indicator_conf: f64,
371 hmm_regime: MarketRegime,
372 hmm_conf: f64,
373 agree: bool,
374 ) -> (MarketRegime, f64) {
375 let w_ind = self.config.indicator_weight;
376 let w_hmm = self.config.hmm_weight;
377
378 let mut combined_conf = w_ind * indicator_conf + w_hmm * hmm_conf;
380
381 if agree {
383 combined_conf += self.config.agreement_confidence_boost;
385
386 if Self::regimes_agree_direction(indicator_regime, hmm_regime) {
388 combined_conf += 0.05;
389 }
390 } else {
391 combined_conf -= self.config.disagreement_confidence_penalty;
393 }
394
395 combined_conf = combined_conf.clamp(0.0, 1.0);
396
397 let regime = if agree {
399 indicator_regime
401 } else if combined_conf < self.config.agreement_threshold {
402 MarketRegime::Uncertain
404 } else {
405 if w_ind >= w_hmm {
407 indicator_regime
408 } else {
409 hmm_regime
410 }
411 };
412
413 (regime, combined_conf)
414 }
415
416 pub fn current_regime(&self) -> MarketRegime {
422 self.current_regime
423 }
424
425 pub fn agreement_rate(&self) -> f64 {
427 if self.agreement_history.is_empty() {
428 return 0.0;
429 }
430 let agrees = self.agreement_history.iter().filter(|&&a| a).count();
431 agrees as f64 / self.agreement_history.len() as f64
432 }
433
434 pub fn is_ready(&self) -> bool {
439 self.indicator_detector.is_ready()
440 && (!self.config.require_hmm_warmup || self.hmm_detector.is_ready())
441 }
442
443 pub fn indicator_ready(&self) -> bool {
445 self.indicator_detector.is_ready()
446 }
447
448 pub fn hmm_ready(&self) -> bool {
450 self.hmm_detector.is_ready()
451 }
452
453 pub fn hmm_state_probabilities(&self) -> &[f64] {
455 self.hmm_detector.state_probabilities()
456 }
457
458 pub fn expected_regime_duration(&self) -> f64 {
460 self.hmm_detector
461 .expected_regime_duration(self.hmm_detector.current_state_index())
462 }
463
464 pub fn status(&self) -> EnsembleStatus {
466 EnsembleStatus {
467 current_regime: self.current_regime,
468 indicator_ready: self.indicator_detector.is_ready(),
469 hmm_ready: self.hmm_detector.is_ready(),
470 agreement_rate: self.agreement_rate(),
471 hmm_state_probs: self.hmm_detector.state_probabilities().to_vec(),
472 expected_duration: self.expected_regime_duration(),
473 }
474 }
475
476 pub fn indicator_detector(&self) -> &RegimeDetector {
478 &self.indicator_detector
479 }
480
481 pub fn hmm_detector(&self) -> &HMMRegimeDetector {
483 &self.hmm_detector
484 }
485
486 pub fn config(&self) -> &EnsembleConfig {
488 &self.config
489 }
490}
491
492#[derive(Debug, Clone)]
494pub struct EnsembleStatus {
495 pub current_regime: MarketRegime,
496 pub indicator_ready: bool,
497 pub hmm_ready: bool,
498 pub agreement_rate: f64,
499 pub hmm_state_probs: Vec<f64>,
500 pub expected_duration: f64,
501}
502
503impl std::fmt::Display for EnsembleStatus {
504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505 write!(
506 f,
507 "Regime: {} | Agreement: {:.1}% | HMM Ready: {} | Expected Duration: {:.1} bars",
508 self.current_regime,
509 self.agreement_rate * 100.0,
510 self.hmm_ready,
511 self.expected_duration
512 )
513 }
514}
515
516#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::types::TrendDirection;
524
525 #[test]
526 fn test_ensemble_creation() {
527 let ensemble = EnsembleRegimeDetector::default_config();
528 assert!(!ensemble.is_ready());
529 assert_eq!(ensemble.current_regime(), MarketRegime::Uncertain);
530 }
531
532 #[test]
533 fn test_balanced_creation() {
534 let ensemble = EnsembleRegimeDetector::balanced();
535 assert!(!ensemble.is_ready());
536 assert_eq!(ensemble.config().indicator_weight, 0.5);
537 assert_eq!(ensemble.config().hmm_weight, 0.5);
538 }
539
540 #[test]
541 fn test_indicator_focused_creation() {
542 let ensemble = EnsembleRegimeDetector::indicator_focused();
543 assert!(ensemble.config().indicator_weight > ensemble.config().hmm_weight);
544 }
545
546 #[test]
547 fn test_hmm_focused_creation() {
548 let ensemble = EnsembleRegimeDetector::hmm_focused();
549 assert!(ensemble.config().hmm_weight > ensemble.config().indicator_weight);
550 }
551
552 #[test]
553 fn test_regimes_agree_same_category() {
554 assert!(EnsembleRegimeDetector::regimes_agree(
556 MarketRegime::Trending(TrendDirection::Bullish),
557 MarketRegime::Trending(TrendDirection::Bearish)
558 ));
559
560 assert!(EnsembleRegimeDetector::regimes_agree(
561 MarketRegime::MeanReverting,
562 MarketRegime::MeanReverting
563 ));
564
565 assert!(EnsembleRegimeDetector::regimes_agree(
566 MarketRegime::Volatile,
567 MarketRegime::Volatile
568 ));
569
570 assert!(EnsembleRegimeDetector::regimes_agree(
571 MarketRegime::Uncertain,
572 MarketRegime::Uncertain
573 ));
574 }
575
576 #[test]
577 fn test_regimes_disagree_different_category() {
578 assert!(!EnsembleRegimeDetector::regimes_agree(
579 MarketRegime::Trending(TrendDirection::Bullish),
580 MarketRegime::MeanReverting
581 ));
582
583 assert!(!EnsembleRegimeDetector::regimes_agree(
584 MarketRegime::Volatile,
585 MarketRegime::Trending(TrendDirection::Bearish)
586 ));
587
588 assert!(!EnsembleRegimeDetector::regimes_agree(
589 MarketRegime::Uncertain,
590 MarketRegime::MeanReverting
591 ));
592 }
593
594 #[test]
595 fn test_regimes_agree_direction() {
596 assert!(EnsembleRegimeDetector::regimes_agree_direction(
597 MarketRegime::Trending(TrendDirection::Bullish),
598 MarketRegime::Trending(TrendDirection::Bullish)
599 ));
600
601 assert!(!EnsembleRegimeDetector::regimes_agree_direction(
602 MarketRegime::Trending(TrendDirection::Bullish),
603 MarketRegime::Trending(TrendDirection::Bearish)
604 ));
605
606 assert!(EnsembleRegimeDetector::regimes_agree_direction(
607 MarketRegime::MeanReverting,
608 MarketRegime::MeanReverting
609 ));
610
611 assert!(!EnsembleRegimeDetector::regimes_agree_direction(
612 MarketRegime::Trending(TrendDirection::Bullish),
613 MarketRegime::MeanReverting
614 ));
615 }
616
617 #[test]
618 fn test_agreement_rate_empty() {
619 let ensemble = EnsembleRegimeDetector::default_config();
620 assert_eq!(ensemble.agreement_rate(), 0.0);
621 }
622
623 #[test]
624 fn test_agreement_rate_tracked() {
625 let mut ensemble = EnsembleRegimeDetector::default_config();
626
627 let mut price = 100.0;
629 for i in 0..50 {
630 price *= if i % 2 == 0 { 1.01 } else { 0.99 };
631 ensemble.update(price * 1.01, price * 0.99, price);
632 }
633
634 let rate = ensemble.agreement_rate();
636 assert!(
637 (0.0..=1.0).contains(&rate),
638 "Agreement rate should be in [0, 1]: {rate}"
639 );
640 }
641
642 #[test]
643 fn test_bull_market_agreement() {
644 let mut ensemble = EnsembleRegimeDetector::default_config();
645
646 let mut price = 100.0;
648 for _ in 0..300 {
649 price *= 1.005; let high = price * 1.002;
651 let low = price * 0.998;
652 ensemble.update(high, low, price);
653 }
654
655 let result = ensemble.update(price * 1.002, price * 0.998, price);
656
657 assert!(
659 ensemble.agreement_rate() > 0.2,
660 "Agreement rate should be > 0.2 in consistent bull market: {}",
661 ensemble.agreement_rate()
662 );
663
664 assert!(
666 (0.0..=1.0).contains(&result.confidence),
667 "Confidence should be in [0, 1]: {}",
668 result.confidence
669 );
670 }
671
672 #[test]
673 fn test_ensemble_result_display() {
674 let result = EnsembleResult {
675 regime: MarketRegime::Trending(TrendDirection::Bullish),
676 confidence: 0.85,
677 methods_agree: true,
678 indicator_result: RegimeConfidence::new(
679 MarketRegime::Trending(TrendDirection::Bullish),
680 0.8,
681 ),
682 hmm_result: RegimeConfidence::new(MarketRegime::Trending(TrendDirection::Bullish), 0.9),
683 indicator_regime: MarketRegime::Trending(TrendDirection::Bullish),
684 hmm_regime: MarketRegime::Trending(TrendDirection::Bullish),
685 };
686
687 let display = format!("{result}");
688 assert!(display.contains("Trending (Bullish)"));
689 assert!(display.contains("85%"));
690 assert!(display.contains("✓"));
691 }
692
693 #[test]
694 fn test_ensemble_result_disagreement_display() {
695 let result = EnsembleResult {
696 regime: MarketRegime::Uncertain,
697 confidence: 0.3,
698 methods_agree: false,
699 indicator_result: RegimeConfidence::new(
700 MarketRegime::Trending(TrendDirection::Bullish),
701 0.6,
702 ),
703 hmm_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.5),
704 indicator_regime: MarketRegime::Trending(TrendDirection::Bullish),
705 hmm_regime: MarketRegime::MeanReverting,
706 };
707
708 let display = format!("{result}");
709 assert!(display.contains("✗"));
710 }
711
712 #[test]
713 fn test_ensemble_to_regime_confidence() {
714 let result = EnsembleResult {
715 regime: MarketRegime::MeanReverting,
716 confidence: 0.72,
717 methods_agree: true,
718 indicator_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.7),
719 hmm_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.75),
720 indicator_regime: MarketRegime::MeanReverting,
721 hmm_regime: MarketRegime::MeanReverting,
722 };
723
724 let rc = result.to_regime_confidence();
725 assert_eq!(rc.regime, MarketRegime::MeanReverting);
726 assert!((rc.confidence - 0.72).abs() < f64::EPSILON);
727 }
728
729 #[test]
730 fn test_status_display() {
731 let status = EnsembleStatus {
732 current_regime: MarketRegime::Volatile,
733 indicator_ready: true,
734 hmm_ready: false,
735 agreement_rate: 0.65,
736 hmm_state_probs: vec![0.3, 0.3, 0.4],
737 expected_duration: 8.5,
738 };
739
740 let display = format!("{status}");
741 assert!(display.contains("Volatile"));
742 assert!(display.contains("65.0%"));
743 assert!(display.contains("false"));
744 }
745
746 #[test]
747 fn test_ready_state() {
748 let mut ensemble = EnsembleRegimeDetector::default_config();
749
750 assert!(!ensemble.is_ready());
752 assert!(!ensemble.indicator_ready());
753 assert!(!ensemble.hmm_ready());
754
755 let mut price = 100.0;
757 for _ in 0..300 {
758 price *= 1.001;
759 ensemble.update(price * 1.01, price * 0.99, price);
760 }
761
762 assert!(ensemble.indicator_ready());
764 }
766
767 #[test]
768 fn test_hmm_state_probabilities_accessible() {
769 let mut ensemble = EnsembleRegimeDetector::default_config();
770
771 let mut price = 100.0;
772 for _ in 0..100 {
773 price *= 1.001;
774 ensemble.update(price * 1.01, price * 0.99, price);
775 }
776
777 let probs = ensemble.hmm_state_probabilities();
778 assert_eq!(probs.len(), 3, "Should have 3 HMM states");
779
780 let sum: f64 = probs.iter().sum();
781 assert!(
782 (sum - 1.0).abs() < 1e-6,
783 "HMM state probs should sum to 1.0: {sum}"
784 );
785 }
786
787 #[test]
788 fn test_expected_regime_duration() {
789 let ensemble = EnsembleRegimeDetector::default_config();
790 let duration = ensemble.expected_regime_duration();
791 assert!(duration > 0.0, "Duration should be > 0: {duration}");
792 }
793
794 #[test]
795 fn test_detector_accessors() {
796 let ensemble = EnsembleRegimeDetector::default_config();
797
798 assert!(!ensemble.indicator_detector().is_ready());
800 assert!(!ensemble.hmm_detector().is_ready());
801 }
802
803 #[test]
804 fn test_combine_results_agreement_boosts_confidence() {
805 let ensemble = EnsembleRegimeDetector::default_config();
806
807 let (_, conf_agree) = ensemble.combine_results(
808 MarketRegime::Trending(TrendDirection::Bullish),
809 0.7,
810 MarketRegime::Trending(TrendDirection::Bullish),
811 0.7,
812 true,
813 );
814
815 let (_, conf_disagree) = ensemble.combine_results(
816 MarketRegime::Trending(TrendDirection::Bullish),
817 0.7,
818 MarketRegime::MeanReverting,
819 0.7,
820 false,
821 );
822
823 assert!(
824 conf_agree > conf_disagree,
825 "Agreement should boost confidence: agree={conf_agree} vs disagree={conf_disagree}"
826 );
827 }
828
829 #[test]
830 fn test_combine_results_disagreement_returns_uncertain_at_low_conf() {
831 let config = EnsembleConfig {
832 agreement_threshold: 0.8,
833 disagreement_confidence_penalty: 0.5,
834 ..Default::default()
835 };
836 let ensemble = EnsembleRegimeDetector::new(config, RegimeConfig::default());
837
838 let (regime, _) = ensemble.combine_results(
839 MarketRegime::Trending(TrendDirection::Bullish),
840 0.4,
841 MarketRegime::MeanReverting,
842 0.4,
843 false,
844 );
845
846 assert_eq!(
847 regime,
848 MarketRegime::Uncertain,
849 "Low confidence + disagreement should produce Uncertain"
850 );
851 }
852}