use std::collections::VecDeque;
use crate::indicators::volatility::ATR;
use crate::traits::Next;
use crate::regimes::MarketRegime;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VolatilityClusterer {
atr: ATR,
window: VecDeque<f64>,
window_size: usize,
k: usize,
centroids: Vec<f64>,
counts: Vec<usize>,
min_observations: usize,
}
impl VolatilityClusterer {
pub fn new(atr_period: usize, window_size: usize, k: usize) -> Self {
Self {
atr: ATR::new(atr_period),
window: VecDeque::with_capacity(window_size),
window_size,
k,
centroids: vec![0.0; k],
counts: vec![0; k],
min_observations: window_size.max(k * 10), }
}
fn update_clusters(&mut self, val: f64) -> usize {
if self.centroids.iter().all(|&c| c == 0.0) {
if self.window.len() >= self.k {
let mut sorted: Vec<f64> = self.window.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
for i in 0..self.k {
let idx = (i * (sorted.len() - 1)) / (self.k - 1);
self.centroids[i] = sorted[idx];
self.counts[i] = 1;
}
}
}
let mut min_dist = f64::MAX;
let mut closest_idx = 0;
for (i, ¢roid) in self.centroids.iter().enumerate() {
let dist = (val - centroid).abs();
if dist < min_dist {
min_dist = dist;
closest_idx = i;
}
}
self.counts[closest_idx] += 1;
let lr = 1.0 / (self.counts[closest_idx] as f64).sqrt(); self.centroids[closest_idx] += lr * (val - self.centroids[closest_idx]);
closest_idx
}
}
impl Next<(f64, f64, f64)> for VolatilityClusterer {
type Output = MarketRegime;
fn next(&mut self, input: (f64, f64, f64)) -> Self::Output {
let atr_val = self.atr.next(input);
if self.window.len() >= self.window_size {
self.window.pop_front();
}
self.window.push_back(atr_val);
let cluster_idx = self.update_clusters(atr_val);
if self.window.len() < self.min_observations {
return MarketRegime::Steady;
}
let mut sorted_indices: Vec<usize> = (0..self.k).collect();
sorted_indices.sort_by(|&a, &b| self.centroids[a].partial_cmp(&self.centroids[b]).unwrap());
let rank = sorted_indices.iter().position(|&i| i == cluster_idx).unwrap_or(0);
match rank {
0 => MarketRegime::Steady, r if r == self.k - 1 => MarketRegime::Crisis, _ => MarketRegime::Cluster(rank as u8),
}
}
}