quantwave_core/features/
regime_probs.rs1use crate::regimes::MarketRegime;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct RegimeProbFeatures {
15 pub probs: [f64; 5],
17 pub hard_label: MarketRegime,
18}
19
20pub fn regime_to_prob_features(regime: MarketRegime) -> RegimeProbFeatures {
21 let mut probs = [0.05; 5]; let idx = match regime {
23 MarketRegime::Bull => 0,
24 MarketRegime::Bear => 1,
25 MarketRegime::Crisis => 2,
26 MarketRegime::Steady => 3,
27 MarketRegime::Cluster(c) => 4.min(c as usize),
28 };
29 probs[idx] = 0.80;
30 let sum: f64 = probs.iter().sum();
32 for p in &mut probs {
33 *p /= sum;
34 }
35 RegimeProbFeatures { probs, hard_label: regime }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41
42 #[test]
43 fn test_regime_prob_features_bull() {
44 let f = regime_to_prob_features(MarketRegime::Bull);
45 assert!(f.probs[0] > 0.7);
46 assert_eq!(f.hard_label, MarketRegime::Bull);
47 }
48}