use crate::regimes::MarketRegime;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RegimeProbFeatures {
pub probs: [f64; 5],
pub hard_label: MarketRegime,
}
pub fn regime_to_prob_features(regime: MarketRegime) -> RegimeProbFeatures {
let mut probs = [0.05; 5]; let idx = match regime {
MarketRegime::Bull => 0,
MarketRegime::Bear => 1,
MarketRegime::Crisis => 2,
MarketRegime::Steady => 3,
MarketRegime::Cluster(c) => 4.min(c as usize),
};
probs[idx] = 0.80;
let sum: f64 = probs.iter().sum();
for p in &mut probs {
*p /= sum;
}
RegimeProbFeatures { probs, hard_label: regime }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_regime_prob_features_bull() {
let f = regime_to_prob_features(MarketRegime::Bull);
assert!(f.probs[0] > 0.7);
assert_eq!(f.hard_label, MarketRegime::Bull);
}
}