1use super::types::{MarketRegime, RegimeConfidence, TrendDirection};
12use serde::{Deserialize, Serialize};
13use std::collections::VecDeque;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct HMMConfig {
18 pub n_states: usize,
20 pub min_observations: usize,
22 pub learning_rate: f64,
24 pub transition_smoothing: f64,
26 pub lookback_window: usize,
28 pub min_confidence: f64,
30}
31
32impl Default for HMMConfig {
33 fn default() -> Self {
34 Self {
35 n_states: 3, min_observations: 100,
37 learning_rate: 0.01,
38 transition_smoothing: 0.1,
39 lookback_window: 252, min_confidence: 0.6,
41 }
42 }
43}
44
45impl HMMConfig {
46 pub fn crypto_optimized() -> Self {
48 Self {
49 n_states: 3,
50 min_observations: 50,
51 learning_rate: 0.02, transition_smoothing: 0.05,
53 lookback_window: 100,
54 min_confidence: 0.5,
55 }
56 }
57
58 pub fn conservative() -> Self {
60 Self {
61 n_states: 2, min_observations: 150,
63 learning_rate: 0.005,
64 transition_smoothing: 0.15,
65 lookback_window: 500,
66 min_confidence: 0.7,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73struct GaussianState {
74 mean: f64,
75 variance: f64,
76 sum: f64,
78 sum_sq: f64,
79 count: usize,
80}
81
82impl GaussianState {
83 fn new(mean: f64, variance: f64) -> Self {
84 Self {
85 mean,
86 variance,
87 sum: 0.0,
88 sum_sq: 0.0,
89 count: 0,
90 }
91 }
92
93 fn pdf(&self, x: f64) -> f64 {
95 let diff = x - self.mean;
96 let exponent = -0.5 * diff * diff / self.variance;
97 let normalizer = (2.0 * std::f64::consts::PI * self.variance).sqrt();
98 exponent.exp() / normalizer
99 }
100
101 fn update(&mut self, x: f64, weight: f64, learning_rate: f64) {
103 if learning_rate > 0.0 {
104 self.mean = (1.0 - learning_rate * weight) * self.mean + learning_rate * weight * x;
106 let new_var = (x - self.mean).powi(2);
107 self.variance =
108 (1.0 - learning_rate * weight) * self.variance + learning_rate * weight * new_var;
109 self.variance = self.variance.max(1e-8); }
111
112 self.sum += x * weight;
114 self.sum_sq += x * x * weight;
115 self.count += 1;
116 }
117}
118
119#[derive(Debug)]
146pub struct HMMRegimeDetector {
147 config: HMMConfig,
148
149 states: Vec<GaussianState>,
151
152 transition_matrix: Vec<Vec<f64>>,
154
155 initial_probs: Vec<f64>,
157
158 state_probs: Vec<f64>,
160
161 returns_history: VecDeque<f64>,
163
164 prices: VecDeque<f64>,
166
167 current_state: usize,
169
170 current_confidence: f64,
172
173 n_observations: usize,
175
176 #[allow(dead_code)]
178 last_regime: MarketRegime,
179}
180
181impl HMMRegimeDetector {
182 pub fn new(config: HMMConfig) -> Self {
184 let n = config.n_states;
185
186 let states = match n {
191 2 => vec![
192 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0004), ],
195 3 => vec![
196 GaussianState::new(0.001, 0.0001), GaussianState::new(-0.001, 0.0002), GaussianState::new(0.0, 0.0009), ],
200 _ => (0..n)
201 .map(|i| {
202 let mean = (i as f64 - n as f64 / 2.0) * 0.001;
203 let var = 0.0001 * (1.0 + i as f64);
204 GaussianState::new(mean, var)
205 })
206 .collect(),
207 };
208
209 let mut transition_matrix = vec![vec![0.0; n]; n];
212 for (i, row) in transition_matrix.iter_mut().enumerate().take(n) {
213 for (j, cell) in row.iter_mut().enumerate().take(n) {
214 if i == j {
215 *cell = 0.9; } else {
217 *cell = 0.1 / (n - 1) as f64;
218 }
219 }
220 }
221
222 let initial_probs = vec![1.0 / n as f64; n];
224 let state_probs = initial_probs.clone();
225
226 Self {
227 config: config.clone(),
228 states,
229 transition_matrix,
230 initial_probs,
231 state_probs,
232 returns_history: VecDeque::with_capacity(config.lookback_window),
233 prices: VecDeque::with_capacity(10),
234 current_state: 0,
235 current_confidence: 0.0,
236 n_observations: 0,
237 last_regime: MarketRegime::Uncertain,
238 }
239 }
240
241 pub fn default_config() -> Self {
243 Self::new(HMMConfig::default())
244 }
245
246 pub fn crypto_optimized() -> Self {
248 Self::new(HMMConfig::crypto_optimized())
249 }
250
251 pub fn conservative() -> Self {
253 Self::new(HMMConfig::conservative())
254 }
255
256 pub fn update(&mut self, close: f64) -> RegimeConfidence {
261 if let Some(&prev_close) = self.prices.back()
263 && prev_close > 0.0
264 {
265 let log_return = (close / prev_close).ln();
266 self.process_return(log_return);
267 }
268
269 self.prices.push_back(close);
271 if self.prices.len() > 10 {
272 self.prices.pop_front();
273 }
274
275 self.get_regime_confidence()
277 }
278
279 pub fn update_ohlc(&mut self, _high: f64, _low: f64, close: f64) -> RegimeConfidence {
281 self.update(close)
282 }
283
284 fn process_return(&mut self, ret: f64) {
286 self.n_observations += 1;
287
288 self.returns_history.push_back(ret);
290 if self.returns_history.len() > self.config.lookback_window {
291 self.returns_history.pop_front();
292 }
293
294 self.forward_step(ret);
296
297 if self.n_observations > self.config.min_observations && self.config.learning_rate > 0.0 {
299 self.online_parameter_update(ret);
300 }
301
302 let reestimate_interval = self.config.lookback_window / 2;
304 if self.n_observations > 0
305 && reestimate_interval > 0
306 && self.n_observations.is_multiple_of(reestimate_interval)
307 && self.returns_history.len() >= self.config.min_observations
308 {
309 self.baum_welch_update();
310 }
311 }
312
313 fn forward_step(&mut self, ret: f64) {
315 let n = self.config.n_states;
316 let mut new_probs = vec![0.0; n];
317
318 let emissions: Vec<f64> = self.states.iter().map(|s| s.pdf(ret)).collect();
320
321 for j in 0..n {
323 let mut sum = 0.0;
324 for i in 0..n {
325 sum += self.transition_matrix[i][j] * self.state_probs[i];
326 }
327 new_probs[j] = emissions[j] * sum;
328 }
329
330 let total: f64 = new_probs.iter().sum();
332 if total > 1e-300 {
333 for p in &mut new_probs {
334 *p /= total;
335 }
336 } else {
337 new_probs = vec![1.0 / n as f64; n];
339 }
340
341 self.state_probs = new_probs;
342
343 let (max_idx, max_prob) = self
345 .state_probs
346 .iter()
347 .enumerate()
348 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
349 .unwrap();
350
351 self.current_state = max_idx;
352 self.current_confidence = *max_prob;
353 }
354
355 fn online_parameter_update(&mut self, ret: f64) {
357 let lr = self.config.learning_rate;
358
359 for (i, state) in self.states.iter_mut().enumerate() {
360 let weight = self.state_probs[i];
361 state.update(ret, weight, lr);
362 }
363
364 let smoothing = self.config.transition_smoothing;
367 for i in 0..self.config.n_states {
368 for j in 0..self.config.n_states {
369 let target = if i == j {
370 0.9
371 } else {
372 0.1 / (self.config.n_states - 1) as f64
373 };
374 self.transition_matrix[i][j] =
375 (1.0 - smoothing) * self.transition_matrix[i][j] + smoothing * target;
376 }
377 }
378 }
379
380 fn baum_welch_update(&mut self) {
386 let returns: Vec<f64> = self.returns_history.iter().copied().collect();
387 if returns.len() < self.config.min_observations {
388 return;
389 }
390
391 let n = self.config.n_states;
392 let t = returns.len();
393
394 let mut alpha = vec![vec![0.0; n]; t];
396
397 for (j, alpha_val) in alpha[0].iter_mut().enumerate().take(n) {
399 *alpha_val = self.initial_probs[j] * self.states[j].pdf(returns[0]);
400 }
401 Self::normalize_vec(&mut alpha[0]);
402
403 for time in 1..t {
405 for j in 0..n {
406 let mut sum = 0.0;
407 for (i, alpha_prev) in alpha[time - 1].iter().enumerate().take(n) {
408 sum += alpha_prev * self.transition_matrix[i][j];
409 }
410 alpha[time][j] = sum * self.states[j].pdf(returns[time]);
411 }
412 Self::normalize_vec(&mut alpha[time]);
413 }
414
415 let mut beta = vec![vec![1.0; n]; t];
417
418 for time in (0..t - 1).rev() {
419 for i in 0..n {
420 let mut sum = 0.0;
421 for (j, beta_next) in beta[time + 1].iter().enumerate().take(n) {
422 sum += self.transition_matrix[i][j]
423 * self.states[j].pdf(returns[time + 1])
424 * beta_next;
425 }
426 beta[time][i] = sum;
427 }
428 Self::normalize_vec(&mut beta[time]);
429 }
430
431 let mut gamma = vec![vec![0.0; n]; t];
433 for time in 0..t {
434 let mut sum = 0.0;
435 for (j, gamma_val) in gamma[time].iter_mut().enumerate().take(n) {
436 *gamma_val = alpha[time][j] * beta[time][j];
437 sum += *gamma_val;
438 }
439 if sum > 1e-300 {
440 for gamma_val in gamma[time].iter_mut().take(n) {
441 *gamma_val /= sum;
442 }
443 }
444 }
445
446 for (j, state) in self.states.iter_mut().enumerate().take(n) {
448 let mut weight_sum = 0.0;
449 let mut mean_sum = 0.0;
450 let mut var_sum = 0.0;
451
452 for time in 0..t {
453 let w = gamma[time][j];
454 weight_sum += w;
455 mean_sum += w * returns[time];
456 }
457
458 if weight_sum > 1e-8 {
459 let new_mean = mean_sum / weight_sum;
460
461 for time in 0..t {
462 let w = gamma[time][j];
463 var_sum += w * (returns[time] - new_mean).powi(2);
464 }
465
466 let new_var = (var_sum / weight_sum).max(1e-8);
467
468 let blend = 0.3;
470 state.mean = (1.0 - blend) * state.mean + blend * new_mean;
471 state.variance = (1.0 - blend) * state.variance + blend * new_var;
472 }
473 }
474 }
475
476 fn normalize_vec(vec: &mut [f64]) {
478 let sum: f64 = vec.iter().sum();
479 if sum > 1e-300 {
480 for v in vec.iter_mut() {
481 *v /= sum;
482 }
483 }
484 }
485
486 pub fn get_regime_confidence(&self) -> RegimeConfidence {
488 if self.n_observations < self.config.min_observations {
489 return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
490 }
491
492 let regime = self.state_to_regime(self.current_state);
493 let confidence = self.current_confidence;
494
495 RegimeConfidence::with_metrics(
496 regime,
497 confidence,
498 self.states[self.current_state].mean * 100.0 * 252.0, self.states[self.current_state].variance.sqrt() * 100.0 * 252.0_f64.sqrt(), 0.0, )
502 }
503
504 fn state_to_regime(&self, state: usize) -> MarketRegime {
512 let state_params = &self.states[state];
513 let mean = state_params.mean;
514 let vol = state_params.variance.sqrt();
515
516 let is_high_vol = vol > 0.02; let is_positive = mean > 0.0005; let is_negative = mean < -0.0005;
520
521 if is_high_vol {
522 MarketRegime::Volatile
523 } else if is_positive {
524 MarketRegime::Trending(TrendDirection::Bullish)
525 } else if is_negative {
526 MarketRegime::Trending(TrendDirection::Bearish)
527 } else {
528 MarketRegime::MeanReverting }
530 }
531
532 pub fn state_probabilities(&self) -> &[f64] {
538 &self.state_probs
539 }
540
541 pub fn state_parameters(&self) -> Vec<(f64, f64)> {
543 self.states.iter().map(|s| (s.mean, s.variance)).collect()
544 }
545
546 pub fn transition_matrix(&self) -> &[Vec<f64>] {
548 &self.transition_matrix
549 }
550
551 pub fn current_state_index(&self) -> usize {
553 self.current_state
554 }
555
556 pub fn is_ready(&self) -> bool {
558 self.n_observations >= self.config.min_observations
559 }
560
561 pub fn expected_regime_duration(&self, state: usize) -> f64 {
565 if state < self.config.n_states {
566 1.0 / (1.0 - self.transition_matrix[state][state])
567 } else {
568 0.0
569 }
570 }
571
572 pub fn predict_next_state(&self) -> (usize, f64) {
574 let mut next_probs = vec![0.0; self.config.n_states];
575
576 for (j, next_prob) in next_probs.iter_mut().enumerate().take(self.config.n_states) {
577 for i in 0..self.config.n_states {
578 *next_prob += self.transition_matrix[i][j] * self.state_probs[i];
579 }
580 }
581
582 let (max_idx, max_prob) = next_probs
583 .iter()
584 .enumerate()
585 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
586 .unwrap();
587
588 (max_idx, *max_prob)
589 }
590
591 pub fn n_observations(&self) -> usize {
593 self.n_observations
594 }
595
596 pub fn current_confidence(&self) -> f64 {
598 self.current_confidence
599 }
600
601 pub fn config(&self) -> &HMMConfig {
603 &self.config
604 }
605}
606
607#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_hmm_initialization() {
617 let detector = HMMRegimeDetector::default_config();
618 assert!(!detector.is_ready());
619 assert_eq!(detector.state_probabilities().len(), 3);
620 }
621
622 #[test]
623 fn test_hmm_crypto_config() {
624 let detector = HMMRegimeDetector::crypto_optimized();
625 assert_eq!(detector.config().n_states, 3);
626 assert_eq!(detector.config().min_observations, 50);
627 }
628
629 #[test]
630 fn test_hmm_conservative_config() {
631 let detector = HMMRegimeDetector::conservative();
632 assert_eq!(detector.config().n_states, 2);
633 assert_eq!(detector.config().min_observations, 150);
634 assert_eq!(detector.state_probabilities().len(), 2);
635 }
636
637 #[test]
638 fn test_hmm_warmup() {
639 let mut detector = HMMRegimeDetector::crypto_optimized();
640
641 for i in 0..49 {
643 let price = 100.0 + (i as f64) * 0.01;
644 let result = detector.update(price);
645 assert_eq!(
646 result.regime,
647 MarketRegime::Uncertain,
648 "Should be Uncertain during warmup at step {i}"
649 );
650 }
651
652 assert!(!detector.is_ready());
653 }
654
655 #[test]
656 fn test_hmm_becomes_ready() {
657 let mut detector = HMMRegimeDetector::crypto_optimized();
658
659 for i in 0..60 {
660 let price = 100.0 + (i as f64) * 0.01;
661 detector.update(price);
662 }
663
664 assert!(detector.is_ready(), "Should be ready after 60 observations");
665 }
666
667 #[test]
668 fn test_bull_market_detection() {
669 let mut detector = HMMRegimeDetector::crypto_optimized();
670
671 let mut price = 100.0;
673 for _ in 0..200 {
674 price *= 1.005; let result = detector.update(price);
676 if detector.is_ready() {
677 assert_ne!(result.regime, MarketRegime::Uncertain);
679 }
680 }
681
682 let final_result = detector.get_regime_confidence();
683 assert!(
685 matches!(
686 final_result.regime,
687 MarketRegime::Trending(TrendDirection::Bullish)
688 ),
689 "Expected Bullish trend, got: {:?}",
690 final_result.regime
691 );
692 }
693
694 #[test]
695 fn test_volatile_market_detection() {
696 let mut detector = HMMRegimeDetector::crypto_optimized();
697
698 let mut price = 100.0;
700 for i in 0..200 {
701 if i % 2 == 0 {
702 price *= 1.05; } else {
704 price *= 0.95; }
706 detector.update(price);
707 }
708
709 let result = detector.get_regime_confidence();
710 assert!(
712 matches!(
713 result.regime,
714 MarketRegime::Volatile | MarketRegime::MeanReverting
715 ),
716 "Expected Volatile or MeanReverting for choppy data, got: {:?}",
717 result.regime
718 );
719 }
720
721 #[test]
722 fn test_state_probabilities_sum_to_one() {
723 let mut detector = HMMRegimeDetector::crypto_optimized();
724
725 let mut price = 100.0;
726 for _ in 0..100 {
727 price *= 1.001;
728 detector.update(price);
729
730 let probs = detector.state_probabilities();
731 let sum: f64 = probs.iter().sum();
732 assert!(
733 (sum - 1.0).abs() < 1e-6,
734 "State probabilities should sum to 1.0, got: {sum}"
735 );
736 }
737 }
738
739 #[test]
740 fn test_transition_matrix_rows_sum_to_one() {
741 let detector = HMMRegimeDetector::default_config();
742 let tm = detector.transition_matrix();
743
744 for (i, row) in tm.iter().enumerate() {
745 let sum: f64 = row.iter().sum();
746 assert!(
747 (sum - 1.0).abs() < 1e-6,
748 "Transition matrix row {i} should sum to 1.0, got: {sum}"
749 );
750 }
751 }
752
753 #[test]
754 fn test_expected_regime_duration() {
755 let detector = HMMRegimeDetector::default_config();
756
757 let duration = detector.expected_regime_duration(0);
759 assert!(
760 (duration - 10.0).abs() < 1e-6,
761 "Expected duration should be ~10 with 0.9 persistence, got: {duration}"
762 );
763 }
764
765 #[test]
766 fn test_predict_next_state() {
767 let mut detector = HMMRegimeDetector::crypto_optimized();
768
769 let mut price = 100.0;
770 for _ in 0..100 {
771 price *= 1.002;
772 detector.update(price);
773 }
774
775 let (next_state, prob) = detector.predict_next_state();
776 assert!(next_state < detector.config().n_states);
777 assert!(
778 (0.0..=1.0).contains(&prob),
779 "Predicted probability should be in [0, 1]: {prob}"
780 );
781 }
782
783 #[test]
784 fn test_state_parameters() {
785 let detector = HMMRegimeDetector::default_config();
786 let params = detector.state_parameters();
787
788 assert_eq!(params.len(), 3, "Should have 3 state parameters");
789
790 for (mean, variance) in ¶ms {
791 assert!(variance > &0.0, "Variance should be positive: {variance}");
792 assert!(mean.is_finite(), "Mean should be finite: {mean}");
793 }
794 }
795
796 #[test]
797 fn test_update_ohlc_uses_close() {
798 let mut det1 = HMMRegimeDetector::crypto_optimized();
799 let mut det2 = HMMRegimeDetector::crypto_optimized();
800
801 for i in 0..100 {
803 let close = 100.0 + i as f64 * 0.1;
804 let r1 = det1.update(close);
805 let r2 = det2.update_ohlc(close * 1.01, close * 0.99, close);
806
807 assert_eq!(
808 r1.regime, r2.regime,
809 "update and update_ohlc should produce same regime"
810 );
811 }
812 }
813
814 #[test]
815 fn test_n_observations_tracking() {
816 let mut detector = HMMRegimeDetector::crypto_optimized();
817
818 assert_eq!(detector.n_observations(), 0);
819
820 for i in 0..50 {
821 detector.update(100.0 + i as f64);
822 }
823
824 assert_eq!(detector.n_observations(), 49);
826 }
827
828 #[test]
829 fn test_confidence_range() {
830 let mut detector = HMMRegimeDetector::crypto_optimized();
831
832 let mut price = 100.0;
833 for _ in 0..200 {
834 price *= 1.002;
835 detector.update(price);
836 }
837
838 let confidence = detector.current_confidence();
839 assert!(
840 (0.0..=1.0).contains(&confidence),
841 "Confidence should be in [0, 1]: {confidence}"
842 );
843 }
844}