quantwave_core/features/
regime_probs.rs1use crate::regimes::MarketRegime;
12use crate::regimes::hmm::HMM;
13use crate::traits::Next;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct RegimeProbFeatures {
17 pub probs: [f64; 5],
19 pub hard_label: MarketRegime,
20}
21
22pub fn regime_to_prob_features(regime: MarketRegime) -> RegimeProbFeatures {
23 let mut probs = [0.05; 5];
24 let idx = match regime {
25 MarketRegime::Bull => 0,
26 MarketRegime::Bear => 1,
27 MarketRegime::Crisis => 2,
28 MarketRegime::Steady => 3,
29 MarketRegime::Cluster(c) => 4.min(c as usize),
30 };
31 probs[idx] = 0.80;
32 let sum: f64 = probs.iter().sum();
33 for p in &mut probs {
34 *p /= sum;
35 }
36 RegimeProbFeatures {
37 probs,
38 hard_label: regime,
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct RegimeProbFeatureExtractor {
45 hmm: HMM,
46}
47
48impl RegimeProbFeatureExtractor {
49 pub fn bull_bear() -> Self {
50 Self {
51 hmm: HMM::bull_bear(),
52 }
53 }
54}
55
56impl Next<f64> for RegimeProbFeatureExtractor {
57 type Output = RegimeProbFeatures;
58
59 fn next(&mut self, input: f64) -> Self::Output {
60 let label = self.hmm.next(input);
61 let state_probs = self.hmm.state_probabilities();
62 let mut probs = [0.05; 5];
63 if state_probs.len() >= 2 {
64 probs[0] = state_probs[0];
65 probs[1] = state_probs[1];
66 }
67 let sum: f64 = probs.iter().sum();
68 if sum > 0.0 {
69 for p in &mut probs {
70 *p /= sum;
71 }
72 }
73 RegimeProbFeatures {
74 probs,
75 hard_label: label,
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn test_regime_prob_features_bull() {
86 let f = regime_to_prob_features(MarketRegime::Bull);
87 assert!(f.probs[0] > 0.7);
88 assert_eq!(f.hard_label, MarketRegime::Bull);
89 }
90}