quantwave_core/regimes/
hmm.rs1use crate::traits::Next;
11use crate::regimes::MarketRegime;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct HMM {
17 n_states: usize,
18 a: Vec<Vec<f64>>,
20 means: Vec<f64>,
22 stds: Vec<f64>,
24 pi: Vec<f64>,
26 last_delta: Vec<f64>,
28 initialized: bool,
29}
30
31impl HMM {
32 pub fn new(
34 a: Vec<Vec<f64>>,
35 means: Vec<f64>,
36 stds: Vec<f64>,
37 pi: Vec<f64>,
38 ) -> Self {
39 let n_states = a.len();
40 Self {
41 n_states,
42 a,
43 means,
44 stds,
45 pi,
46 last_delta: vec![0.0; n_states],
47 initialized: false,
48 }
49 }
50
51 pub fn bull_bear() -> Self {
53 Self::new(
54 vec![
55 vec![0.95, 0.05], vec![0.10, 0.90], ],
58 vec![0.001, -0.002], vec![0.01, 0.02], vec![0.5, 0.5],
61 )
62 }
63
64 fn gaussian_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
66 let variance = sigma * sigma;
67 let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
68 let exponent = -((x - mu).powi(2)) / (2.0 * variance);
69 exponent.exp() / denom
70 }
71
72 pub fn state_probabilities(&self) -> Vec<f64> {
74 if !self.initialized {
75 return self.pi.clone();
76 }
77 let max_log = self
78 .last_delta
79 .iter()
80 .cloned()
81 .fold(f64::NEG_INFINITY, f64::max);
82 let mut probs: Vec<f64> = self
83 .last_delta
84 .iter()
85 .map(|&d| (d - max_log).exp())
86 .collect();
87 let sum: f64 = probs.iter().sum();
88 if sum > 0.0 {
89 for p in &mut probs {
90 *p /= sum;
91 }
92 }
93 probs
94 }
95}
96
97impl Next<f64> for HMM {
98 type Output = MarketRegime;
99
100 fn next(&mut self, x: f64) -> Self::Output {
101 let mut next_delta = vec![0.0; self.n_states];
102 let mut best_state = 0;
103 let mut max_prob = -f64::INFINITY;
104
105 if !self.initialized {
106 for i in 0..self.n_states {
107 let emission = Self::gaussian_pdf(x, self.means[i], self.stds[i]);
108 next_delta[i] = (self.pi[i] * emission).ln();
109 if next_delta[i] > max_prob {
110 max_prob = next_delta[i];
111 best_state = i;
112 }
113 }
114 self.initialized = true;
115 } else {
116 for j in 0..self.n_states {
117 let mut max_prev = -f64::INFINITY;
118 for i in 0..self.n_states {
119 let prob = self.last_delta[i] + self.a[i][j].ln();
121 if prob > max_prev {
122 max_prev = prob;
123 }
124 }
125 let emission = Self::gaussian_pdf(x, self.means[j], self.stds[j]);
126 next_delta[j] = max_prev + emission.ln();
127
128 if next_delta[j] > max_prob {
129 max_prob = next_delta[j];
130 best_state = j;
131 }
132 }
133 }
134
135 self.last_delta = next_delta;
136
137 match best_state {
138 0 => MarketRegime::Bull,
139 1 => MarketRegime::Bear,
140 _ => MarketRegime::Cluster(best_state as u8),
141 }
142 }
143}