1use std::collections::{HashMap, VecDeque};
12
13use serde::{Deserialize, Serialize};
14
15use super::types::{MarketRegime, RegimeConfidence, TrendDirection};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
34pub struct HmmIndicator {
35 pub config: HMMConfig,
36}
37
38impl HmmIndicator {
39 pub fn new(config: HMMConfig) -> Self {
40 Self { config }
41 }
42
43 pub fn with_defaults() -> Self {
44 Self::new(HMMConfig::default())
45 }
46}
47
48fn hmm_regime_id(r: MarketRegime) -> f64 {
49 match r {
50 MarketRegime::MeanReverting => 1.0,
51 MarketRegime::Volatile => 2.0,
52 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
53 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
54 MarketRegime::Uncertain => 0.0,
55 }
56}
57
58impl Indicator for HmmIndicator {
59 fn name(&self) -> &'static str {
60 "HMMRegime"
61 }
62
63 fn required_len(&self) -> usize {
68 self.config.min_observations + 1
69 }
70
71 fn required_columns(&self) -> &[&'static str] {
72 &["close"]
73 }
74
75 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76 self.check_len(candles)?;
77 let mut det = HMMRegimeDetector::new(self.config.clone());
78 let n = candles.len();
79 let mut conf = vec![f64::NAN; n];
80 let mut regime = vec![f64::NAN; n];
81 for (i, c) in candles.iter().enumerate() {
82 let rc = det.update(c.close);
83 conf[i] = rc.confidence;
84 regime[i] = hmm_regime_id(rc.regime);
85 }
86 Ok(IndicatorOutput::from_pairs([
87 ("hmm_conf", conf),
88 ("hmm_regime_id", regime),
89 ]))
90 }
91}
92
93pub fn factory<S: ::std::hash::BuildHasher>(
96 params: &HashMap<String, String, S>,
97) -> Result<Box<dyn Indicator>, IndicatorError> {
98 let min_observations = param_usize(params, "min_observations", 100)?;
99 let n_states = param_usize(params, "n_states", 3)?;
100 let config = HMMConfig {
101 n_states,
102 min_observations,
103 ..HMMConfig::default()
104 };
105 Ok(Box::new(HmmIndicator::new(config)))
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct HMMConfig {
111 pub n_states: usize,
113 pub min_observations: usize,
115 pub learning_rate: f64,
117 pub transition_smoothing: f64,
119 pub lookback_window: usize,
121 pub min_confidence: f64,
123}
124
125impl Default for HMMConfig {
126 fn default() -> Self {
127 Self {
128 n_states: 3, min_observations: 100,
130 learning_rate: 0.01,
131 transition_smoothing: 0.1,
132 lookback_window: 252, min_confidence: 0.6,
134 }
135 }
136}
137
138impl HMMConfig {
139 pub fn crypto_optimized() -> Self {
141 Self {
142 n_states: 3,
143 min_observations: 50,
144 learning_rate: 0.02, transition_smoothing: 0.05,
146 lookback_window: 100,
147 min_confidence: 0.5,
148 }
149 }
150
151 pub fn conservative() -> Self {
153 Self {
154 n_states: 2, min_observations: 150,
156 learning_rate: 0.005,
157 transition_smoothing: 0.15,
158 lookback_window: 500,
159 min_confidence: 0.7,
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
166struct GaussianState {
167 mean: f64,
168 variance: f64,
169 sum: f64,
171 sum_sq: f64,
172 count: usize,
173}
174
175impl GaussianState {
176 fn new(mean: f64, variance: f64) -> Self {
177 Self {
178 mean,
179 variance,
180 sum: 0.0,
181 sum_sq: 0.0,
182 count: 0,
183 }
184 }
185
186 fn pdf(&self, x: f64) -> f64 {
188 let diff = x - self.mean;
189 let exponent = -0.5 * diff * diff / self.variance;
190 let normalizer = (2.0 * std::f64::consts::PI * self.variance).sqrt();
191 exponent.exp() / normalizer
192 }
193
194 fn update(&mut self, x: f64, weight: f64, learning_rate: f64) {
196 if learning_rate > 0.0 {
197 self.mean = (1.0 - learning_rate * weight) * self.mean + learning_rate * weight * x;
199 let new_var = (x - self.mean).powi(2);
200 self.variance =
201 (1.0 - learning_rate * weight) * self.variance + learning_rate * weight * new_var;
202 self.variance = self.variance.max(1e-8); }
204
205 self.sum += x * weight;
207 self.sum_sq += x * x * weight;
208 self.count += 1;
209 }
210}
211
212#[derive(Debug)]
239pub struct HMMRegimeDetector {
240 config: HMMConfig,
241
242 states: Vec<GaussianState>,
244
245 transition_matrix: Vec<Vec<f64>>,
247
248 initial_probs: Vec<f64>,
250
251 state_probs: Vec<f64>,
253
254 returns_history: VecDeque<f64>,
256
257 prices: VecDeque<f64>,
259
260 current_state: usize,
262
263 current_confidence: f64,
265
266 n_observations: usize,
268
269 last_regime: MarketRegime,
271}
272
273impl HMMRegimeDetector {
274 pub fn new(config: HMMConfig) -> Self {
276 let n = config.n_states;
277
278 let states = match n {
283 2 => vec![
284 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0004), ],
287 3 => vec![
288 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0002), GaussianState::new(0.0, 0.0009), ],
292 _ => (0..n)
293 .map(|i| {
294 let mean = (i as f64 - n as f64 / 2.0) * 0.001;
295 let var = 0.0001 * (1.0 + i as f64);
296 GaussianState::new(mean, var)
297 })
298 .collect(),
299 };
300
301 let mut transition_matrix = vec![vec![0.0; n]; n];
304 for (i, row) in transition_matrix.iter_mut().enumerate().take(n) {
305 for (j, cell) in row.iter_mut().enumerate().take(n) {
306 if i == j {
307 *cell = 0.9; } else {
309 *cell = 0.1 / (n - 1) as f64;
310 }
311 }
312 }
313
314 let initial_probs = vec![1.0 / n as f64; n];
316 let state_probs = initial_probs.clone();
317
318 Self {
319 config: config.clone(),
320 states,
321 transition_matrix,
322 initial_probs,
323 state_probs,
324 returns_history: VecDeque::with_capacity(config.lookback_window),
325 prices: VecDeque::with_capacity(10),
326 current_state: 0,
327 current_confidence: 0.0,
328 n_observations: 0,
329 last_regime: MarketRegime::Uncertain,
330 }
331 }
332
333 pub fn default_config() -> Self {
335 Self::new(HMMConfig::default())
336 }
337
338 pub fn crypto_optimized() -> Self {
340 Self::new(HMMConfig::crypto_optimized())
341 }
342
343 pub fn conservative() -> Self {
345 Self::new(HMMConfig::conservative())
346 }
347
348 pub fn update(&mut self, close: f64) -> RegimeConfidence {
353 if let Some(&prev_close) = self.prices.back()
355 && prev_close > 0.0
356 {
357 let log_return = (close / prev_close).ln();
358 self.process_return(log_return);
359 }
360
361 self.prices.push_back(close);
363 if self.prices.len() > 10 {
364 self.prices.pop_front();
365 }
366
367 let confidence = self.get_regime_confidence();
369 self.last_regime = confidence.regime;
370 confidence
371 }
372
373 pub fn update_ohlc(&mut self, _high: f64, _low: f64, close: f64) -> RegimeConfidence {
375 self.update(close)
376 }
377
378 fn process_return(&mut self, ret: f64) {
380 self.n_observations += 1;
381
382 self.returns_history.push_back(ret);
384 if self.returns_history.len() > self.config.lookback_window {
385 self.returns_history.pop_front();
386 }
387
388 self.forward_step(ret);
390
391 if self.n_observations > self.config.min_observations && self.config.learning_rate > 0.0 {
393 self.online_parameter_update(ret);
394 }
395
396 let reestimate_interval = self.config.lookback_window / 2;
398 if self.n_observations > 0
399 && reestimate_interval > 0
400 && self.n_observations.is_multiple_of(reestimate_interval)
401 && self.returns_history.len() >= self.config.min_observations
402 {
403 self.baum_welch_update();
404 }
405 }
406
407 fn forward_step(&mut self, ret: f64) {
409 let n = self.config.n_states;
410 let mut new_probs = vec![0.0; n];
411
412 let emissions: Vec<f64> = self.states.iter().map(|s| s.pdf(ret)).collect();
414
415 for j in 0..n {
417 let mut sum = 0.0;
418 for i in 0..n {
419 sum += self.transition_matrix[i][j] * self.state_probs[i];
420 }
421 new_probs[j] = emissions[j] * sum;
422 }
423
424 let total: f64 = new_probs.iter().sum();
426 if total > 1e-300 {
427 for p in &mut new_probs {
428 *p /= total;
429 }
430 } else {
431 new_probs = vec![1.0 / n as f64; n];
433 }
434
435 self.state_probs = new_probs;
436
437 let (max_idx, max_prob) = self
439 .state_probs
440 .iter()
441 .enumerate()
442 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
443 .unwrap();
444
445 self.current_state = max_idx;
446 self.current_confidence = *max_prob;
447 }
448
449 fn online_parameter_update(&mut self, ret: f64) {
451 let lr = self.config.learning_rate;
452
453 for (i, state) in self.states.iter_mut().enumerate() {
454 let weight = self.state_probs[i];
455 state.update(ret, weight, lr);
456 }
457
458 let smoothing = self.config.transition_smoothing;
461 for i in 0..self.config.n_states {
462 for j in 0..self.config.n_states {
463 let target = if i == j {
464 0.9
465 } else {
466 0.1 / (self.config.n_states - 1) as f64
467 };
468 self.transition_matrix[i][j] =
469 (1.0 - smoothing) * self.transition_matrix[i][j] + smoothing * target;
470 }
471 }
472 }
473
474 fn baum_welch_update(&mut self) {
480 let returns: Vec<f64> = self.returns_history.iter().copied().collect();
481 if returns.len() < self.config.min_observations {
482 return;
483 }
484
485 let n = self.config.n_states;
486 let t = returns.len();
487
488 let mut alpha = vec![vec![0.0; n]; t];
490
491 for (j, alpha_val) in alpha[0].iter_mut().enumerate().take(n) {
493 *alpha_val = self.initial_probs[j] * self.states[j].pdf(returns[0]);
494 }
495 Self::normalize_vec(&mut alpha[0]);
496
497 for time in 1..t {
499 for j in 0..n {
500 let mut sum = 0.0;
501 for (i, alpha_prev) in alpha[time - 1].iter().enumerate().take(n) {
502 sum += alpha_prev * self.transition_matrix[i][j];
503 }
504 alpha[time][j] = sum * self.states[j].pdf(returns[time]);
505 }
506 Self::normalize_vec(&mut alpha[time]);
507 }
508
509 let mut beta = vec![vec![1.0; n]; t];
511
512 for time in (0..t - 1).rev() {
513 for i in 0..n {
514 let mut sum = 0.0;
515 for (j, beta_next) in beta[time + 1].iter().enumerate().take(n) {
516 sum += self.transition_matrix[i][j]
517 * self.states[j].pdf(returns[time + 1])
518 * beta_next;
519 }
520 beta[time][i] = sum;
521 }
522 Self::normalize_vec(&mut beta[time]);
523 }
524
525 let mut gamma = vec![vec![0.0; n]; t];
527 for time in 0..t {
528 let mut sum = 0.0;
529 for (j, gamma_val) in gamma[time].iter_mut().enumerate().take(n) {
530 *gamma_val = alpha[time][j] * beta[time][j];
531 sum += *gamma_val;
532 }
533 if sum > 1e-300 {
534 for gamma_val in gamma[time].iter_mut().take(n) {
535 *gamma_val /= sum;
536 }
537 }
538 }
539
540 for (j, state) in self.states.iter_mut().enumerate().take(n) {
542 let mut weight_sum = 0.0;
543 let mut mean_sum = 0.0;
544 let mut var_sum = 0.0;
545
546 for time in 0..t {
547 let w = gamma[time][j];
548 weight_sum += w;
549 mean_sum += w * returns[time];
550 }
551
552 if weight_sum > 1e-8 {
553 let new_mean = mean_sum / weight_sum;
554
555 for time in 0..t {
556 let w = gamma[time][j];
557 var_sum += w * (returns[time] - new_mean).powi(2);
558 }
559
560 let new_var = (var_sum / weight_sum).max(1e-8);
561
562 let blend = 0.3;
564 state.mean = (1.0 - blend) * state.mean + blend * new_mean;
565 state.variance = (1.0 - blend) * state.variance + blend * new_var;
566 }
567 }
568 }
569
570 fn normalize_vec(vec: &mut [f64]) {
572 let sum: f64 = vec.iter().sum();
573 if sum > 1e-300 {
574 for v in vec.iter_mut() {
575 *v /= sum;
576 }
577 }
578 }
579
580 pub fn get_regime_confidence(&self) -> RegimeConfidence {
582 if self.n_observations < self.config.min_observations {
583 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
584 }
585
586 let regime = self.state_to_regime(self.current_state);
587 let confidence = self.current_confidence;
588
589 RegimeConfidence::with_metrics(
590 regime,
591 confidence,
592 self.states[self.current_state].mean * 100.0 * 252.0, self.states[self.current_state].variance.sqrt() * 100.0 * 252.0_f64.sqrt(), 0.0, )
596 }
597
598 fn state_to_regime(&self, state: usize) -> MarketRegime {
606 let state_params = &self.states[state];
607 let mean = state_params.mean;
608 let vol = state_params.variance.sqrt();
609
610 let is_high_vol = vol > 0.02; let is_positive = mean > 0.0005; let is_negative = mean < -0.0005;
614
615 if is_high_vol {
616 MarketRegime::Volatile
617 } else if is_positive {
618 MarketRegime::Trending(TrendDirection::Bullish)
619 } else if is_negative {
620 MarketRegime::Trending(TrendDirection::Bearish)
621 } else {
622 MarketRegime::MeanReverting }
624 }
625
626 pub fn state_probabilities(&self) -> &[f64] {
632 &self.state_probs
633 }
634
635 pub fn state_parameters(&self) -> Vec<(f64, f64)> {
637 self.states.iter().map(|s| (s.mean, s.variance)).collect()
638 }
639
640 pub fn transition_matrix(&self) -> &[Vec<f64>] {
642 &self.transition_matrix
643 }
644
645 pub fn current_state_index(&self) -> usize {
647 self.current_state
648 }
649
650 pub fn is_ready(&self) -> bool {
652 self.n_observations >= self.config.min_observations
653 }
654
655 pub fn expected_regime_duration(&self, state: usize) -> f64 {
659 if state < self.config.n_states {
660 1.0 / (1.0 - self.transition_matrix[state][state])
661 } else {
662 0.0
663 }
664 }
665
666 pub fn predict_next_state(&self) -> (usize, f64) {
668 let mut next_probs = vec![0.0; self.config.n_states];
669
670 for (j, next_prob) in next_probs.iter_mut().enumerate().take(self.config.n_states) {
671 for i in 0..self.config.n_states {
672 *next_prob += self.transition_matrix[i][j] * self.state_probs[i];
673 }
674 }
675
676 let (max_idx, max_prob) = next_probs
677 .iter()
678 .enumerate()
679 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
680 .unwrap();
681
682 (max_idx, *max_prob)
683 }
684
685 pub fn n_observations(&self) -> usize {
687 self.n_observations
688 }
689
690 pub fn current_confidence(&self) -> f64 {
692 self.current_confidence
693 }
694
695 pub fn config(&self) -> &HMMConfig {
697 &self.config
698 }
699}
700
701#[cfg(test)]
706mod tests {
707 use super::*;
708
709 #[test]
710 fn test_hmm_initialization() {
711 let detector = HMMRegimeDetector::default_config();
712 assert!(!detector.is_ready());
713 assert_eq!(detector.state_probabilities().len(), 3);
714 }
715
716 #[test]
717 fn test_hmm_crypto_config() {
718 let detector = HMMRegimeDetector::crypto_optimized();
719 assert_eq!(detector.config().n_states, 3);
720 assert_eq!(detector.config().min_observations, 50);
721 }
722
723 #[test]
724 fn test_hmm_conservative_config() {
725 let detector = HMMRegimeDetector::conservative();
726 assert_eq!(detector.config().n_states, 2);
727 assert_eq!(detector.config().min_observations, 150);
728 assert_eq!(detector.state_probabilities().len(), 2);
729 }
730
731 #[test]
732 fn test_hmm_warmup() {
733 let mut detector = HMMRegimeDetector::crypto_optimized();
734
735 for i in 0..49 {
737 let price = 100.0 + (i as f64) * 0.01;
738 let result = detector.update(price);
739 assert_eq!(
740 result.regime,
741 MarketRegime::Uncertain,
742 "Should be Uncertain during warmup at step {i}"
743 );
744 }
745
746 assert!(!detector.is_ready());
747 }
748
749 #[test]
750 fn test_hmm_becomes_ready() {
751 let mut detector = HMMRegimeDetector::crypto_optimized();
752
753 for i in 0..60 {
754 let price = 100.0 + (i as f64) * 0.01;
755 detector.update(price);
756 }
757
758 assert!(detector.is_ready(), "Should be ready after 60 observations");
759 }
760
761 #[test]
762 fn test_bull_market_detection() {
763 let mut detector = HMMRegimeDetector::crypto_optimized();
764
765 let mut price = 100.0;
767 for _ in 0..200 {
768 price *= 1.005; let result = detector.update(price);
770 if detector.is_ready() {
771 assert_ne!(result.regime, MarketRegime::Uncertain);
773 }
774 }
775
776 let final_result = detector.get_regime_confidence();
777 assert!(
779 matches!(
780 final_result.regime,
781 MarketRegime::Trending(TrendDirection::Bullish)
782 ),
783 "Expected Bullish trend, got: {:?}",
784 final_result.regime
785 );
786 }
787
788 #[test]
789 fn test_volatile_market_detection() {
790 let mut detector = HMMRegimeDetector::crypto_optimized();
791
792 let mut price = 100.0;
794 for i in 0..200 {
795 if i % 2 == 0 {
796 price *= 1.05; } else {
798 price *= 0.95; }
800 detector.update(price);
801 }
802
803 let result = detector.get_regime_confidence();
804 assert!(
806 matches!(
807 result.regime,
808 MarketRegime::Volatile | MarketRegime::MeanReverting
809 ),
810 "Expected Volatile or MeanReverting for choppy data, got: {:?}",
811 result.regime
812 );
813 }
814
815 #[test]
816 fn test_state_probabilities_sum_to_one() {
817 let mut detector = HMMRegimeDetector::crypto_optimized();
818
819 let mut price = 100.0;
820 for _ in 0..100 {
821 price *= 1.001;
822 detector.update(price);
823
824 let probs = detector.state_probabilities();
825 let sum: f64 = probs.iter().sum();
826 assert!(
827 (sum - 1.0).abs() < 1e-6,
828 "State probabilities should sum to 1.0, got: {sum}"
829 );
830 }
831 }
832
833 #[test]
834 fn test_transition_matrix_rows_sum_to_one() {
835 let detector = HMMRegimeDetector::default_config();
836 let tm = detector.transition_matrix();
837
838 for (i, row) in tm.iter().enumerate() {
839 let sum: f64 = row.iter().sum();
840 assert!(
841 (sum - 1.0).abs() < 1e-6,
842 "Transition matrix row {i} should sum to 1.0, got: {sum}"
843 );
844 }
845 }
846
847 #[test]
848 fn test_expected_regime_duration() {
849 let detector = HMMRegimeDetector::default_config();
850
851 let duration = detector.expected_regime_duration(0);
853 assert!(
854 (duration - 10.0).abs() < 1e-6,
855 "Expected duration should be ~10 with 0.9 persistence, got: {duration}"
856 );
857 }
858
859 #[test]
860 fn test_predict_next_state() {
861 let mut detector = HMMRegimeDetector::crypto_optimized();
862
863 let mut price = 100.0;
864 for _ in 0..100 {
865 price *= 1.002;
866 detector.update(price);
867 }
868
869 let (next_state, prob) = detector.predict_next_state();
870 assert!(next_state < detector.config().n_states);
871 assert!(
872 (0.0..=1.0).contains(&prob),
873 "Predicted probability should be in [0, 1]: {prob}"
874 );
875 }
876
877 #[test]
878 fn test_state_parameters() {
879 let detector = HMMRegimeDetector::default_config();
880 let params = detector.state_parameters();
881
882 assert_eq!(params.len(), 3, "Should have 3 state parameters");
883
884 for (mean, variance) in ¶ms {
885 assert!(variance > &0.0, "Variance should be positive: {variance}");
886 assert!(mean.is_finite(), "Mean should be finite: {mean}");
887 }
888 }
889
890 #[test]
891 fn test_update_ohlc_uses_close() {
892 let mut det1 = HMMRegimeDetector::crypto_optimized();
893 let mut det2 = HMMRegimeDetector::crypto_optimized();
894
895 for i in 0..100 {
897 let close = 100.0 + i as f64 * 0.1;
898 let r1 = det1.update(close);
899 let r2 = det2.update_ohlc(close * 1.01, close * 0.99, close);
900
901 assert_eq!(
902 r1.regime, r2.regime,
903 "update and update_ohlc should produce same regime"
904 );
905 }
906 }
907
908 #[test]
909 fn test_n_observations_tracking() {
910 let mut detector = HMMRegimeDetector::crypto_optimized();
911
912 assert_eq!(detector.n_observations(), 0);
913
914 for i in 0..50 {
915 detector.update(100.0 + i as f64);
916 }
917
918 assert_eq!(detector.n_observations(), 49);
920 }
921
922 #[test]
923 fn test_confidence_range() {
924 let mut detector = HMMRegimeDetector::crypto_optimized();
925
926 let mut price = 100.0;
927 for _ in 0..200 {
928 price *= 1.002;
929 detector.update(price);
930 }
931
932 let confidence = detector.current_confidence();
933 assert!(
934 (0.0..=1.0).contains(&confidence),
935 "Confidence should be in [0, 1]: {confidence}"
936 );
937 }
938}