quantwave_core/regimes/
hsmm.rs1use crate::traits::Next;
9use crate::regimes::MarketRegime;
10use serde::{Deserialize, Serialize};
11use statrs::distribution::{Discrete, Poisson};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum DurationDistribution {
16 Poisson { lambda: f64 },
18 Fixed { duration: usize },
20}
21
22impl DurationDistribution {
23 pub fn p(&self, d: usize) -> f64 {
24 match self {
25 Self::Poisson { lambda } => {
26 let dist = Poisson::new(*lambda).unwrap();
27 dist.pmf(d as u64)
28 }
29 Self::Fixed { duration } => {
30 if d == *duration { 1.0 } else { 0.0 }
31 }
32 }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct HSMM {
39 pub n_states: usize,
40 pub a: Vec<Vec<f64>>,
42 pub means: Vec<f64>,
43 pub stds: Vec<f64>,
44 pub durations: Vec<DurationDistribution>,
45 current_duration: usize,
47 last_state: usize,
48 initialized: bool,
49}
50
51impl HSMM {
52 pub fn new(
53 a: Vec<Vec<f64>>,
54 means: Vec<f64>,
55 stds: Vec<f64>,
56 durations: Vec<DurationDistribution>,
57 ) -> Self {
58 Self {
59 n_states: a.len(),
60 a,
61 means,
62 stds,
63 durations,
64 current_duration: 0,
65 last_state: 0,
66 initialized: false,
67 }
68 }
69
70 fn gaussian_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
71 let variance = sigma * sigma;
72 let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
73 let exponent = -((x - mu).powi(2)) / (2.0 * variance);
74 exponent.exp() / denom
75 }
76}
77
78impl Next<f64> for HSMM {
79 type Output = MarketRegime;
80
81 fn next(&mut self, x: f64) -> Self::Output {
82 if !self.initialized {
83 self.initialized = true;
84 return MarketRegime::Steady;
86 }
87
88 self.current_duration += 1;
89
90 let prob_stay = self.durations[self.last_state].p(self.current_duration);
92
93 let mut max_prob;
94 let mut best_state = self.last_state;
95
96 let emission_stay = Self::gaussian_pdf(x, self.means[self.last_state], self.stds[self.last_state]);
98 max_prob = prob_stay * emission_stay;
99
100 for j in 0..self.n_states {
102 if j == self.last_state { continue; }
103
104 let transition_prob = self.a[self.last_state][j];
105 let emission_j = Self::gaussian_pdf(x, self.means[j], self.stds[j]);
106 let prob_j = (1.0 - prob_stay) * transition_prob * self.durations[j].p(1) * emission_j;
108
109 if prob_j > max_prob {
110 max_prob = prob_j;
111 best_state = j;
112 }
113 }
114
115 if best_state != self.last_state {
116 self.last_state = best_state;
117 self.current_duration = 1;
118 }
119
120 match best_state {
121 0 => MarketRegime::Steady,
122 1 => MarketRegime::Crisis,
123 _ => MarketRegime::Cluster(best_state as u8),
124 }
125 }
126}