1use std::collections::{HashMap, VecDeque};
12
13use serde::{Deserialize, Serialize};
14
15use super::types::{MarketRegime, RegimeConfidence, TrendDirection};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22#[derive(Debug, Clone)]
34pub struct HmmIndicator {
35 pub config: HMMConfig,
36}
37
38impl HmmIndicator {
39 pub fn new(config: HMMConfig) -> Self {
40 Self { config }
41 }
42
43 pub fn with_defaults() -> Self {
44 Self::new(HMMConfig::default())
45 }
46}
47
48fn hmm_regime_id(r: MarketRegime) -> f64 {
49 match r {
50 MarketRegime::MeanReverting => 1.0,
51 MarketRegime::Volatile => 2.0,
52 MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
53 MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
54 MarketRegime::Uncertain => 0.0,
55 }
56}
57
58impl Indicator for HmmIndicator {
59 fn name(&self) -> &'static str {
60 "HMMRegime"
61 }
62
63 fn required_len(&self) -> usize {
68 self.config.min_observations + 1
69 }
70
71 fn required_columns(&self) -> &[&'static str] {
72 &["close"]
73 }
74
75 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76 self.check_len(candles)?;
77 let mut det = HMMRegimeDetector::new(self.config.clone());
78 let n = candles.len();
79 let mut conf = vec![f64::NAN; n];
80 let mut regime = vec![f64::NAN; n];
81 for (i, c) in candles.iter().enumerate() {
82 let rc = det.update(c.close);
83 conf[i] = rc.confidence;
84 regime[i] = hmm_regime_id(rc.regime);
85 }
86 Ok(IndicatorOutput::from_pairs([
87 ("hmm_conf", conf),
88 ("hmm_regime_id", regime),
89 ]))
90 }
91}
92
93pub fn factory<S: ::std::hash::BuildHasher>(
96 params: &HashMap<String, String, S>,
97) -> Result<Box<dyn Indicator>, IndicatorError> {
98 let min_observations = param_usize(params, "min_observations", 100)?;
99 let n_states = param_usize(params, "n_states", 3)?;
100 let config = HMMConfig {
101 n_states,
102 min_observations,
103 ..HMMConfig::default()
104 };
105 Ok(Box::new(HmmIndicator::new(config)))
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct HMMConfig {
111 pub n_states: usize,
113 pub min_observations: usize,
115 pub learning_rate: f64,
117 pub transition_smoothing: f64,
119 pub lookback_window: usize,
121 pub min_confidence: f64,
123}
124
125impl Default for HMMConfig {
126 fn default() -> Self {
127 Self {
128 n_states: 3, min_observations: 100,
130 learning_rate: 0.01,
131 transition_smoothing: 0.1,
132 lookback_window: 252, min_confidence: 0.6,
134 }
135 }
136}
137
138impl HMMConfig {
139 pub fn crypto_optimized() -> Self {
141 Self {
142 n_states: 3,
143 min_observations: 50,
144 learning_rate: 0.02, transition_smoothing: 0.05,
146 lookback_window: 100,
147 min_confidence: 0.5,
148 }
149 }
150
151 pub fn conservative() -> Self {
153 Self {
154 n_states: 2, min_observations: 150,
156 learning_rate: 0.005,
157 transition_smoothing: 0.15,
158 lookback_window: 500,
159 min_confidence: 0.7,
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
166struct GaussianState {
167 mean: f64,
168 variance: f64,
169 sum: f64,
171 sum_sq: f64,
172 count: usize,
173}
174
175impl GaussianState {
176 fn new(mean: f64, variance: f64) -> Self {
177 Self {
178 mean,
179 variance,
180 sum: 0.0,
181 sum_sq: 0.0,
182 count: 0,
183 }
184 }
185
186 fn pdf(&self, x: f64) -> f64 {
188 let diff = x - self.mean;
189 let exponent = -0.5 * diff * diff / self.variance;
190 let normalizer = (2.0 * std::f64::consts::PI * self.variance).sqrt();
191 exponent.exp() / normalizer
192 }
193
194 fn update(&mut self, x: f64, weight: f64, learning_rate: f64) {
196 if learning_rate > 0.0 {
197 self.mean = (1.0 - learning_rate * weight) * self.mean + learning_rate * weight * x;
199 let new_var = (x - self.mean).powi(2);
200 self.variance =
201 (1.0 - learning_rate * weight) * self.variance + learning_rate * weight * new_var;
202 self.variance = self.variance.max(1e-8); }
204
205 self.sum += x * weight;
207 self.sum_sq += x * x * weight;
208 self.count += 1;
209 }
210}
211
212#[derive(Debug)]
239pub struct HMMRegimeDetector {
240 config: HMMConfig,
241
242 states: Vec<GaussianState>,
244
245 transition_matrix: Vec<Vec<f64>>,
247
248 initial_probs: Vec<f64>,
250
251 state_probs: Vec<f64>,
253
254 returns_history: VecDeque<f64>,
256
257 prices: VecDeque<f64>,
259
260 current_state: usize,
262
263 current_confidence: f64,
265
266 n_observations: usize,
268
269 last_regime: MarketRegime,
271}
272
273impl HMMRegimeDetector {
274 pub fn new(config: HMMConfig) -> Self {
276 let n = config.n_states;
277
278 let states = match n {
283 2 => vec![
284 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0004), ],
287 3 => vec![
288 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0002), GaussianState::new(0.0, 0.0009), ],
292 _ => (0..n)
293 .map(|i| {
294 let mean = (i as f64 - n as f64 / 2.0) * 0.001;
295 let var = 0.0001 * (1.0 + i as f64);
296 GaussianState::new(mean, var)
297 })
298 .collect(),
299 };
300
301 let mut transition_matrix = vec![vec![0.0; n]; n];
304 for (i, row) in transition_matrix.iter_mut().enumerate().take(n) {
305 for (j, cell) in row.iter_mut().enumerate().take(n) {
306 if i == j {
307 *cell = 0.9; } else {
309 *cell = 0.1 / (n - 1) as f64;
310 }
311 }
312 }
313
314 let initial_probs = vec![1.0 / n as f64; n];
316 let state_probs = initial_probs.clone();
317
318 Self {
319 config: config.clone(),
320 states,
321 transition_matrix,
322 initial_probs,
323 state_probs,
324 returns_history: VecDeque::with_capacity(config.lookback_window),
325 prices: VecDeque::with_capacity(10),
326 current_state: 0,
327 current_confidence: 0.0,
328 n_observations: 0,
329 last_regime: MarketRegime::Uncertain,
330 }
331 }
332
333 pub fn default_config() -> Self {
335 Self::new(HMMConfig::default())
336 }
337
338 pub fn crypto_optimized() -> Self {
340 Self::new(HMMConfig::crypto_optimized())
341 }
342
343 pub fn conservative() -> Self {
345 Self::new(HMMConfig::conservative())
346 }
347
348 pub fn update(&mut self, close: f64) -> RegimeConfidence {
353 if let Some(&prev_close) = self.prices.back()
355 && prev_close > 0.0
356 {
357 let log_return = (close / prev_close).ln();
358 self.process_return(log_return);
359 }
360
361 self.prices.push_back(close);
363 if self.prices.len() > 10 {
364 self.prices.pop_front();
365 }
366
367 let confidence = self.get_regime_confidence();
369 self.last_regime = confidence.regime;
370 confidence
371 }
372
373 pub fn update_ohlc(&mut self, _high: f64, _low: f64, close: f64) -> RegimeConfidence {
375 self.update(close)
376 }
377
378 fn process_return(&mut self, ret: f64) {
380 self.n_observations += 1;
381
382 self.returns_history.push_back(ret);
384 if self.returns_history.len() > self.config.lookback_window {
385 self.returns_history.pop_front();
386 }
387
388 self.forward_step(ret);
390
391 if self.n_observations > self.config.min_observations && self.config.learning_rate > 0.0 {
393 self.online_parameter_update(ret);
394 }
395
396 let reestimate_interval = self.config.lookback_window / 2;
398 if self.n_observations > 0
399 && reestimate_interval > 0
400 && self.n_observations.is_multiple_of(reestimate_interval)
401 && self.returns_history.len() >= self.config.min_observations
402 {
403 self.baum_welch_update();
404 }
405 }
406
407 fn forward_step(&mut self, ret: f64) {
409 let n = self.config.n_states;
410 let mut new_probs = vec![0.0; n];
411
412 let emissions: Vec<f64> = self.states.iter().map(|s| s.pdf(ret)).collect();
414
415 for j in 0..n {
417 let mut sum = 0.0;
418 for i in 0..n {
419 sum += self.transition_matrix[i][j] * self.state_probs[i];
420 }
421 new_probs[j] = emissions[j] * sum;
422 }
423
424 let total: f64 = new_probs.iter().sum();
426 if total > 1e-300 {
427 for p in &mut new_probs {
428 *p /= total;
429 }
430 } else {
431 new_probs = vec![1.0 / n as f64; n];
433 }
434
435 self.state_probs = new_probs;
436
437 let (max_idx, max_prob) = self
441 .state_probs
442 .iter()
443 .enumerate()
444 .max_by(|(_, a), (_, b)| a.total_cmp(b))
445 .map_or((0, 1.0 / n as f64), |(i, p)| (i, *p));
446
447 self.current_state = max_idx;
448 self.current_confidence = max_prob;
449 }
450
451 fn online_parameter_update(&mut self, ret: f64) {
453 let lr = self.config.learning_rate;
454
455 for (i, state) in self.states.iter_mut().enumerate() {
456 let weight = self.state_probs[i];
457 state.update(ret, weight, lr);
458 }
459
460 let smoothing = self.config.transition_smoothing;
463 for i in 0..self.config.n_states {
464 for j in 0..self.config.n_states {
465 let target = if i == j {
466 0.9
467 } else {
468 0.1 / (self.config.n_states - 1) as f64
469 };
470 self.transition_matrix[i][j] =
471 (1.0 - smoothing) * self.transition_matrix[i][j] + smoothing * target;
472 }
473 }
474 }
475
476 fn baum_welch_update(&mut self) {
482 let returns: Vec<f64> = self.returns_history.iter().copied().collect();
483 if returns.len() < self.config.min_observations {
484 return;
485 }
486
487 let n = self.config.n_states;
488 let t = returns.len();
489
490 let mut alpha = vec![vec![0.0; n]; t];
492
493 for (j, alpha_val) in alpha[0].iter_mut().enumerate().take(n) {
495 *alpha_val = self.initial_probs[j] * self.states[j].pdf(returns[0]);
496 }
497 Self::normalize_vec(&mut alpha[0]);
498
499 for time in 1..t {
501 for j in 0..n {
502 let mut sum = 0.0;
503 for (i, alpha_prev) in alpha[time - 1].iter().enumerate().take(n) {
504 sum += alpha_prev * self.transition_matrix[i][j];
505 }
506 alpha[time][j] = sum * self.states[j].pdf(returns[time]);
507 }
508 Self::normalize_vec(&mut alpha[time]);
509 }
510
511 let mut beta = vec![vec![1.0; n]; t];
513
514 for time in (0..t - 1).rev() {
515 for i in 0..n {
516 let mut sum = 0.0;
517 for (j, beta_next) in beta[time + 1].iter().enumerate().take(n) {
518 sum += self.transition_matrix[i][j]
519 * self.states[j].pdf(returns[time + 1])
520 * beta_next;
521 }
522 beta[time][i] = sum;
523 }
524 Self::normalize_vec(&mut beta[time]);
525 }
526
527 let mut gamma = vec![vec![0.0; n]; t];
529 for time in 0..t {
530 let mut sum = 0.0;
531 for (j, gamma_val) in gamma[time].iter_mut().enumerate().take(n) {
532 *gamma_val = alpha[time][j] * beta[time][j];
533 sum += *gamma_val;
534 }
535 if sum > 1e-300 {
536 for gamma_val in gamma[time].iter_mut().take(n) {
537 *gamma_val /= sum;
538 }
539 }
540 }
541
542 for (j, state) in self.states.iter_mut().enumerate().take(n) {
544 let mut weight_sum = 0.0;
545 let mut mean_sum = 0.0;
546 let mut var_sum = 0.0;
547
548 for time in 0..t {
549 let w = gamma[time][j];
550 weight_sum += w;
551 mean_sum += w * returns[time];
552 }
553
554 if weight_sum > 1e-8 {
555 let new_mean = mean_sum / weight_sum;
556
557 for time in 0..t {
558 let w = gamma[time][j];
559 var_sum += w * (returns[time] - new_mean).powi(2);
560 }
561
562 let new_var = (var_sum / weight_sum).max(1e-8);
563
564 let blend = 0.3;
566 state.mean = (1.0 - blend) * state.mean + blend * new_mean;
567 state.variance = (1.0 - blend) * state.variance + blend * new_var;
568 }
569 }
570 }
571
572 fn normalize_vec(vec: &mut [f64]) {
574 let sum: f64 = vec.iter().sum();
575 if sum > 1e-300 {
576 for v in vec.iter_mut() {
577 *v /= sum;
578 }
579 }
580 }
581
582 pub fn get_regime_confidence(&self) -> RegimeConfidence {
584 if self.n_observations < self.config.min_observations {
585 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
586 }
587
588 let regime = self.state_to_regime(self.current_state);
589 let confidence = self.current_confidence;
590
591 RegimeConfidence::with_metrics(
592 regime,
593 confidence,
594 self.states[self.current_state].mean * 100.0 * 252.0, self.states[self.current_state].variance.sqrt() * 100.0 * 252.0_f64.sqrt(), 0.0, )
598 }
599
600 fn state_to_regime(&self, state: usize) -> MarketRegime {
608 let state_params = &self.states[state];
609 let mean = state_params.mean;
610 let vol = state_params.variance.sqrt();
611
612 let is_high_vol = vol > 0.02; let is_positive = mean > 0.0005; let is_negative = mean < -0.0005;
616
617 if is_high_vol {
618 MarketRegime::Volatile
619 } else if is_positive {
620 MarketRegime::Trending(TrendDirection::Bullish)
621 } else if is_negative {
622 MarketRegime::Trending(TrendDirection::Bearish)
623 } else {
624 MarketRegime::MeanReverting }
626 }
627
628 pub fn state_probabilities(&self) -> &[f64] {
634 &self.state_probs
635 }
636
637 pub fn state_parameters(&self) -> Vec<(f64, f64)> {
639 self.states.iter().map(|s| (s.mean, s.variance)).collect()
640 }
641
642 pub fn transition_matrix(&self) -> &[Vec<f64>] {
644 &self.transition_matrix
645 }
646
647 pub fn current_state_index(&self) -> usize {
649 self.current_state
650 }
651
652 pub fn is_ready(&self) -> bool {
654 self.n_observations >= self.config.min_observations
655 }
656
657 pub fn expected_regime_duration(&self, state: usize) -> f64 {
661 if state < self.config.n_states {
662 1.0 / (1.0 - self.transition_matrix[state][state])
663 } else {
664 0.0
665 }
666 }
667
668 pub fn predict_next_state(&self) -> (usize, f64) {
670 let mut next_probs = vec![0.0; self.config.n_states];
671
672 for (j, next_prob) in next_probs.iter_mut().enumerate().take(self.config.n_states) {
673 for i in 0..self.config.n_states {
674 *next_prob += self.transition_matrix[i][j] * self.state_probs[i];
675 }
676 }
677
678 let (max_idx, max_prob) = next_probs
679 .iter()
680 .enumerate()
681 .max_by(|(_, a), (_, b)| a.total_cmp(b))
682 .map_or((0, 0.0), |(i, p)| (i, *p));
683
684 (max_idx, max_prob)
685 }
686
687 pub fn n_observations(&self) -> usize {
689 self.n_observations
690 }
691
692 pub fn current_confidence(&self) -> f64 {
694 self.current_confidence
695 }
696
697 pub fn config(&self) -> &HMMConfig {
699 &self.config
700 }
701}
702
703#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
712 fn test_hmm_initialization() {
713 let detector = HMMRegimeDetector::default_config();
714 assert!(!detector.is_ready());
715 assert_eq!(detector.state_probabilities().len(), 3);
716 }
717
718 #[test]
719 fn test_hmm_crypto_config() {
720 let detector = HMMRegimeDetector::crypto_optimized();
721 assert_eq!(detector.config().n_states, 3);
722 assert_eq!(detector.config().min_observations, 50);
723 }
724
725 #[test]
726 fn test_hmm_conservative_config() {
727 let detector = HMMRegimeDetector::conservative();
728 assert_eq!(detector.config().n_states, 2);
729 assert_eq!(detector.config().min_observations, 150);
730 assert_eq!(detector.state_probabilities().len(), 2);
731 }
732
733 #[test]
734 fn test_hmm_warmup() {
735 let mut detector = HMMRegimeDetector::crypto_optimized();
736
737 for i in 0..49 {
739 let price = 100.0 + (i as f64) * 0.01;
740 let result = detector.update(price);
741 assert_eq!(
742 result.regime,
743 MarketRegime::Uncertain,
744 "Should be Uncertain during warmup at step {i}"
745 );
746 }
747
748 assert!(!detector.is_ready());
749 }
750
751 #[test]
752 fn test_hmm_becomes_ready() {
753 let mut detector = HMMRegimeDetector::crypto_optimized();
754
755 for i in 0..60 {
756 let price = 100.0 + (i as f64) * 0.01;
757 detector.update(price);
758 }
759
760 assert!(detector.is_ready(), "Should be ready after 60 observations");
761 }
762
763 #[test]
764 fn test_bull_market_detection() {
765 let mut detector = HMMRegimeDetector::crypto_optimized();
766
767 let mut price = 100.0;
769 for _ in 0..200 {
770 price *= 1.005; let result = detector.update(price);
772 if detector.is_ready() {
773 assert_ne!(result.regime, MarketRegime::Uncertain);
775 }
776 }
777
778 let final_result = detector.get_regime_confidence();
779 assert!(
781 matches!(
782 final_result.regime,
783 MarketRegime::Trending(TrendDirection::Bullish)
784 ),
785 "Expected Bullish trend, got: {:?}",
786 final_result.regime
787 );
788 }
789
790 #[test]
791 fn test_volatile_market_detection() {
792 let mut detector = HMMRegimeDetector::crypto_optimized();
793
794 let mut price = 100.0;
796 for i in 0..200 {
797 if i % 2 == 0 {
798 price *= 1.05; } else {
800 price *= 0.95; }
802 detector.update(price);
803 }
804
805 let result = detector.get_regime_confidence();
806 assert!(
808 matches!(
809 result.regime,
810 MarketRegime::Volatile | MarketRegime::MeanReverting
811 ),
812 "Expected Volatile or MeanReverting for choppy data, got: {:?}",
813 result.regime
814 );
815 }
816
817 #[test]
818 fn test_state_probabilities_sum_to_one() {
819 let mut detector = HMMRegimeDetector::crypto_optimized();
820
821 let mut price = 100.0;
822 for _ in 0..100 {
823 price *= 1.001;
824 detector.update(price);
825
826 let probs = detector.state_probabilities();
827 let sum: f64 = probs.iter().sum();
828 assert!(
829 (sum - 1.0).abs() < 1e-6,
830 "State probabilities should sum to 1.0, got: {sum}"
831 );
832 }
833 }
834
835 #[test]
836 fn test_transition_matrix_rows_sum_to_one() {
837 let detector = HMMRegimeDetector::default_config();
838 let tm = detector.transition_matrix();
839
840 for (i, row) in tm.iter().enumerate() {
841 let sum: f64 = row.iter().sum();
842 assert!(
843 (sum - 1.0).abs() < 1e-6,
844 "Transition matrix row {i} should sum to 1.0, got: {sum}"
845 );
846 }
847 }
848
849 #[test]
850 fn test_expected_regime_duration() {
851 let detector = HMMRegimeDetector::default_config();
852
853 let duration = detector.expected_regime_duration(0);
855 assert!(
856 (duration - 10.0).abs() < 1e-6,
857 "Expected duration should be ~10 with 0.9 persistence, got: {duration}"
858 );
859 }
860
861 #[test]
862 fn test_predict_next_state() {
863 let mut detector = HMMRegimeDetector::crypto_optimized();
864
865 let mut price = 100.0;
866 for _ in 0..100 {
867 price *= 1.002;
868 detector.update(price);
869 }
870
871 let (next_state, prob) = detector.predict_next_state();
872 assert!(next_state < detector.config().n_states);
873 assert!(
874 (0.0..=1.0).contains(&prob),
875 "Predicted probability should be in [0, 1]: {prob}"
876 );
877 }
878
879 #[test]
880 fn test_state_parameters() {
881 let detector = HMMRegimeDetector::default_config();
882 let params = detector.state_parameters();
883
884 assert_eq!(params.len(), 3, "Should have 3 state parameters");
885
886 for (mean, variance) in ¶ms {
887 assert!(variance > &0.0, "Variance should be positive: {variance}");
888 assert!(mean.is_finite(), "Mean should be finite: {mean}");
889 }
890 }
891
892 #[test]
893 fn test_update_ohlc_uses_close() {
894 let mut det1 = HMMRegimeDetector::crypto_optimized();
895 let mut det2 = HMMRegimeDetector::crypto_optimized();
896
897 for i in 0..100 {
899 let close = 100.0 + i as f64 * 0.1;
900 let r1 = det1.update(close);
901 let r2 = det2.update_ohlc(close * 1.01, close * 0.99, close);
902
903 assert_eq!(
904 r1.regime, r2.regime,
905 "update and update_ohlc should produce same regime"
906 );
907 }
908 }
909
910 #[test]
911 fn test_n_observations_tracking() {
912 let mut detector = HMMRegimeDetector::crypto_optimized();
913
914 assert_eq!(detector.n_observations(), 0);
915
916 for i in 0..50 {
917 detector.update(100.0 + i as f64);
918 }
919
920 assert_eq!(detector.n_observations(), 49);
922 }
923
924 #[test]
925 fn test_confidence_range() {
926 let mut detector = HMMRegimeDetector::crypto_optimized();
927
928 let mut price = 100.0;
929 for _ in 0..200 {
930 price *= 1.002;
931 detector.update(price);
932 }
933
934 let confidence = detector.current_confidence();
935 assert!(
936 (0.0..=1.0).contains(&confidence),
937 "Confidence should be in [0, 1]: {confidence}"
938 );
939 }
940}