Skip to main content

quantwave_core/features/
regime_probs.rs

1//! Regime probability features (soft labels) for ML pipelines.
2//!
3//! Provides a simple way to turn hard regime labels or basic clustering into
4//! probability-like vectors that are excellent stationary features.
5//!
6//! This is a thin wrapper; richer soft outputs (HMM forward probs, GMM responsibilities)
7//! can be wired here later when the regimes module exposes them cleanly.
8//!
9//! Source: regimes/mod.rs (MarketRegime) + quantwave-4ub research notes (regime probs as meta-features).
10
11use crate::regimes::MarketRegime;
12use crate::regimes::hmm::HMM;
13use crate::traits::Next;
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct RegimeProbFeatures {
17    /// 5-dimensional soft vector (Bull, Bear, Crisis, Steady, Other/Cluster)
18    pub probs: [f64; 5],
19    pub hard_label: MarketRegime,
20}
21
22pub fn regime_to_prob_features(regime: MarketRegime) -> RegimeProbFeatures {
23    let mut probs = [0.05; 5];
24    let idx = match regime {
25        MarketRegime::Bull => 0,
26        MarketRegime::Bear => 1,
27        MarketRegime::Crisis => 2,
28        MarketRegime::Steady => 3,
29        MarketRegime::Cluster(c) => 4.min(c as usize),
30    };
31    probs[idx] = 0.80;
32    let sum: f64 = probs.iter().sum();
33    for p in &mut probs {
34        *p /= sum;
35    }
36    RegimeProbFeatures {
37        probs,
38        hard_label: regime,
39    }
40}
41
42/// Streaming extractor: HMM hard label + forward state probabilities as soft features.
43#[derive(Debug, Clone)]
44pub struct RegimeProbFeatureExtractor {
45    hmm: HMM,
46}
47
48impl RegimeProbFeatureExtractor {
49    pub fn bull_bear() -> Self {
50        Self {
51            hmm: HMM::bull_bear(),
52        }
53    }
54}
55
56impl Next<f64> for RegimeProbFeatureExtractor {
57    type Output = RegimeProbFeatures;
58
59    fn next(&mut self, input: f64) -> Self::Output {
60        let label = self.hmm.next(input);
61        let state_probs = self.hmm.state_probabilities();
62        let mut probs = [0.05; 5];
63        if state_probs.len() >= 2 {
64            probs[0] = state_probs[0];
65            probs[1] = state_probs[1];
66        }
67        let sum: f64 = probs.iter().sum();
68        if sum > 0.0 {
69            for p in &mut probs {
70                *p /= sum;
71            }
72        }
73        RegimeProbFeatures {
74            probs,
75            hard_label: label,
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_regime_prob_features_bull() {
86        let f = regime_to_prob_features(MarketRegime::Bull);
87        assert!(f.probs[0] > 0.7);
88        assert_eq!(f.hard_label, MarketRegime::Bull);
89    }
90}