use crate::traits::Next;
use crate::regimes::MarketRegime;
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct GarchParams {
pub omega: f64,
pub alpha: f64,
pub beta: f64,
}
impl GarchParams {
pub fn new(omega: f64, alpha: f64, beta: f64) -> Self {
Self { omega, alpha, beta }
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MSGarch {
pub n_states: usize,
pub a: Vec<Vec<f64>>,
pub params: Vec<GarchParams>,
last_variances: Vec<f64>,
last_probs: Vec<f64>,
initialized: bool,
}
impl MSGarch {
pub fn new(a: Vec<Vec<f64>>, params: Vec<GarchParams>, initial_probs: Vec<f64>) -> Self {
let n_states = a.len();
Self {
n_states,
a,
params,
last_variances: vec![0.0001; n_states], last_probs: initial_probs,
initialized: false,
}
}
pub fn low_high_vol() -> Self {
Self::new(
vec![
vec![0.98, 0.02], vec![0.05, 0.95], ],
vec![
GarchParams::new(1e-6, 0.05, 0.90), GarchParams::new(1e-4, 0.15, 0.80), ],
vec![0.9, 0.1],
)
}
fn gaussian_pdf(x: f64, sigma: f64) -> f64 {
let variance = sigma * sigma;
let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
let exponent = -(x.powi(2)) / (2.0 * variance);
exponent.exp() / denom
}
}
impl Next<f64> for MSGarch {
type Output = (MarketRegime, f64);
fn next(&mut self, returns: f64) -> Self::Output {
if !self.initialized {
self.initialized = true;
return (MarketRegime::Steady, self.last_variances[0].sqrt());
}
let mut next_probs = vec![0.0; self.n_states];
let mut likelihoods = vec![0.0; self.n_states];
let mut total_likelihood = 0.0;
for j in 0..self.n_states {
let mut prob_j = 0.0;
for i in 0..self.n_states {
prob_j += self.last_probs[i] * self.a[i][j];
}
let emission = Self::gaussian_pdf(returns, self.last_variances[j].sqrt());
likelihoods[j] = prob_j * emission;
total_likelihood += likelihoods[j];
}
if total_likelihood > 0.0 {
for j in 0..self.n_states {
next_probs[j] = likelihoods[j] / total_likelihood;
}
} else {
next_probs = self.last_probs.clone();
}
let epsilon_sq = returns.powi(2);
for j in 0..self.n_states {
let p = &self.params[j];
self.last_variances[j] = p.omega + p.alpha * epsilon_sq + p.beta * self.last_variances[j];
}
self.last_probs = next_probs;
let mut max_p = -1.0;
let mut best_state = 0;
let mut combined_var = 0.0;
for j in 0..self.n_states {
if self.last_probs[j] > max_p {
max_p = self.last_probs[j];
best_state = j;
}
combined_var += self.last_probs[j] * self.last_variances[j];
}
let regime = match best_state {
0 => MarketRegime::Steady,
1 => MarketRegime::Crisis,
_ => MarketRegime::Cluster(best_state as u8),
};
(regime, combined_var.sqrt())
}
}