Skip to main content

indicators/regime/
hmm.rs

1//! Hidden Markov Model Regime Detection
2//!
3//! Implements HMM-based regime detection as described in:
4//! - Hamilton, J.D. (1989) "A New Approach to the Economic Analysis of Nonstationary Time Series"
5//!
6//! The HMM approach learns regime distributions directly from returns data,
7//! making no assumptions about what indicators define each regime.
8//!
9//! Uses stable Rust only (no nightly features).
10
11use std::collections::{HashMap, VecDeque};
12
13use serde::{Deserialize, Serialize};
14
15use super::types::{MarketRegime, RegimeConfidence, TrendDirection};
16
17use crate::error::IndicatorError;
18use crate::indicator::{Indicator, IndicatorOutput};
19use crate::registry::param_usize;
20use crate::types::Candle;
21
22// ── Indicator wrapper ─────────────────────────────────────────────────────────
23
24/// Batch `Indicator` adapter for [`HMMRegimeDetector`].
25///
26/// Replays candles through the streaming HMM detector and emits per-bar
27/// `hmm_conf` (0–1) and `hmm_regime_id`:
28/// - 0 = Uncertain
29/// - 1 = MeanReverting
30/// - 2 = Volatile
31/// - 3 = Trending(Bullish)
32/// - 4 = Trending(Bearish)
33#[derive(Debug, Clone)]
34pub struct HmmIndicator {
35    pub config: HMMConfig,
36}
37
38impl HmmIndicator {
39    pub fn new(config: HMMConfig) -> Self {
40        Self { config }
41    }
42
43    pub fn with_defaults() -> Self {
44        Self::new(HMMConfig::default())
45    }
46}
47
48fn hmm_regime_id(r: MarketRegime) -> f64 {
49    match r {
50        MarketRegime::MeanReverting => 1.0,
51        MarketRegime::Volatile => 2.0,
52        MarketRegime::Trending(TrendDirection::Bullish) => 3.0,
53        MarketRegime::Trending(TrendDirection::Bearish) => 4.0,
54        MarketRegime::Uncertain => 0.0,
55    }
56}
57
58impl Indicator for HmmIndicator {
59    fn name(&self) -> &'static str {
60        "HMMRegime"
61    }
62
63    /// Minimum candles before meaningful output.
64    ///
65    /// The HMM requires `min_observations` returns, which means
66    /// `min_observations + 1` close prices (one extra to form the first return).
67    fn required_len(&self) -> usize {
68        self.config.min_observations + 1
69    }
70
71    fn required_columns(&self) -> &[&'static str] {
72        &["close"]
73    }
74
75    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
76        self.check_len(candles)?;
77        let mut det = HMMRegimeDetector::new(self.config.clone());
78        let n = candles.len();
79        let mut conf = vec![f64::NAN; n];
80        let mut regime = vec![f64::NAN; n];
81        for (i, c) in candles.iter().enumerate() {
82            let rc = det.update(c.close);
83            conf[i] = rc.confidence;
84            regime[i] = hmm_regime_id(rc.regime);
85        }
86        Ok(IndicatorOutput::from_pairs([
87            ("hmm_conf", conf),
88            ("hmm_regime_id", regime),
89        ]))
90    }
91}
92
93// ── Registry factory ──────────────────────────────────────────────────────────
94
95pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
96    let min_observations = param_usize(params, "min_observations", 100)?;
97    let n_states = param_usize(params, "n_states", 3)?;
98    let config = HMMConfig {
99        n_states,
100        min_observations,
101        ..HMMConfig::default()
102    };
103    Ok(Box::new(HmmIndicator::new(config)))
104}
105
106/// Configuration for HMM regime detector
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct HMMConfig {
109    /// Number of hidden states (regimes)
110    pub n_states: usize,
111    /// Minimum observations before making predictions
112    pub min_observations: usize,
113    /// Learning rate for online updates (0 = no online learning)
114    pub learning_rate: f64,
115    /// Smoothing factor for transition probabilities
116    pub transition_smoothing: f64,
117    /// Window size for return calculations
118    pub lookback_window: usize,
119    /// Confidence threshold for regime classification
120    pub min_confidence: f64,
121}
122
123impl Default for HMMConfig {
124    fn default() -> Self {
125        Self {
126            n_states: 3, // Bull, Bear, High-Vol
127            min_observations: 100,
128            learning_rate: 0.01,
129            transition_smoothing: 0.1,
130            lookback_window: 252, // ~1 year of daily data
131            min_confidence: 0.6,
132        }
133    }
134}
135
136impl HMMConfig {
137    /// Config optimized for crypto (faster regime changes)
138    pub fn crypto_optimized() -> Self {
139        Self {
140            n_states: 3,
141            min_observations: 50,
142            learning_rate: 0.02, // Faster adaptation
143            transition_smoothing: 0.05,
144            lookback_window: 100,
145            min_confidence: 0.5,
146        }
147    }
148
149    /// Conservative config (more stable regimes)
150    pub fn conservative() -> Self {
151        Self {
152            n_states: 2, // Just bull/bear
153            min_observations: 150,
154            learning_rate: 0.005,
155            transition_smoothing: 0.15,
156            lookback_window: 500,
157            min_confidence: 0.7,
158        }
159    }
160}
161
162/// Gaussian parameters for a single hidden state
163#[derive(Debug, Clone)]
164struct GaussianState {
165    mean: f64,
166    variance: f64,
167    /// Running statistics for online updates
168    sum: f64,
169    sum_sq: f64,
170    count: usize,
171}
172
173impl GaussianState {
174    fn new(mean: f64, variance: f64) -> Self {
175        Self {
176            mean,
177            variance,
178            sum: 0.0,
179            sum_sq: 0.0,
180            count: 0,
181        }
182    }
183
184    /// Probability density function
185    fn pdf(&self, x: f64) -> f64 {
186        let diff = x - self.mean;
187        let exponent = -0.5 * diff * diff / self.variance;
188        let normalizer = (2.0 * std::f64::consts::PI * self.variance).sqrt();
189        exponent.exp() / normalizer
190    }
191
192    /// Update statistics with new observation
193    fn update(&mut self, x: f64, weight: f64, learning_rate: f64) {
194        if learning_rate > 0.0 {
195            // Online update using exponential moving average
196            self.mean = (1.0 - learning_rate * weight) * self.mean + learning_rate * weight * x;
197            let new_var = (x - self.mean).powi(2);
198            self.variance =
199                (1.0 - learning_rate * weight) * self.variance + learning_rate * weight * new_var;
200            self.variance = self.variance.max(1e-8); // Prevent zero variance
201        }
202
203        // Also track running stats
204        self.sum += x * weight;
205        self.sum_sq += x * x * weight;
206        self.count += 1;
207    }
208}
209
210/// Hidden Markov Model for regime detection.
211///
212/// Uses a 3-state HMM (by default) to model market regimes:
213/// - State 0: Bull market (positive returns, low volatility)
214/// - State 1: Bear market (negative returns, medium volatility)
215/// - State 2: High volatility (any direction, high volatility)
216///
217/// The model uses the forward algorithm for online filtering and periodically
218/// re-estimates parameters using the Baum-Welch algorithm.
219///
220/// # Example
221///
222/// ```rust
223/// use indicators::{HMMRegimeDetector, HMMConfig, MarketRegime};
224///
225/// let mut detector = HMMRegimeDetector::crypto_optimized();
226///
227/// // Feed close prices
228/// for i in 0..200 {
229///     let price = 100.0 * (1.0 + 0.001 * i as f64); // gentle uptrend
230///     let result = detector.update(price);
231///     if detector.is_ready() {
232///         println!("HMM regime: {} (conf: {:.0}%)", result.regime, result.confidence * 100.0);
233///     }
234/// }
235/// ```
236#[derive(Debug)]
237pub struct HMMRegimeDetector {
238    config: HMMConfig,
239
240    /// Gaussian emission distributions for each state
241    states: Vec<GaussianState>,
242
243    /// Transition probability matrix A[i][j] = P(state_j | state_i)
244    transition_matrix: Vec<Vec<f64>>,
245
246    /// Initial state probabilities
247    initial_probs: Vec<f64>,
248
249    /// Current state probabilities (filtered)
250    state_probs: Vec<f64>,
251
252    /// History of returns for batch updates
253    returns_history: VecDeque<f64>,
254
255    /// History of prices for return calculation
256    prices: VecDeque<f64>,
257
258    /// Current most likely state
259    current_state: usize,
260
261    /// Confidence in current state
262    current_confidence: f64,
263
264    /// Total observations processed
265    n_observations: usize,
266
267    /// Last detected regime
268    last_regime: MarketRegime,
269}
270
271impl HMMRegimeDetector {
272    /// Create a new HMM detector with the given configuration
273    pub fn new(config: HMMConfig) -> Self {
274        let n = config.n_states;
275
276        // Initialize states with reasonable priors for financial returns
277        // State 0: Bull (positive returns, low vol)
278        // State 1: Bear (negative returns, higher vol)
279        // State 2: High Vol (any direction, high vol)
280        let states = match n {
281            2 => vec![
282                GaussianState::new(0.001, 0.0001),  // Bull: ~0.1% daily, low vol
283                GaussianState::new(-0.001, 0.0004), // Bear: -0.1% daily, higher vol
284            ],
285            3 => vec![
286                GaussianState::new(0.001, 0.0001),  // Bull: positive, low vol
287                GaussianState::new(-0.001, 0.0002), // Bear: negative, medium vol
288                GaussianState::new(0.0, 0.0009),    // High Vol: neutral, high vol
289            ],
290            _ => (0..n)
291                .map(|i| {
292                    let mean = (i as f64 - n as f64 / 2.0) * 0.001;
293                    let var = 0.0001 * (1.0 + i as f64);
294                    GaussianState::new(mean, var)
295                })
296                .collect(),
297        };
298
299        // Initialize transition matrix with slight persistence
300        // Higher diagonal = states tend to persist
301        let mut transition_matrix = vec![vec![0.0; n]; n];
302        for (i, row) in transition_matrix.iter_mut().enumerate().take(n) {
303            for (j, cell) in row.iter_mut().enumerate().take(n) {
304                if i == j {
305                    *cell = 0.9; // 90% stay in same state
306                } else {
307                    *cell = 0.1 / (n - 1) as f64;
308                }
309            }
310        }
311
312        // Equal initial probabilities
313        let initial_probs = vec![1.0 / n as f64; n];
314        let state_probs = initial_probs.clone();
315
316        Self {
317            config: config.clone(),
318            states,
319            transition_matrix,
320            initial_probs,
321            state_probs,
322            returns_history: VecDeque::with_capacity(config.lookback_window),
323            prices: VecDeque::with_capacity(10),
324            current_state: 0,
325            current_confidence: 0.0,
326            n_observations: 0,
327            last_regime: MarketRegime::Uncertain,
328        }
329    }
330
331    /// Create with default config
332    pub fn default_config() -> Self {
333        Self::new(HMMConfig::default())
334    }
335
336    /// Create optimized for crypto
337    pub fn crypto_optimized() -> Self {
338        Self::new(HMMConfig::crypto_optimized())
339    }
340
341    /// Create with conservative config
342    pub fn conservative() -> Self {
343        Self::new(HMMConfig::conservative())
344    }
345
346    /// Update with new close price and get regime.
347    ///
348    /// Calculates log return from the previous close, then runs the forward
349    /// algorithm step and optional parameter updates.
350    pub fn update(&mut self, close: f64) -> RegimeConfidence {
351        // Calculate log return
352        if let Some(&prev_close) = self.prices.back()
353            && prev_close > 0.0
354        {
355            let log_return = (close / prev_close).ln();
356            self.process_return(log_return);
357        }
358
359        // Store price
360        self.prices.push_back(close);
361        if self.prices.len() > 10 {
362            self.prices.pop_front();
363        }
364
365        // Return current regime
366        let confidence = self.get_regime_confidence();
367        self.last_regime = confidence.regime;
368        confidence
369    }
370
371    /// Update with OHLC data (uses close price for HMM)
372    pub fn update_ohlc(&mut self, _high: f64, _low: f64, close: f64) -> RegimeConfidence {
373        self.update(close)
374    }
375
376    /// Process a single return observation
377    fn process_return(&mut self, ret: f64) {
378        self.n_observations += 1;
379
380        // Store return
381        self.returns_history.push_back(ret);
382        if self.returns_history.len() > self.config.lookback_window {
383            self.returns_history.pop_front();
384        }
385
386        // Forward algorithm step (filtering)
387        self.forward_step(ret);
388
389        // Update state parameters if we have enough data
390        if self.n_observations > self.config.min_observations && self.config.learning_rate > 0.0 {
391            self.online_parameter_update(ret);
392        }
393
394        // Periodically re-estimate with Baum-Welch if we have enough data
395        let reestimate_interval = self.config.lookback_window / 2;
396        if self.n_observations > 0
397            && reestimate_interval > 0
398            && self.n_observations.is_multiple_of(reestimate_interval)
399            && self.returns_history.len() >= self.config.min_observations
400        {
401            self.baum_welch_update();
402        }
403    }
404
405    /// Forward algorithm step - update state probabilities given new observation
406    fn forward_step(&mut self, ret: f64) {
407        let n = self.config.n_states;
408        let mut new_probs = vec![0.0; n];
409
410        // Calculate emission probabilities
411        let emissions: Vec<f64> = self.states.iter().map(|s| s.pdf(ret)).collect();
412
413        // Forward step: P(state_j | obs) ∝ P(obs | state_j) * Σᵢ P(state_j | state_i) * P(state_i)
414        for j in 0..n {
415            let mut sum = 0.0;
416            for i in 0..n {
417                sum += self.transition_matrix[i][j] * self.state_probs[i];
418            }
419            new_probs[j] = emissions[j] * sum;
420        }
421
422        // Normalize
423        let total: f64 = new_probs.iter().sum();
424        if total > 1e-300 {
425            for p in &mut new_probs {
426                *p /= total;
427            }
428        } else {
429            // Reset to uniform if probabilities collapse
430            new_probs = vec![1.0 / n as f64; n];
431        }
432
433        self.state_probs = new_probs;
434
435        // Update current state and confidence
436        let (max_idx, max_prob) = self
437            .state_probs
438            .iter()
439            .enumerate()
440            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
441            .unwrap();
442
443        self.current_state = max_idx;
444        self.current_confidence = *max_prob;
445    }
446
447    /// Online parameter update using soft assignments
448    fn online_parameter_update(&mut self, ret: f64) {
449        let lr = self.config.learning_rate;
450
451        for (i, state) in self.states.iter_mut().enumerate() {
452            let weight = self.state_probs[i];
453            state.update(ret, weight, lr);
454        }
455
456        // Update transition matrix (soft transitions)
457        // This is a simplified online update
458        let smoothing = self.config.transition_smoothing;
459        for i in 0..self.config.n_states {
460            for j in 0..self.config.n_states {
461                let target = if i == j {
462                    0.9
463                } else {
464                    0.1 / (self.config.n_states - 1) as f64
465                };
466                self.transition_matrix[i][j] =
467                    (1.0 - smoothing) * self.transition_matrix[i][j] + smoothing * target;
468            }
469        }
470    }
471
472    /// Baum-Welch algorithm for batch parameter re-estimation.
473    ///
474    /// Runs the full forward-backward algorithm on the returns history
475    /// to re-estimate emission parameters. Uses blending with existing
476    /// parameters to prevent sudden jumps.
477    fn baum_welch_update(&mut self) {
478        let returns: Vec<f64> = self.returns_history.iter().copied().collect();
479        if returns.len() < self.config.min_observations {
480            return;
481        }
482
483        let n = self.config.n_states;
484        let t = returns.len();
485
486        // Forward pass
487        let mut alpha = vec![vec![0.0; n]; t];
488
489        // Initialize
490        for (j, alpha_val) in alpha[0].iter_mut().enumerate().take(n) {
491            *alpha_val = self.initial_probs[j] * self.states[j].pdf(returns[0]);
492        }
493        Self::normalize_vec(&mut alpha[0]);
494
495        // Forward
496        for time in 1..t {
497            for j in 0..n {
498                let mut sum = 0.0;
499                for (i, alpha_prev) in alpha[time - 1].iter().enumerate().take(n) {
500                    sum += alpha_prev * self.transition_matrix[i][j];
501                }
502                alpha[time][j] = sum * self.states[j].pdf(returns[time]);
503            }
504            Self::normalize_vec(&mut alpha[time]);
505        }
506
507        // Backward pass
508        let mut beta = vec![vec![1.0; n]; t];
509
510        for time in (0..t - 1).rev() {
511            for i in 0..n {
512                let mut sum = 0.0;
513                for (j, beta_next) in beta[time + 1].iter().enumerate().take(n) {
514                    sum += self.transition_matrix[i][j]
515                        * self.states[j].pdf(returns[time + 1])
516                        * beta_next;
517                }
518                beta[time][i] = sum;
519            }
520            Self::normalize_vec(&mut beta[time]);
521        }
522
523        // Compute gamma (state occupancy probabilities)
524        let mut gamma = vec![vec![0.0; n]; t];
525        for time in 0..t {
526            let mut sum = 0.0;
527            for (j, gamma_val) in gamma[time].iter_mut().enumerate().take(n) {
528                *gamma_val = alpha[time][j] * beta[time][j];
529                sum += *gamma_val;
530            }
531            if sum > 1e-300 {
532                for gamma_val in gamma[time].iter_mut().take(n) {
533                    *gamma_val /= sum;
534                }
535            }
536        }
537
538        // Re-estimate emission parameters
539        for (j, state) in self.states.iter_mut().enumerate().take(n) {
540            let mut weight_sum = 0.0;
541            let mut mean_sum = 0.0;
542            let mut var_sum = 0.0;
543
544            for time in 0..t {
545                let w = gamma[time][j];
546                weight_sum += w;
547                mean_sum += w * returns[time];
548            }
549
550            if weight_sum > 1e-8 {
551                let new_mean = mean_sum / weight_sum;
552
553                for time in 0..t {
554                    let w = gamma[time][j];
555                    var_sum += w * (returns[time] - new_mean).powi(2);
556                }
557
558                let new_var = (var_sum / weight_sum).max(1e-8);
559
560                // Blend with existing parameters (prevents sudden jumps)
561                let blend = 0.3;
562                state.mean = (1.0 - blend) * state.mean + blend * new_mean;
563                state.variance = (1.0 - blend) * state.variance + blend * new_var;
564            }
565        }
566    }
567
568    /// Helper to normalize a probability vector
569    fn normalize_vec(vec: &mut [f64]) {
570        let sum: f64 = vec.iter().sum();
571        if sum > 1e-300 {
572            for v in vec.iter_mut() {
573                *v /= sum;
574            }
575        }
576    }
577
578    /// Get current regime with confidence
579    pub fn get_regime_confidence(&self) -> RegimeConfidence {
580        if self.n_observations < self.config.min_observations {
581            return RegimeConfidence::new(MarketRegime::Uncertain, 0.0);
582        }
583
584        let regime = self.state_to_regime(self.current_state);
585        let confidence = self.current_confidence;
586
587        RegimeConfidence::with_metrics(
588            regime,
589            confidence,
590            self.states[self.current_state].mean * 100.0 * 252.0, // Annualized return %
591            self.states[self.current_state].variance.sqrt() * 100.0 * 252.0_f64.sqrt(), // Annualized vol %
592            0.0, // No trend strength in HMM
593        )
594    }
595
596    /// Map state index to `MarketRegime` based on learned parameters.
597    ///
598    /// Classification is based on the Gaussian emission parameters:
599    /// - High variance → Volatile
600    /// - Positive mean → Trending(Bullish)
601    /// - Negative mean → Trending(Bearish)
602    /// - Low variance, neutral mean → MeanReverting
603    fn state_to_regime(&self, state: usize) -> MarketRegime {
604        let state_params = &self.states[state];
605        let mean = state_params.mean;
606        let vol = state_params.variance.sqrt();
607
608        // Classify based on learned parameters
609        let is_high_vol = vol > 0.02; // > 2% daily vol
610        let is_positive = mean > 0.0005; // > 0.05% daily
611        let is_negative = mean < -0.0005;
612
613        if is_high_vol {
614            MarketRegime::Volatile
615        } else if is_positive {
616            MarketRegime::Trending(TrendDirection::Bullish)
617        } else if is_negative {
618            MarketRegime::Trending(TrendDirection::Bearish)
619        } else {
620            MarketRegime::MeanReverting // Low vol, neutral returns = ranging
621        }
622    }
623
624    // ========================================================================
625    // Public Accessors
626    // ========================================================================
627
628    /// Get state probabilities
629    pub fn state_probabilities(&self) -> &[f64] {
630        &self.state_probs
631    }
632
633    /// Get state parameters (mean, variance) for inspection
634    pub fn state_parameters(&self) -> Vec<(f64, f64)> {
635        self.states.iter().map(|s| (s.mean, s.variance)).collect()
636    }
637
638    /// Get transition matrix
639    pub fn transition_matrix(&self) -> &[Vec<f64>] {
640        &self.transition_matrix
641    }
642
643    /// Get current state index
644    pub fn current_state_index(&self) -> usize {
645        self.current_state
646    }
647
648    /// Check if model is warmed up (has enough observations)
649    pub fn is_ready(&self) -> bool {
650        self.n_observations >= self.config.min_observations
651    }
652
653    /// Get expected regime duration (from transition matrix).
654    ///
655    /// Expected duration = 1 / (1 - P(stay in state))
656    pub fn expected_regime_duration(&self, state: usize) -> f64 {
657        if state < self.config.n_states {
658            1.0 / (1.0 - self.transition_matrix[state][state])
659        } else {
660            0.0
661        }
662    }
663
664    /// Predict most likely next state
665    pub fn predict_next_state(&self) -> (usize, f64) {
666        let mut next_probs = vec![0.0; self.config.n_states];
667
668        for (j, next_prob) in next_probs.iter_mut().enumerate().take(self.config.n_states) {
669            for i in 0..self.config.n_states {
670                *next_prob += self.transition_matrix[i][j] * self.state_probs[i];
671            }
672        }
673
674        let (max_idx, max_prob) = next_probs
675            .iter()
676            .enumerate()
677            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
678            .unwrap();
679
680        (max_idx, *max_prob)
681    }
682
683    /// Get the total number of observations processed
684    pub fn n_observations(&self) -> usize {
685        self.n_observations
686    }
687
688    /// Get the current confidence score
689    pub fn current_confidence(&self) -> f64 {
690        self.current_confidence
691    }
692
693    /// Get the configuration
694    pub fn config(&self) -> &HMMConfig {
695        &self.config
696    }
697}
698
699// ============================================================================
700// Tests
701// ============================================================================
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706
707    #[test]
708    fn test_hmm_initialization() {
709        let detector = HMMRegimeDetector::default_config();
710        assert!(!detector.is_ready());
711        assert_eq!(detector.state_probabilities().len(), 3);
712    }
713
714    #[test]
715    fn test_hmm_crypto_config() {
716        let detector = HMMRegimeDetector::crypto_optimized();
717        assert_eq!(detector.config().n_states, 3);
718        assert_eq!(detector.config().min_observations, 50);
719    }
720
721    #[test]
722    fn test_hmm_conservative_config() {
723        let detector = HMMRegimeDetector::conservative();
724        assert_eq!(detector.config().n_states, 2);
725        assert_eq!(detector.config().min_observations, 150);
726        assert_eq!(detector.state_probabilities().len(), 2);
727    }
728
729    #[test]
730    fn test_hmm_warmup() {
731        let mut detector = HMMRegimeDetector::crypto_optimized();
732
733        // Feed fewer than min_observations
734        for i in 0..49 {
735            let price = 100.0 + (i as f64) * 0.01;
736            let result = detector.update(price);
737            assert_eq!(
738                result.regime,
739                MarketRegime::Uncertain,
740                "Should be Uncertain during warmup at step {i}"
741            );
742        }
743
744        assert!(!detector.is_ready());
745    }
746
747    #[test]
748    fn test_hmm_becomes_ready() {
749        let mut detector = HMMRegimeDetector::crypto_optimized();
750
751        for i in 0..60 {
752            let price = 100.0 + (i as f64) * 0.01;
753            detector.update(price);
754        }
755
756        assert!(detector.is_ready(), "Should be ready after 60 observations");
757    }
758
759    #[test]
760    fn test_bull_market_detection() {
761        let mut detector = HMMRegimeDetector::crypto_optimized();
762
763        // Strong consistent uptrend
764        let mut price = 100.0;
765        for _ in 0..200 {
766            price *= 1.005; // 0.5% daily gain
767            let result = detector.update(price);
768            if detector.is_ready() {
769                // After warmup, regime should be trending or at least not uncertain
770                assert_ne!(result.regime, MarketRegime::Uncertain);
771            }
772        }
773
774        let final_result = detector.get_regime_confidence();
775        // In a strong bull market, we expect bullish trending
776        assert!(
777            matches!(
778                final_result.regime,
779                MarketRegime::Trending(TrendDirection::Bullish)
780            ),
781            "Expected Bullish trend, got: {:?}",
782            final_result.regime
783        );
784    }
785
786    #[test]
787    fn test_volatile_market_detection() {
788        let mut detector = HMMRegimeDetector::crypto_optimized();
789
790        // High volatility: large alternating swings
791        let mut price = 100.0;
792        for i in 0..200 {
793            if i % 2 == 0 {
794                price *= 1.05; // 5% up
795            } else {
796                price *= 0.95; // 5% down
797            }
798            detector.update(price);
799        }
800
801        let result = detector.get_regime_confidence();
802        // With large swings, should detect volatile or at least not a clean trend
803        assert!(
804            matches!(
805                result.regime,
806                MarketRegime::Volatile | MarketRegime::MeanReverting
807            ),
808            "Expected Volatile or MeanReverting for choppy data, got: {:?}",
809            result.regime
810        );
811    }
812
813    #[test]
814    fn test_state_probabilities_sum_to_one() {
815        let mut detector = HMMRegimeDetector::crypto_optimized();
816
817        let mut price = 100.0;
818        for _ in 0..100 {
819            price *= 1.001;
820            detector.update(price);
821
822            let probs = detector.state_probabilities();
823            let sum: f64 = probs.iter().sum();
824            assert!(
825                (sum - 1.0).abs() < 1e-6,
826                "State probabilities should sum to 1.0, got: {sum}"
827            );
828        }
829    }
830
831    #[test]
832    fn test_transition_matrix_rows_sum_to_one() {
833        let detector = HMMRegimeDetector::default_config();
834        let tm = detector.transition_matrix();
835
836        for (i, row) in tm.iter().enumerate() {
837            let sum: f64 = row.iter().sum();
838            assert!(
839                (sum - 1.0).abs() < 1e-6,
840                "Transition matrix row {i} should sum to 1.0, got: {sum}"
841            );
842        }
843    }
844
845    #[test]
846    fn test_expected_regime_duration() {
847        let detector = HMMRegimeDetector::default_config();
848
849        // With 0.9 persistence, expected duration = 1 / (1 - 0.9) = 10
850        let duration = detector.expected_regime_duration(0);
851        assert!(
852            (duration - 10.0).abs() < 1e-6,
853            "Expected duration should be ~10 with 0.9 persistence, got: {duration}"
854        );
855    }
856
857    #[test]
858    fn test_predict_next_state() {
859        let mut detector = HMMRegimeDetector::crypto_optimized();
860
861        let mut price = 100.0;
862        for _ in 0..100 {
863            price *= 1.002;
864            detector.update(price);
865        }
866
867        let (next_state, prob) = detector.predict_next_state();
868        assert!(next_state < detector.config().n_states);
869        assert!(
870            (0.0..=1.0).contains(&prob),
871            "Predicted probability should be in [0, 1]: {prob}"
872        );
873    }
874
875    #[test]
876    fn test_state_parameters() {
877        let detector = HMMRegimeDetector::default_config();
878        let params = detector.state_parameters();
879
880        assert_eq!(params.len(), 3, "Should have 3 state parameters");
881
882        for (mean, variance) in &params {
883            assert!(variance > &0.0, "Variance should be positive: {variance}");
884            assert!(mean.is_finite(), "Mean should be finite: {mean}");
885        }
886    }
887
888    #[test]
889    fn test_update_ohlc_uses_close() {
890        let mut det1 = HMMRegimeDetector::crypto_optimized();
891        let mut det2 = HMMRegimeDetector::crypto_optimized();
892
893        // Both should produce identical results since OHLC just uses close
894        for i in 0..100 {
895            let close = 100.0 + i as f64 * 0.1;
896            let r1 = det1.update(close);
897            let r2 = det2.update_ohlc(close * 1.01, close * 0.99, close);
898
899            assert_eq!(
900                r1.regime, r2.regime,
901                "update and update_ohlc should produce same regime"
902            );
903        }
904    }
905
906    #[test]
907    fn test_n_observations_tracking() {
908        let mut detector = HMMRegimeDetector::crypto_optimized();
909
910        assert_eq!(detector.n_observations(), 0);
911
912        for i in 0..50 {
913            detector.update(100.0 + i as f64);
914        }
915
916        // n_observations counts returns, so it's prices - 1
917        assert_eq!(detector.n_observations(), 49);
918    }
919
920    #[test]
921    fn test_confidence_range() {
922        let mut detector = HMMRegimeDetector::crypto_optimized();
923
924        let mut price = 100.0;
925        for _ in 0..200 {
926            price *= 1.002;
927            detector.update(price);
928        }
929
930        let confidence = detector.current_confidence();
931        assert!(
932            (0.0..=1.0).contains(&confidence),
933            "Confidence should be in [0, 1]: {confidence}"
934        );
935    }
936}