use crate::traits::Next;
use crate::regimes::MarketRegime;
use crate::regimes::volatility_clustering::VolatilityClusterer;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiAssetClusterer {
n_assets: usize,
window_size: usize,
inner: VolatilityClusterer,
history: Vec<VecDeque<f64>>,
}
impl MultiAssetClusterer {
pub fn new(n_assets: usize, window_size: usize, k: usize) -> Self {
Self {
n_assets,
window_size,
inner: VolatilityClusterer::new(14, window_size, k),
history: vec![VecDeque::with_capacity(window_size); n_assets],
}
}
fn calculate_average_correlation(&self) -> f64 {
if self.history[0].len() < self.window_size {
return 1.0;
}
let mut total_corr = 0.0;
let mut pairs = 0;
for i in 0..self.n_assets {
for j in (i + 1)..self.n_assets {
let corr = self.correlation(i, j);
total_corr += corr;
pairs += 1;
}
}
if pairs == 0 { 1.0 } else { total_corr / pairs as f64 }
}
fn correlation(&self, i: usize, j: usize) -> f64 {
let x = &self.history[i];
let y = &self.history[j];
let n = x.len() as f64;
let mean_x = x.iter().sum::<f64>() / n;
let mean_y = y.iter().sum::<f64>() / n;
let mut cov = 0.0;
let mut var_x = 0.0;
let mut var_y = 0.0;
for k in 0..x.len() {
let dx = x[k] - mean_x;
let dy = y[k] - mean_y;
cov += dx * dy;
var_x += dx * dx;
var_y += dy * dy;
}
let den = (var_x * var_y).sqrt();
if den == 0.0 { 1.0 } else { cov / den }
}
}
impl Next<&[f64]> for MultiAssetClusterer {
type Output = MarketRegime;
fn next(&mut self, returns: &[f64]) -> Self::Output {
if returns.len() != self.n_assets {
return MarketRegime::Steady;
}
for (i, &r) in returns.iter().enumerate() {
self.history[i].push_back(r);
if self.history[i].len() > self.window_size {
self.history[i].pop_front();
}
}
let mean_abs_ret = returns.iter().map(|r| r.abs()).sum::<f64>() / self.n_assets as f64;
let mean_ret = returns.iter().sum::<f64>() / self.n_assets as f64;
let dispersion = returns.iter().map(|r| (r - mean_ret).powi(2)).sum::<f64>() / self.n_assets as f64;
let avg_corr = self.calculate_average_correlation();
self.inner.next((mean_abs_ret, mean_abs_ret * (1.0 - dispersion.sqrt()), avg_corr))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_asset_clusterer_basic() {
let mut clusterer = MultiAssetClusterer::new(2, 5, 2);
for _ in 0..10 {
clusterer.next(&[0.01, 0.01]);
}
let r1 = clusterer.next(&[0.01, 0.01]);
for _ in 0..10 {
clusterer.next(&[0.05, 0.05]);
}
let r2 = clusterer.next(&[0.05, 0.05]);
assert!(matches!(r1, MarketRegime::Steady | MarketRegime::Cluster(_)));
assert!(matches!(r2, MarketRegime::Steady | MarketRegime::Cluster(_)));
}
}