1use super::detector::RegimeDetector;
14use super::hmm::HMMRegimeDetector;
15use super::types::{MarketRegime, RegimeConfidence, RegimeConfig};
16use serde::{Deserialize, Serialize};
17use std::collections::VecDeque;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct EnsembleConfig {
22 pub indicator_weight: f64,
24 pub hmm_weight: f64,
26 pub agreement_threshold: f64,
28 pub require_hmm_warmup: bool,
30 pub agreement_confidence_boost: f64,
32 pub disagreement_confidence_penalty: f64,
34}
35
36impl Default for EnsembleConfig {
37 fn default() -> Self {
38 Self {
39 indicator_weight: 0.6, hmm_weight: 0.4,
41 agreement_threshold: 0.5,
42 require_hmm_warmup: true,
43 agreement_confidence_boost: 0.15,
44 disagreement_confidence_penalty: 0.2,
45 }
46 }
47}
48
49impl EnsembleConfig {
50 pub fn balanced() -> Self {
52 Self {
53 indicator_weight: 0.5,
54 hmm_weight: 0.5,
55 ..Default::default()
56 }
57 }
58
59 pub fn hmm_focused() -> Self {
61 Self {
62 indicator_weight: 0.3,
63 hmm_weight: 0.7,
64 agreement_threshold: 0.6,
65 ..Default::default()
66 }
67 }
68
69 pub fn indicator_focused() -> Self {
71 Self {
72 indicator_weight: 0.7,
73 hmm_weight: 0.3,
74 agreement_threshold: 0.4,
75 ..Default::default()
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct EnsembleResult {
83 pub regime: MarketRegime,
85 pub confidence: f64,
87 pub methods_agree: bool,
89 pub indicator_result: RegimeConfidence,
91 pub hmm_result: RegimeConfidence,
93 pub indicator_regime: MarketRegime,
95 pub hmm_regime: MarketRegime,
96}
97
98impl EnsembleResult {
99 pub fn to_regime_confidence(&self) -> RegimeConfidence {
101 RegimeConfidence::new(self.regime, self.confidence)
102 }
103}
104
105impl std::fmt::Display for EnsembleResult {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(
108 f,
109 "Ensemble: {} (conf: {:.0}%, agree: {})",
110 self.regime,
111 self.confidence * 100.0,
112 if self.methods_agree { "✓" } else { "✗" }
113 )
114 }
115}
116
117#[derive(Debug)]
139pub struct EnsembleRegimeDetector {
140 config: EnsembleConfig,
141
142 indicator_detector: RegimeDetector,
144
145 hmm_detector: HMMRegimeDetector,
147
148 current_regime: MarketRegime,
150
151 agreement_history: VecDeque<bool>,
153}
154
155impl EnsembleRegimeDetector {
156 pub fn new(ensemble_config: EnsembleConfig, indicator_config: RegimeConfig) -> Self {
158 Self {
159 config: ensemble_config,
160 indicator_detector: RegimeDetector::new(indicator_config),
161 hmm_detector: HMMRegimeDetector::crypto_optimized(),
162 current_regime: MarketRegime::Uncertain,
163 agreement_history: VecDeque::with_capacity(100),
164 }
165 }
166
167 pub fn default_config() -> Self {
169 Self::new(EnsembleConfig::default(), RegimeConfig::crypto_optimized())
170 }
171
172 pub fn balanced() -> Self {
174 Self::new(EnsembleConfig::balanced(), RegimeConfig::crypto_optimized())
175 }
176
177 pub fn indicator_focused() -> Self {
179 Self::new(
180 EnsembleConfig::indicator_focused(),
181 RegimeConfig::crypto_optimized(),
182 )
183 }
184
185 pub fn hmm_focused() -> Self {
187 Self::new(
188 EnsembleConfig::hmm_focused(),
189 RegimeConfig::crypto_optimized(),
190 )
191 }
192
193 pub fn update(&mut self, high: f64, low: f64, close: f64) -> EnsembleResult {
198 let indicator_result = self.indicator_detector.update(high, low, close);
200 let hmm_result = self.hmm_detector.update_ohlc(high, low, close);
201
202 let indicator_regime = indicator_result.regime;
204 let hmm_regime = hmm_result.regime;
205
206 let hmm_ready = self.hmm_detector.is_ready();
208
209 let methods_agree = Self::regimes_agree(indicator_regime, hmm_regime);
211
212 self.agreement_history.push_back(methods_agree);
214 if self.agreement_history.len() > 100 {
215 self.agreement_history.pop_front();
216 }
217
218 let (regime, confidence) = if self.config.require_hmm_warmup && !hmm_ready {
220 (indicator_regime, indicator_result.confidence)
222 } else {
223 self.combine_results(
224 indicator_regime,
225 indicator_result.confidence,
226 hmm_regime,
227 hmm_result.confidence,
228 methods_agree,
229 )
230 };
231
232 self.current_regime = regime;
233
234 EnsembleResult {
235 regime,
236 confidence,
237 methods_agree,
238 indicator_result,
239 hmm_result,
240 indicator_regime,
241 hmm_regime,
242 }
243 }
244
245 fn regimes_agree(r1: MarketRegime, r2: MarketRegime) -> bool {
247 matches!(
248 (r1, r2),
249 (MarketRegime::Trending(_), MarketRegime::Trending(_))
250 | (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
251 | (MarketRegime::Volatile, MarketRegime::Volatile)
252 | (MarketRegime::Uncertain, MarketRegime::Uncertain)
253 )
254 }
255
256 fn regimes_agree_direction(r1: MarketRegime, r2: MarketRegime) -> bool {
258 match (r1, r2) {
259 (MarketRegime::Trending(d1), MarketRegime::Trending(d2)) => d1 == d2,
260 (MarketRegime::MeanReverting, MarketRegime::MeanReverting)
261 | (MarketRegime::Volatile, MarketRegime::Volatile)
262 | (MarketRegime::Uncertain, MarketRegime::Uncertain) => true,
263 _ => false,
264 }
265 }
266
267 fn combine_results(
269 &self,
270 indicator_regime: MarketRegime,
271 indicator_conf: f64,
272 hmm_regime: MarketRegime,
273 hmm_conf: f64,
274 agree: bool,
275 ) -> (MarketRegime, f64) {
276 let w_ind = self.config.indicator_weight;
277 let w_hmm = self.config.hmm_weight;
278
279 let mut combined_conf = w_ind * indicator_conf + w_hmm * hmm_conf;
281
282 if agree {
284 combined_conf += self.config.agreement_confidence_boost;
286
287 if Self::regimes_agree_direction(indicator_regime, hmm_regime) {
289 combined_conf += 0.05;
290 }
291 } else {
292 combined_conf -= self.config.disagreement_confidence_penalty;
294 }
295
296 combined_conf = combined_conf.clamp(0.0, 1.0);
297
298 let regime = if agree {
300 indicator_regime
302 } else if combined_conf < self.config.agreement_threshold {
303 MarketRegime::Uncertain
305 } else {
306 if w_ind >= w_hmm {
308 indicator_regime
309 } else {
310 hmm_regime
311 }
312 };
313
314 (regime, combined_conf)
315 }
316
317 pub fn current_regime(&self) -> MarketRegime {
323 self.current_regime
324 }
325
326 pub fn agreement_rate(&self) -> f64 {
328 if self.agreement_history.is_empty() {
329 return 0.0;
330 }
331 let agrees = self.agreement_history.iter().filter(|&&a| a).count();
332 agrees as f64 / self.agreement_history.len() as f64
333 }
334
335 pub fn is_ready(&self) -> bool {
340 self.indicator_detector.is_ready()
341 && (!self.config.require_hmm_warmup || self.hmm_detector.is_ready())
342 }
343
344 pub fn indicator_ready(&self) -> bool {
346 self.indicator_detector.is_ready()
347 }
348
349 pub fn hmm_ready(&self) -> bool {
351 self.hmm_detector.is_ready()
352 }
353
354 pub fn hmm_state_probabilities(&self) -> &[f64] {
356 self.hmm_detector.state_probabilities()
357 }
358
359 pub fn expected_regime_duration(&self) -> f64 {
361 self.hmm_detector
362 .expected_regime_duration(self.hmm_detector.current_state_index())
363 }
364
365 pub fn status(&self) -> EnsembleStatus {
367 EnsembleStatus {
368 current_regime: self.current_regime,
369 indicator_ready: self.indicator_detector.is_ready(),
370 hmm_ready: self.hmm_detector.is_ready(),
371 agreement_rate: self.agreement_rate(),
372 hmm_state_probs: self.hmm_detector.state_probabilities().to_vec(),
373 expected_duration: self.expected_regime_duration(),
374 }
375 }
376
377 pub fn indicator_detector(&self) -> &RegimeDetector {
379 &self.indicator_detector
380 }
381
382 pub fn hmm_detector(&self) -> &HMMRegimeDetector {
384 &self.hmm_detector
385 }
386
387 pub fn config(&self) -> &EnsembleConfig {
389 &self.config
390 }
391}
392
393#[derive(Debug, Clone)]
395pub struct EnsembleStatus {
396 pub current_regime: MarketRegime,
397 pub indicator_ready: bool,
398 pub hmm_ready: bool,
399 pub agreement_rate: f64,
400 pub hmm_state_probs: Vec<f64>,
401 pub expected_duration: f64,
402}
403
404impl std::fmt::Display for EnsembleStatus {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 write!(
407 f,
408 "Regime: {} | Agreement: {:.1}% | HMM Ready: {} | Expected Duration: {:.1} bars",
409 self.current_regime,
410 self.agreement_rate * 100.0,
411 self.hmm_ready,
412 self.expected_duration
413 )
414 }
415}
416
417#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::types::TrendDirection;
425
426 #[test]
427 fn test_ensemble_creation() {
428 let ensemble = EnsembleRegimeDetector::default_config();
429 assert!(!ensemble.is_ready());
430 assert_eq!(ensemble.current_regime(), MarketRegime::Uncertain);
431 }
432
433 #[test]
434 fn test_balanced_creation() {
435 let ensemble = EnsembleRegimeDetector::balanced();
436 assert!(!ensemble.is_ready());
437 assert_eq!(ensemble.config().indicator_weight, 0.5);
438 assert_eq!(ensemble.config().hmm_weight, 0.5);
439 }
440
441 #[test]
442 fn test_indicator_focused_creation() {
443 let ensemble = EnsembleRegimeDetector::indicator_focused();
444 assert!(ensemble.config().indicator_weight > ensemble.config().hmm_weight);
445 }
446
447 #[test]
448 fn test_hmm_focused_creation() {
449 let ensemble = EnsembleRegimeDetector::hmm_focused();
450 assert!(ensemble.config().hmm_weight > ensemble.config().indicator_weight);
451 }
452
453 #[test]
454 fn test_regimes_agree_same_category() {
455 assert!(EnsembleRegimeDetector::regimes_agree(
457 MarketRegime::Trending(TrendDirection::Bullish),
458 MarketRegime::Trending(TrendDirection::Bearish)
459 ));
460
461 assert!(EnsembleRegimeDetector::regimes_agree(
462 MarketRegime::MeanReverting,
463 MarketRegime::MeanReverting
464 ));
465
466 assert!(EnsembleRegimeDetector::regimes_agree(
467 MarketRegime::Volatile,
468 MarketRegime::Volatile
469 ));
470
471 assert!(EnsembleRegimeDetector::regimes_agree(
472 MarketRegime::Uncertain,
473 MarketRegime::Uncertain
474 ));
475 }
476
477 #[test]
478 fn test_regimes_disagree_different_category() {
479 assert!(!EnsembleRegimeDetector::regimes_agree(
480 MarketRegime::Trending(TrendDirection::Bullish),
481 MarketRegime::MeanReverting
482 ));
483
484 assert!(!EnsembleRegimeDetector::regimes_agree(
485 MarketRegime::Volatile,
486 MarketRegime::Trending(TrendDirection::Bearish)
487 ));
488
489 assert!(!EnsembleRegimeDetector::regimes_agree(
490 MarketRegime::Uncertain,
491 MarketRegime::MeanReverting
492 ));
493 }
494
495 #[test]
496 fn test_regimes_agree_direction() {
497 assert!(EnsembleRegimeDetector::regimes_agree_direction(
498 MarketRegime::Trending(TrendDirection::Bullish),
499 MarketRegime::Trending(TrendDirection::Bullish)
500 ));
501
502 assert!(!EnsembleRegimeDetector::regimes_agree_direction(
503 MarketRegime::Trending(TrendDirection::Bullish),
504 MarketRegime::Trending(TrendDirection::Bearish)
505 ));
506
507 assert!(EnsembleRegimeDetector::regimes_agree_direction(
508 MarketRegime::MeanReverting,
509 MarketRegime::MeanReverting
510 ));
511
512 assert!(!EnsembleRegimeDetector::regimes_agree_direction(
513 MarketRegime::Trending(TrendDirection::Bullish),
514 MarketRegime::MeanReverting
515 ));
516 }
517
518 #[test]
519 fn test_agreement_rate_empty() {
520 let ensemble = EnsembleRegimeDetector::default_config();
521 assert_eq!(ensemble.agreement_rate(), 0.0);
522 }
523
524 #[test]
525 fn test_agreement_rate_tracked() {
526 let mut ensemble = EnsembleRegimeDetector::default_config();
527
528 let mut price = 100.0;
530 for i in 0..50 {
531 price *= if i % 2 == 0 { 1.01 } else { 0.99 };
532 ensemble.update(price * 1.01, price * 0.99, price);
533 }
534
535 let rate = ensemble.agreement_rate();
537 assert!(
538 (0.0..=1.0).contains(&rate),
539 "Agreement rate should be in [0, 1]: {rate}"
540 );
541 }
542
543 #[test]
544 fn test_bull_market_agreement() {
545 let mut ensemble = EnsembleRegimeDetector::default_config();
546
547 let mut price = 100.0;
549 for _ in 0..300 {
550 price *= 1.005; let high = price * 1.002;
552 let low = price * 0.998;
553 ensemble.update(high, low, price);
554 }
555
556 let result = ensemble.update(price * 1.002, price * 0.998, price);
557
558 assert!(
560 ensemble.agreement_rate() > 0.2,
561 "Agreement rate should be > 0.2 in consistent bull market: {}",
562 ensemble.agreement_rate()
563 );
564
565 assert!(
567 (0.0..=1.0).contains(&result.confidence),
568 "Confidence should be in [0, 1]: {}",
569 result.confidence
570 );
571 }
572
573 #[test]
574 fn test_ensemble_result_display() {
575 let result = EnsembleResult {
576 regime: MarketRegime::Trending(TrendDirection::Bullish),
577 confidence: 0.85,
578 methods_agree: true,
579 indicator_result: RegimeConfidence::new(
580 MarketRegime::Trending(TrendDirection::Bullish),
581 0.8,
582 ),
583 hmm_result: RegimeConfidence::new(MarketRegime::Trending(TrendDirection::Bullish), 0.9),
584 indicator_regime: MarketRegime::Trending(TrendDirection::Bullish),
585 hmm_regime: MarketRegime::Trending(TrendDirection::Bullish),
586 };
587
588 let display = format!("{result}");
589 assert!(display.contains("Trending (Bullish)"));
590 assert!(display.contains("85%"));
591 assert!(display.contains("✓"));
592 }
593
594 #[test]
595 fn test_ensemble_result_disagreement_display() {
596 let result = EnsembleResult {
597 regime: MarketRegime::Uncertain,
598 confidence: 0.3,
599 methods_agree: false,
600 indicator_result: RegimeConfidence::new(
601 MarketRegime::Trending(TrendDirection::Bullish),
602 0.6,
603 ),
604 hmm_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.5),
605 indicator_regime: MarketRegime::Trending(TrendDirection::Bullish),
606 hmm_regime: MarketRegime::MeanReverting,
607 };
608
609 let display = format!("{result}");
610 assert!(display.contains("✗"));
611 }
612
613 #[test]
614 fn test_ensemble_to_regime_confidence() {
615 let result = EnsembleResult {
616 regime: MarketRegime::MeanReverting,
617 confidence: 0.72,
618 methods_agree: true,
619 indicator_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.7),
620 hmm_result: RegimeConfidence::new(MarketRegime::MeanReverting, 0.75),
621 indicator_regime: MarketRegime::MeanReverting,
622 hmm_regime: MarketRegime::MeanReverting,
623 };
624
625 let rc = result.to_regime_confidence();
626 assert_eq!(rc.regime, MarketRegime::MeanReverting);
627 assert!((rc.confidence - 0.72).abs() < f64::EPSILON);
628 }
629
630 #[test]
631 fn test_status_display() {
632 let status = EnsembleStatus {
633 current_regime: MarketRegime::Volatile,
634 indicator_ready: true,
635 hmm_ready: false,
636 agreement_rate: 0.65,
637 hmm_state_probs: vec![0.3, 0.3, 0.4],
638 expected_duration: 8.5,
639 };
640
641 let display = format!("{status}");
642 assert!(display.contains("Volatile"));
643 assert!(display.contains("65.0%"));
644 assert!(display.contains("false"));
645 }
646
647 #[test]
648 fn test_ready_state() {
649 let mut ensemble = EnsembleRegimeDetector::default_config();
650
651 assert!(!ensemble.is_ready());
653 assert!(!ensemble.indicator_ready());
654 assert!(!ensemble.hmm_ready());
655
656 let mut price = 100.0;
658 for _ in 0..300 {
659 price *= 1.001;
660 ensemble.update(price * 1.01, price * 0.99, price);
661 }
662
663 assert!(ensemble.indicator_ready());
665 }
667
668 #[test]
669 fn test_hmm_state_probabilities_accessible() {
670 let mut ensemble = EnsembleRegimeDetector::default_config();
671
672 let mut price = 100.0;
673 for _ in 0..100 {
674 price *= 1.001;
675 ensemble.update(price * 1.01, price * 0.99, price);
676 }
677
678 let probs = ensemble.hmm_state_probabilities();
679 assert_eq!(probs.len(), 3, "Should have 3 HMM states");
680
681 let sum: f64 = probs.iter().sum();
682 assert!(
683 (sum - 1.0).abs() < 1e-6,
684 "HMM state probs should sum to 1.0: {sum}"
685 );
686 }
687
688 #[test]
689 fn test_expected_regime_duration() {
690 let ensemble = EnsembleRegimeDetector::default_config();
691 let duration = ensemble.expected_regime_duration();
692 assert!(duration > 0.0, "Duration should be > 0: {duration}");
693 }
694
695 #[test]
696 fn test_detector_accessors() {
697 let ensemble = EnsembleRegimeDetector::default_config();
698
699 assert!(!ensemble.indicator_detector().is_ready());
701 assert!(!ensemble.hmm_detector().is_ready());
702 }
703
704 #[test]
705 fn test_combine_results_agreement_boosts_confidence() {
706 let ensemble = EnsembleRegimeDetector::default_config();
707
708 let (_, conf_agree) = ensemble.combine_results(
709 MarketRegime::Trending(TrendDirection::Bullish),
710 0.7,
711 MarketRegime::Trending(TrendDirection::Bullish),
712 0.7,
713 true,
714 );
715
716 let (_, conf_disagree) = ensemble.combine_results(
717 MarketRegime::Trending(TrendDirection::Bullish),
718 0.7,
719 MarketRegime::MeanReverting,
720 0.7,
721 false,
722 );
723
724 assert!(
725 conf_agree > conf_disagree,
726 "Agreement should boost confidence: agree={conf_agree} vs disagree={conf_disagree}"
727 );
728 }
729
730 #[test]
731 fn test_combine_results_disagreement_returns_uncertain_at_low_conf() {
732 let config = EnsembleConfig {
733 agreement_threshold: 0.8,
734 disagreement_confidence_penalty: 0.5,
735 ..Default::default()
736 };
737 let ensemble = EnsembleRegimeDetector::new(config, RegimeConfig::default());
738
739 let (regime, _) = ensemble.combine_results(
740 MarketRegime::Trending(TrendDirection::Bullish),
741 0.4,
742 MarketRegime::MeanReverting,
743 0.4,
744 false,
745 );
746
747 assert_eq!(
748 regime,
749 MarketRegime::Uncertain,
750 "Low confidence + disagreement should produce Uncertain"
751 );
752 }
753}