Skip to main content

indicators/regime/
hmm.rs

1//! Hidden Markov Model Regime Detection
2//!
3//! Implements HMM-based regime detection as described in:
4//! - Hamilton, J.D. (1989) "A New Approach to the Economic Analysis of Nonstationary Time Series"
5//!
6//! The HMM approach learns regime distributions directly from returns data,
7//! making no assumptions about what indicators define each regime.
8//!
9//! Uses stable Rust only (no nightly features).
10
11use std::collections::{HashMap, VecDeque};
12
13use serde::{Deserialize, Serialize};
14
15use super::types::{MarketRegime, RegimeConfidence, TrendDirection};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22// ── Indicator wrapper ─────────────────────────────────────────────────────────
23
24/// Batch `Indicator` adapter for [`HMMRegimeDetector`].
25///
26/// Replays candles through the streaming HMM detector and emits per-bar
27/// `hmm_conf` (0–1) and `hmm_regime_id`:
28/// - 0 = Uncertain
29/// - 1 = MeanReverting
30/// - 2 = Volatile
31/// - 3 = Trending(Bullish)
32/// - 4 = Trending(Bearish)
33#[derive(Debug, Clone)]
34pub struct HmmIndicator {
35    pub config: HMMConfig,
36}
37
38impl HmmIndicator {
39    pub fn new(config: HMMConfig) -> Self {
40        Self { config }
41    }
42
43    pub fn with_defaults() -> Self {
44        Self::new(HMMConfig::default())
45    }
46}
47
48fn hmm_regime_id(r: MarketRegime) -> f64 {
49    match r {
50        MarketRegime::MeanReverting => 1.0,
51        MarketRegime::Volatile => 2.0,
52        MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
53        MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
54        MarketRegime::Uncertain => 0.0,
55    }
56}
57
58impl Indicator for HmmIndicator {
59    fn name(&self) -> &'static str {
60        "HMMRegime"
61    }
62
63    /// Minimum candles before meaningful output.
64    ///
65    /// The HMM requires `min_observations` returns, which means
66    /// `min_observations + 1` close prices (one extra to form the first return).
67    fn required_len(&self) -> usize {
68        self.config.min_observations + 1
69    }
70
71    fn required_columns(&self) -> &[&'static str] {
72        &["close"]
73    }
74
75    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76        self.check_len(candles)?;
77        let mut det = HMMRegimeDetector::new(self.config.clone());
78        let n = candles.len();
79        let mut conf = vec![f64::NAN; n];
80        let mut regime = vec![f64::NAN; n];
81        for (i, c) in candles.iter().enumerate() {
82            let rc = det.update(c.close);
83            conf[i] = rc.confidence;
84            regime[i] = hmm_regime_id(rc.regime);
85        }
86        Ok(IndicatorOutput::from_pairs([
87            ("hmm_conf", conf),
88            ("hmm_regime_id", regime),
89        ]))
90    }
91}
92
93// ── Registry factory ──────────────────────────────────────────────────────────
94
95pub fn factory<S: ::std::hash::BuildHasher>(
96    params: &HashMap<String, String, S>,
97) -> Result<Box<dyn Indicator>, IndicatorError> {
98    let min_observations = param_usize(params, "min_observations", 100)?;
99    let n_states = param_usize(params, "n_states", 3)?;
100    let config = HMMConfig {
101        n_states,
102        min_observations,
103        ..HMMConfig::default()
104    };
105    Ok(Box::new(HmmIndicator::new(config)))
106}
107
108/// Configuration for HMM regime detector
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct HMMConfig {
111    /// Number of hidden states (regimes)
112    pub n_states: usize,
113    /// Minimum observations before making predictions
114    pub min_observations: usize,
115    /// Learning rate for online updates (0 = no online learning)
116    pub learning_rate: f64,
117    /// Smoothing factor for transition probabilities
118    pub transition_smoothing: f64,
119    /// Window size for return calculations
120    pub lookback_window: usize,
121    /// Confidence threshold for regime classification
122    pub min_confidence: f64,
123}
124
125impl Default for HMMConfig {
126    fn default() -> Self {
127        Self {
128            n_states: 3, // Bull, Bear, High-Vol
129            min_observations: 100,
130            learning_rate: 0.01,
131            transition_smoothing: 0.1,
132            lookback_window: 252, // ~1 year of daily data
133            min_confidence: 0.6,
134        }
135    }
136}
137
138impl HMMConfig {
139    /// Config optimized for crypto (faster regime changes)
140    pub fn crypto_optimized() -> Self {
141        Self {
142            n_states: 3,
143            min_observations: 50,
144            learning_rate: 0.02, // Faster adaptation
145            transition_smoothing: 0.05,
146            lookback_window: 100,
147            min_confidence: 0.5,
148        }
149    }
150
151    /// Conservative config (more stable regimes)
152    pub fn conservative() -> Self {
153        Self {
154            n_states: 2, // Just bull/bear
155            min_observations: 150,
156            learning_rate: 0.005,
157            transition_smoothing: 0.15,
158            lookback_window: 500,
159            min_confidence: 0.7,
160        }
161    }
162}
163
164/// Gaussian parameters for a single hidden state
165#[derive(Debug, Clone)]
166struct GaussianState {
167    mean: f64,
168    variance: f64,
169    /// Running statistics for online updates
170    sum: f64,
171    sum_sq: f64,
172    count: usize,
173}
174
175impl GaussianState {
176    fn new(mean: f64, variance: f64) -> Self {
177        Self {
178            mean,
179            variance,
180            sum: 0.0,
181            sum_sq: 0.0,
182            count: 0,
183        }
184    }
185
186    /// Probability density function
187    fn pdf(&self, x: f64) -> f64 {
188        let diff = x - self.mean;
189        let exponent = -0.5 * diff * diff / self.variance;
190        let normalizer = (2.0 * std::f64::consts::PI * self.variance).sqrt();
191        exponent.exp() / normalizer
192    }
193
194    /// Update statistics with new observation
195    fn update(&mut self, x: f64, weight: f64, learning_rate: f64) {
196        if learning_rate > 0.0 {
197            // Online update using exponential moving average
198            self.mean = (1.0 - learning_rate * weight) * self.mean + learning_rate * weight * x;
199            let new_var = (x - self.mean).powi(2);
200            self.variance =
201                (1.0 - learning_rate * weight) * self.variance + learning_rate * weight * new_var;
202            self.variance = self.variance.max(1e-8); // Prevent zero variance
203        }
204
205        // Also track running stats
206        self.sum += x * weight;
207        self.sum_sq += x * x * weight;
208        self.count += 1;
209    }
210}
211
212/// Hidden Markov Model for regime detection.
213///
214/// Uses a 3-state HMM (by default) to model market regimes:
215/// - State 0: Bull market (positive returns, low volatility)
216/// - State 1: Bear market (negative returns, medium volatility)
217/// - State 2: High volatility (any direction, high volatility)
218///
219/// The model uses the forward algorithm for online filtering and periodically
220/// re-estimates parameters using the Baum-Welch algorithm.
221///
222/// # Example
223///
224/// ```rust
225/// use indicators::{HMMRegimeDetector, HMMConfig, MarketRegime};
226///
227/// let mut detector = HMMRegimeDetector::crypto_optimized();
228///
229/// // Feed close prices
230/// for i in 0..200 {
231///     let price = 100.0 * (1.0 + 0.001 * i as f64); // gentle uptrend
232///     let result = detector.update(price);
233///     if detector.is_ready() {
234///         println!("HMM regime: {} (conf: {:.0}%)", result.regime, result.confidence * 100.0);
235///     }
236/// }
237/// ```
238#[derive(Debug)]
239pub struct HMMRegimeDetector {
240    config: HMMConfig,
241
242    /// Gaussian emission distributions for each state
243    states: Vec<GaussianState>,
244
245    /// Transition probability matrix A[i][j] = P(state_j | state_i)
246    transition_matrix: Vec<Vec<f64>>,
247
248    /// Initial state probabilities
249    initial_probs: Vec<f64>,
250
251    /// Current state probabilities (filtered)
252    state_probs: Vec<f64>,
253
254    /// History of returns for batch updates
255    returns_history: VecDeque<f64>,
256
257    /// History of prices for return calculation
258    prices: VecDeque<f64>,
259
260    /// Current most likely state
261    current_state: usize,
262
263    /// Confidence in current state
264    current_confidence: f64,
265
266    /// Total observations processed
267    n_observations: usize,
268
269    /// Last detected regime
270    last_regime: MarketRegime,
271}
272
273impl HMMRegimeDetector {
274    /// Create a new HMM detector with the given configuration
275    pub fn new(config: HMMConfig) -> Self {
276        let n = config.n_states;
277
278        // Initialize states with reasonable priors for financial returns
279        // State 0: Bull (positive returns, low vol)
280        // State 1: Bear (negative returns, higher vol)
281        // State 2: High Vol (any direction, high vol)
282        let states = match n {
283            2 => vec![
284                GaussianState::new(0.001, 0.0001),  // Bull: ~0.1% daily, low vol
285                GaussianState::new(-0.001, 0.0004), // Bear: -0.1% daily, higher vol
286            ],
287            3 => vec![
288                GaussianState::new(0.001, 0.0001),  // Bull: positive, low vol
289                GaussianState::new(-0.001, 0.0002), // Bear: negative, medium vol
290                GaussianState::new(0.0, 0.0009),    // High Vol: neutral, high vol
291            ],
292            _ => (0..n)
293                .map(|i| {
294                    let mean = (i as f64 - n as f64 / 2.0) * 0.001;
295                    let var = 0.0001 * (1.0 + i as f64);
296                    GaussianState::new(mean, var)
297                })
298                .collect(),
299        };
300
301        // Initialize transition matrix with slight persistence
302        // Higher diagonal = states tend to persist
303        let mut transition_matrix = vec![vec![0.0; n]; n];
304        for (i, row) in transition_matrix.iter_mut().enumerate().take(n) {
305            for (j, cell) in row.iter_mut().enumerate().take(n) {
306                if i == j {
307                    *cell = 0.9; // 90% stay in same state
308                } else {
309                    *cell = 0.1 / (n - 1) as f64;
310                }
311            }
312        }
313
314        // Equal initial probabilities
315        let initial_probs = vec![1.0 / n as f64; n];
316        let state_probs = initial_probs.clone();
317
318        Self {
319            config: config.clone(),
320            states,
321            transition_matrix,
322            initial_probs,
323            state_probs,
324            returns_history: VecDeque::with_capacity(config.lookback_window),
325            prices: VecDeque::with_capacity(10),
326            current_state: 0,
327            current_confidence: 0.0,
328            n_observations: 0,
329            last_regime: MarketRegime::Uncertain,
330        }
331    }
332
333    /// Create with default config
334    pub fn default_config() -> Self {
335        Self::new(HMMConfig::default())
336    }
337
338    /// Create optimized for crypto
339    pub fn crypto_optimized() -> Self {
340        Self::new(HMMConfig::crypto_optimized())
341    }
342
343    /// Create with conservative config
344    pub fn conservative() -> Self {
345        Self::new(HMMConfig::conservative())
346    }
347
348    /// Update with new close price and get regime.
349    ///
350    /// Calculates log return from the previous close, then runs the forward
351    /// algorithm step and optional parameter updates.
352    pub fn update(&mut self, close: f64) -> RegimeConfidence {
353        // Calculate log return
354        if let Some(&prev_close) = self.prices.back()
355            && prev_close > 0.0
356        {
357            let log_return = (close / prev_close).ln();
358            self.process_return(log_return);
359        }
360
361        // Store price
362        self.prices.push_back(close);
363        if self.prices.len() > 10 {
364            self.prices.pop_front();
365        }
366
367        // Return current regime
368        let confidence = self.get_regime_confidence();
369        self.last_regime = confidence.regime;
370        confidence
371    }
372
373    /// Update with OHLC data (uses close price for HMM)
374    pub fn update_ohlc(&mut self, _high: f64, _low: f64, close: f64) -> RegimeConfidence {
375        self.update(close)
376    }
377
378    /// Process a single return observation
379    fn process_return(&mut self, ret: f64) {
380        self.n_observations += 1;
381
382        // Store return
383        self.returns_history.push_back(ret);
384        if self.returns_history.len() > self.config.lookback_window {
385            self.returns_history.pop_front();
386        }
387
388        // Forward algorithm step (filtering)
389        self.forward_step(ret);
390
391        // Update state parameters if we have enough data
392        if self.n_observations > self.config.min_observations && self.config.learning_rate > 0.0 {
393            self.online_parameter_update(ret);
394        }
395
396        // Periodically re-estimate with Baum-Welch if we have enough data
397        let reestimate_interval = self.config.lookback_window / 2;
398        if self.n_observations > 0
399            && reestimate_interval > 0
400            && self.n_observations.is_multiple_of(reestimate_interval)
401            && self.returns_history.len() >= self.config.min_observations
402        {
403            self.baum_welch_update();
404        }
405    }
406
407    /// Forward algorithm step - update state probabilities given new observation
408    fn forward_step(&mut self, ret: f64) {
409        let n = self.config.n_states;
410        let mut new_probs = vec![0.0; n];
411
412        // Calculate emission probabilities
413        let emissions: Vec<f64> = self.states.iter().map(|s| s.pdf(ret)).collect();
414
415        // Forward step: P(state_j | obs) ∝ P(obs | state_j) * Σᵢ P(state_j | state_i) * P(state_i)
416        for j in 0..n {
417            let mut sum = 0.0;
418            for i in 0..n {
419                sum += self.transition_matrix[i][j] * self.state_probs[i];
420            }
421            new_probs[j] = emissions[j] * sum;
422        }
423
424        // Normalize
425        let total: f64 = new_probs.iter().sum();
426        if total > 1e-300 {
427            for p in &mut new_probs {
428                *p /= total;
429            }
430        } else {
431            // Reset to uniform if probabilities collapse
432            new_probs = vec![1.0 / n as f64; n];
433        }
434
435        self.state_probs = new_probs;
436
437        // Update current state and confidence. total_cmp is NaN-safe: a NaN
438        // probability (numerical blow-up) sorts below all reals instead of
439        // panicking mid-update.
440        let (max_idx, max_prob) = self
441            .state_probs
442            .iter()
443            .enumerate()
444            .max_by(|(_, a), (_, b)| a.total_cmp(b))
445            .map_or((0, 1.0 / n as f64), |(i, p)| (i, *p));
446
447        self.current_state = max_idx;
448        self.current_confidence = max_prob;
449    }
450
451    /// Online parameter update using soft assignments
452    fn online_parameter_update(&mut self, ret: f64) {
453        let lr = self.config.learning_rate;
454
455        for (i, state) in self.states.iter_mut().enumerate() {
456            let weight = self.state_probs[i];
457            state.update(ret, weight, lr);
458        }
459
460        // Update transition matrix (soft transitions)
461        // This is a simplified online update
462        let smoothing = self.config.transition_smoothing;
463        for i in 0..self.config.n_states {
464            for j in 0..self.config.n_states {
465                let target = if i == j {
466                    0.9
467                } else {
468                    0.1 / (self.config.n_states - 1) as f64
469                };
470                self.transition_matrix[i][j] =
471                    (1.0 - smoothing) * self.transition_matrix[i][j] + smoothing * target;
472            }
473        }
474    }
475
476    /// Baum-Welch algorithm for batch parameter re-estimation.
477    ///
478    /// Runs the full forward-backward algorithm on the returns history
479    /// to re-estimate emission parameters. Uses blending with existing
480    /// parameters to prevent sudden jumps.
481    fn baum_welch_update(&mut self) {
482        let returns: Vec<f64> = self.returns_history.iter().copied().collect();
483        if returns.len() < self.config.min_observations {
484            return;
485        }
486
487        let n = self.config.n_states;
488        let t = returns.len();
489
490        // Forward pass
491        let mut alpha = vec![vec![0.0; n]; t];
492
493        // Initialize
494        for (j, alpha_val) in alpha[0].iter_mut().enumerate().take(n) {
495            *alpha_val = self.initial_probs[j] * self.states[j].pdf(returns[0]);
496        }
497        Self::normalize_vec(&mut alpha[0]);
498
499        // Forward
500        for time in 1..t {
501            for j in 0..n {
502                let mut sum = 0.0;
503                for (i, alpha_prev) in alpha[time - 1].iter().enumerate().take(n) {
504                    sum += alpha_prev * self.transition_matrix[i][j];
505                }
506                alpha[time][j] = sum * self.states[j].pdf(returns[time]);
507            }
508            Self::normalize_vec(&mut alpha[time]);
509        }
510
511        // Backward pass
512        let mut beta = vec![vec![1.0; n]; t];
513
514        for time in (0..t - 1).rev() {
515            for i in 0..n {
516                let mut sum = 0.0;
517                for (j, beta_next) in beta[time + 1].iter().enumerate().take(n) {
518                    sum += self.transition_matrix[i][j]
519                        * self.states[j].pdf(returns[time + 1])
520                        * beta_next;
521                }
522                beta[time][i] = sum;
523            }
524            Self::normalize_vec(&mut beta[time]);
525        }
526
527        // Compute gamma (state occupancy probabilities)
528        let mut gamma = vec![vec![0.0; n]; t];
529        for time in 0..t {
530            let mut sum = 0.0;
531            for (j, gamma_val) in gamma[time].iter_mut().enumerate().take(n) {
532                *gamma_val = alpha[time][j] * beta[time][j];
533                sum += *gamma_val;
534            }
535            if sum > 1e-300 {
536                for gamma_val in gamma[time].iter_mut().take(n) {
537                    *gamma_val /= sum;
538                }
539            }
540        }
541
542        // Re-estimate emission parameters
543        for (j, state) in self.states.iter_mut().enumerate().take(n) {
544            let mut weight_sum = 0.0;
545            let mut mean_sum = 0.0;
546            let mut var_sum = 0.0;
547
548            for time in 0..t {
549                let w = gamma[time][j];
550                weight_sum += w;
551                mean_sum += w * returns[time];
552            }
553
554            if weight_sum > 1e-8 {
555                let new_mean = mean_sum / weight_sum;
556
557                for time in 0..t {
558                    let w = gamma[time][j];
559                    var_sum += w * (returns[time] - new_mean).powi(2);
560                }
561
562                let new_var = (var_sum / weight_sum).max(1e-8);
563
564                // Blend with existing parameters (prevents sudden jumps)
565                let blend = 0.3;
566                state.mean = (1.0 - blend) * state.mean + blend * new_mean;
567                state.variance = (1.0 - blend) * state.variance + blend * new_var;
568            }
569        }
570    }
571
572    /// Helper to normalize a probability vector
573    fn normalize_vec(vec: &mut [f64]) {
574        let sum: f64 = vec.iter().sum();
575        if sum > 1e-300 {
576            for v in vec.iter_mut() {
577                *v /= sum;
578            }
579        }
580    }
581
582    /// Get current regime with confidence
583    pub fn get_regime_confidence(&self) -> RegimeConfidence {
584        if self.n_observations < self.config.min_observations {
585            return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
586        }
587
588        let regime = self.state_to_regime(self.current_state);
589        let confidence = self.current_confidence;
590
591        RegimeConfidence::with_metrics(
592            regime,
593            confidence,
594            self.states[self.current_state].mean * 100.0 * 252.0, // Annualized return %
595            self.states[self.current_state].variance.sqrt() * 100.0 * 252.0_f64.sqrt(), // Annualized vol %
596            0.0, // No trend strength in HMM
597        )
598    }
599
600    /// Map state index to `MarketRegime` based on learned parameters.
601    ///
602    /// Classification is based on the Gaussian emission parameters:
603    /// - High variance → Volatile
604    /// - Positive mean → Trending(Bullish)
605    /// - Negative mean → Trending(Bearish)
606    /// - Low variance, neutral mean → MeanReverting
607    fn state_to_regime(&self, state: usize) -> MarketRegime {
608        let state_params = &self.states[state];
609        let mean = state_params.mean;
610        let vol = state_params.variance.sqrt();
611
612        // Classify based on learned parameters
613        let is_high_vol = vol > 0.02; // > 2% daily vol
614        let is_positive = mean > 0.0005; // > 0.05% daily
615        let is_negative = mean < -0.0005;
616
617        if is_high_vol {
618            MarketRegime::Volatile
619        } else if is_positive {
620            MarketRegime::Trending(TrendDirection::Bullish)
621        } else if is_negative {
622            MarketRegime::Trending(TrendDirection::Bearish)
623        } else {
624            MarketRegime::MeanReverting // Low vol, neutral returns = ranging
625        }
626    }
627
628    // ========================================================================
629    // Public Accessors
630    // ========================================================================
631
632    /// Get state probabilities
633    pub fn state_probabilities(&self) -> &[f64] {
634        &self.state_probs
635    }
636
637    /// Get state parameters (mean, variance) for inspection
638    pub fn state_parameters(&self) -> Vec<(f64, f64)> {
639        self.states.iter().map(|s| (s.mean, s.variance)).collect()
640    }
641
642    /// Get transition matrix
643    pub fn transition_matrix(&self) -> &[Vec<f64>] {
644        &self.transition_matrix
645    }
646
647    /// Get current state index
648    pub fn current_state_index(&self) -> usize {
649        self.current_state
650    }
651
652    /// Check if model is warmed up (has enough observations)
653    pub fn is_ready(&self) -> bool {
654        self.n_observations >= self.config.min_observations
655    }
656
657    /// Get expected regime duration (from transition matrix).
658    ///
659    /// Expected duration = 1 / (1 - P(stay in state))
660    pub fn expected_regime_duration(&self, state: usize) -> f64 {
661        if state < self.config.n_states {
662            1.0 / (1.0 - self.transition_matrix[state][state])
663        } else {
664            0.0
665        }
666    }
667
668    /// Predict most likely next state
669    pub fn predict_next_state(&self) -> (usize, f64) {
670        let mut next_probs = vec![0.0; self.config.n_states];
671
672        for (j, next_prob) in next_probs.iter_mut().enumerate().take(self.config.n_states) {
673            for i in 0..self.config.n_states {
674                *next_prob += self.transition_matrix[i][j] * self.state_probs[i];
675            }
676        }
677
678        let (max_idx, max_prob) = next_probs
679            .iter()
680            .enumerate()
681            .max_by(|(_, a), (_, b)| a.total_cmp(b))
682            .map_or((0, 0.0), |(i, p)| (i, *p));
683
684        (max_idx, max_prob)
685    }
686
687    /// Get the total number of observations processed
688    pub fn n_observations(&self) -> usize {
689        self.n_observations
690    }
691
692    /// Get the current confidence score
693    pub fn current_confidence(&self) -> f64 {
694        self.current_confidence
695    }
696
697    /// Get the configuration
698    pub fn config(&self) -> &HMMConfig {
699        &self.config
700    }
701}
702
703// ============================================================================
704// Tests
705// ============================================================================
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710
711    #[test]
712    fn test_hmm_initialization() {
713        let detector = HMMRegimeDetector::default_config();
714        assert!(!detector.is_ready());
715        assert_eq!(detector.state_probabilities().len(), 3);
716    }
717
718    #[test]
719    fn test_hmm_crypto_config() {
720        let detector = HMMRegimeDetector::crypto_optimized();
721        assert_eq!(detector.config().n_states, 3);
722        assert_eq!(detector.config().min_observations, 50);
723    }
724
725    #[test]
726    fn test_hmm_conservative_config() {
727        let detector = HMMRegimeDetector::conservative();
728        assert_eq!(detector.config().n_states, 2);
729        assert_eq!(detector.config().min_observations, 150);
730        assert_eq!(detector.state_probabilities().len(), 2);
731    }
732
733    #[test]
734    fn test_hmm_warmup() {
735        let mut detector = HMMRegimeDetector::crypto_optimized();
736
737        // Feed fewer than min_observations
738        for i in 0..49 {
739            let price = 100.0 + (i as f64) * 0.01;
740            let result = detector.update(price);
741            assert_eq!(
742                result.regime,
743                MarketRegime::Uncertain,
744                "Should be Uncertain during warmup at step {i}"
745            );
746        }
747
748        assert!(!detector.is_ready());
749    }
750
751    #[test]
752    fn test_hmm_becomes_ready() {
753        let mut detector = HMMRegimeDetector::crypto_optimized();
754
755        for i in 0..60 {
756            let price = 100.0 + (i as f64) * 0.01;
757            detector.update(price);
758        }
759
760        assert!(detector.is_ready(), "Should be ready after 60 observations");
761    }
762
763    #[test]
764    fn test_bull_market_detection() {
765        let mut detector = HMMRegimeDetector::crypto_optimized();
766
767        // Strong consistent uptrend
768        let mut price = 100.0;
769        for _ in 0..200 {
770            price *= 1.005; // 0.5% daily gain
771            let result = detector.update(price);
772            if detector.is_ready() {
773                // After warmup, regime should be trending or at least not uncertain
774                assert_ne!(result.regime, MarketRegime::Uncertain);
775            }
776        }
777
778        let final_result = detector.get_regime_confidence();
779        // In a strong bull market, we expect bullish trending
780        assert!(
781            matches!(
782                final_result.regime,
783                MarketRegime::Trending(TrendDirection::Bullish)
784            ),
785            "Expected Bullish trend, got: {:?}",
786            final_result.regime
787        );
788    }
789
790    #[test]
791    fn test_volatile_market_detection() {
792        let mut detector = HMMRegimeDetector::crypto_optimized();
793
794        // High volatility: large alternating swings
795        let mut price = 100.0;
796        for i in 0..200 {
797            if i % 2 == 0 {
798                price *= 1.05; // 5% up
799            } else {
800                price *= 0.95; // 5% down
801            }
802            detector.update(price);
803        }
804
805        let result = detector.get_regime_confidence();
806        // With large swings, should detect volatile or at least not a clean trend
807        assert!(
808            matches!(
809                result.regime,
810                MarketRegime::Volatile | MarketRegime::MeanReverting
811            ),
812            "Expected Volatile or MeanReverting for choppy data, got: {:?}",
813            result.regime
814        );
815    }
816
817    #[test]
818    fn test_state_probabilities_sum_to_one() {
819        let mut detector = HMMRegimeDetector::crypto_optimized();
820
821        let mut price = 100.0;
822        for _ in 0..100 {
823            price *= 1.001;
824            detector.update(price);
825
826            let probs = detector.state_probabilities();
827            let sum: f64 = probs.iter().sum();
828            assert!(
829                (sum - 1.0).abs() < 1e-6,
830                "State probabilities should sum to 1.0, got: {sum}"
831            );
832        }
833    }
834
835    #[test]
836    fn test_transition_matrix_rows_sum_to_one() {
837        let detector = HMMRegimeDetector::default_config();
838        let tm = detector.transition_matrix();
839
840        for (i, row) in tm.iter().enumerate() {
841            let sum: f64 = row.iter().sum();
842            assert!(
843                (sum - 1.0).abs() < 1e-6,
844                "Transition matrix row {i} should sum to 1.0, got: {sum}"
845            );
846        }
847    }
848
849    #[test]
850    fn test_expected_regime_duration() {
851        let detector = HMMRegimeDetector::default_config();
852
853        // With 0.9 persistence, expected duration = 1 / (1 - 0.9) = 10
854        let duration = detector.expected_regime_duration(0);
855        assert!(
856            (duration - 10.0).abs() < 1e-6,
857            "Expected duration should be ~10 with 0.9 persistence, got: {duration}"
858        );
859    }
860
861    #[test]
862    fn test_predict_next_state() {
863        let mut detector = HMMRegimeDetector::crypto_optimized();
864
865        let mut price = 100.0;
866        for _ in 0..100 {
867            price *= 1.002;
868            detector.update(price);
869        }
870
871        let (next_state, prob) = detector.predict_next_state();
872        assert!(next_state < detector.config().n_states);
873        assert!(
874            (0.0..=1.0).contains(&prob),
875            "Predicted probability should be in [0, 1]: {prob}"
876        );
877    }
878
879    #[test]
880    fn test_state_parameters() {
881        let detector = HMMRegimeDetector::default_config();
882        let params = detector.state_parameters();
883
884        assert_eq!(params.len(), 3, "Should have 3 state parameters");
885
886        for (mean, variance) in &params {
887            assert!(variance > &0.0, "Variance should be positive: {variance}");
888            assert!(mean.is_finite(), "Mean should be finite: {mean}");
889        }
890    }
891
892    #[test]
893    fn test_update_ohlc_uses_close() {
894        let mut det1 = HMMRegimeDetector::crypto_optimized();
895        let mut det2 = HMMRegimeDetector::crypto_optimized();
896
897        // Both should produce identical results since OHLC just uses close
898        for i in 0..100 {
899            let close = 100.0 + i as f64 * 0.1;
900            let r1 = det1.update(close);
901            let r2 = det2.update_ohlc(close * 1.01, close * 0.99, close);
902
903            assert_eq!(
904                r1.regime, r2.regime,
905                "update and update_ohlc should produce same regime"
906            );
907        }
908    }
909
910    #[test]
911    fn test_n_observations_tracking() {
912        let mut detector = HMMRegimeDetector::crypto_optimized();
913
914        assert_eq!(detector.n_observations(), 0);
915
916        for i in 0..50 {
917            detector.update(100.0 + i as f64);
918        }
919
920        // n_observations counts returns, so it's prices - 1
921        assert_eq!(detector.n_observations(), 49);
922    }
923
924    #[test]
925    fn test_confidence_range() {
926        let mut detector = HMMRegimeDetector::crypto_optimized();
927
928        let mut price = 100.0;
929        for _ in 0..200 {
930            price *= 1.002;
931            detector.update(price);
932        }
933
934        let confidence = detector.current_confidence();
935        assert!(
936            (0.0..=1.0).contains(&confidence),
937            "Confidence should be in [0, 1]: {confidence}"
938        );
939    }
940}