quantwave_core/regimes/
hmm.rs1use crate::traits::Next;
11use crate::regimes::MarketRegime;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct HMM {
17 n_states: usize,
18 a: Vec<Vec<f64>>,
20 means: Vec<f64>,
22 stds: Vec<f64>,
24 pi: Vec<f64>,
26 last_delta: Vec<f64>,
28 initialized: bool,
29}
30
31impl HMM {
32 pub fn new(
34 a: Vec<Vec<f64>>,
35 means: Vec<f64>,
36 stds: Vec<f64>,
37 pi: Vec<f64>,
38 ) -> Self {
39 let n_states = a.len();
40 Self {
41 n_states,
42 a,
43 means,
44 stds,
45 pi,
46 last_delta: vec![0.0; n_states],
47 initialized: false,
48 }
49 }
50
51 pub fn bull_bear() -> Self {
53 Self::new(
54 vec![
55 vec![0.95, 0.05], vec![0.10, 0.90], ],
58 vec![0.001, -0.002], vec![0.01, 0.02], vec![0.5, 0.5],
61 )
62 }
63
64 fn gaussian_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
66 let variance = sigma * sigma;
67 let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
68 let exponent = -((x - mu).powi(2)) / (2.0 * variance);
69 exponent.exp() / denom
70 }
71}
72
73impl Next<f64> for HMM {
74 type Output = MarketRegime;
75
76 fn next(&mut self, x: f64) -> Self::Output {
77 let mut next_delta = vec![0.0; self.n_states];
78 let mut best_state = 0;
79 let mut max_prob = -f64::INFINITY;
80
81 if !self.initialized {
82 for i in 0..self.n_states {
83 let emission = Self::gaussian_pdf(x, self.means[i], self.stds[i]);
84 next_delta[i] = (self.pi[i] * emission).ln();
85 if next_delta[i] > max_prob {
86 max_prob = next_delta[i];
87 best_state = i;
88 }
89 }
90 self.initialized = true;
91 } else {
92 for j in 0..self.n_states {
93 let mut max_prev = -f64::INFINITY;
94 for i in 0..self.n_states {
95 let prob = self.last_delta[i] + self.a[i][j].ln();
97 if prob > max_prev {
98 max_prev = prob;
99 }
100 }
101 let emission = Self::gaussian_pdf(x, self.means[j], self.stds[j]);
102 next_delta[j] = max_prev + emission.ln();
103
104 if next_delta[j] > max_prob {
105 max_prob = next_delta[j];
106 best_state = j;
107 }
108 }
109 }
110
111 self.last_delta = next_delta;
112
113 match best_state {
114 0 => MarketRegime::Bull,
115 1 => MarketRegime::Bear,
116 _ => MarketRegime::Cluster(best_state as u8),
117 }
118 }
119}