use crate::traits::Next;
use crate::regimes::MarketRegime;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GMM {
k: usize,
dims: usize,
means: Vec<Vec<f64>>,
vars: Vec<Vec<f64>>,
weights: Vec<f64>,
}
impl GMM {
pub fn new(means: Vec<Vec<f64>>, vars: Vec<Vec<f64>>, weights: Vec<f64>) -> Self {
let k = means.len();
let dims = means[0].len();
Self { k, dims, means, vars, weights }
}
fn pdf(&self, x: &[f64], k_idx: usize) -> f64 {
let mut prob = 1.0;
for d in 0..self.dims {
let mu = self.means[k_idx][d];
let var = self.vars[k_idx][d].max(1e-9);
let denom = (2.0 * std::f64::consts::PI * var).sqrt();
let exponent = -((x[d] - mu).powi(2)) / (2.0 * var);
prob *= exponent.exp() / denom;
}
prob
}
pub fn fit(&mut self, _data: &[Vec<f64>], _max_iter: usize) {
}
}
impl Next<&[f64]> for GMM {
type Output = MarketRegime;
fn next(&mut self, x: &[f64]) -> Self::Output {
let mut max_prob = -1.0;
let mut best_k = 0;
for k in 0..self.k {
let p = self.weights[k] * self.pdf(x, k);
if p > max_prob {
max_prob = p;
best_k = k;
}
}
match best_k {
0 => MarketRegime::Steady,
k if k == self.k - 1 => MarketRegime::Crisis,
_ => MarketRegime::Cluster(best_k as u8),
}
}
}