quantwave_core/regimes/
hmm_gas.rs1use crate::traits::Next;
10use crate::regimes::MarketRegime;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct HMMGAS {
16 pub p11_params: [f64; 3],
18 pub p22_params: [f64; 3],
20 f11: f64,
22 f22: f64,
23 pub means: [f64; 2],
24 pub stds: [f64; 2],
25 last_probs: [f64; 2],
26 initialized: bool,
27}
28
29impl HMMGAS {
30 pub fn new(
31 p11_params: [f64; 3],
32 p22_params: [f64; 3],
33 means: [f64; 2],
34 stds: [f64; 2],
35 ) -> Self {
36 Self {
37 p11_params,
38 p22_params,
39 f11: 2.0, f22: 2.0,
41 means,
42 stds,
43 last_probs: [0.5, 0.5],
44 initialized: false,
45 }
46 }
47
48 fn logit_inv(f: f64) -> f64 {
49 1.0 / (1.0 + (-f).exp())
50 }
51
52 fn gaussian_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
53 let variance = sigma * sigma;
54 let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
55 let exponent = -((x - mu).powi(2)) / (2.0 * variance);
56 exponent.exp() / denom
57 }
58}
59
60impl Next<f64> for HMMGAS {
61 type Output = MarketRegime;
62
63 fn next(&mut self, x: f64) -> Self::Output {
64 let p11 = Self::logit_inv(self.f11);
66 let p22 = Self::logit_inv(self.f22);
67
68 let a = [[p11, 1.0 - p11], [1.0 - p22, p22]];
69
70 let mut likelihoods = [0.0; 2];
71 let mut total_likelihood = 0.0;
72
73 for j in 0..2 {
75 let mut prob_j = 0.0;
76 for i in 0..2 {
77 prob_j += self.last_probs[i] * a[i][j];
78 }
79 let emission = Self::gaussian_pdf(x, self.means[j], self.stds[j]);
80 likelihoods[j] = prob_j * emission;
81 total_likelihood += likelihoods[j];
82 }
83
84 let next_probs = if total_likelihood > 0.0 {
85 [likelihoods[0] / total_likelihood, likelihoods[1] / total_likelihood]
86 } else {
87 self.last_probs
88 };
89
90 let score11 = next_probs[0] - self.last_probs[0];
94 let score22 = next_probs[1] - self.last_probs[1];
95
96 self.f11 = self.p11_params[0] + self.p11_params[1] * score11 + self.p11_params[2] * self.f11;
97 self.f22 = self.p22_params[0] + self.p22_params[1] * score22 + self.p22_params[2] * self.f22;
98
99 self.last_probs = next_probs;
100 self.initialized = true;
101
102 if next_probs[0] > next_probs[1] {
103 MarketRegime::Steady
104 } else {
105 MarketRegime::Crisis
106 }
107 }
108}