quantwave_core/regimes/
ms_garch.rs1use crate::traits::Next;
9use crate::regimes::MarketRegime;
10
11#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
13pub struct GarchParams {
14 pub omega: f64,
15 pub alpha: f64,
16 pub beta: f64,
17}
18
19impl GarchParams {
20 pub fn new(omega: f64, alpha: f64, beta: f64) -> Self {
21 Self { omega, alpha, beta }
22 }
23}
24
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27pub struct MSGarch {
28 pub n_states: usize,
29 pub a: Vec<Vec<f64>>,
31 pub params: Vec<GarchParams>,
33 last_variances: Vec<f64>,
35 last_probs: Vec<f64>,
37 initialized: bool,
38}
39
40impl MSGarch {
41 pub fn new(a: Vec<Vec<f64>>, params: Vec<GarchParams>, initial_probs: Vec<f64>) -> Self {
42 let n_states = a.len();
43 Self {
44 n_states,
45 a,
46 params,
47 last_variances: vec![0.0001; n_states], last_probs: initial_probs,
49 initialized: false,
50 }
51 }
52
53 pub fn low_high_vol() -> Self {
55 Self::new(
56 vec![
57 vec![0.98, 0.02], vec![0.05, 0.95], ],
60 vec![
61 GarchParams::new(1e-6, 0.05, 0.90), GarchParams::new(1e-4, 0.15, 0.80), ],
64 vec![0.9, 0.1],
65 )
66 }
67
68 fn gaussian_pdf(x: f64, sigma: f64) -> f64 {
70 let variance = sigma * sigma;
71 let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
72 let exponent = -(x.powi(2)) / (2.0 * variance);
73 exponent.exp() / denom
74 }
75}
76
77impl Next<f64> for MSGarch {
78 type Output = (MarketRegime, f64); fn next(&mut self, returns: f64) -> Self::Output {
81 if !self.initialized {
82 self.initialized = true;
83 return (MarketRegime::Steady, self.last_variances[0].sqrt());
84 }
85
86 let mut next_probs = vec![0.0; self.n_states];
87 let mut likelihoods = vec![0.0; self.n_states];
88 let mut total_likelihood = 0.0;
89
90 for j in 0..self.n_states {
92 let mut prob_j = 0.0;
93 for i in 0..self.n_states {
94 prob_j += self.last_probs[i] * self.a[i][j];
95 }
96 let emission = Self::gaussian_pdf(returns, self.last_variances[j].sqrt());
97 likelihoods[j] = prob_j * emission;
98 total_likelihood += likelihoods[j];
99 }
100
101 if total_likelihood > 0.0 {
102 for j in 0..self.n_states {
103 next_probs[j] = likelihoods[j] / total_likelihood;
104 }
105 } else {
106 next_probs = self.last_probs.clone();
107 }
108
109 let epsilon_sq = returns.powi(2);
111 for j in 0..self.n_states {
112 let p = &self.params[j];
113 self.last_variances[j] = p.omega + p.alpha * epsilon_sq + p.beta * self.last_variances[j];
115 }
116
117 self.last_probs = next_probs;
118
119 let mut max_p = -1.0;
121 let mut best_state = 0;
122 let mut combined_var = 0.0;
123 for j in 0..self.n_states {
124 if self.last_probs[j] > max_p {
125 max_p = self.last_probs[j];
126 best_state = j;
127 }
128 combined_var += self.last_probs[j] * self.last_variances[j];
129 }
130
131 let regime = match best_state {
132 0 => MarketRegime::Steady,
133 1 => MarketRegime::Crisis,
134 _ => MarketRegime::Cluster(best_state as u8),
135 };
136
137 (regime, combined_var.sqrt())
138 }
139}