1use std::collections::{HashMap, VecDeque};
12
13use serde::{Deserialize, Serialize};
14
15use super::types::{MarketRegime, RegimeConfidence, TrendDirection};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
34pub struct HmmIndicator {
35 pub config: HMMConfig,
36}
37
38impl HmmIndicator {
39 pub fn new(config: HMMConfig) -> Self {
40 Self { config }
41 }
42
43 pub fn with_defaults() -> Self {
44 Self::new(HMMConfig::default())
45 }
46}
47
48fn hmm_regime_id(r: MarketRegime) -> f64 {
49 match r {
50 MarketRegime::MeanReverting => 1.0,
51 MarketRegime::Volatile => 2.0,
52 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
53 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
54 MarketRegime::Uncertain => 0.0,
55 }
56}
57
58impl Indicator for HmmIndicator {
59 fn name(&self) -> &'static str {
60 "HMMRegime"
61 }
62
63 fn required_len(&self) -> usize {
68 self.config.min_observations + 1
69 }
70
71 fn required_columns(&self) -> &[&'static str] {
72 &["close"]
73 }
74
75 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76 self.check_len(candles)?;
77 let mut det = HMMRegimeDetector::new(self.config.clone());
78 let n = candles.len();
79 let mut conf = vec![f64::NAN; n];
80 let mut regime = vec![f64::NAN; n];
81 for (i, c) in candles.iter().enumerate() {
82 let rc = det.update(c.close);
83 conf[i] = rc.confidence;
84 regime[i] = hmm_regime_id(rc.regime);
85 }
86 Ok(IndicatorOutput::from_pairs([
87 ("hmm_conf", conf),
88 ("hmm_regime_id", regime),
89 ]))
90 }
91}
92
93pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
96 let min_observations = param_usize(params, "min_observations", 100)?;
97 let n_states = param_usize(params, "n_states", 3)?;
98 let config = HMMConfig {
99 n_states,
100 min_observations,
101 ..HMMConfig::default()
102 };
103 Ok(Box::new(HmmIndicator::new(config)))
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct HMMConfig {
109 pub n_states: usize,
111 pub min_observations: usize,
113 pub learning_rate: f64,
115 pub transition_smoothing: f64,
117 pub lookback_window: usize,
119 pub min_confidence: f64,
121}
122
123impl Default for HMMConfig {
124 fn default() -> Self {
125 Self {
126 n_states: 3, min_observations: 100,
128 learning_rate: 0.01,
129 transition_smoothing: 0.1,
130 lookback_window: 252, min_confidence: 0.6,
132 }
133 }
134}
135
136impl HMMConfig {
137 pub fn crypto_optimized() -> Self {
139 Self {
140 n_states: 3,
141 min_observations: 50,
142 learning_rate: 0.02, transition_smoothing: 0.05,
144 lookback_window: 100,
145 min_confidence: 0.5,
146 }
147 }
148
149 pub fn conservative() -> Self {
151 Self {
152 n_states: 2, min_observations: 150,
154 learning_rate: 0.005,
155 transition_smoothing: 0.15,
156 lookback_window: 500,
157 min_confidence: 0.7,
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164struct GaussianState {
165 mean: f64,
166 variance: f64,
167 sum: f64,
169 sum_sq: f64,
170 count: usize,
171}
172
173impl GaussianState {
174 fn new(mean: f64, variance: f64) -> Self {
175 Self {
176 mean,
177 variance,
178 sum: 0.0,
179 sum_sq: 0.0,
180 count: 0,
181 }
182 }
183
184 fn pdf(&self, x: f64) -> f64 {
186 let diff = x - self.mean;
187 let exponent = -0.5 * diff * diff / self.variance;
188 let normalizer = (2.0 * std::f64::consts::PI * self.variance).sqrt();
189 exponent.exp() / normalizer
190 }
191
192 fn update(&mut self, x: f64, weight: f64, learning_rate: f64) {
194 if learning_rate > 0.0 {
195 self.mean = (1.0 - learning_rate * weight) * self.mean + learning_rate * weight * x;
197 let new_var = (x - self.mean).powi(2);
198 self.variance =
199 (1.0 - learning_rate * weight) * self.variance + learning_rate * weight * new_var;
200 self.variance = self.variance.max(1e-8); }
202
203 self.sum += x * weight;
205 self.sum_sq += x * x * weight;
206 self.count += 1;
207 }
208}
209
210#[derive(Debug)]
237pub struct HMMRegimeDetector {
238 config: HMMConfig,
239
240 states: Vec<GaussianState>,
242
243 transition_matrix: Vec<Vec<f64>>,
245
246 initial_probs: Vec<f64>,
248
249 state_probs: Vec<f64>,
251
252 returns_history: VecDeque<f64>,
254
255 prices: VecDeque<f64>,
257
258 current_state: usize,
260
261 current_confidence: f64,
263
264 n_observations: usize,
266
267 last_regime: MarketRegime,
269}
270
271impl HMMRegimeDetector {
272 pub fn new(config: HMMConfig) -> Self {
274 let n = config.n_states;
275
276 let states = match n {
281 2 => vec![
282 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0004), ],
285 3 => vec![
286 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0002), GaussianState::new(0.0, 0.0009), ],
290 _ => (0..n)
291 .map(|i| {
292 let mean = (i as f64 - n as f64 / 2.0) * 0.001;
293 let var = 0.0001 * (1.0 + i as f64);
294 GaussianState::new(mean, var)
295 })
296 .collect(),
297 };
298
299 let mut transition_matrix = vec![vec![0.0; n]; n];
302 for (i, row) in transition_matrix.iter_mut().enumerate().take(n) {
303 for (j, cell) in row.iter_mut().enumerate().take(n) {
304 if i == j {
305 *cell = 0.9; } else {
307 *cell = 0.1 / (n - 1) as f64;
308 }
309 }
310 }
311
312 let initial_probs = vec![1.0 / n as f64; n];
314 let state_probs = initial_probs.clone();
315
316 Self {
317 config: config.clone(),
318 states,
319 transition_matrix,
320 initial_probs,
321 state_probs,
322 returns_history: VecDeque::with_capacity(config.lookback_window),
323 prices: VecDeque::with_capacity(10),
324 current_state: 0,
325 current_confidence: 0.0,
326 n_observations: 0,
327 last_regime: MarketRegime::Uncertain,
328 }
329 }
330
331 pub fn default_config() -> Self {
333 Self::new(HMMConfig::default())
334 }
335
336 pub fn crypto_optimized() -> Self {
338 Self::new(HMMConfig::crypto_optimized())
339 }
340
341 pub fn conservative() -> Self {
343 Self::new(HMMConfig::conservative())
344 }
345
346 pub fn update(&mut self, close: f64) -> RegimeConfidence {
351 if let Some(&prev_close) = self.prices.back()
353 && prev_close > 0.0
354 {
355 let log_return = (close / prev_close).ln();
356 self.process_return(log_return);
357 }
358
359 self.prices.push_back(close);
361 if self.prices.len() > 10 {
362 self.prices.pop_front();
363 }
364
365 let confidence = self.get_regime_confidence();
367 self.last_regime = confidence.regime;
368 confidence
369 }
370
371 pub fn update_ohlc(&mut self, _high: f64, _low: f64, close: f64) -> RegimeConfidence {
373 self.update(close)
374 }
375
376 fn process_return(&mut self, ret: f64) {
378 self.n_observations += 1;
379
380 self.returns_history.push_back(ret);
382 if self.returns_history.len() > self.config.lookback_window {
383 self.returns_history.pop_front();
384 }
385
386 self.forward_step(ret);
388
389 if self.n_observations > self.config.min_observations && self.config.learning_rate > 0.0 {
391 self.online_parameter_update(ret);
392 }
393
394 let reestimate_interval = self.config.lookback_window / 2;
396 if self.n_observations > 0
397 && reestimate_interval > 0
398 && self.n_observations.is_multiple_of(reestimate_interval)
399 && self.returns_history.len() >= self.config.min_observations
400 {
401 self.baum_welch_update();
402 }
403 }
404
405 fn forward_step(&mut self, ret: f64) {
407 let n = self.config.n_states;
408 let mut new_probs = vec![0.0; n];
409
410 let emissions: Vec<f64> = self.states.iter().map(|s| s.pdf(ret)).collect();
412
413 for j in 0..n {
415 let mut sum = 0.0;
416 for i in 0..n {
417 sum += self.transition_matrix[i][j] * self.state_probs[i];
418 }
419 new_probs[j] = emissions[j] * sum;
420 }
421
422 let total: f64 = new_probs.iter().sum();
424 if total > 1e-300 {
425 for p in &mut new_probs {
426 *p /= total;
427 }
428 } else {
429 new_probs = vec![1.0 / n as f64; n];
431 }
432
433 self.state_probs = new_probs;
434
435 let (max_idx, max_prob) = self
437 .state_probs
438 .iter()
439 .enumerate()
440 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
441 .unwrap();
442
443 self.current_state = max_idx;
444 self.current_confidence = *max_prob;
445 }
446
447 fn online_parameter_update(&mut self, ret: f64) {
449 let lr = self.config.learning_rate;
450
451 for (i, state) in self.states.iter_mut().enumerate() {
452 let weight = self.state_probs[i];
453 state.update(ret, weight, lr);
454 }
455
456 let smoothing = self.config.transition_smoothing;
459 for i in 0..self.config.n_states {
460 for j in 0..self.config.n_states {
461 let target = if i == j {
462 0.9
463 } else {
464 0.1 / (self.config.n_states - 1) as f64
465 };
466 self.transition_matrix[i][j] =
467 (1.0 - smoothing) * self.transition_matrix[i][j] + smoothing * target;
468 }
469 }
470 }
471
472 fn baum_welch_update(&mut self) {
478 let returns: Vec<f64> = self.returns_history.iter().copied().collect();
479 if returns.len() < self.config.min_observations {
480 return;
481 }
482
483 let n = self.config.n_states;
484 let t = returns.len();
485
486 let mut alpha = vec![vec![0.0; n]; t];
488
489 for (j, alpha_val) in alpha[0].iter_mut().enumerate().take(n) {
491 *alpha_val = self.initial_probs[j] * self.states[j].pdf(returns[0]);
492 }
493 Self::normalize_vec(&mut alpha[0]);
494
495 for time in 1..t {
497 for j in 0..n {
498 let mut sum = 0.0;
499 for (i, alpha_prev) in alpha[time - 1].iter().enumerate().take(n) {
500 sum += alpha_prev * self.transition_matrix[i][j];
501 }
502 alpha[time][j] = sum * self.states[j].pdf(returns[time]);
503 }
504 Self::normalize_vec(&mut alpha[time]);
505 }
506
507 let mut beta = vec![vec![1.0; n]; t];
509
510 for time in (0..t - 1).rev() {
511 for i in 0..n {
512 let mut sum = 0.0;
513 for (j, beta_next) in beta[time + 1].iter().enumerate().take(n) {
514 sum += self.transition_matrix[i][j]
515 * self.states[j].pdf(returns[time + 1])
516 * beta_next;
517 }
518 beta[time][i] = sum;
519 }
520 Self::normalize_vec(&mut beta[time]);
521 }
522
523 let mut gamma = vec![vec![0.0; n]; t];
525 for time in 0..t {
526 let mut sum = 0.0;
527 for (j, gamma_val) in gamma[time].iter_mut().enumerate().take(n) {
528 *gamma_val = alpha[time][j] * beta[time][j];
529 sum += *gamma_val;
530 }
531 if sum > 1e-300 {
532 for gamma_val in gamma[time].iter_mut().take(n) {
533 *gamma_val /= sum;
534 }
535 }
536 }
537
538 for (j, state) in self.states.iter_mut().enumerate().take(n) {
540 let mut weight_sum = 0.0;
541 let mut mean_sum = 0.0;
542 let mut var_sum = 0.0;
543
544 for time in 0..t {
545 let w = gamma[time][j];
546 weight_sum += w;
547 mean_sum += w * returns[time];
548 }
549
550 if weight_sum > 1e-8 {
551 let new_mean = mean_sum / weight_sum;
552
553 for time in 0..t {
554 let w = gamma[time][j];
555 var_sum += w * (returns[time] - new_mean).powi(2);
556 }
557
558 let new_var = (var_sum / weight_sum).max(1e-8);
559
560 let blend = 0.3;
562 state.mean = (1.0 - blend) * state.mean + blend * new_mean;
563 state.variance = (1.0 - blend) * state.variance + blend * new_var;
564 }
565 }
566 }
567
568 fn normalize_vec(vec: &mut [f64]) {
570 let sum: f64 = vec.iter().sum();
571 if sum > 1e-300 {
572 for v in vec.iter_mut() {
573 *v /= sum;
574 }
575 }
576 }
577
578 pub fn get_regime_confidence(&self) -> RegimeConfidence {
580 if self.n_observations < self.config.min_observations {
581 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
582 }
583
584 let regime = self.state_to_regime(self.current_state);
585 let confidence = self.current_confidence;
586
587 RegimeConfidence::with_metrics(
588 regime,
589 confidence,
590 self.states[self.current_state].mean * 100.0 * 252.0, self.states[self.current_state].variance.sqrt() * 100.0 * 252.0_f64.sqrt(), 0.0, )
594 }
595
596 fn state_to_regime(&self, state: usize) -> MarketRegime {
604 let state_params = &self.states[state];
605 let mean = state_params.mean;
606 let vol = state_params.variance.sqrt();
607
608 let is_high_vol = vol > 0.02; let is_positive = mean > 0.0005; let is_negative = mean < -0.0005;
612
613 if is_high_vol {
614 MarketRegime::Volatile
615 } else if is_positive {
616 MarketRegime::Trending(TrendDirection::Bullish)
617 } else if is_negative {
618 MarketRegime::Trending(TrendDirection::Bearish)
619 } else {
620 MarketRegime::MeanReverting }
622 }
623
624 pub fn state_probabilities(&self) -> &[f64] {
630 &self.state_probs
631 }
632
633 pub fn state_parameters(&self) -> Vec<(f64, f64)> {
635 self.states.iter().map(|s| (s.mean, s.variance)).collect()
636 }
637
638 pub fn transition_matrix(&self) -> &[Vec<f64>] {
640 &self.transition_matrix
641 }
642
643 pub fn current_state_index(&self) -> usize {
645 self.current_state
646 }
647
648 pub fn is_ready(&self) -> bool {
650 self.n_observations >= self.config.min_observations
651 }
652
653 pub fn expected_regime_duration(&self, state: usize) -> f64 {
657 if state < self.config.n_states {
658 1.0 / (1.0 - self.transition_matrix[state][state])
659 } else {
660 0.0
661 }
662 }
663
664 pub fn predict_next_state(&self) -> (usize, f64) {
666 let mut next_probs = vec![0.0; self.config.n_states];
667
668 for (j, next_prob) in next_probs.iter_mut().enumerate().take(self.config.n_states) {
669 for i in 0..self.config.n_states {
670 *next_prob += self.transition_matrix[i][j] * self.state_probs[i];
671 }
672 }
673
674 let (max_idx, max_prob) = next_probs
675 .iter()
676 .enumerate()
677 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
678 .unwrap();
679
680 (max_idx, *max_prob)
681 }
682
683 pub fn n_observations(&self) -> usize {
685 self.n_observations
686 }
687
688 pub fn current_confidence(&self) -> f64 {
690 self.current_confidence
691 }
692
693 pub fn config(&self) -> &HMMConfig {
695 &self.config
696 }
697}
698
699#[cfg(test)]
704mod tests {
705 use super::*;
706
707 #[test]
708 fn test_hmm_initialization() {
709 let detector = HMMRegimeDetector::default_config();
710 assert!(!detector.is_ready());
711 assert_eq!(detector.state_probabilities().len(), 3);
712 }
713
714 #[test]
715 fn test_hmm_crypto_config() {
716 let detector = HMMRegimeDetector::crypto_optimized();
717 assert_eq!(detector.config().n_states, 3);
718 assert_eq!(detector.config().min_observations, 50);
719 }
720
721 #[test]
722 fn test_hmm_conservative_config() {
723 let detector = HMMRegimeDetector::conservative();
724 assert_eq!(detector.config().n_states, 2);
725 assert_eq!(detector.config().min_observations, 150);
726 assert_eq!(detector.state_probabilities().len(), 2);
727 }
728
729 #[test]
730 fn test_hmm_warmup() {
731 let mut detector = HMMRegimeDetector::crypto_optimized();
732
733 for i in 0..49 {
735 let price = 100.0 + (i as f64) * 0.01;
736 let result = detector.update(price);
737 assert_eq!(
738 result.regime,
739 MarketRegime::Uncertain,
740 "Should be Uncertain during warmup at step {i}"
741 );
742 }
743
744 assert!(!detector.is_ready());
745 }
746
747 #[test]
748 fn test_hmm_becomes_ready() {
749 let mut detector = HMMRegimeDetector::crypto_optimized();
750
751 for i in 0..60 {
752 let price = 100.0 + (i as f64) * 0.01;
753 detector.update(price);
754 }
755
756 assert!(detector.is_ready(), "Should be ready after 60 observations");
757 }
758
759 #[test]
760 fn test_bull_market_detection() {
761 let mut detector = HMMRegimeDetector::crypto_optimized();
762
763 let mut price = 100.0;
765 for _ in 0..200 {
766 price *= 1.005; let result = detector.update(price);
768 if detector.is_ready() {
769 assert_ne!(result.regime, MarketRegime::Uncertain);
771 }
772 }
773
774 let final_result = detector.get_regime_confidence();
775 assert!(
777 matches!(
778 final_result.regime,
779 MarketRegime::Trending(TrendDirection::Bullish)
780 ),
781 "Expected Bullish trend, got: {:?}",
782 final_result.regime
783 );
784 }
785
786 #[test]
787 fn test_volatile_market_detection() {
788 let mut detector = HMMRegimeDetector::crypto_optimized();
789
790 let mut price = 100.0;
792 for i in 0..200 {
793 if i % 2 == 0 {
794 price *= 1.05; } else {
796 price *= 0.95; }
798 detector.update(price);
799 }
800
801 let result = detector.get_regime_confidence();
802 assert!(
804 matches!(
805 result.regime,
806 MarketRegime::Volatile | MarketRegime::MeanReverting
807 ),
808 "Expected Volatile or MeanReverting for choppy data, got: {:?}",
809 result.regime
810 );
811 }
812
813 #[test]
814 fn test_state_probabilities_sum_to_one() {
815 let mut detector = HMMRegimeDetector::crypto_optimized();
816
817 let mut price = 100.0;
818 for _ in 0..100 {
819 price *= 1.001;
820 detector.update(price);
821
822 let probs = detector.state_probabilities();
823 let sum: f64 = probs.iter().sum();
824 assert!(
825 (sum - 1.0).abs() < 1e-6,
826 "State probabilities should sum to 1.0, got: {sum}"
827 );
828 }
829 }
830
831 #[test]
832 fn test_transition_matrix_rows_sum_to_one() {
833 let detector = HMMRegimeDetector::default_config();
834 let tm = detector.transition_matrix();
835
836 for (i, row) in tm.iter().enumerate() {
837 let sum: f64 = row.iter().sum();
838 assert!(
839 (sum - 1.0).abs() < 1e-6,
840 "Transition matrix row {i} should sum to 1.0, got: {sum}"
841 );
842 }
843 }
844
845 #[test]
846 fn test_expected_regime_duration() {
847 let detector = HMMRegimeDetector::default_config();
848
849 let duration = detector.expected_regime_duration(0);
851 assert!(
852 (duration - 10.0).abs() < 1e-6,
853 "Expected duration should be ~10 with 0.9 persistence, got: {duration}"
854 );
855 }
856
857 #[test]
858 fn test_predict_next_state() {
859 let mut detector = HMMRegimeDetector::crypto_optimized();
860
861 let mut price = 100.0;
862 for _ in 0..100 {
863 price *= 1.002;
864 detector.update(price);
865 }
866
867 let (next_state, prob) = detector.predict_next_state();
868 assert!(next_state < detector.config().n_states);
869 assert!(
870 (0.0..=1.0).contains(&prob),
871 "Predicted probability should be in [0, 1]: {prob}"
872 );
873 }
874
875 #[test]
876 fn test_state_parameters() {
877 let detector = HMMRegimeDetector::default_config();
878 let params = detector.state_parameters();
879
880 assert_eq!(params.len(), 3, "Should have 3 state parameters");
881
882 for (mean, variance) in ¶ms {
883 assert!(variance > &0.0, "Variance should be positive: {variance}");
884 assert!(mean.is_finite(), "Mean should be finite: {mean}");
885 }
886 }
887
888 #[test]
889 fn test_update_ohlc_uses_close() {
890 let mut det1 = HMMRegimeDetector::crypto_optimized();
891 let mut det2 = HMMRegimeDetector::crypto_optimized();
892
893 for i in 0..100 {
895 let close = 100.0 + i as f64 * 0.1;
896 let r1 = det1.update(close);
897 let r2 = det2.update_ohlc(close * 1.01, close * 0.99, close);
898
899 assert_eq!(
900 r1.regime, r2.regime,
901 "update and update_ohlc should produce same regime"
902 );
903 }
904 }
905
906 #[test]
907 fn test_n_observations_tracking() {
908 let mut detector = HMMRegimeDetector::crypto_optimized();
909
910 assert_eq!(detector.n_observations(), 0);
911
912 for i in 0..50 {
913 detector.update(100.0 + i as f64);
914 }
915
916 assert_eq!(detector.n_observations(), 49);
918 }
919
920 #[test]
921 fn test_confidence_range() {
922 let mut detector = HMMRegimeDetector::crypto_optimized();
923
924 let mut price = 100.0;
925 for _ in 0..200 {
926 price *= 1.002;
927 detector.update(price);
928 }
929
930 let confidence = detector.current_confidence();
931 assert!(
932 (0.0..=1.0).contains(&confidence),
933 "Confidence should be in [0, 1]: {confidence}"
934 );
935 }
936}