use crate::traits::Next;
use crate::regimes::MarketRegime;
use serde::{Deserialize, Serialize};
use statrs::distribution::{Discrete, Poisson};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DurationDistribution {
Poisson { lambda: f64 },
Fixed { duration: usize },
}
impl DurationDistribution {
pub fn p(&self, d: usize) -> f64 {
match self {
Self::Poisson { lambda } => {
let dist = Poisson::new(*lambda).unwrap();
dist.pmf(d as u64)
}
Self::Fixed { duration } => {
if d == *duration { 1.0 } else { 0.0 }
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HSMM {
pub n_states: usize,
pub a: Vec<Vec<f64>>,
pub means: Vec<f64>,
pub stds: Vec<f64>,
pub durations: Vec<DurationDistribution>,
current_duration: usize,
last_state: usize,
initialized: bool,
}
impl HSMM {
pub fn new(
a: Vec<Vec<f64>>,
means: Vec<f64>,
stds: Vec<f64>,
durations: Vec<DurationDistribution>,
) -> Self {
Self {
n_states: a.len(),
a,
means,
stds,
durations,
current_duration: 0,
last_state: 0,
initialized: false,
}
}
fn gaussian_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
let variance = sigma * sigma;
let denom = (2.0 * std::f64::consts::PI * variance).sqrt();
let exponent = -((x - mu).powi(2)) / (2.0 * variance);
exponent.exp() / denom
}
}
impl Next<f64> for HSMM {
type Output = MarketRegime;
fn next(&mut self, x: f64) -> Self::Output {
if !self.initialized {
self.initialized = true;
return MarketRegime::Steady;
}
self.current_duration += 1;
let prob_stay = self.durations[self.last_state].p(self.current_duration);
let mut max_prob;
let mut best_state = self.last_state;
let emission_stay = Self::gaussian_pdf(x, self.means[self.last_state], self.stds[self.last_state]);
max_prob = prob_stay * emission_stay;
for j in 0..self.n_states {
if j == self.last_state { continue; }
let transition_prob = self.a[self.last_state][j];
let emission_j = Self::gaussian_pdf(x, self.means[j], self.stds[j]);
let prob_j = (1.0 - prob_stay) * transition_prob * self.durations[j].p(1) * emission_j;
if prob_j > max_prob {
max_prob = prob_j;
best_state = j;
}
}
if best_state != self.last_state {
self.last_state = best_state;
self.current_duration = 1;
}
match best_state {
0 => MarketRegime::Steady,
1 => MarketRegime::Crisis,
_ => MarketRegime::Cluster(best_state as u8),
}
}
}